
How can I limit the range of parameters in pytorch?

So normally in pytorch, there is no strict limit to the parameters in models, but what if I wanted them to stay in the range [0,1]? Is there a way to block the update of parameters to outside that range?


  • A trick used in some generative adversarial networks (some of which require the parameters of the discriminator to be within a certain range) is to clamp the values after every gradient update. For example:

    model = YourPyTorchModule()
    for _ in range(epochs):
        loss = ...
        for p in model.parameters():
  , 1.0)