NumPy的没有尚未有一个基数排序,所以我想知道是否有可能使用一个预先存在numpy的功能来写.到目前为止,我有以下,它确实有效,但比numpy的快速排序慢约10倍.
测试和基准测试:
a = np.random.randint(0, 1e8, 1e6) assert(np.all(radix_sort(a) == np.sort(a))) %timeit np.sort(a) %timeit radix_sort(a)
该mask_b
循环可以至少部分地被矢量化,从掩码中广播&
并cumsum
与axis
arg一起使用,但是这最终是一种悲观,可能是由于增加的存储器占用.
如果有人能够看到一种方法来改进我所拥有的东西,我会有兴趣听到,即使它仍然比np.sort
... 慢......这更像是一种对知识的好奇心和对numpy技巧的兴趣.
请注意,您可以轻松地实现快速计数排序,但这仅与小整数数据相关.
编辑1:以np.arange(n)
圈外的帮助一点,但不是很exiciting.
编辑2:在cumsum
实际上是多余的(哎呀!),但这个简单的版本仅具有性能稍微帮助..
def radix_sort(a): bit_len = np.max(a).bit_length() n = len(a) cached_arange = arange(n) idx = np.empty(n, dtype=int) # fully overwritten each iteration for mask_b in xrange(bit_len): is_one = (a & 2**mask_b).astype(bool) n_ones = np.sum(is_one) n_zeros = n-n_ones idx[~is_one] = cached_arange[:n_zeros] idx[is_one] = cached_arange[:n_ones] + n_zeros # next three lines just do: a[idx] = a, but correctly new_a = np.empty(n, dtype=a.dtype) new_a[idx] = a a = new_a return a
编辑3:如果您在多个步骤中构造idx,则可以一次循环两个或更多个,而不是循环使用单个位.使用2位有点帮助,我没有尝试过更多:
idx[is_zero] = np.arange(n_zeros) idx[is_one] = np.arange(n_ones) idx[is_two] = np.arange(n_twos) idx[is_three] = np.arange(n_threes)
编辑4和5:对于我正在测试的输入,4位似乎是最好的.此外,你可以idx
完全摆脱这一步.现在只有5倍,而不是10倍,慢于np.sort
(作为gist提供的源代码):
编辑6:这是上面的一个整理版本,但它也有点慢.80%的时间花在repeat
和extract
- 如果只有一种方式广播extract
:( ...
def radix_sort(a, batch_m_bits=3): bit_len = np.max(a).bit_length() batch_m = 2**batch_m_bits mask = 2**batch_m_bits - 1 val_set = np.arange(batch_m, dtype=a.dtype)[:, nax] # nax = np.newaxis for _ in range((bit_len-1)//batch_m_bits + 1): # ceil-division a = np.extract((a & mask)[nax, :] == val_set, np.repeat(a[nax, :], batch_m, axis=0)) val_set <<= batch_m_bits mask <<= batch_m_bits return a
编辑7和8:实际上,您可以使用as_strided
from 来广播提取numpy.lib.stride_tricks
,但它似乎没有太大的性能帮助:
最初这对我来说是有意义的,因为它extract
会在整个数组batch_m
时间内进行迭代,因此CPU请求的高速缓存行总数将与之前相同(只是在它请求每个请求的过程结束时)缓存行batch_m
时间).然而,实际情况是,extract
不足以巧妙地迭代任意阶梯数组,并且必须在开始之前扩展数组,即无论如何最终都会重复执行.事实上,在查看源代码之后extract
,我现在看到我们用这种方法做的最好的事情是:
a = a[np.flatnonzero((a & mask)[nax, :] == val_set) % len(a)]
这比一点慢extract
.然而,如果len(a)
是两个电源可以代替昂贵的MOD与操作& (len(a) - 1)
,这并最终被略高于更快的extract
版本(目前约4.9x np.sort
的a=randint(0, 1e8, 2**20
).我想我们可以通过零填充使这两个长度的非幂次工作,然后在排序结束时裁剪额外的零...但是这将是一个悲观,除非长度已经接近于二.