pythontorch

fvcore multiple arguments forward modules?


I have a model that take more than one arguments in forward. Recently I'm trying to query some informations my model by fvcore module in python, but I can't find any document for multiple forward arguments!

and edited code tried to have multiple functions:

from fvcore.nn import FlopCountAnalysis, parameter_count_table

def modelCount(model, input_tensor, *args, **kwargs):
    def _wrapped_forward(x):
        return model(x, *args, **kwargs)
    flops = FlopCountAnalysis(_wrapped_forward, input_tensor).total()
    params = parameter_count_table(model)
    return flops, params
    

but it does not help ... I still have error.


Solution

  • Pass your multiple arguments as a single tuple. fvcore will handle it automatically.


    Example

    from fvcore.nn import FlopCountAnalysis
    
    # Your model and input tensors
    # model = MyModel()
    # tensor1 = torch.randn(1, 20)
    # tensor2 = torch.randn(1, 5)
    
    # Package inputs into a tuple
    inputs = (tensor1, tensor2)
    
    # Pass the model and the input tuple
    flops = FlopCountAnalysis(model, inputs).total()
    

    The tool unpacks the tuple and calls model.forward(tensor1, tensor2) internally. You don't need a wrapper.