当前位置:  开发笔记 > 编程语言 > 正文

Rcpp函数在给定值的矢量及其频率的情况下找到中值

如何解决《Rcpp函数在给定值的矢量及其频率的情况下找到中值》经验,为你挑选了1个好方法。

我正在编写一个函数来查找一组值的中位数.数据表示为唯一值的向量(称为'值')和它们的频率向量('freqs').频率通常非常高,因此将它们粘贴出来会占用大量内存.我有一个缓慢的R实现,它是我的代码中的主要瓶颈,所以我正在编写一个自定义Rcpp函数用于R/Bioconductor包.Bioconductor的网站建议不要使用C++ 11,这对我来说是一个问题.

我的问题在于尝试根据值的顺序将两个向量排序在一起.在R中,我们可以使用order()函数.尽管遵循了关于这个问题的建议:C++排序和跟踪索引,我似乎无法使其工作

以下几行是问题所在:

   // sort vector based on order of values
 IntegerVector idx_ord = std::sort(idx.begin(), idx.end(),
    bool (int i1, int i2) {return values[i1] < values[i2];});

这是完整的功能,为了任何人的利益.任何进一步的提示将不胜感激:

    #include 
using namespace Rcpp;

// [[Rcpp::export]]
double median_freq(NumericVector values, IntegerVector freqs) {
    int len = freqs.size();
    if (any(freqs!=0)){
        int med = 0;
        return med;
    }
    // filter out the zeros pre-sorting
    IntegerVector non_zeros;
    for (int i = 0; i < len; i++){
        if(freqs[i] != 0){
            non_zeros.push_back(i);
        }
    }
    freqs = freqs[non_zeros];
    values = values[non_zeros];
    // find the order of values
    // create integer vector of indices
    IntegerVector idx(len);
    for (int i = 0; i < len; ++i) idx[i] = i;

    // sort vector based on order of values
 IntegerVector idx_ord = std::sort(idx.begin(), idx.end(),
    bool (int i1, int i2) {return values[i1] < values[i2];});

    //apply to freqs and values
    freqs = freqs[idx_ord];
    values=values[idx_ord];
    IntegerVector cum_freqs(len);
    cum_freqs[0] = freqs[0];
    for (int i = 1; i < len; ++i) cum_freqs[i] = freqs[i] + cum_freqs[i-1];
    int total_freqs = cum_freqs[len-1];
    // split into odd and even frequencies and calculate the median
    if (total_freqs % 2 == 1) {
        int med_ind = (total_freqs + 1)/2 - 1; // C++ indexes from 0
        int i = 0;
        while ((i < len) && cum_freqs[i] < med_ind){
            i++;
        }
        double ret = values[i];
        return ret;
    } else {
        int med_ind_1 = total_freqs/2 - 1; // C++ indexes from 0
        int med_ind_2 = med_ind_1 + 1; // C++ indexes from 0
        int i = 0;
        while ((i < len) && cum_freqs[i] < med_ind_1){
            i++;
        }
        double ret_1 = values[i];
        i = 0;
        while ((i < len) && cum_freqs[i] < med_ind_2){
            i++;
        }
        double ret_2 = values[i];
        double ret = (ret_1 + ret_2)/2;
        return ret;
    }
}

对于使用RUnit测试框架的任何人,这里有一些基本的单元测试:

test_median_freq <- function(){
    checkEquals(median_freq(1:10,1:10),7)
    checkEquals(median_freq(1:10,rep(1,10)),5.5)
    checkEquals(median_freq(2:6,c(1,2,1,45,2)),5)
}

谢谢!



1> josliber..:

我实际上将值和频率组合成a std::pair然后只是用它们排序std::sort; 通过这种方式,您始终可以将值和频率保持在一起.这使您可以编写更清晰的代码,因为没有额外的索引集浮动:

#include 
using namespace Rcpp;

// [[Rcpp::export]]
double median_freq(NumericVector values, IntegerVector freqs) {
  const int len = freqs.size();
  std::vector > allDat;
  int freqSum = 0;
  for (int i=0; i < len; ++i) {
    allDat.push_back(std::pair(values[i], freqs[i]));
    freqSum += freqs[i];
  }
  std::sort(allDat.begin(), allDat.end());
  int accum = 0;
  for (int i=0; i < len; ++i) {
    accum += allDat[i].second;
    if (freqSum % 2 == 0) {
      if (accum > freqSum / 2) {
        return allDat[i].first;
      } else if (accum == freqSum / 2) {
        return (allDat[i].first + allDat[i+1].first) / 2;
      }
    } else {
      if (accum >= (freqSum+1)/2) {
        return allDat[i].first;
      }
    }
  }
  return NA_REAL;  // Should not be reached
}

在R中尝试一下:

median_freq(1:10, 1:10)
# [1] 7
median_freq(1:10,rep(1,10))
# [1] 5.5
median_freq(2:6,c(1,2,1,45,2))
# [1] 5

我们还可以编写一个简单的R实现来确定我们使用Rcpp获得的效率增益:

med.freq.r <- function(values, freqs) {
  ord <- order(values)
  values <- values[ord]
  freqs <- freqs[ord]
  s <- sum(freqs)
  cs <- cumsum(freqs)
  idx <- min(which(cs >= s/2))
  if (s %% 2 == 0 && cs[idx] == s/2) {
    (values[idx] + values[idx+1]) / 2
  } else {
    values[idx]
  }
}
med.freq.r(1:10, 1:10)
# [1] 7
med.freq.r(1:10,rep(1,10))
# [1] 5.5
med.freq.r(2:6,c(1,2,1,45,2))
# [1] 5

要进行基准测试,让我们看一组非常大的值:

set.seed(144)
values <- rnorm(1000000)
freqs <- sample(1:100, 1000000, replace=TRUE)
all.equal(median_freq(values, freqs), med.freq.r(values, freqs))
# [1] TRUE
library(microbenchmark)
microbenchmark(median_freq(values, freqs), med.freq.r(values, freqs))
# Unit: milliseconds
#                        expr      min       lq     mean   median       uq      max neval
#  median_freq(values, freqs) 128.5322 131.6095 146.8360 145.6389 159.6117 165.0306    10
#   med.freq.r(values, freqs) 715.2187 744.5709 776.0539 765.9178 817.7157 855.1898    10

对于100万个条目,Rcpp解决方案比R解决方案快约5倍; 考虑到编译开销,如果你正在处理非常大的向量或者这是一个经常重复的选项,那么这个性能才有吸引力.

线性时间方法

一般来说,我们知道如何在不排序的情况下计算中位数(详见http://www.cc.gatech.edu/~mihail/medianCMU.pdf).虽然算法比排序和迭代更精细,但它可以产生显着的加速:

double fast_median_freq(NumericVector values, IntegerVector freqs) {
  const int len = freqs.size();
  std::vector > allDat;
  int freqSum = 0;
  for (int i=0; i < len; ++i) {
    allDat.push_back(std::pair(values[i], freqs[i]));
    freqSum += freqs[i];
  }

  int target = freqSum / 2;
  int low = 0;
  int high = len-1;
  while (true) {
    // Random pivot; move to the end
    int rnd = low + (rand() % (high-low+1));
    std::swap(allDat[rnd], allDat[high]);

    // In-place pivot
    int highPos = low;  // Start of values higher than pivot
    int lowSum = 0;  // Sum of frequencies of elements below pivot
    for (int pos=low; pos < high; ++pos) {
      if (allDat[pos].first <= allDat[high].first) {
        lowSum += allDat[pos].second;
        std::swap(allDat[highPos], allDat[pos]);
        ++highPos;
      }
    }
    std::swap(allDat[highPos], allDat[high]);  // Move pivot to "highPos"

    // If we found the element then return; o/w recurse on proper side
    if (lowSum >= target) {
      // Recurse on lower elements
      high = highPos - 1;
    } else if (lowSum + allDat[highPos].second >= target) {
      // Return
      if (target < lowSum + allDat[highPos].second || freqSum % 2 == 1) {
        return allDat[highPos].first;
      } else {
        double nextHighest = std::min_element(allDat.begin() + highPos+1, allDat.begin() + len-1)->first;
        return (allDat[highPos].first + nextHighest) / 2;
      }
    } else {
      // Recurse on higher elements
      low = highPos + 1;
      target -= (lowSum + allDat[highPos].second);
    }
  }
}

标杆:

all.equal(median_freq(values, freqs), fast_median_freq(values, freqs))
[1] TRUE
microbenchmark(median_freq(values, freqs), med.freq.r(values, freqs), fast_median_freq(values, freqs), times=10)
# Unit: milliseconds
#                             expr       min        lq      mean    median        uq       max neval
#       median_freq(values, freqs) 119.57989 122.48622 130.47841 130.48811 132.75421 146.36136    10
#        med.freq.r(values, freqs) 665.72803 690.15016 708.05729 702.65885 731.83936 749.36834    10
#  fast_median_freq(values, freqs)  24.37572  29.39641  31.86144  31.77459  34.88418  36.81606    10

线性方法提供了比排序迭代Rcpp解决方案快4倍的速度和比基本R解决方案高20倍的速度.

推荐阅读
路人甲
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有