pytorchtorchonnx

Unused or variables only used for control flow in onnx model


I have an onnx model, which has some (ideally) boolean inputs, that are only used for control flow within the model.

Some minimal code for what I try to do:

import onnx
import onnxruntime
import torch.onnx


class SumModule(torch.nn.Module):
  def forward(self, x1, x2):
    if x2 is not None:
      x1 *= 1
    return torch.sum(x1)


torch_model = SumModule()
torch_model.eval()
model_inputs = {'x1': torch.tensor([1, 2]), 'x2': torch.tensor([1, 2])}

torch_out = torch_model(**model_inputs)
torch.onnx.export(torch_model,
                  tuple(model_inputs.values()),
                  'model.onnx',
                  export_params=True,
                  opset_version=16,
                  do_constant_folding=True,
                  input_names=list(model_inputs.keys()),
                  output_names=['output'],
                  dynamic_axes={'x1': {0: 'batch_size'}, })

onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession('model.onnx')


def to_numpy(tensor):
  if isinstance(tensor, torch.Tensor):
    return tensor.detach().cpu().numpy()
  return tensor


model_inputs_np = {k: to_numpy(v) for k, v in model_inputs.items()}
ort_outs = ort_session.run(None, input_feed=model_inputs_np)

While the onnx export goes through, I cannot run the inference model without the error

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:x2

I think I misunderstand something fundamental with onnx here. The argument x2 obviously exists, why does onnx discard it somehow? The same happens if I don't use the argument x2 at all (but have it as an input argument), which I also find weird.

In my actual code, the control flow I want to do is the following: I have 3 inputs that are optional, so ideally it would be Optional[torch.Tensor]. However, onnx seems to be unable to deal with None. So instead I wanted to have the 3 inputs + 3 boolean flags (torch.tensor(True) or if need be torch.tensor([True]) or replace Truewith1or1.0` --> Same issue with all of those). Then within the code I do different things based on those flags. Why does onnx not allow this? I found out that having those variables if fine if I somehow include them in some computations sometimes, but I can't figure out the rule behind all of this.


Solution

  • Your problem is related to how torch.onnx.export works.

    When generating the ONNX model, torch executes (traces) the module once with given inputs while keeping track of all performed computations, then maps them to the corresponding ONNX Operators, and finally simplifies the graph. In your case, the noteworthy detail is that all control flows are evaluated once and Python built-in types are evaluated as constants. So the code

    if x2 is not None:
        x1 *= 1
    return torch.sum(x1)
    

    is saved as

    if True:
        x1 *= 1
    return torch.sum(x1)
    

    and when torch.onnx.export simplifies the graph, it removes all unused variables including x2, hence your error.

    If you want to preserve control flow in your exported model, you need torch to evaluate your model with torch.jit.script instead of torch.jit.trace. As you've already pointed out, ONNX expects a fixed amount of tensors as inputs, and does not accept "optional" arguments. Exporting the model with Scripting is done like this

    scripted_model = torch.jit.script(torch_model)
    torch.onnx.export(scripted model, ...)
    

    However, with this your model will still not work. We notice that the if statement in your forward pass is a Pythonic comparison, and doesn't operate on the tensor itself. So x2 will still be discarded during simplification. Changing SumModel to

    class SumModule(torch.nn.Module):
      def forward(self, x1, x2):
        if torch.any(x2):
          x1 *= 1
        return torch.sum(x1)
    

    will yield the correct graph, since now x2 is actually operated on. With this, you could use x2 as a boolean flag for control flow.

    Highly recommend looking into the torch documentation, as it explains a lot of common mistakes in regard to exporting.

    EDIT

    For completeness, I should add that the aforementioned approach should generally be avoided. Much of the hardware acceleration is not designed for conditionals, and trying to run ONNX models containing a lot of control flow with, for example CUDA, often leads to large parts of the graph falling back to CPU. When presented with a situation described in this question, I would recommend to consider

    1. If you can input the "optional" tensors as zeros without affecting the result
    2. If it makes sense to split the model into different several models altogether, and running each one for the corresponding input

    Rather than using the solution presented above

    EDIT

    Add an example of avoiding using if-else control flow for better hardware acceleration.

    # if the shape of x2 is static when x2 exists, let's say the shape is [1,2]
    # and assume your x2 will never be all zeros when x2 exists (you need to figure out a special case that x2 will never be)
    # then you can try the following
    # initialize/pass x2 as torch.zeros(1,2) when x2 is None
    # this can guarantee that x2 are always passed into the function
    # and it always has a static shape
    
    # initialize x2 as all zeros, and update it if x2 exists
    # this is outside the ONNX model
    x2 = torch.zeros(1,2)
    
    class SumModule(torch.nn.Module):
        def forward(self, x1, x2):
            # create the condition by pure pytorch, and convert it to float (0.0 or 1.0)
            condition = torch.tensor(torch.equal(x2, torch.zeros_like(x2))).to(x1.device).float()
            # avoid using control flow
            x1 = condition*(x1*1)+(1-condition)*(x1)
            return torch.sum(x1)