I'm currently working with DNA sequence data and I have run into a bit of a performance roadblock.
I have two lookup dictionaries/hashes (as RDDs) with DNA "words" (short sequences) as keys and a list of index positions as the value. One is for a shorter query sequence and the other for a database sequence. Creating the tables is pretty fast even with very, very large sequences.
For the next step, I need to pair these up and find "hits" (pairs of index positions for each common word).
I first join the lookup dictionaries, which is reasonably fast. However, I now need the pairs, so I have to flatmap twice, once to expand the list of indices from the query and the second time to expand the list of indices from the database. This isn't ideal, but I don't see another way to do it. At least it performs ok.
The output at this point is: (query_index, (word_length, diagonal_offset))
, where the diagonal offset is the database_sequence_index minus the query sequence index.
However, I now need to find pairs of indices on with the same diagonal offset (db_index - query_index) and reasonably close together and join them (so I increase the length of the word), but only as pairs (i.e. once I join one index with another, I don't want anything else to merge with it).
I do this with an aggregateByKey operation using a special object called Seed().
PARALELLISM = 16 # I have 4 cores with hyperthreading
def generateHsps(query_lookup_table_rdd, database_lookup_table_rdd):
global broadcastSequences
def mergeValueOp(seedlist, (query_index, seed_length)):
return seedlist.addSeed((query_index, seed_length))
def mergeSeedListsOp(seedlist1, seedlist2):
return seedlist1.mergeSeedListIntoSelf(seedlist2)
hits_rdd = (query_lookup_table_rdd.join(database_lookup_table_rdd)
.flatMap(lambda (word, (query_indices, db_indices)): [(query_index, db_indices) for query_index in query_indices], preservesPartitioning=True)
.flatMap(lambda (query_index, db_indices): [(db_index - query_index, (query_index, WORD_SIZE)) for db_index in db_indices], preservesPartitioning=True)
.aggregateByKey(SeedList(), mergeValueOp, mergeSeedListsOp, PARALLELISM)
.map(lambda (diagonal, seedlist): (diagonal, seedlist.mergedSeedList))
.flatMap(lambda (diagonal, seedlist): [(query_index, seed_length, diagonal) for query_index, seed_length in seedlist])
)
return hits_rdd
Seed():
class SeedList():
def __init__(self):
self.unmergedSeedList = []
self.mergedSeedList = []
#Try to find a more efficient way to do this
def addSeed(self, (query_index1, seed_length1)):
for i in range(0, len(self.unmergedSeedList)):
(query_index2, seed_length2) = self.unmergedSeedList[i]
#print "Checking ({0}, {1})".format(query_index2, seed_length2)
if min(abs(query_index2 + seed_length2 - query_index1), abs(query_index1 + seed_length1 - query_index2)) <= WINDOW_SIZE:
self.mergedSeedList.append((min(query_index1, query_index2), max(query_index1+seed_length1, query_index2+seed_length2)-min(query_index1, query_index2)))
self.unmergedSeedList.pop(i)
return self
self.unmergedSeedList.append((query_index1, seed_length1))
return self
def mergeSeedListIntoSelf(self, seedlist2):
print "merging seed"
for (query_index2, seed_length2) in seedlist2.unmergedSeedList:
wasmerged = False
for i in range(0, len(self.unmergedSeedList)):
(query_index1, seed_length1) = self.unmergedSeedList[i]
if min(abs(query_index2 + seed_length2 - query_index1), abs(query_index1 + seed_length1 - query_index2)) <= WINDOW_SIZE:
self.mergedSeedList.append((min(query_index1, query_index2), max(query_index1+seed_length1, query_index2+seed_length2)-min(query_index1, query_index2)))
self.unmergedSeedList.pop(i)
wasmerged = True
break
if not wasmerged:
self.unmergedSeedList.append((query_index2, seed_length2))
return self
This is where the performance really breaks down for even sequences of moderate length.
Is there any better way to do this aggregation? My gut feeling says yes, but I can't come up with it.
I know this is a very long winded and technical question, and I would really appreciate any insight even if there is no easy solution.
Edit: Here is how I am making the lookup tables:
def createLookupTable(sequence_rdd, sequence_name, word_length):
global broadcastSequences
blank_list = []
def addItemToList(lst, val):
lst.append(val)
return lst
def mergeLists(lst1, lst2):
#print "Merging"
return lst1+lst2
return (sequence_rdd
.flatMap(lambda seq_len: range(0, seq_len - word_length + 1))
.repartition(PARALLELISM)
#.partitionBy(PARALLELISM)
.map(lambda index: (str(broadcastSequences.value[sequence_name][index:index + word_length]), index), preservesPartitioning=True)
.aggregateByKey(blank_list, addItemToList, mergeLists, PARALLELISM))
#.map(lambda (word, indices): (word, sorted(indices))))
And here is the function that runs the whole operation:
def run(query_seq, database_sequence, translate_query=False):
global broadcastSequences
scoring_matrix = 'nucleotide' if isinstance(query_seq.alphabet, DNAAlphabet) else 'blosum62'
sequences = {'query': query_seq,
'database': database_sequence}
broadcastSequences = sc.broadcast(sequences)
query_rdd = sc.parallelize([len(query_seq)])
query_rdd.cache()
database_rdd = sc.parallelize([len(database_sequence)])
database_rdd.cache()
query_lookup_table_rdd = createLookupTable(query_rdd, 'query', WORD_SIZE)
query_lookup_table_rdd.cache()
database_lookup_table_rdd = createLookupTable(database_rdd, 'database', WORD_SIZE)
seeds_rdd = generateHsps(query_lookup_table_rdd, database_lookup_table_rdd)
return seeds_rdd
Edit 2: I have tweaked things a bit and slightly improved performance by replacing:
.flatMap(lambda (word, (query_indices, db_indices)): [(query_index, db_indices) for query_index in query_indices], preservesPartitioning=True)
.flatMap(lambda (query_index, db_indices): [(db_index - query_index, (query_index, WORD_SIZE)) for db_index in db_indices], preservesPartitioning=True)
in hits_rdd with:
.flatMap(lambda (word, (query_indices, db_indices)): itertools.product(query_indices, db_indices))
.map(lambda (query_index, db_index): (db_index - query_index, (query_index, WORD_SIZE) ))
At least now I'm not burning up storage with intermediate data structures as much.
Let's forget about the technical details of what your doing and think "functionally" about the steps involved, forgetting about the details of the implementation. Functional thinking like this is an important part of parallel data analysis; ideally if we can break the problem up like this, we can reason more clearly about the steps involved, and end up with clearer and often more concise. Thinking in terms of a tabular data model, I would consider your problem to consist of the following steps:
delta
containing the difference between the indices.delta
and concatenate the strings in the sequence column, to obtain the full matches between your datasets.For the first 3 steps, I think it makes sense to use DataFrames, since this data model makes sense in my head of the kind processing that we're doing. (Actually I might use DataFrames for step 4 as well, except pyspark doesn't currently support custom aggregate functions for DataFrames, although Scala does).
For the fourth step (which is if I understand correctly what you're really asking about in your question), it's a little tricky to think about how to do this functionally, however I think an elegant and efficient solution is to use a reduce (also known a right fold); this pattern can be applied to any problem that you can phrase in terms of iteratively applying an associative binary function, that is a function where the "grouping" of any 3 arguments doesn't matter (although the order certainly may matter), Symbolically, this is a function x,y -> f(x,y)
where f(x, f(y, z)) = f(f(x, y), z)
. String (or more generally list) concatenation is just such a function.
Here's an example of how this might look in pyspark
; hopefully you can adapt this to the details of your data:
#setup some sample data
query = [['abcd', 30] ,['adab', 34] ,['dbab',38]]
reference = [['dbab', 20], ['ccdd', 24], ['abcd', 50], ['adab',54], ['dbab',58], ['dbab', 62]]
#create data frames
query_df = sqlContext.createDataFrame(query, schema = ['sequence1', 'index1'])
reference_df = sqlContext.createDataFrame(reference, schema = ['sequence2', 'index2'])
#step 1: join
matches = query_df.join(reference_df, query_df.sequence1 == reference_df.sequence2)
#step 2: calculate delta column
matches_delta = matches.withColumn('delta', matches.index2 - matches.index1)
#step 3: sort by index
matches_sorted = matches_delta.sort('delta').sort('index2')
#step 4: convert back to rdd and reduce
#note that + is just string concatenation for strings
r = matches_sorted['delta', 'sequence1'].rdd
r.reduceByKey(lambda x, y : x + y).collect()
#expected output:
#[(24, u'dbab'), (-18, u'dbab'), (20, u'abcdadabdbab')]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With