deep-learningneural-networkgenerative-adversarial-networkstylegan

WHat does Lambda do in this code (python keras)?


def AdaIN(x):
    #Normalize x[0] (image representation)
    mean = K.mean(x[0], axis = [1, 2], keepdims = True)
    std = K.std(x[0], axis = [1, 2], keepdims = True) + 1e-7
    y = (x[0] - mean) / std
    
    #Reshape scale and bias parameters
    pool_shape = [-1, 1, 1, y.shape[-1]]
    scale = K.reshape(x[1], pool_shape)
    bias = K.reshape(x[2], pool_shape)#Multiply by x[1] (GAMMA) and add x[2] (BETA)
    return y * scale + bias

    

def g_block(input_tensor, latent_vector, filters):
    gamma = Dense(filters, bias_initializer = 'ones')(latent_vector)
    beta = Dense(filters)(latent_vector)
    
    out = UpSampling2D()(input_tensor)
    out = Conv2D(filters, 3, padding = 'same')(out)
    out = Lambda(AdaIN)([out, gamma, beta])
    out = Activation('relu')(out)
    
    return out

Please see code above. I am currently studying styleGAN. I am trying to convert this code into pytorch but I cant seem to understand what does Lambda do in g_block. AdaIN needs only one input based on its declaration but some how is gamma and beta also used as input? Please inform me what does the Lambda do in this code.

Thank you very much.


Solution

  • Lambda layers in keras are used to call custom functions inside the model. In g_block Lambda calls AdaIN function and passes out, gamma, beta as arguments inside a list. And AdaIN function receives these 3 tensors encapsulated within a single list as x. And also those tensors are accessed inside AdaIN function by indexing list x(x[0], x[1], x[2]).

    Here's pytorch equivalent:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class AdaIN(nn.Module):
        def forward(self, out, gamma, beta):
            bs, ch = out.size()[:2]
            mean   = out.reshape(bs, ch, -1).mean(dim=2).reshape(bs, ch, 1, 1)
            std    = out.reshape(bs, ch, -1).std(dim=2).reshape(bs, ch, 1, 1) + 1e-7
            y      = (out - mean) / std
            bias   = beta.unsqueeze(-1).unsqueeze(-1).expand_as(out)
            scale  = gamma.unsqueeze(-1).unsqueeze(-1).expand_as(out)
            return y * scale + bias
    
               
    
    class g_block(nn.Module):
        def __init__(self, filters, latent_vector_shape, input_tensor_channels):
            super().__init__()
            self.gamma = nn.Linear(in_features = latent_vector_shape, out_features = filters)
            # Initializes all bias to 1
            self.gamma.bias.data = torch.ones(filters)
            self.beta  = nn.Linear(in_features = latent_vector_shape, out_features = filters)
            # calculate appropriate padding 
            self.conv  = nn.Conv2d(input_tensor_channels, filters, 3, 1, padding=1)# calc padding
            self.adain = AdaIN()
    
        def forward(self, input_tensor, latent_vector):
            gamma = self.gamma(latent_vector)
            beta  = self.beta(latent_vector)
            # check default interpolation mode in keras and replace mode below if different
            out   = F.interpolate(input_tensor, scale_factor=2, mode='nearest') 
            out   = self.conv(out)
            out   = self.adain(out, gamma, beta)
            out   = torch.relu(out)        
            return out
    
    # Sample:
    input_tensor  = torch.randn((1, 3, 10, 10))
    latent_vector = torch.randn((1, 5))
    g   = g_block(3, latent_vector.shape[1], input_tensor.shape[1])
    out = g(input_tensor, latent_vector)
    print(out)
    

    Note: you need to pass latent_vector and input_tensor shapes while creating g_block.