我目前正在研究DNA序列数据,我遇到了一些性能障碍.
我有两个查找字典/哈希(作为RDD),其中DNA"单词"(短序列)作为键,索引位置列表作为值.一个用于较短的查询序列,另一个用于数据库序列.即使是非常非常大的序列,创建表也非常快.
对于下一步,我需要将它们配对并找到"命中"(每个常用词的索引位置对).
我首先加入查找字典,速度相当快.但是,我现在需要这些对,所以我必须进行两次flatmap,一次是从查询中扩展索引列表,第二次是从数据库中扩展索引列表.这不是理想的,但我没有看到另一种方法.至少它表现不错.
此时的输出是:(query_index, (word_length, diagonal_offset))
,其中对角线偏移量是database_sequence_index减去查询序列索引.
但是,我现在需要找到具有相同对角线偏移量的索引对(db_index - query_index)并合理地靠近并加入它们(因此我增加了单词的长度),但仅作为对(即一旦我加入一个索引)与另一个,我不希望任何其他东西与它合并).
我使用一个名为Seed()的特殊对象使用aggregateByKey操作.
PARALELLISM = 16 # I have 4 cores with hyperthreading def generateHsps(query_lookup_table_rdd, database_lookup_table_rdd): global broadcastSequences def mergeValueOp(seedlist, (query_index, seed_length)): return seedlist.addSeed((query_index, seed_length)) def mergeSeedListsOp(seedlist1, seedlist2): return seedlist1.mergeSeedListIntoSelf(seedlist2) hits_rdd = (query_lookup_table_rdd.join(database_lookup_table_rdd) .flatMap(lambda (word, (query_indices, db_indices)): [(query_index, db_indices) for query_index in query_indices], preservesPartitioning=True) .flatMap(lambda (query_index, db_indices): [(db_index - query_index, (query_index, WORD_SIZE)) for db_index in db_indices], preservesPartitioning=True) .aggregateByKey(SeedList(), mergeValueOp, mergeSeedListsOp, PARALLELISM) .map(lambda (diagonal, seedlist): (diagonal, seedlist.mergedSeedList)) .flatMap(lambda (diagonal, seedlist): [(query_index, seed_length, diagonal) for query_index, seed_length in seedlist]) ) return hits_rdd
种子():
class SeedList(): def __init__(self): self.unmergedSeedList = [] self.mergedSeedList = [] #Try to find a more efficient way to do this def addSeed(self, (query_index1, seed_length1)): for i in range(0, len(self.unmergedSeedList)): (query_index2, seed_length2) = self.unmergedSeedList[i] #print "Checking ({0}, {1})".format(query_index2, seed_length2) if min(abs(query_index2 + seed_length2 - query_index1), abs(query_index1 + seed_length1 - query_index2)) <= WINDOW_SIZE: self.mergedSeedList.append((min(query_index1, query_index2), max(query_index1+seed_length1, query_index2+seed_length2)-min(query_index1, query_index2))) self.unmergedSeedList.pop(i) return self self.unmergedSeedList.append((query_index1, seed_length1)) return self def mergeSeedListIntoSelf(self, seedlist2): print "merging seed" for (query_index2, seed_length2) in seedlist2.unmergedSeedList: wasmerged = False for i in range(0, len(self.unmergedSeedList)): (query_index1, seed_length1) = self.unmergedSeedList[i] if min(abs(query_index2 + seed_length2 - query_index1), abs(query_index1 + seed_length1 - query_index2)) <= WINDOW_SIZE: self.mergedSeedList.append((min(query_index1, query_index2), max(query_index1+seed_length1, query_index2+seed_length2)-min(query_index1, query_index2))) self.unmergedSeedList.pop(i) wasmerged = True break if not wasmerged: self.unmergedSeedList.append((query_index2, seed_length2)) return self
对于即使是中等长度的序列,这也是性能真正崩溃的地方.
有没有更好的方法来进行这种聚合?我的直觉是肯定的,但我无法想出来.
我知道这是一个非常漫长的技术问题,即使没有简单的解决方案,我也非常感谢任何见解.
编辑:这是我如何制作查找表:
def createLookupTable(sequence_rdd, sequence_name, word_length): global broadcastSequences blank_list = [] def addItemToList(lst, val): lst.append(val) return lst def mergeLists(lst1, lst2): #print "Merging" return lst1+lst2 return (sequence_rdd .flatMap(lambda seq_len: range(0, seq_len - word_length + 1)) .repartition(PARALLELISM) #.partitionBy(PARALLELISM) .map(lambda index: (str(broadcastSequences.value[sequence_name][index:index + word_length]), index), preservesPartitioning=True) .aggregateByKey(blank_list, addItemToList, mergeLists, PARALLELISM)) #.map(lambda (word, indices): (word, sorted(indices))))
这是运行整个操作的函数:
def run(query_seq, database_sequence, translate_query=False): global broadcastSequences scoring_matrix = 'nucleotide' if isinstance(query_seq.alphabet, DNAAlphabet) else 'blosum62' sequences = {'query': query_seq, 'database': database_sequence} broadcastSequences = sc.broadcast(sequences) query_rdd = sc.parallelize([len(query_seq)]) query_rdd.cache() database_rdd = sc.parallelize([len(database_sequence)]) database_rdd.cache() query_lookup_table_rdd = createLookupTable(query_rdd, 'query', WORD_SIZE) query_lookup_table_rdd.cache() database_lookup_table_rdd = createLookupTable(database_rdd, 'database', WORD_SIZE) seeds_rdd = generateHsps(query_lookup_table_rdd, database_lookup_table_rdd) return seeds_rdd
编辑2:通过更换以下内容,我稍微调整了一些内容并略微提高了性能:
.flatMap(lambda (word, (query_indices, db_indices)): [(query_index, db_indices) for query_index in query_indices], preservesPartitioning=True) .flatMap(lambda (query_index, db_indices): [(db_index - query_index, (query_index, WORD_SIZE)) for db_index in db_indices], preservesPartitioning=True)
在hits_rdd中:
.flatMap(lambda (word, (query_indices, db_indices)): itertools.product(query_indices, db_indices)) .map(lambda (query_index, db_index): (db_index - query_index, (query_index, WORD_SIZE) ))
至少现在我没有用中间数据结构烧掉存储空间.