pythonpytorch

Programmatically change components of pytorch Model?


I am training a model in pytorch and would like to be able to programmatically change some components of the model architecture to check which works best without any if-blocks in the forward(). Consider a toy example:

import torch

class Model(torch.nn.Model):
   def __init__(self, layers: str, d_in: int, d_out: int):
      super().__init__()
      self.layers = layers
      linears = torch.nn.ModuleList([
         torch.nn.Linear(d_in, d_out),
         torch.nn.Linear(d_in, d_out),
      ])

   def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
      if self.layers == "parallel":
         x1 = self.linears[0](x1)
         x2 = self.linears[1](x2)
         x = x1 + x2
      elif self.layers == "sequential":
         x = x1 + x2
         x = self.linears[0](x)
         x = self.linears[1](x)
      return x

My first intution was to provide external functions, e.g.

def parallel(x1, x2):
   x1 = self.linears[0](x1)
   x2 = self.linears[1](x2)
   return x1 + x2

and provide them to the model, like

class Model(torch.nn.Model):
   def __init__(self, layers: str, d_in: int, d_out: int, fn: Callable):
      super().__init__()
      self.layers = layers
      linears = torch.nn.ModuleList([
         torch.nn.Linear(d_in, d_out),
         torch.nn.Linear(d_in, d_out),
      ])
      self.fn = fn

   def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
      x = self.fn(x1, x2)

but of course the function's scope does not know self.linears and I would also like to avoid having to pass each and every architectural element to the function.

Do I wish for too much? Do I have to "bite the sour apple" as it says in German and either have larger function signatures, or use if-conditions, or something else? Or is there a solution to my problem?


Solution

  • You could just use the if statement in the init function or in another function, for example:

    from enum import Enum
    
    class ModelType(Enum):
        Parallel = 1
        Sequential = 2
        
    class Model(torch.nn.Model):
        def __init__(self, layers: str, d_in: int, d_out: int, model_type: ModelType):
            super().__init__()
            self.layers = layers
            linears = torch.nn.ModuleList([
             torch.nn.Linear(d_in, d_out),
             torch.nn.Linear(d_in, d_out),
            ])
            self.model_type = model_type
            self.initialize()
    
        def initialize(self):
            if self.model_type == ModelType.Parallel:
                self.fn = self.parallel
            else if self.model_type == ModelType.Sequential::
                self.fn = self.sequential
    
        def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
            x = self.fn(x1, x2)
            return x
        
        def parallel(self, x1, x2):
            x1 = self.linears[0](x1)
            x2 = self.linears[0](x2)
            x = x1 + x2
            return x
        
        def sequential(self, x1, x2):
            x = x1 + x2
            x = self.linears[0](x)
            x = self.linears[0](x)
            return x
    

    I hope it helps.