Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark - Sort DStream by Key and limit to 5 values

I've started to learn spark and I wrote a pyspark streaming program to read stock data (symbol, volume) from port 3333.

Sample data streamed at 3333

"AAC",111113
"ABT",7451020
"ABBV",7325429
"ADPT",318617
"AET",1839122
"ALR",372777
"AGN",4170581
"ABC",3001798
"ANTM",1968246

I want to display the top 5 symbols based on volume. So I used a mapper to read each line, then split it by comma and reversed.

from pyspark import SparkContext
from pyspark.streaming import StreamingContext

sc = SparkContext("local[2]", "NetworkWordCount")
ssc = StreamingContext(sc, 5)

lines = ssc.socketTextStream("localhost", 3333)
stocks = lines.map(lambda line: sorted(line.split(','), reverse=True))
stocks.pprint()

Following is the output of stocks.pprint()

[u'111113', u'"AAC"']
[u'7451020', u'"ABT"']
[u'7325429', u'"ABBV"']
[u'318617', u'"ADPT"']
[u'1839122', u'"AET"']
[u'372777', u'"ALR"']
[u'4170581', u'"AGN"']
[u'3001798', u'"ABC"']
[u'1968246', u'"ANTM"']

I've got the following function in mind to display the stock symbols but not sure how to sort the stocks by key(volume) and then limit the function to display only first 5 values.

stocks.foreachRDD(processStocks)

def processStocks(stock):
    for st in stock.collect():
        print st[1]
like image 736
DhiwaTdG Avatar asked Jan 05 '23 06:01

DhiwaTdG


1 Answers

Since stream represents an infinite sequence all you can do is sort each batch. First, you'll have to correctly parse the data:

lines = ssc.queueStream([sc.parallelize([
    "AAC,111113", "ABT,7451020", "ABBV,7325429","ADPT,318617",
    "AET,1839122", "ALR,372777", "AGN,4170581", "ABC,3001798", 
    "ANTM,1968246"
])])

def parse(line):
    try:
        k, v = line.split(",")
        yield (k, int(v))
    except ValueError:
        pass 

parsed = lines.flatMap(parse)

Next, sort each batch:

sorted_ = parsed.transform(
    lambda rdd: rdd.sortBy(lambda x: x[1], ascending=False))

Finally, you can pprint top elements:

sorted_.pprint(5)

If all went well you should get output like below:

-------------------------------------------                         
Time: 2016-10-02 14:52:30
-------------------------------------------
('ABT', 7451020)
('ABBV', 7325429)
('AGN', 4170581)
('ABC', 3001798)
('ANTM', 1968246)
...

Depending on the size of a batch full sort can be prohibitively expensive. In that case you can take top and parallelize:

sorted_ = parsed.transform(lambda rdd: rdd.ctx.parallelize(rdd.top(5)))

or even reduceByKey:

from operator import itemgetter
import heapq

key = itemgetter(1)

def create_combiner(key=lambda x: x):
    def _(x):
        return [(key(x), x)]
    return _

def merge_value(n=5, key=lambda x: x):
    def _(acc, x):
        heapq.heappush(acc, (key(x), x))
        return heapq.nlargest(n, acc) if len(acc) > n else acc
    return _

def merge_combiners(n=5):
    def _(acc1, acc2):
        merged = list(heapq.merge(acc1, acc2))
        return heapq.nlargest(n, merged) if len(merged) > n else merged
    return _

(parsed
    .map(lambda x: (None, x))
    .combineByKey(
        create_combiner(key=key), merge_value(key=key), merge_combiners())
    .flatMap(lambda x: x[1]))
like image 117
zero323 Avatar answered Feb 19 '23 13:02

zero323