I am trying to build a RL model, where my actor network has some pruned connections. When using the data collector SyncDataCollector from torchrl, the deepcopy fails (see error below).
This seems to be due to the pruned connections, which sets the pruned layers with gradfn (and not requires_grad=True) as suggested in this post.
Here is an example of code I would like to run, where SyncDataCollector attempts a deepcopy of the model,
device = torch.device("cpu")
model = nn.Sequential(
nn.Linear(1,5),
nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)
policy_module = TensorDictModule(
model, in_keys=["in"], out_keys=["out"]
)
env = FlyEnv()
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=1,
total_frames=2,
split_trajs=False,
device=device,
)
And here is a minimal example producing the error
import torch
from torch import nn
from copy import deepcopy
import torch.nn.utils.prune as prune
device = torch.device("cpu")
model = nn.Sequential(
nn.Linear(1,5),
nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)
new_model = deepcopy(model)
where the error is
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001
I tried to remove the pruning with prune.remove(model[0], 'weight')
and then setting model[0].requires_grad_()
, which fixes the result but then all the weights are trained...
I think it might work to mask the pruned weights "manually", by masking them before each forward pass, but it does not seem efficient (nor elegant).
The error is caused because the parameter is moved to <param>_orig
and the masked value is stored alongside it.
When the SyncDataCollector takes the params and buffers out and puts them on "meta" device to create a stateless policy, these additional values are ignored because they're not parameters anymore (and hence not caught by the call to "to"
).
What you can do as a fix is to call
policy_module.module[0].weight = policy_module.module[0].weight.detach()
before creating the collector.
That should be ok because the weight
attribute will be recomputed during the next forward call anyway.
TorchRL should maybe handle better the deepcopy, although in this case the error is caused by a tensor requiring gradients at a place where it shouldn't. IMO the pruning methods should compute the "weight"
during forward call (as they do) but then prune