我正在修改Brieman的随机森林程序(我不知道C/C++),所以我在R中从头开始编写自己的RF变体.我的程序和标准程序之间的区别基本上只是在如何计算终端节点中的分割点和值 - 一旦我在森林中有一棵树,它可以被认为与典型的RF算法中的树非常相似.
我的问题是它的预测速度很慢,而且我很难想办法让它更快.
测试树对象链接在这里,和一些测试数据被链接在这里.您可以直接下载,也可以在下面加载它(如果已repmis
安装).他们被称为testtree
和sampx
.
library(repmis) testtree <- source_DropboxData(file = "testtree", key = "sfbmojc394cnae8") sampx <- source_DropboxData(file = "sampx", key = "r9imf317hpflpsx")
编辑:不知怎的,我还没有真正学习如何使用github.我上传所需的文件到存储库在这里 -我无法弄清楚如何在此刻得到一个永久的歉意......
它看起来像这样(使用我编写的绘图函数):
这里有一些关于对象的结构:
1> summary(testtree) Length Class Mode nodes 7 -none- list minsplit 1 -none- numeric X 29 data.frame list y 6719 -none- numeric weights 6719 -none- numeric oob 2158 -none- numeric 1> summary(testtree$nodes) Length Class Mode [1,] 4 -none- list [2,] 8 -none- list [3,] 8 -none- list [4,] 7 -none- list [5,] 7 -none- list [6,] 7 -none- list [7,] 7 -none- list 1> summary(testtree$nodes[[1]]) Length Class Mode y 6719 -none- numeric output 1 -none- numeric Terminal 1 -none- logical children 2 -none- numeric 1> testtree$nodes[[1]][2:4] $output [1] 40.66925 $Terminal [1] FALSE $children [1] 2 3 1> summary(testtree$nodes[[2]]) Length Class Mode y 2182 -none- numeric parent 1 -none- numeric splitvar 1 -none- character splitpoint 1 -none- numeric handedness 1 -none- character children 2 -none- numeric output 1 -none- numeric Terminal 1 -none- logical 1> testtree$nodes[[2]][2:8] $parent [1] 1 $splitvar [1] "bizrev_allHH" $splitpoint 25% 788.875 $handedness [1] "Left" $children [1] 4 5 $output [1] 287.0085 $Terminal [1] FALSE
output
是该节点的返回值 - 我希望其他一切都是不言自明的.
我写的预测函数有效,但速度太慢了.基本上它"走下树",通过观察观察:
predict.NT = function(tree.obj, newdata=NULL){ if (is.null(newdata)){X = tree.obj$X} else {X = newdata} tree = tree.obj$nodes if (length(tree)==1){#Return the mean for a stump return(rep(tree[[1]]$output,length(X))) } pred = apply(X = newdata, 1, godowntree, nn=1, tree=tree) return(pred) } godowntree = function(x, tree, nn = 1){ while (tree[[nn]]$Terminal == FALSE){ fb = tree[[nn]]$children[1] sv = tree[[fb]]$splitvar sp = tree[[fb]]$splitpoint if (class(sp)=='factor'){ if (as.character(x[names(x) == sv]) == sp){ nn<-fb } else{ nn<-fb+1 } } else { if (as.character(x[names(x) == sv]) < sp){ nn<-fb } else{ nn<-fb+1 } } } return(tree[[nn]]$output) }
问题是它真的很慢(当你考虑非样本树更大,我需要做很多次)时,即使对于一个简单的树:
library(microbenchmark) microbenchmark(predict.NT(testtree,sampx)) Unit: milliseconds expr min lq mean median uq predict.NT(testtree, sampx) 16.19845 16.36351 17.37022 16.54396 17.07274 max neval 40.4691 100
我从今天的某个人那里得到了一个想法,我可以编写一个函数工厂类型的函数(即:生成闭包,我正在学习的那个)将我的树分解成一堆嵌套的if/else语句.然后我可以通过它发送数据,这可能比一遍又一遍地从树中提取数据更快.我还没有编写函数函数生成函数,但我亲自编写了我从中得到的那种输出,并测试了:
predictif = function(x){ if (x[names(x) == 'bizrev_allHH'] < 788.875){ if (x[names(x) == 'male_head'] <.872){ return(548) } else { return(165) } } else { if (x[names(x) == 'nondurable_exp_mo'] < 4190.965){ return(-283) }else{ return(-11.4) } } } predictif.NT = function(tree.obj, newdata=NULL){ if (is.null(newdata)){X = tree.obj$X} else {X = newdata} tree = tree.obj$nodes if (length(tree)==1){#Return the mean for a stump return(rep(tree[[1]]$output,length(X))) } pred = apply(X = newdata, 1, predictif) return(pred) } microbenchmark(predictif.NT(testtree,sampx)) Unit: milliseconds expr min lq mean median uq predictif.CT(testtree, sampx) 12.77701 12.97551 14.21417 13.18939 13.67667 max neval 30.48373 100
快一点,但不多!
我真的很感激任何提高速度的想法!或者,如果答案是"如果不将其转换为C/C++,那么你真的无法获得更快的速度",这也是有价值的信息(特别是如果你给我一些关于为什么会这样的信息).
虽然我当然很欣赏R中的答案,但伪代码的答案也会非常有用.
谢谢!
加速你的功能的秘诀是矢量化.不是单独对每一行执行所有操作,而是一次在所有行上执行它们.
让我们重新考虑你的predictif
功能
predictif = function(x){ if (x[names(x) == 'bizrev_allHH'] < 788.875){ if (x[names(x) == 'male_head'] <.872){ return(548) } else { return(165) } } else { if (x[names(x) == 'nondurable_exp_mo'] < 4190.965){ return(-283) }else{ return(-11.4) } } }
这是一种缓慢的方法,因为它在每个单独的实例上应用所有这些操作.函数调用,if语句,尤其是像names(x) == 'bizrev_allHH'
all这样的操作都会产生一些开销,当你为每个实例执行操作时会增加这些开销.
相比之下,简单地比较两个数字非常快!所以,编写上面的矢量化版本.
predictif_fast <- function(newdata) { n1 <- newdata$bizrev_allHH < 788.875 n2 <- newdata$male_head < .872 n3 <- newdata$nondurable_exp_mo < 4190.965 ifelse(n1, ifelse(n2, 548.55893, 165.15537), ifelse(n3, -283.35145, -11.40185)) }
注意,这非常重要,这个函数没有被传递给一个实例.它意味着通过你的整个新数据.这是有效的,因为<
和ifelse
操作都是矢量化的:当给定矢量时,它们返回一个矢量.
让我们比较你的功能和这个新功能:
> microbenchmark(predictif.NT(testtree, sampx), predictif_fast(sampx)) Unit: microseconds expr min lq mean median uq predictif.NT(testtree, sampx) 12106.419 13144.2390 14684.46 13719.406 14593.1565 predictif_fast(sampx) 189.093 213.6505 263.74 246.192 260.7895 max neval cld 79136.335 100 b 2344.059 100 a
请注意,我们通过矢量化获得了50倍的加速.
顺便说一下,有可能加快这一速度(ifelse
如果你通过索引获得聪明的话,有更快的替代方案),但是从"在每一行上执行一个函数"到"在整个向量上执行操作"的整体切换可以获得最大的加速.
这并不能完全解决您的问题,因为您需要在常规树上执行这些矢量化操作,而不仅仅是在这个特定的树上.我不会为您解决一般版本,但考虑到您可以重写您的godowntree
函数,以便它占用整个数据框并在完整的数据框上执行其操作,而不仅仅是一个.然后,不要使用if
分支,而是保留每个实例当前所处的子项的向量.