I am trying to implement a custom activation function that works with complex numbers, in pytorch. I am re-implementing an existing implementation of this activation function available at neuroptica under the "ElectroOpticActivation" class. My implementation of it in pytorch-based torch-onn library is as follows:
import numpy as np
import torch
from torch.types import Device
from torchonn.layers.base_layer import ONNBaseLayer
from torch.nn import Parameter, init
from torch import Tensor
from torch.autograd import Function
__all__ = [
"ElectroOptic"
]
class EOActivation(Function):
'''
Electro-optic activations as described in -
{Williamson, Ian AD, et al. "Reprogrammable electro-optic nonlinear
activation functions for optical neural networks." IEEE Journal of
Selected Topics in Quantum Electronics 26.1 (2019): 1-12.}
'''
@staticmethod
def forward(Z: Tensor,
alpha: Tensor,
g: Tensor,
phi_b: Tensor) -> Tensor:
'''
Forward-pass of EO nactivation function
Z: tensor, Input tensor
alpha: tensor, parameter 'alpha'
g: tensor, parameter 'g'
phi_b: tensor, parameter 'phi_b'
'''
Z = 1j * torch.sqrt(1 - alpha) * torch.exp(
-1j * 0.5 * g * torch.square(torch.abs(Z)) - 1j * 0.5 * phi_b) * torch.cos(
0.5 * g * torch.square(torch.abs(Z)) + 0.5 * phi_b) * Z
return Z
@staticmethod
def setup_context(ctx, inputs, output):
'''
ctx: Context object
inputs: Inputs are the inputs to forward()
output: Output tensor of forward()
'''
# Save parameters and output of forward for backward pass
input, alpha, g, phi_b = inputs
ctx.save_for_backward(input, alpha, g, phi_b)
@staticmethod
def backward(ctx, grad_Z: Tensor) -> Tensor:
'''
ctx: context object
grad_Z: backpropagated gradient signal from (l+1)th layer
'''
# get the parameters and input field to the forward pass
Z, alpha, g, phi_b = ctx.saved_tensors
zR, zI = Z.real, Z.imag
# df_dRe - Gradient w.r.t. real part of the input
df_dRe = torch.sqrt(1 - alpha) * torch.exp((-0.5 * 1j) * g * (zR - 1j * zI) * (
zR + 1j * zI) - (0.5 * 1j) * phi_b) * (zR * g * (zI - 1j * zR) * torch.sin(
0.5 * zR ** 2 * g + 0.5 * zI ** 2 * g + 0.5 * phi_b) + (
zR ** 2 * g + 1j * zR * zI * g + 1j) * torch.cos(
0.5 * zR ** 2 * g + 0.5 * zI ** 2 * g + 0.5 * phi_b))
#df_dIm - Gradient w.r.t. imaginary part of the input
df_dIm = torch.sqrt(1 - alpha) * torch.exp((-0.5 * 1j) * g * (zR - 1j * zI) * (
zR + 1j * zI) - (0.5 * 1j) * phi_b) * (zI * g * (zI - 1j * zR) * torch.sin(
0.5 * zR ** 2 * g + 0.5 * zI ** 2 * g + 0.5 * phi_b) + (
zR * zI * g + 1j * zI ** 2 * g - 1) * torch.cos(
0.5 * zR ** 2 * g + 0.5 * zI ** 2 * g + 0.5 * phi_b))
# Return the gradient and 'None' for parameters in forward()
return (grad_Z * df_dRe).real - 1j * (grad_Z * df_dIm).real, None, None, None
class ElectroOptic(ONNBaseLayer):
'''
'''
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
#miniblock: int = 4,
alpha: float = 0.2,
responsivity: float = 0.8,
area: float = 1.0,
V_pi: float = 10.0,
V_bias: float = 10.0,
R: float = 1e3,
impedance = 120 * np.pi,
g: float = None,
phi_b: float = None,
photodetect: bool = True,
device: Device = torch.device("cpu")
):
super(ElectroOptic, self).__init__(device=device)
self.in_features = in_features
self.out_features = out_features
self.photodetect = photodetect
""" self.miniblock = miniblock
self.grid_dim_x = int(np.ceil(self.in_features / miniblock))
self.grid_dim_y = int(np.ceil(self.out_features / miniblock))
self.in_features_pad = self.grid_dim_x * miniblock
self.out_features_pad = self.grid_dim_y * miniblock """
self.alpha = Parameter(torch.tensor(alpha).to(self.device), requires_grad=False)
if g is not None and phi_b is not None:
self.g = Parameter(torch.tensor(g).to(self.device), requires_grad=False)
self.phi_b = Parameter(torch.tensor(phi_b).to(self.device), requires_grad=False)
else:
# convert "feedforward phase gain" and "phase bias" parameters
self.g = Parameter(torch.tensor(np.pi * alpha * R * responsivity *
area * 1e-12 / 2 / V_pi / impedance).to(self.device), requires_grad=False)
self.phi_b = Parameter(torch.tensor(np.pi * V_bias / V_pi).to(self.device), requires_grad=False)
if bias:
self.bias = Parameter(torch.Tensor(out_features).to(self.device))
init.uniform_(self.bias, 0, 0)
else:
self.register_parameter("bias", None)
def forward(self, Z: Tensor) -> Tensor:
'''
Z: Input tensor from (l-1)th layer
Z_out: Output tensor after forward propagation (activation)
'''
Z_out = EOActivation.apply(Z, self.alpha, self.g, self.phi_b)
if self.photodetect:
Z_out = Z_out.square()
if self.bias is not None:
Z_out = Z_out + self.bias.unsqueeze(0)
return Z_out
I followed pytorch's tutorial to write this implementation. However, this implementation is not working as expected and I am struggling to figure out why.
Is it possible that I am not handling the propagated gradients correctly? I am testing it with a network: linear layer(4x4) -> activation(4x4) -> linear(4x1) -> activation(1x1) -> abs(output). I am testing it with 4 input XOR samples using complex numbers (complex part of inputs is 0, however, but the weights can have non-zero imaginary parts).
Update 1: I replaced the activations with a custom complex Relu that thresholds real and imaginary parts of the tensors individually and, the network trains well. As for the complex relu, the gradients are implicitly computed by pytorch's autograd, I strongly believe that my EOActivation's gradients are not being handled correctly. Have to look into it.
This issue can be tracked here and a solution has been proposed.