I am relatively new to object oriented programming and I am currently working on Generative Adversarial Networks project. I came across maxout activation function. The function is defined through maxout class. See the code below:
class maxout(torch.nn.Module):
def __init__(self, num_pieces):
super(maxout, self).__init__()
self.num_pieces = num_pieces
def forward(self, x):
assert x.shape[1] % self.num_pieces == 0 # 625 % 5 = 0
ret = x.view(*x.shape[:1], # batch_size
x.shape[1] // self.num_pieces,
self.num_pieces, # num_pieces
*x.shape[2:] )
ret, _ = ret.max(dim=2)
return ret
This maxout function was later on used in a discriminator class. Following is the code for discriminator class.
class discriminator(torch.nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.fcn = torch.nn.Sequential(
# Fully connected layer 1
torch.nn.Linear(
in_features = 784,
out_features=240,
bias = True
),
maxout(5),
# Fully connected layer 2
torch.nn.Linear(
in_features = 48,
out_features=1,
bias = True
) )
def forward(self, batch):
inputs = batch.view(batch.size(0), -1)
outputs = self.fcn(inputs)
outputs = outputs.mean(0)
return outputs.view(1) # it will return a single value
The code is working fine but as per my naive understanding of Object oriented programming, the value of 'x' in forward() function in maxout class should be provided through init() function.
My question is: How input 'x' is being received by forward() function of maxout class,with out getting input through init() function.
Another way to put this question is: How output of Linear layer in discriminator class is passed to maxout function as 'x'?
You are passing layers to the constructior of the Sequential
model, which is assigned to self.fcn
, than you in discriminator.forward
are calling this moddel. It's __call__
method than calls all the forward functions of the layers it contains.
You can imagine something like this is going on
...
def forward(self, batch):
return torch.nn.Linear(
in_features = 48,
out_features=1,
bias = True
).forward(
maxout(5).forward(
torch.nn.Linear(
in_features = 784,
out_features=240,
bias = True
).forward(batch.view(batch.size(0), -1)
)
).mean(0).view(1)