pytorchneural-networkcomplex-numbersautograd

Understanding pytorch custom activation Function


I am trying to implement a custom activation function that works with complex numbers, in pytorch. I am re-implementing an existing implementation of this activation function available at neuroptica under the "ElectroOpticActivation" class. My implementation of it in pytorch-based torch-onn library is as follows:

import numpy as np
import torch
from torch.types import Device
from torchonn.layers.base_layer import ONNBaseLayer
from torch.nn import Parameter, init
from torch import Tensor
from torch.autograd import Function

__all__ = [
    "ElectroOptic"
    ]

class EOActivation(Function):
    '''
    Electro-optic activations as described in - 
        {Williamson, Ian AD, et al. "Reprogrammable electro-optic nonlinear 
        activation functions for optical neural networks." IEEE Journal of 
        Selected Topics in Quantum Electronics 26.1 (2019): 1-12.}
    '''
    @staticmethod
    def forward(Z: Tensor,
                alpha: Tensor,
                g: Tensor,
                phi_b: Tensor) -> Tensor:
        '''
        Forward-pass of EO nactivation function
        Z: tensor, Input tensor
        alpha: tensor, parameter 'alpha'
        g: tensor, parameter 'g'
        phi_b: tensor, parameter 'phi_b'
        '''
        Z =  1j * torch.sqrt(1 - alpha) * torch.exp(
            -1j * 0.5 * g * torch.square(torch.abs(Z)) - 1j * 0.5 * phi_b) * torch.cos(
                0.5 * g * torch.square(torch.abs(Z)) + 0.5 * phi_b) * Z
        return Z
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        '''
        ctx: Context object
        inputs: Inputs are the inputs to forward()
        output: Output tensor of forward()
        '''
        # Save parameters and output of forward for backward pass
        input, alpha, g, phi_b = inputs
        ctx.save_for_backward(input, alpha, g, phi_b)
    
    @staticmethod
    def backward(ctx, grad_Z: Tensor) -> Tensor:
        '''
        ctx: context object
        grad_Z: backpropagated gradient signal from (l+1)th layer
        '''
        # get the parameters and input field to the forward pass
        Z, alpha, g, phi_b = ctx.saved_tensors
        zR, zI = Z.real, Z.imag
        # df_dRe - Gradient w.r.t. real part of the input
        df_dRe = torch.sqrt(1 - alpha) * torch.exp((-0.5 * 1j) * g * (zR - 1j * zI) * (
            zR + 1j * zI) - (0.5 * 1j) * phi_b) * (zR * g * (zI - 1j * zR) * torch.sin(
                0.5 * zR ** 2 * g + 0.5 * zI ** 2 * g + 0.5 * phi_b) + (
                    zR ** 2 * g + 1j * zR * zI * g + 1j) * torch.cos(
                        0.5 * zR ** 2 * g + 0.5 * zI ** 2 * g + 0.5 * phi_b))
        #df_dIm - Gradient w.r.t. imaginary part of the input
        df_dIm = torch.sqrt(1 - alpha) * torch.exp((-0.5 * 1j) * g * (zR - 1j * zI) * (
            zR + 1j * zI) - (0.5 * 1j) * phi_b) * (zI * g * (zI - 1j * zR) * torch.sin(
                0.5 * zR ** 2 * g + 0.5 * zI ** 2 * g + 0.5 * phi_b) + (
                    zR * zI * g + 1j * zI ** 2 * g - 1) * torch.cos(
                        0.5 * zR ** 2 * g + 0.5 * zI ** 2 * g + 0.5 * phi_b))
        # Return the gradient and 'None' for parameters in forward()
        return (grad_Z * df_dRe).real - 1j * (grad_Z * df_dIm).real, None, None, None

class ElectroOptic(ONNBaseLayer):
    '''
    
    '''

    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int

    def __init__(
            self,
            in_features: int,
            out_features: int,
            bias: bool = False,
            #miniblock: int = 4,
            alpha: float = 0.2,
            responsivity: float = 0.8,
            area: float = 1.0,
            V_pi: float = 10.0,
            V_bias: float = 10.0,
            R: float = 1e3,
            impedance = 120 * np.pi,
            g: float = None,
            phi_b: float = None,
            photodetect: bool = True,
            device: Device = torch.device("cpu")
            ):
        super(ElectroOptic, self).__init__(device=device)
        self.in_features = in_features
        self.out_features = out_features
        self.photodetect = photodetect
        """ self.miniblock = miniblock
        self.grid_dim_x = int(np.ceil(self.in_features / miniblock))
        self.grid_dim_y = int(np.ceil(self.out_features / miniblock))
        self.in_features_pad = self.grid_dim_x * miniblock
        self.out_features_pad = self.grid_dim_y * miniblock """

        self.alpha = Parameter(torch.tensor(alpha).to(self.device), requires_grad=False)
        if g is not None and phi_b is not None:
            self.g = Parameter(torch.tensor(g).to(self.device), requires_grad=False)
            self.phi_b = Parameter(torch.tensor(phi_b).to(self.device), requires_grad=False)
        else:
            # convert "feedforward phase gain" and "phase bias" parameters
            self.g = Parameter(torch.tensor(np.pi * alpha * R * responsivity * 
                                  area * 1e-12 / 2 / V_pi / impedance).to(self.device), requires_grad=False)
            self.phi_b = Parameter(torch.tensor(np.pi * V_bias / V_pi).to(self.device), requires_grad=False)

        if bias:
            self.bias = Parameter(torch.Tensor(out_features).to(self.device))
            init.uniform_(self.bias, 0, 0)
        else:
            self.register_parameter("bias", None)

    def forward(self, Z: Tensor) -> Tensor:
        '''
        Z: Input tensor from (l-1)th layer
        Z_out: Output tensor after forward propagation (activation)
        '''
        Z_out = EOActivation.apply(Z, self.alpha, self.g, self.phi_b)       
        if self.photodetect:
            Z_out = Z_out.square()
        if self.bias is not None:
            Z_out = Z_out + self.bias.unsqueeze(0)

        return Z_out

I followed pytorch's tutorial to write this implementation. However, this implementation is not working as expected and I am struggling to figure out why.

Is it possible that I am not handling the propagated gradients correctly? I am testing it with a network: linear layer(4x4) -> activation(4x4) -> linear(4x1) -> activation(1x1) -> abs(output). I am testing it with 4 input XOR samples using complex numbers (complex part of inputs is 0, however, but the weights can have non-zero imaginary parts).

Update 1: I replaced the activations with a custom complex Relu that thresholds real and imaginary parts of the tensors individually and, the network trains well. As for the complex relu, the gradients are implicitly computed by pytorch's autograd, I strongly believe that my EOActivation's gradients are not being handled correctly. Have to look into it.

gradcheck error The gradcheck output shows only a sign mismatch.


Solution

  • This issue can be tracked here and a solution has been proposed.