I am training a model in pytorch and would like to be able to programmatically change some components of the model architecture to check which works best without any if-blocks in the forward()
. Consider a toy example:
import torch
class Model(torch.nn.Model):
def __init__(self, layers: str, d_in: int, d_out: int):
super().__init__()
self.layers = layers
linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
if self.layers == "parallel":
x1 = self.linears[0](x1)
x2 = self.linears[1](x2)
x = x1 + x2
elif self.layers == "sequential":
x = x1 + x2
x = self.linears[0](x)
x = self.linears[1](x)
return x
My first intution was to provide external functions, e.g.
def parallel(x1, x2):
x1 = self.linears[0](x1)
x2 = self.linears[1](x2)
return x1 + x2
and provide them to the model, like
class Model(torch.nn.Model):
def __init__(self, layers: str, d_in: int, d_out: int, fn: Callable):
super().__init__()
self.layers = layers
linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
self.fn = fn
def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x = self.fn(x1, x2)
but of course the function's scope does not know self.linears
and I would also like to avoid having to pass each and every architectural element to the function.
Do I wish for too much? Do I have to "bite the sour apple" as it says in German and either have larger function signatures, or use if-conditions, or something else? Or is there a solution to my problem?
You could just use the if statement in the init function or in another function, for example:
from enum import Enum
class ModelType(Enum):
Parallel = 1
Sequential = 2
class Model(torch.nn.Model):
def __init__(self, layers: str, d_in: int, d_out: int, model_type: ModelType):
super().__init__()
self.layers = layers
linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
self.model_type = model_type
self.initialize()
def initialize(self):
if self.model_type == ModelType.Parallel:
self.fn = self.parallel
else if self.model_type == ModelType.Sequential::
self.fn = self.sequential
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x = self.fn(x1, x2)
return x
def parallel(self, x1, x2):
x1 = self.linears[0](x1)
x2 = self.linears[0](x2)
x = x1 + x2
return x
def sequential(self, x1, x2):
x = x1 + x2
x = self.linears[0](x)
x = self.linears[0](x)
return x
I hope it helps.