pythonoopobject-oriented-analysis

Python OOP issue: How a function in a class is getting input value, when it is not mentioned in __init__() function as a parameter?


I am relatively new to object oriented programming and I am currently working on Generative Adversarial Networks project. I came across maxout activation function. The function is defined through maxout class. See the code below:

class maxout(torch.nn.Module):
    def __init__(self, num_pieces):
        
        super(maxout, self).__init__()
        
        self.num_pieces = num_pieces
        
    def forward(self, x):
        
        assert x.shape[1] % self.num_pieces == 0  # 625 % 5 = 0
        
        ret = x.view(*x.shape[:1], # batch_size
                      x.shape[1] // self.num_pieces, 
                      self.num_pieces, # num_pieces
                      *x.shape[2:]    )
                     
        ret, _ = ret.max(dim=2)
        
        return ret

This maxout function was later on used in a discriminator class. Following is the code for discriminator class.

class discriminator(torch.nn.Module):
    
    def __init__(self):
        
        super(discriminator, self).__init__()
        
        self.fcn = torch.nn.Sequential(
            # Fully connected layer 1
            torch.nn.Linear(
                in_features = 784,
                out_features=240,
                bias = True
            ),
            maxout(5),
            
            # Fully connected layer 2
            torch.nn.Linear(
                in_features = 48,
                out_features=1,
                bias = True
            )  )
        
    def forward(self, batch):
        inputs = batch.view(batch.size(0), -1)
        outputs = self.fcn(inputs)
        outputs = outputs.mean(0)
        return outputs.view(1)  # it will return a single value

The code is working fine but as per my naive understanding of Object oriented programming, the value of 'x' in forward() function in maxout class should be provided through init() function.

My question is: How input 'x' is being received by forward() function of maxout class,with out getting input through init() function.

Another way to put this question is: How output of Linear layer in discriminator class is passed to maxout function as 'x'?


Solution

  • You are passing layers to the constructior of the Sequential model, which is assigned to self.fcn, than you in discriminator.forward are calling this moddel. It's __call__ method than calls all the forward functions of the layers it contains.

    You can imagine something like this is going on

    ...
    def forward(self, batch):
       return torch.nn.Linear(
           in_features = 48,
           out_features=1,
           bias = True
       ).forward(
           maxout(5).forward(
               torch.nn.Linear(
                    in_features = 784,
                    out_features=240,
                    bias = True
               ).forward(batch.view(batch.size(0), -1)
           )
        ).mean(0).view(1)