I would like to subclass torch.Tensor
to make tensors that will always satisfy some user-defined property. For example, I might want my subclassed tensor to represent a categorical probability distribution, so I always want the last dim to sum to one.
I can define my subclass as:
class ValidatedArray(torch.Tensor):
def __init__(self, array: torch.Tensor):
self.validate_array()
def __setitem__(self, key, value):
super().__setitem__(key, value)
self.validate_array()
def validate_array(self):
assert torch.allclose(self.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.'
This catches the most likely cases when the tensor might violate my property: at instantiation and while trying to set certain values.
Validation works at instantiation:
>>> array = torch.ones(3,4)
>>> va1 = ValidatedArray(array)
AssertionError: The last dim represents a categorical distribution. It must sum to one.
Validation works when trying to set an invalid value:
>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va1[0] = 1
AssertionError: The last dim represents a categorical distribution. It must sum to one.
Validation "fails" in these cases. I can make a ValidatedArray
that wouldn't pass validation.
>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va2 = va1 + 2
>>> va2
ValidatedArray([[2.2500, 2.2500, 2.2500, 2.2500],
[2.2500, 2.2500, 2.2500, 2.2500],
[2.2500, 2.2500, 2.2500, 2.2500]])
>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va1.fill_(2.)
>>> va1
ValidatedArray([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
Is there a way I can ensure an instance of ValidatedArray
will always pass my validation method? Is there some inherited method from torch.Tensor
that runs anytime the underlying data is changed? I imagine I could extend that method with my validation.
Note: I don't want to make a container object with setter
and getter
methods to a tensor attribute. I would like my subclass to otherwise be a drop in replacement for a normal torch.Tensor
so I can use all the normal torch operations.
I think I can see what you're trying to do and it makes sense as something that could be useful, you get this nice guarantee that every time you interact with the tensor it obeys some property. I'd probably advise against doing something like this though for these reasons:
If I was in your position, I'd probably write a validate function like you have
def validate_array(tensor: torch.tensor):
# do check here
...
and then every time this became important I'd add an assert
assert validate_array(tensor), "tensor should obey property"
I know this isn't a super exciting way to do this but I've not come across a use case that needs something more heavy weight.
One nice thing you could do with the container class is
class ValidatedArray:
def __init__(self, array: torch.Tensor):
self.validate_array()
self._array = array
def _validate_array(self):
assert torch.allclose(self.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.'
@property
def array(self):
self._validate_array()
return self._array
That way, every time the array is accessed, you perform the check - however, for the reasons above I probably wouldn't do this. (Also note how I've removed the subclass ;))