pythonnumpyclasspytorchtorch

Specify np.ndarray and torch.tensor as two options for a dtype


I would like to specify two options for a datatype for a variable. Something like this:

class MyDataClass:

    MyData: np.ndarray | torch.float

    def __init__(self, as_torch):
        MyData = np.arange(1,10)
        if as_torch:
            MyData = torch.from_numpy(MyData)

... but this throws an error: TypeError: unsupported operand type(s) for |: 'type' and 'torch.dtype'

I think it would be possible to just not assign a datatype to MyData, but I was wondering if there's a better solution. What is the best practice to work around this error or solve this? Thanks in advance!

Tried: np.ndarray | torch.tensor


Solution

  • You need to use a type for type hinting, not torch.float. If you'd use torch.dtype (torch.float is of type torch.dtype), it'd work but you don't want that presumably. As you want a numpy array or torch tensor, you'd want to use torch.Tensor (note the capital T) i.e.:

    MyData: np.ndarray | torch.Tensor
    

    FWIW, torch.tensor is a function, not a type. All torch tensors are of type torch.Tensor.