AFAIK there are 2 ways to express ResNet Block in pytorch:
Which leads to 2 kinds of code:
def forward(self, x):
y = x
x = self.conv1(x)
x = self.norm1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.norm2(x)
x += y
x = self.act2(x)
return x
def forward(self, x):
y = self.conv1(x)
y = self.norm1(y)
y = self.act1(y)
y = self.conv2(y)
y = self.norm2(y)
y += x
y = self.act2(y)
return y
Are they identical? Which one is preferred? Why?
It doesn't matter so long as the you retain some reference to the input.
At a high level, you are trying to compute output = activation(input + f(input))
Both methods shown accomplish this. As long as you don't lose the input
reference or change input
through an in-place operation, you should be fine.
For what it's worth, I would separate out the residual connection and the sub-block just for clarity:
class Block(nn.Module):
def __init__(self, ...):
super().__init__()
self.conv1 = ...
self.norm1 = ...
self.act = ...
self.conv2 = ...
self.norm2 = ...
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
return x
class ResBlock(nn.Module):
def __init__(self, block):
super().__init__()
self.block = block
self.act = ...
def forward(self, x):
return self.act(x + self.block(x))