我正在尝试过滤元组的RDD,以根据键值返回最大的N元组.我需要返回格式为RDD.
所以RDD:
[(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')]
过滤掉最大的3个键应该返回RDD:
[(6,'p'), (12,'e'), (49,'y')]
执行a sortByKey()
然后take(N)
返回值并且不会导致RDD,因此不起作用.
我可以返回所有键,对它们进行排序,找到第N个最大值,然后过滤RDD以获得大于该值的键值,但这似乎非常低效.
最好的方法是什么?
同 RDD
一个快速但不是特别有效的解决方案是遵循sortByKey
使用zipWithIndex
和filter
:
n = 3 rdd = sc.parallelize([(4, 'a'), (12, 'e'), (2, 'u'), (49, 'y'), (6, 'p')]) rdd.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
如果n与RDD大小相比相对较小,则更有效的方法是避免完全排序:
import heapq def key(kv): return kv[0] top_per_partition = rdd.mapPartitions(lambda iter: heapq.nlargest(n, iter, key)) top_per_partition.sortByKey().zipWithIndex().filter(lambda xi: xi[1] < n).keys()
如果键比值小得多,并且最终输出的顺序无关紧要那么filter
方法可以正常工作:
keys = rdd.keys() identity = lambda x: x offset = (keys .mapPartitions(lambda iter: heapq.nlargest(n, iter)) .sortBy(identity) .zipWithIndex() .filter(lambda xi: xi[1] < n) .keys() .max()) rdd.filter(lambda kv: kv[0] <= offset)
在关系的情况下,它也不会保持精确的n值.
同 DataFrames
你可以orderBy
和limit
:
from pyspark.sql.functions import col rdd.toDF().orderBy(col("_1").desc()).limit(n)