当前位置:  开发笔记 > 编程语言 > 正文

矢量化的基数排序与numpy - 它可以击败np.sort?

如何解决《矢量化的基数排序与numpy-它可以击败np.sort?》经验,为你挑选了0个好方法。

NumPy的没有尚未有一个基数排序,所以我想知道是否有可能使用一个预先存在numpy的功能来写.到目前为止,我有以下,它确实有效,但比numpy的快速排序慢约10倍.

line profiler输出

测试和基准测试:

a = np.random.randint(0, 1e8, 1e6)
assert(np.all(radix_sort(a) == np.sort(a))) 
%timeit np.sort(a)
%timeit radix_sort(a)

mask_b循环可以至少部分地被矢量化,从掩码中广播&cumsumaxisarg一起使用,但是这最终是一种悲观,可能是由于增加的存储器占用.

如果有人能够看到一种方法来改进我所拥有的东西,我会有兴趣听到,即使它仍然比np.sort... 慢......这更像是一种对知识的好奇心和对numpy技巧的兴趣.

请注意,您可以轻松地实现快速计数排序,但这仅与小整数数据相关.

编辑1:np.arange(n)圈外的帮助一点,但不是很exiciting.

编辑2:cumsum实际上是多余的(哎呀!),但这个简单的版本仅具有性能稍微帮助..

def radix_sort(a):
    bit_len = np.max(a).bit_length()
    n = len(a)
    cached_arange = arange(n)
    idx = np.empty(n, dtype=int) # fully overwritten each iteration
    for mask_b in xrange(bit_len):
        is_one = (a & 2**mask_b).astype(bool)
        n_ones = np.sum(is_one)      
        n_zeros = n-n_ones
        idx[~is_one] = cached_arange[:n_zeros]
        idx[is_one] = cached_arange[:n_ones] + n_zeros
        # next three lines just do: a[idx] = a, but correctly
        new_a = np.empty(n, dtype=a.dtype)
        new_a[idx] = a
        a = new_a
    return a

编辑3:如果您在多个步骤中构造idx,则可以一次循环两个或更多个,而不是循环使用单个位.使用2位有点帮助,我没有尝试过更多:

idx[is_zero] = np.arange(n_zeros)
idx[is_one] = np.arange(n_ones)
idx[is_two] = np.arange(n_twos)
idx[is_three] = np.arange(n_threes)

编辑4和5:对于我正在测试的输入,4位似乎是最好的.此外,你可以idx完全摆脱这一步.现在只有5倍,而不是10倍,慢于np.sort(作为gist提供的源代码):

在此输入图像描述

编辑6:这是上面的一个整理版本,但它也有点.80%的时间花在repeatextract- 如果只有一种方式广播extract:( ...

def radix_sort(a, batch_m_bits=3):
    bit_len = np.max(a).bit_length()
    batch_m = 2**batch_m_bits
    mask = 2**batch_m_bits - 1
    val_set = np.arange(batch_m, dtype=a.dtype)[:, nax] # nax = np.newaxis
    for _ in range((bit_len-1)//batch_m_bits + 1): # ceil-division
        a = np.extract((a & mask)[nax, :] == val_set,
                        np.repeat(a[nax, :], batch_m, axis=0))
        val_set <<= batch_m_bits
        mask <<= batch_m_bits
    return a

编辑7和8:实际上,您可以使用as_stridedfrom 来广播提取numpy.lib.stride_tricks,但它似乎没有太大的性能帮助:

在此输入图像描述

最初这对我来说是有意义的,因为它extract会在整个数组batch_m时间内进行迭代,因此CPU请求的高速缓存行总数将与之前相同(只是在它请求每个请求的过程结束时)缓存行batch_m时间).然而,实际情况是,extract不足以巧妙地迭代任意阶梯数组,并且必须在开始之前扩展数组,即无论如何最终都会重复执行.事实上,在查看源代码之后extract,我现在看到我们用这种方法做的最好的事情是:

a = a[np.flatnonzero((a & mask)[nax, :] == val_set) % len(a)]

这比一点慢extract.然而,如果len(a)是两个电源可以代替昂贵的MOD与操作& (len(a) - 1),这并最终被略高于更快的extract版本(目前约4.9x np.sorta=randint(0, 1e8, 2**20).我想我们可以通过零填充使这两个长度的非幂次工作,然后在排序结束时裁剪额外的零...但是这将是一个悲观,除非长度已经接近于二.

推荐阅读
刘美娥94662
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有