目前在numba中处理高阶函数的最佳方法是什么?
我实现了割线方法:
def secant_method_curried (f): def inner (x_minus1, x_0, consecutive_tolerance): x_new = x_0 x_old = x_minus1 x_oldest = None while abs(x_new - x_old) > consecutive_tolerance: x_oldest = x_old x_old = x_new x_new = x_old - f(x_old)*((x_old-x_oldest)/(f(x_old)-f(x_oldest))) return x_new return numba.jit(nopython=False)(inner)
问题是,有没有办法告诉numba那f
是doube(double)
,所以上面的代码与突破nopython=True
:
TypingError: Failed at nopython (nopython frontend) Untyped global name 'f'
看起来在以前的版本中有一个FunctionType,但被删除/重命名:http://numba.pydata.org/numba-doc/0.8/types.html#functions
在这个页面上,他们提到了一个名为numba.addressof()的东西,这看起来很有帮助,但又可以追溯到4年前.
经过一些实验,我可以重现你的错误.在这种情况下,jit
传递给你的函数就足够了secant_method_curried
:
>>> from numba import njit >>> def func(x): # an example function ... return x >>> p = secant_method_curried(njit(func)) # jitted the function >>> p(1,2,3) 2.0
您也可以在传递njit(func)
或时声明签名jit(func)
.
在文档中还有一个关于numba闭包的好例子,并且还提到:
[...]如果从另一个jitted函数调用它,你应该JIT编译该函数.