为什么用Python的numba库速度不升反降?

看了这篇文章 https://zhuanlan.zhihu.com/p/24168485 试了一下里面的 ma_numba 函数

import time

@numba.jit

def ma_numba(data, ma_length):

ma = []
data_window = data[:ma_length]
test_data = data[ma_length:]

for new_tick in test_data: data_window.pop(0) data_window.append(new_tick) sum_tick = 0 for tick in data_window: sum_tick += tick ma.append(sum_tick/ma_length)

a = np.arange(10000) t1 = time.time() b = list(a) bb = ma_numba(b, 5) t2 = time.time() print(t2 - t1)

不用 numba,大概耗时 0.03-0.04 秒,用了 numba,耗时 0.7-0.8 秒…奇了怪了,难道是我的姿势不对?


为什么用Python的numba库速度不升反降?

16 回复

我遇到numba速度变慢的情况,通常有这几个原因:

  1. 首次运行开销:numba第一次执行时会编译函数,这次运行会包含编译时间。用@jit(nopython=True, cache=True)可以缓存编译结果。

  2. 回退到object模式:如果numba无法编译为机器码,会回退到慢速的Python模式。确保函数参数和变量都有明确的类型。

  3. 频繁调用小函数:对于简单的操作,numba的调用开销可能超过加速收益。这种情况更适合用numpy的向量化操作。

看个例子,这个函数用numba反而更慢:

import numba
import numpy as np

@numba.jit
def slow_func(arr):
    result = 0
    for i in range(len(arr)):
        result += arr[i] * 2  # 太简单的操作
    return result

# 改成这样会好很多
@numba.jit(nopython=True)
def better_func(arr):
    return arr * 2  # 直接用numpy向量化

检查你的代码:用@jit(nopython=True)强制使用nopython模式,如果报错说明有类型问题;对于简单计算直接用numpy;确保在循环外调用numba函数。

numba不是万能药,得用在合适的地方。


第一,a = np.arange(10000) 这一句是排除在耗时计算之外的.
第二,b = list(a) 这一句是都被计入耗时之内的.所以对比是不存在这个问题的

对比时就是简单地把 .jit 这一句注释掉和不注释掉

你这代码本来就不科学啊。data_window.pop 你这是想干嘛啊?还有 sum_tick 有你这种写法嘛?好好的 O(n) 算法你给写成 O(n*k) ?

In [1]: import numpy as np

In [2]: import numba

In [3]: def moving_average(data, k):
…: partial_sum = sum(data[:k])
…: ret = [partial_sum / k]
…: for old_d, new_d in zip(data[:-k], data[k:]):
…: partial_sum = partial_sum - old_d + new_d
…: ret.append(partial_sum / k)
…: return ret
…:

In [4]: numba_moving_average = numba.jit(moving_average)

In [5]: arr = np.arange(10000)


In [6]: arr_list = list(arr)

In [7]: %timeit moving_average(arr_liset)

In [8]: %timeit moving_average(arr_list, 5)
3.8 ms ± 9.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [9]: %timeit numba_moving_average(arr_list, 5)
722 µs ± 35.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

不对啊,我这结果还是 numba 耗时长啊

t1 = time.time()
b=numba_moving_average(a,5)
t2 = time.time()
c=moving_average(a,5)
t3 = time.time()
print(t2-t1)
print(t3 - t2)

结果:

0.7720441818237305
0.008000373840332031

我也发现了这个情况。numba 和 numpy、cython 混用时耗时不降反升。推测是多种格式数据通过解释器互转效率低下。



In [8]: arr_list = list(np.arange(100000))

In [10]: t1 = time.time(); moving_average(arr_list, 5); t2 = time.time(); numba_moving_average(arr_list, 5); t3 = time.time()

In [11]: (t2 - t1, t3 - t2)
Out[11]: (0.0019309520721435547, 0.23806500434875488)

In [12]: t1 = time.time(); moving_average(arr_list, 5); t2 = time.time(); numba_moving_average(arr_list, 5); t3 = time.time()

In [13]: (t2 - t1, t3 - t2)
Out[13]: (0.0016407966613769531, 0.005582094192504883)

In [14]: t1 = time.time(); [moving_average(arr_list, 5) for i in range(100)]; t2 = time.time(); [numba_moving_average(arr_list, 5) for i in range(100)]; t3 = time.time()

In [15]: (t2 - t1, t3 - t2)
Out[15]: (0.18658995628356934, 0.12822914123535156)

In [16]: t1 = time.time(); [moving_average(arr_list, 5) for i in range(1000)]; t2 = time.time(); [numba_moving_average(arr_list, 5) for i in range(1000)]; t3 = time.time()

In [17]: (t2 - t1, t3 - t2)
Out[17]: (1.3983790874481201, 1.3098900318145752)

你这个结果也不乐观.看来还是混用不行. 后面再去折腾一下 cython 看看

我觉得 7# 说很清楚了吧,一般没有用 time.time() - start 来测试的,除非你程序大概跑在分钟级,data 大个一百万倍再说吧,timeit 是比较合适的测时间的工具。

还有,我想吐槽这个专栏,4# (同一人哎)说得更清楚,这个专栏是来逗比的么……写个移动平均当例子把 o(n) 弄成 o(n*k),这蛋疼的 pop(0)

更吐槽的是,还说第一反应上 NumPy,还 numpy_right ……
为啥不用 np.convolve(data, np.ones(500)/500,mode=‘valid’) 试试?
20.6 ms ± 93.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 是渣渣 i7-3687u 的结果,一样的 data size (100000), 他的 cython 版本说单次时间最快也就 0.0098s 也就是 9.8ms ,这货认真的? numpy 对他来说用法仅限于 a.mean() 和方便的索引了是么……

仔细看一下连 cython 里都还有 pop(0)……这个大哥仗着自己是 i7-6700k 就日了天了么……

你的代码跑 10000 遍使用了 numba.jit 是 6.175 秒,不使用 numba.jit 是 72.809 秒。numba 的 jit 技术还是起到作用了。

大佬都是从哪里知道这些偏僻的 numpy 函数的,系统地看文档吗?

顺便再扯一嘴,用 convole 还是一般带窗口的,像这个方窗的情况
<br>def maa(data, n):<br> ret = np.cumsum(data)<br> ret[n:] = ret[n:] - ret[:-n]<br> return ret[n-1:]/n<br>
渣渣本上也只要 990 微秒。这些不好好看 NumPy 的同学弄得好像 python 咋折腾都很低效……

我不是大佬……而且这个问题不能算生僻吧,移动平均,尤其是带有窗口函数的移动平均,遇到得应该还是很多的。我其实看到“移动平均”第一反应是“这其实是个卷积的问题”,当然这么想问题也会复杂化,卷积是 o(n*k),当然一些大窗口体系还能用更快的 fftconvole ……扯远了,知道 NumPy 里都有啥好玩儿的需要一定的数学基础吧,我感觉,把遇到的问题能比较“数学地”进行描述,NumPy/SciPy 总会有惊喜。一般来说都是 Google 一下 问题+scipy 就会看到好玩儿的函数在下面贴着。

不知道在哪看的了, 说是 jit 启动需要花费一点时间. 可能你这段代码的计算规模还是低了点~ 试试把规模再翻几十倍看看如何.

算了,运算量整太大了,脱离实际需求也没有意义了.反正优化方案里面已经 pass 掉 numba 了

回到顶部