In PyTorch (v1.10) Distibuted DataParallel, unused parameters in a model that don't contribute to the final loss can raise a RuntimeError (as mentioned in this other question, this PyTorch forums thread).
"RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument
find_unused_parameters=True
totorch.nn.parallel.DistributedDataParallel
, and by making sure allforward
function outputs participate in calculating loss."
Although it's possible to inspect which parameters are affected at error-time (as mentioned above, or setting env var TORCH_DISTRIBUTED_DEBUG="INFO"
), it seems like there should be a way to statically inspect a model to locate (and presumably prune or disable gradient on) parameters that aren't contributing to the current loss objective?
So given a torch.nn.Module
-based model
whose forward()
function returns some loss
tensor (maybe alongside others) - How can we programmatically, before starting to train, find all parameters (including nested modules) that aren't contributing to loss
?
By default, PyTorch tensors that are the result of some computation record their history, that is their ancestors. This is needed for the backward pass to compute the gradient.
We can make use of this to find all tensors that contribute to some new tensors by just going through the whole history.
Note that this works for a static network that always has the same architecture. As soon as you have conditionals that e.g. depend on some intermediate value this won't work, and I claim in that case it is impossible to find what tensors are involved in advance. (It's similar to the halting problem.)
import torch
import torch.nn as nn
# Example of a simple network
class Net(nn.Module):
def __init__(self):
super().__init__()
self.x = nn.Parameter(torch.tensor([999999.0])) # not contributing
self.layers = nn.ModuleList([nn.Sequential(nn.Linear(1, 4), nn.Linear(4, 1)) for _ in range(3)])
def forward(self, x):
for m in self.layers: x = m(x) + x
return x
net = Net()
x = torch.ones((1, 1))
# compute the forward pass to create the computation graph
y = net(x)
# use computation graph to find all contributing tensors
def get_contributing_params(y, top_level=True):
nf = y.grad_fn.next_functions if top_level else y.next_functions
for f, _ in nf:
try:
yield f.variable
except AttributeError:
pass # node has no tensor
if f is not None:
yield from get_contributing_params(f, top_level=False)
contributing_parameters = set(get_contributing_params(y))
all_parameters = set(net.parameters())
non_contributing = all_parameters - contributing_parameters
print(non_contributing) # returns the [999999.0] tensor