The following implementations of evaluate
compile correctly :
import numpy as np
import numpy.typing as npt
from numba import njit
from numba.experimental import jitclass
Standalone Function
@njit
def evaluate(x : npt.NDArray[np.float64], m : float, b : float) -> npt.NDArray[np.float64]:
return m * x + b
x = np.linspace(0, 100)
y = evaluate(x, 2, 3)
@jitclass
with Standalone Function
@jitclass
class LineEvaluator:
def __init__(self):
...
def evaluate(self, x : npt.NDArray[np.float64], m : float, b : float) -> npt.NDArray[np.float64]:
return m * x + b
x = np.linspace(0, 100)
y = LineEvaluator().evaluate(x, 2, 3)
However the following implementation fails to compile with an error :
@jitclass
with Arguments in Constructor
@jitclass
class LineEvaluator:
def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
self.x = x
self.m = m
self.b = b
def evaluate(self) -> npt.NDArray[np.float64]:
return self.m * self.x + self.b
x = np.linspace(0, 100)
y = LineEvaluator(x, 2, 3).evaluate()
Failed in nopython mode pipeline (step: nopython frontend)
Cannot resolve setattr: (instance.jitclass.LineEvaluator#117caf610<>).x = array(float64, 1d, C)
File "test.py", line 9:
def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
self.x = x
^
During: typing of set attribute 'x' at /private/tmp/test.py (9)
File "test.py", line 9:
def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
self.x = x
^
During: resolving callee type: jitclass.LineEvaluator#117caf610<>
During: typing of call at <string> (3)
During: resolving callee type: jitclass.LineEvaluator#117caf610<>
During: typing of call at <string> (3)
File "<string>", line 3:
<source missing, REPL/exec in use?>
@jitclass
Member Type Annotations
@jitclass
class LineEvaluator:
x : npt.NDArray[np.float64]
m : float
b : float
def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
self.x = x
self.m = m
self.b = b
def evaluate(self) -> npt.NDArray[np.float64]:
return self.m * self.x + self.b
x = np.linspace(0, 100)
y = LineEvaluator(x, 2, 3).evaluate()
Traceback (most recent call last):
File "/private/tmp/test.py", line 7, in <module>
class LineEvaluator:
File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/decorators.py", line 88, in jitclass
return wrap(cls_or_spec)
File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/decorators.py", line 77, in wrap
cls_jitted = register_class_type(cls, spec, types.ClassType,
File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/base.py", line 180, in register_class_type
spec[attr] = as_numba_type(py_type)
File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/core/typing/asnumbatype.py", line 121, in __call__
return self.infer(py_type)
File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/core/typing/asnumbatype.py", line 115, in infer
raise errors.TypingError(
numba.core.errors.TypingError: Cannot infer numba type of python type numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]
The error messages are pretty opaque, why is the compilation failing in this specific case?
Thank you!
You can define spec=
in your @jitclass
:
import numba as nb
import numpy as np
from numba.experimental import jitclass
@jitclass(spec=[("x", nb.float64[:]), ("m", nb.float64), ("b", nb.float64)])
class LineEvaluator:
def __init__(self, x, m, b):
self.x = x
self.m = m
self.b = b
def evaluate(self):
return self.m * self.x + self.b
x = np.linspace(0, 100)
y = LineEvaluator(x, 2, 3).evaluate()
print(y)
Prints:
[ 3. 7.08163265 11.16326531 15.24489796 19.32653061
23.40816327 27.48979592 31.57142857 35.65306122 39.73469388
43.81632653 47.89795918 51.97959184 56.06122449 60.14285714
64.2244898 68.30612245 72.3877551 76.46938776 80.55102041
84.63265306 88.71428571 92.79591837 96.87755102 100.95918367
105.04081633 109.12244898 113.20408163 117.28571429 121.36734694
125.44897959 129.53061224 133.6122449 137.69387755 141.7755102
145.85714286 149.93877551 154.02040816 158.10204082 162.18367347
166.26530612 170.34693878 174.42857143 178.51020408 182.59183673
186.67346939 190.75510204 194.83673469 198.91836735 203. ]