numba
npx skills add https://github.com/tondevrel/scientific-agent-skills --skill numba
Agent 安装分布
Skill 文档
Numba – High-Performance Python with JIT
Numba makes Python code go fast. It works by decorating your functions with decorators that tell Numba to compile them. It is particularly effective for code that involves heavy numerical loops and NumPy array manipulations.
When to Use
- When NumPy’s built-in vectorization isn’t enough for your specific algorithm.
- You have complex nested loops that are slow in standard Python.
- You need to write custom “ufuncs” (universal functions) that operate element-wise on arrays.
- High-performance physical simulations (Monte Carlo, N-body, Grid-based solvers).
- Accelerating code for execution on NVIDIA GPUs (CUDA).
- Creating parallelized code that utilizes all CPU cores without the overhead of multiprocessing.
Reference Documentation
Official docs: https://numba.pydata.org/numba-doc/latest/index.html
User Guide: https://numba.pydata.org/numba-doc/latest/user/index.html
Search patterns: @njit, @vectorize, prange, cuda.jit, numba.typed
Core Principles
nopython Mode (@njit)
This is the “gold standard” for Numba. In this mode, Numba compiles the code without using the Python C-API, resulting in maximum speed. If it can’t compile (e.g., because of unsupported Python objects), it throws an error.
Just-In-Time (JIT) Compilation
Compilation happens the first time you call the function. The machine code is then cached for subsequent calls.
Array-Oriented
Numba is designed to work with NumPy arrays. It understands their memory layout and can generate highly optimized loops over them.
Quick Reference
Installation
pip install numba
Standard Imports
import numpy as np
from numba import njit, prange, vectorize, guvectorize, cuda
Basic Pattern – Accelerating a Loop
import numpy as np
from numba import njit
# 1. Apply the @njit decorator (alias for @jit(nopython=True))
@njit
def sum_array(arr):
res = 0.0
# Standard Python loop that would be slow is now fast as C
for i in range(arr.shape[0]):
res += arr[i]
return res
# 2. Execute
data = np.random.random(1_000_000)
result = sum_array(data) # First call compiles, then runs
Critical Rules
â DO
- Prefer @njit – Always use nopython=True (or its alias @njit). It ensures your code is actually running at machine speed.
- Use NumPy Arrays – Numba is optimized for NumPy. Avoid standard Python lists inside jitted functions.
- Enable Parallelism – Use
@njit(parallel=True)andprangeinstead ofrangefor automatic multi-threading. - Cache Compiled Code – Use
@njit(cache=True)to avoid recompilation every time you restart your script. - Warm up – Remember that the first call is slow due to compilation. In timing benchmarks, always run the function once before measuring.
- Type Specifying (Optional) – You can provide signatures (e.g.,
(float64[:],)) to speed up the very first call, but Numba usually infers them well.
â DON’T
- Don’t use Python Objects – Strings, dictionaries, and custom classes are slow or unsupported in nopython mode. Use
numba.typedfor specialized containers if needed. - Don’t JIT small functions – The overhead of calling a jitted function from Python can outweigh the gains for trivial operations.
- Don’t use unsupported libraries – You cannot use pandas, matplotlib, or requests inside an
@njitfunction. - Don’t modify global state – Jitted functions should be “pure” as much as possible for stability.
Anti-Patterns (NEVER)
from numba import njit
import pandas as pd
# â BAD: Using Pandas inside @njit (Unsupported)
@njit
def bad_func(df):
return df['col'].sum() # Will raise a LoweringError
# â
GOOD: Pass NumPy arrays instead
@njit
def good_func(arr):
return arr.sum()
# â BAD: Using @jit without nopython=True
from numba import jit
@jit
def slow_func(x): # This might fall back to "Object Mode" (slow)
return x + 1
# â
GOOD: Always ensure nopython mode
@njit
def fast_func(x):
return x + 1
# â BAD: Manual loops in Python to call a JIT function
# for i in range(1000):
# process_element(arr[i]) # Calling JIT overhead 1000 times
# â
GOOD: Move the loop INSIDE the @njit function
@njit
def process_all(arr):
for i in range(arr.shape[0]):
process_element(arr[i])
Parallelism and Vectorization
Automatic Multi-threading
from numba import njit, prange
@njit(parallel=True)
def parallel_sum(A):
# Use prange for the loop that should be parallelized
s = 0
for i in prange(A.shape[0]):
s += A[i]
return s
Creating Fast ufuncs (@vectorize)
from numba import vectorize
# This creates a NumPy ufunc that supports broadcasting
@vectorize(['float64(float64, float64)'], target='parallel')
def fast_add(x, y):
return x + y
# Now you can use it on massive arrays
res = fast_add(arr1, arr2)
Working with Structs and Types
numba.typed for Non-Array Data
from numba.typed import List, Dict
from numba import njit
@njit
def use_typed_list():
l = List()
l.append(1.0)
return l
GPU Acceleration (numba.cuda)
Writing CUDA Kernels
from numba import cuda
@cuda.jit
def my_kernel(io_array):
# Calculate thread indices
pos = cuda.grid(1)
if pos < io_array.size:
io_array[pos] *= 2
# Usage
data = np.ones(256)
threadsperblock = 32
blockspergrid = (data.size + (threadsperblock - 1)) // threadsperblock
my_kernel[blockspergrid, threadsperblock](data)
Practical Workflows
1. Fast Monte Carlo Simulation
import random
@njit(parallel=True)
def monte_carlo_pi(nsamples):
acc = 0
for i in prange(nsamples):
x = random.random()
y = random.random()
if (x**2 + y**2) < 1.0:
acc += 1
return 4.0 * acc / nsamples
2. Custom Image Filter (Stencil)
from numba import njit
@njit
def apply_threshold(image, threshold):
M, N = image.shape
result = np.zeros_like(image)
for i in range(M):
for j in range(N):
if image[i, j] > threshold:
result[i, j] = 255
return result
3. Solving a Physics Grid (Laplace Equation)
@njit
def solve_laplace(u, niters):
M, N = u.shape
for n in range(niters):
for i in range(1, M-1):
for j in range(1, N-1):
u[i, j] = 0.25 * (u[i+1, j] + u[i-1, j] + u[i, j+1] + u[i, j-1])
return u
Performance Optimization
The inspect_types() method
Use this to see if Numba had to fall back to expensive Python objects or if it managed to optimize everything to native types.
fast_func.inspect_types() # Prints color-coded annotated code
Avoid Array Allocation in Loops
Pre-allocate arrays outside the @njit function or pass them as arguments to avoid memory management overhead.
# â
GOOD:
@njit
def compute_into(out_arr, in_arr):
for i in range(in_arr.shape[0]):
out_arr[i] = in_arr[i] * 2
Common Pitfalls and Solutions
The “Global Variable” problem
Numba captures the value of global variables at the time of compilation.
# â Problem: Changing a global variable won't affect the jitted function
K = 10
@njit
def f(x): return x + K
K = 20
f(1) # Result is still 11!
# â
Solution: Pass constants as arguments
Object Mode Fallback
If Numba says “Object mode is enabled”, your code will be slow.
# â
Solution: Force nopython mode
@njit # If this throws error, fix the code instead of removing @njit
Random Seed in Parallel
Using np.random in parallel=True requires care to ensure independent streams for each thread. Standard random.random() or np.random.random() inside Numba are thread-safe and handle seeding per-thread automatically.
Best Practices
- Always use @njit – Never use
@jitwithoutnopython=True - Pre-allocate arrays – Avoid creating arrays inside hot loops
- Use prange for parallelism – Enable automatic multi-threading with
parallel=Trueandprange - Cache compiled functions – Use
cache=Trueto avoid recompilation - Warm up functions – Call jitted functions once before benchmarking
- Pass NumPy arrays – Convert Python lists to NumPy arrays before calling jitted functions
- Avoid Python objects – Use
numba.typed.Listandnumba.typed.Dictif you need containers - Check compilation mode – Use
inspect_types()to verify nopython mode - Handle first-call overhead – Remember the first call compiles the function
- Use appropriate signatures – Optional but can speed up first compilation
Numba is the bridge that allows Python to compete with C++ and Fortran in the high-performance computing arena. It removes the “Python tax” from your loops, enabling rapid prototyping without sacrificing execution speed.