pytorchsubclassing

How do I force a PyTorch tensor to always satisfy some (possibly arbitrary) property?


I would like to subclass torch.Tensor to make tensors that will always satisfy some user-defined property. For example, I might want my subclassed tensor to represent a categorical probability distribution, so I always want the last dim to sum to one.

I can define my subclass as:

class ValidatedArray(torch.Tensor):
    def __init__(self, array: torch.Tensor):
        self.validate_array()
        
    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        self.validate_array()
        
    def validate_array(self):
        assert torch.allclose(self.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.'  

This catches the most likely cases when the tensor might violate my property: at instantiation and while trying to set certain values.

Validation works at instantiation:

>>> array = torch.ones(3,4)
>>> va1 = ValidatedArray(array)
AssertionError: The last dim represents a categorical distribution. It must sum to one.

Validation works when trying to set an invalid value:

>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va1[0] = 1
AssertionError: The last dim represents a categorical distribution. It must sum to one.

Validation "fails" in these cases. I can make a ValidatedArray that wouldn't pass validation.

>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va2 = va1 + 2
>>> va2
ValidatedArray([[2.2500, 2.2500, 2.2500, 2.2500],
                [2.2500, 2.2500, 2.2500, 2.2500],
                [2.2500, 2.2500, 2.2500, 2.2500]])
>>> array = torch.nn.functional.softmax(torch.ones(3,4), -1)
>>> va1 = ValidatedArray(array)
>>> va1.fill_(2.)
>>> va1
ValidatedArray([[2., 2., 2., 2.],
                [2., 2., 2., 2.],
                [2., 2., 2., 2.]])

Is there a way I can ensure an instance of ValidatedArray will always pass my validation method? Is there some inherited method from torch.Tensor that runs anytime the underlying data is changed? I imagine I could extend that method with my validation.

Note: I don't want to make a container object with setter and getter methods to a tensor attribute. I would like my subclass to otherwise be a drop in replacement for a normal torch.Tensor so I can use all the normal torch operations.


Solution

  • I think I can see what you're trying to do and it makes sense as something that could be useful, you get this nice guarantee that every time you interact with the tensor it obeys some property. I'd probably advise against doing something like this though for these reasons:

    If I was in your position, I'd probably write a validate function like you have

    def validate_array(tensor: torch.tensor):
        # do check here
        ...
    

    and then every time this became important I'd add an assert

    assert validate_array(tensor), "tensor should obey property"
    

    I know this isn't a super exciting way to do this but I've not come across a use case that needs something more heavy weight.

    One nice thing you could do with the container class is

    class ValidatedArray:
        def __init__(self, array: torch.Tensor):
            self.validate_array()
            self._array = array
            
            
        def _validate_array(self):
            assert torch.allclose(self.sum(-1), torch.ones(1)), f'The last dim represents a categorical distribution. It must sum to one.' 
    
        @property
        def array(self):
            self._validate_array()
            return self._array 
        
    

    That way, every time the array is accessed, you perform the check - however, for the reasons above I probably wouldn't do this. (Also note how I've removed the subclass ;))