Numba - @jit basic usage

https://numba.readthedocs.io/en/stable/user/5minguide.html

LAZY COMPILATION

The recommended way to use the @jit decorator is to let Numba decide when and how to optimize.


from numba import jit

@jit(nopython=True)
def f(x, y):
    return x + y

print(f(1,2))   # return 3
print(f(1.3,2))   # return 3.3
print(f(1j,2))   # return (2+1j)
# Three calls will generate different code paths.

EAGER COMPILATION

We can tell Numba the function signature we are expecting. No other specialization will be allowed.


from numba import jit, int64

@jit(int64(int64, int64),nopython=True)
def f(x, y):
    return x + y

print(f(1,2))   # return 3
print(f(1.3,2))   # return 3 (float -> int)
#print(f(1j,2))   # TypeError

COMPILATION OPTIONS

A number of keyword-only arguments can be passed to the @jit decorator.


nopython=True   # force 'nopython mode'
forceobj=True   # force 'object mode'
fastmath=True   # relax some numerical rigour
parallel=True   # enables automatic parallelization (no GIL!), only in 'nopython mode'
nogil=True   # Numba will release the GIL
cache=True   # using a file-based cache
debug=True   # for debugging

PARALLEL RANGE

https://numba.pydata.org/numba-doc/0.11/prange.html

Numba implements the ability to run loops in parallel. The loops body is scheduled in seperate threads, and they execute in a nopython numba context. 'prange' automatically takes care of data privatization and reductions.


# numba3.py
from numba import njit, float64, prange

x = np.arange(1000000, dtype=np.float64)

#@njit(parallel=True)   # with parallelization
#@njit('float64(float64[:])')   # eager compilation.
@njit
def parallel_sum(arr):
    result = 0.0
    for i in prange(arr.shape[0]):
        result += arr[i]
    return result

start = time.time()
parallel_sum(x)
end = time.time()
print("Elapsed (with compilation) = {}".format(end - start))

start = time.time()
parallel_sum(x)
end = time.time()
print("Elapsed (after compilation) = {}".format(end - start))

# Results.
# Without numba = 0.11021208763122559
# Elapsed (with compilation) = 0.05432915687561035
# Elapsed (after compilation) = 0.0010497570037841797
# With parallelization (12 cores).
# Elapsed (with compilation) = 0.1597743034362793
# Elapsed (after compilation) = 0.0004372596740722656