I've compared the gradient calculations of a BatchNorm2d
layer across different deep learning frameworks, specifically TensorFlow and PyTorch. While doing so, I've encountered a significant discrepancy in the gradients computed by the two frameworks. More specifically, the gradient computed by PyTorch is near zero, while TensorFlow returns something close to one.
Here’s the code I used for comparison:
import numpy as np
# TensorFlow/Keras
import tensorflow as tf
# PyTorch
import torch
import torch.nn as nn
# Set a common random seed for reproducibility
seed = 0
np.random.seed(seed)
tf.random.set_seed(seed)
torch.manual_seed(seed)
# Create a random input tensor with the same shape for all frameworks
input_shape = (4, 3, 5, 5)
x_np = np.random.randn(*input_shape).astype(np.float32)
# TensorFlow/Keras BatchNorm2d
class TFModel(tf.keras.Model):
def __init__(self):
super(TFModel, self).__init__()
self.bn = tf.keras.layers.BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)
def call(self, x):
return self.bn(x)
# Instantiate the model
tf_model = TFModel()
# Convert the numpy array to a tensor and ensure it's being watched
x_tf = tf.convert_to_tensor(x_np)
x_tf = tf.Variable(x_tf)
with tf.GradientTape() as tape:
y_tf = tf_model(x_tf)
y_tf_sum = tf.reduce_sum(y_tf)
# Compute the gradient of the output with respect to the input
grad_tf = tape.gradient(y_tf_sum, x_tf)
# Convert the gradient to a numpy array for comparison
grad_tf = grad_tf.numpy()
# PyTorch BatchNorm2d
class TorchModel(nn.Module):
def __init__(self):
super(TorchModel, self).__init__()
self.bn = nn.BatchNorm2d(3)
def forward(self, x):
return self.bn(x)
torch_model = TorchModel()
x_torch = torch.tensor(x_np, requires_grad=True)
y_torch = torch_model(x_torch)
y_torch.sum().backward()
grad_torch = x_torch.grad.detach().numpy()
# Calculate the difference between TensorFlow and PyTorch gradients
diff_tf_torch = np.mean(np.abs(grad_tf - grad_torch))
# Print the differences
print(f"Difference between TensorFlow and PyTorch gradients: {diff_tf_torch:.6f}")
print("grad pytorch : ", grad_torch[0])
print("grad tf : ", grad_tf[0])
grad_pytorch
: Shows a gradient of all zeros (10^-10).grad_tf
: Shows non-zero gradients.I tried to compare both under same conditions. Furthermore, zero gradient for the input seems weird to me. Could there be an issue with how the layers are initialized, or how the inputs are being processed in PyTorch that might lead to this discrepancy? Any insights or explanations would be greatly appreciated
BatchNorm layers are one of the few layers that behave differently at training and evaluation time. TensorFlow models default to assuming evaluation mode, PyTorch defaults modules in training mode.
The gradients are a side effect of this, but the output of the models are also different in the different modes.
Simplified version of your code specifying the training mode.
import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
# Specify if we are training or evaluating
is_training = False
# Create a random input tensor with the same shape for all frameworks
input_shape = (2, 1, 1, 1)
np.random.seed(0)
x_np = np.random.randn(*input_shape).astype(np.float32)
# TensorFlow
model_tf = tf.keras.layers.BatchNormalization(axis=1, epsilon=1e-05, momentum=0.1)
x_tf = tf.Variable(tf.convert_to_tensor(x_np))
with tf.GradientTape() as tape:
y_tf = model_tf(x_tf, training=is_training) # <--- specify if we are training to tensorflow
y_tf_sum = tf.reduce_sum(y_tf)
grad_tf = tape.gradient(y_tf_sum, x_tf).numpy()
# PyTorch
model_pt = nn.BatchNorm2d(1)
model_pt.train(is_training) # <--- set the pytorch model to train/eval mode
x_pt = torch.tensor(x_np, requires_grad=True)
y_pt = model_pt(x_pt)
y_pt.sum().backward()
grad_pt = x_pt.grad.numpy()
# Calculate the differences
diff_y = np.sum(np.abs(y_tf.numpy() - y_pt.detach().numpy()))
diff_dydx = np.sum(np.abs(grad_pt - grad_tf))
# Print the differences
print("diff forward : ", diff_y)
print("diff grads : ", diff_dydx)
print("y_pt : ", y_pt.detach().numpy().flatten())
print("y_tf : ", y_tf.numpy().flatten())
print("dydx_pt : ", grad_pt.flatten())
print("dydx_tf : ", grad_tf.flatten())