In the context of exporting pytorch code to ONNX, I get this warning:
TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
Here is the offending line:
text_lengths = torch.tensor([text_inputs.shape[1]]).long().to(text_inputs.device)
text_inputs
is a torch.Tensor
of shape torch.Size([1, 81])
And the warning is spot on and cannot be ignored, because the shape of text_inputs
is supposed to be dynamic.
I need text_lengths
to be a torch.Tensor
that contains the number 81
coming from the shape
of text_inputs
. The "offending line" from above succeeds in doing that, but we actually make a round trip from pytorch to a Python int
and back to pytorch, because the elements in the torch.Size
objects are Python int
s. This is (1) somewhat weird, (2) probably inefficient in terms of GPU -> CPU -> GPU and, as stated above, an actual problem in the ONNX exporting context.
Is there some other way how I can use a tensor's shape in torch computations, without "leaving" the torch world?
EDIT As was pointed out in the comments, the original answer was incorrect.
Using torch.tensor
will lead to a constant value in the exported graph. When tracing, outputs of torch.tensor.shape
and torch.tensor.size
should return tensors (instead of python integers), so the code above should export as intended just with
text_lengths = text_inputs.shape[1]
returning a LongTensor
with the correct value.
Seems that the code doesn't behave as intended because of the brackets in the constructor. The shape inside of [text_inputs.shape[1]]
gets interpreted as an integer inside of a python list, and consequently saved as a constant during the trace.
Dropping the extra brackets when creating the tensor will export correctly as a scalar defined at runtime by the shape of text_inputs
.
text_lengths = torch.LongTensor(text_inputs.size(1))