0%

Cython的优化指南

很有用的资源
官方: https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html
B站: https://www.bilibili.com/video/BV1NF411i74S/

基本使用

Cython可以在python中实现近似于C++的速度,基本方法是用类似于python的语法写.pyx 文件,然后再编译成python可以直接调用的library

编译需要一个setup.py

1
2
3
4
5
6
7
8
9
10
11
# setup.py

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
from Cython.Build import cythonize

setup(
ext_modules = cythonize("cython_module.pyx", annotate=True)
)

然后敲入入下代码就可以编译了

1
python setup.py build_ext --inplace

Cython还提供了很方便的visualization,可以看implementation里面有哪些是C++ 哪些是python

1
cython -a target.pyx

例子:

其中黄色越深,代表调用的Python API 越多,which means the efficiency is lower

怎么优化

最常用的一般是下面几个要点

  • 在function前加上@cython.wraparound(False) @cython.boundscheck(False) @cython.cdivision(True),避免python里面耗时间的boundscheck and so on (当然这样的话bound error就不会提示了,所以建议最后没bug了再加
  • 标注function input和return 的type
  • 标注所有要用到的variable 的type(包括numpy array 和 index,例如i j等) (就和c一样)
  • 不要直接用numpy的api,要用就自己写

效果如下,可以看到最后loop全部都白啦!

其他小技巧 OpenMP 多线程加速

.pyx 中,加入如下prefix

1
2
# distutils: extra_compile_args=-fopenmp
# distutils: extra_link_args=-fopenmp

然后在想要并行计算的for loop里用prange, 就可以啦:

1
2
3
4
5
6
7
# We use prange here.
for x in prange(x_max, nogil=True):
for y in range(y_max):

tmp = clip(array_1[x, y], 2, 10)
tmp = tmp * a + array_2[x, y] * b
result_view[x, y] = tmp + c

总结

如果有些实现起来很简单,但是python因为type determine, memory allocation, bound check等 各种原因导致效率很慢,可以试着用cython来加速。效果很显著,最主要的是很容易就可以integrate into python, which is a must-have for DL work.