pythonnumpyjit

Numba fails to compile `jitclass` with with constructor accepting Numpy array arguments


The following implementations of evaluate compile correctly :

import numpy as np
import numpy.typing as npt
from numba import njit
from numba.experimental import jitclass

Standalone Function

@njit
def evaluate(x : npt.NDArray[np.float64], m : float, b : float) -> npt.NDArray[np.float64]:
    return m * x + b

x = np.linspace(0, 100)
y = evaluate(x, 2, 3)

@jitclass with Standalone Function

@jitclass
class LineEvaluator:
    def __init__(self):
        ...

    def evaluate(self, x : npt.NDArray[np.float64], m : float, b : float) -> npt.NDArray[np.float64]:
        return m * x + b

x = np.linspace(0, 100)
y = LineEvaluator().evaluate(x, 2, 3)

However the following implementation fails to compile with an error :

@jitclass with Arguments in Constructor

@jitclass
class LineEvaluator:
    def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
        self.x = x
        self.m = m
        self.b = b

    def evaluate(self) -> npt.NDArray[np.float64]:
        return self.m * self.x + self.b

x = np.linspace(0, 100)
y = LineEvaluator(x, 2, 3).evaluate()
Failed in nopython mode pipeline (step: nopython frontend)
Cannot resolve setattr: (instance.jitclass.LineEvaluator#117caf610<>).x = array(float64, 1d, C)

File "test.py", line 9:
    def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
        self.x = x
        ^

During: typing of set attribute 'x' at /private/tmp/test.py (9)

File "test.py", line 9:
    def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
        self.x = x
        ^

During: resolving callee type: jitclass.LineEvaluator#117caf610<>
During: typing of call at <string> (3)

During: resolving callee type: jitclass.LineEvaluator#117caf610<>
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>

@jitclass Member Type Annotations

@jitclass
class LineEvaluator:
    x : npt.NDArray[np.float64]
    m : float
    b : float

    def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
        self.x = x
        self.m = m
        self.b = b

    def evaluate(self) -> npt.NDArray[np.float64]:
        return self.m * self.x + self.b

x = np.linspace(0, 100)
y = LineEvaluator(x, 2, 3).evaluate()
Traceback (most recent call last):
  File "/private/tmp/test.py", line 7, in <module>
    class LineEvaluator:
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/decorators.py", line 88, in jitclass
    return wrap(cls_or_spec)
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/decorators.py", line 77, in wrap
    cls_jitted = register_class_type(cls, spec, types.ClassType,
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/base.py", line 180, in register_class_type
    spec[attr] = as_numba_type(py_type)
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/core/typing/asnumbatype.py", line 121, in __call__
    return self.infer(py_type)
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/core/typing/asnumbatype.py", line 115, in infer
    raise errors.TypingError(
numba.core.errors.TypingError: Cannot infer numba type of python type numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]

The error messages are pretty opaque, why is the compilation failing in this specific case?

Thank you!


Solution

  • You can define spec= in your @jitclass:

    import numba as nb
    import numpy as np
    from numba.experimental import jitclass
    
    
    @jitclass(spec=[("x", nb.float64[:]), ("m", nb.float64), ("b", nb.float64)])
    class LineEvaluator:
        def __init__(self, x, m, b):
            self.x = x
            self.m = m
            self.b = b
    
        def evaluate(self):
            return self.m * self.x + self.b
    
    
    x = np.linspace(0, 100)
    y = LineEvaluator(x, 2, 3).evaluate()
    print(y)
    

    Prints:

    [  3.           7.08163265  11.16326531  15.24489796  19.32653061
      23.40816327  27.48979592  31.57142857  35.65306122  39.73469388
      43.81632653  47.89795918  51.97959184  56.06122449  60.14285714
      64.2244898   68.30612245  72.3877551   76.46938776  80.55102041
      84.63265306  88.71428571  92.79591837  96.87755102 100.95918367
     105.04081633 109.12244898 113.20408163 117.28571429 121.36734694
     125.44897959 129.53061224 133.6122449  137.69387755 141.7755102
     145.85714286 149.93877551 154.02040816 158.10204082 162.18367347
     166.26530612 170.34693878 174.42857143 178.51020408 182.59183673
     186.67346939 190.75510204 194.83673469 198.91836735 203.        ]