pythonpytorchonnx

How can I get a pytorch Tensor containing some other Tensor's size (or shape) without conversion to Python int?


In the context of exporting pytorch code to ONNX, I get this warning:

TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

Here is the offending line:

text_lengths = torch.tensor([text_inputs.shape[1]]).long().to(text_inputs.device)

text_inputs is a torch.Tensor of shape torch.Size([1, 81])

And the warning is spot on and cannot be ignored, because the shape of text_inputs is supposed to be dynamic.

I need text_lengths to be a torch.Tensor that contains the number 81 coming from the shape of text_inputs. The "offending line" from above succeeds in doing that, but we actually make a round trip from pytorch to a Python int and back to pytorch, because the elements in the torch.Size objects are Python ints. This is (1) somewhat weird, (2) probably inefficient in terms of GPU -> CPU -> GPU and, as stated above, an actual problem in the ONNX exporting context.

Is there some other way how I can use a tensor's shape in torch computations, without "leaving" the torch world?


Solution

  • EDIT As was pointed out in the comments, the original answer was incorrect.

    Using torch.tensor will lead to a constant value in the exported graph. When tracing, outputs of torch.tensor.shape and torch.tensor.size should return tensors (instead of python integers), so the code above should export as intended just with

    text_lengths = text_inputs.shape[1]
    

    returning a LongTensor with the correct value.

    Seems that the code doesn't behave as intended because of the brackets in the constructor. The shape inside of [text_inputs.shape[1]] gets interpreted as an integer inside of a python list, and consequently saved as a constant during the trace.

    Dropping the extra brackets when creating the tensor will export correctly as a scalar defined at runtime by the shape of text_inputs .

    text_lengths = torch.LongTensor(text_inputs.size(1))