I am trying to train one CNN model with Pytorch, so that the output behaves differently for different types of inputs. (i.e. If the input images are human-beings, it outputs pattern A, but if the input is some other animals, it outputs pattern B).
After some online search, it seems Siamese network is related to this. So I have the following 2 questions:
(1) Is Siamese network really a good way to train such a model?
(2) From the implementation point of view, how should I implement the code in pytorch?
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(1, 4, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.ReflectionPad2d(1),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
)
self.fc1 = nn.Sequential(
nn.Linear(8*100*100, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5))
def forward_once(self, x):
output = self.cnn1(x)
output = output.view(output.size()[0], -1)
output = self.fc1(output)
return output
def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
return output1, output2
Currently, I am trying some existing implementation I found online like the above class definition. It works, but there will always be two inputs and two outputs for this model. I agree that it is convenient for training, but ideally, it should be only one input and one (two is also fine) output during inference.
Could someone provide some guidance on how to modify the code to make it single input?
You can call forward_once
during inference: this takes a single input and returns a single output. Note that explicitly calling forward_once
will not invoke any hooks you might have on forward/backward calls of your module.
Alternatively, you can make forward_once
your module's forward
function, and make your training function do the double calling of your model (which makes more sense: Siamese networks is a training method, and not part of a network's architecture).