I have a model that take more than one arguments in forward. Recently I'm trying to query some informations my model by fvcore module in python, but I can't find any document for multiple forward arguments!
and edited code tried to have multiple functions:
from fvcore.nn import FlopCountAnalysis, parameter_count_table
def modelCount(model, input_tensor, *args, **kwargs):
def _wrapped_forward(x):
return model(x, *args, **kwargs)
flops = FlopCountAnalysis(_wrapped_forward, input_tensor).total()
params = parameter_count_table(model)
return flops, params
but it does not help ... I still have error.
Pass your multiple arguments as a single tuple. fvcore
will handle it automatically.
from fvcore.nn import FlopCountAnalysis
# Your model and input tensors
# model = MyModel()
# tensor1 = torch.randn(1, 20)
# tensor2 = torch.randn(1, 5)
# Package inputs into a tuple
inputs = (tensor1, tensor2)
# Pass the model and the input tuple
flops = FlopCountAnalysis(model, inputs).total()
The tool unpacks the tuple and calls model.forward(tensor1, tensor2)
internally. You don't need a wrapper.