I have following sympy equation for RMSNorm (easier to see in Jupyter notebook)
import sympy as sp
# Define the symbols
x = sp.Symbol('x') # Input variable
n = sp.Symbol('n') # Number of elements
gamma = sp.Symbol('gamma')
epsilon = sp.Symbol('epsilon') # Small constant to avoid division by zero
# Define the RMS normalization equation
mean_square = sp.Sum(x**2, (x, 1, n)) / n
rms = sp.sqrt(mean_square + epsilon)
fwd_out = x * gamma / rms
# Display the equation
sp.pprint(fwd_out)
I have issue with the rms
term when I take the derivative of fwd_out
wrt x
as follows:
d_activation = sp.diff(fwd_out, x)
Sympy does not consider rms
as a function of x
-- it considers it as a constant, as it evaluates rms
over n
, following displays 0:
sp.diff(rms, x)
But as per the RMSNorm paper, rms
should considered as a function of x
.
Is there a way where sympy can be forced to consider rms
as a function of x
?
I am using Python 3.12.9
and Sympy 1.12.1
.
Complete answer based on @smichr 's answer:
from sympy import *
from sympy.abc import n, gamma, epsilon
x = IndexedBase("x")
i = symbols('i', cls = Idx)
mean_squared = Sum(x[i] ** 2, (i, 1, n)) / n
rms = sqrt(mean_squared + epsilon)
fwd_out = x * gamma / r
# diff wrt x[i]
d_fwd_out = diff(fwd_out, x[i])
d_rms = diff(rms, x[i])
Ref:
RMSNorm Paper: https://arxiv.org/pdf/1910.07467 Pytorch API: https://pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html
In the paper it is Sum(a[i], (i, 1, n))
. If you create an indexed variable you can differentiate with respect to it:
from sympy import *
from sympy.abc import n, gamma, epsilon
a = IndexedBase('a')
# Define the RMS normalization equation
mean_square = Sum(a[i]**2, (i, 1, n)) / n
rms = sqrt(mean_square + epsilon)
fwd_out = x * gamma / rms
>>> print(str(rms.diff(a[i])))
Sum(2*a[i], (i, 1, n))/(2*n*sqrt(epsilon + Sum(a[i]**2, (i, 1, n))/n))