pytorch

How can I limit the range of parameters in pytorch?


So normally in pytorch, there is no strict limit to the parameters in models, but what if I wanted them to stay in the range [0,1]? Is there a way to block the update of parameters to outside that range?


Solution

  • A trick used in some generative adversarial networks (some of which require the parameters of the discriminator to be within a certain range) is to clamp the values after every gradient update. For example:

    model = YourPyTorchModule()
    
    for _ in range(epochs):
        loss = ...
        optimizer.step()
        for p in model.parameters():
            p.data.clamp_(0.0, 1.0)