pythonsympy

RMSNorm derivative using sympy -- problem with summation over fixed number of elements


I have following sympy equation for RMSNorm (easier to see in Jupyter notebook)

import sympy as sp

# Define the symbols
x = sp.Symbol('x')  # Input variable
n = sp.Symbol('n')  # Number of elements
gamma = sp.Symbol('gamma')
epsilon = sp.Symbol('epsilon')  # Small constant to avoid division by zero

# Define the RMS normalization equation
mean_square = sp.Sum(x**2, (x, 1, n)) / n
rms = sp.sqrt(mean_square + epsilon)
fwd_out = x * gamma / rms

# Display the equation
sp.pprint(fwd_out)

I have issue with the rms term when I take the derivative of fwd_out wrt x as follows:

d_activation = sp.diff(fwd_out, x)

Sympy does not consider rms as a function of x -- it considers it as a constant, as it evaluates rms over n, following displays 0:

sp.diff(rms, x)

But as per the RMSNorm paper, rms should considered as a function of x.

Is there a way where sympy can be forced to consider rms as a function of x?

I am using Python 3.12.9 and Sympy 1.12.1.


Complete answer based on @smichr 's answer:

from sympy import *
from sympy.abc import n, gamma, epsilon

x = IndexedBase("x")
i = symbols('i', cls = Idx)
mean_squared = Sum(x[i] ** 2, (i, 1, n)) / n
rms = sqrt(mean_squared + epsilon)
fwd_out = x * gamma / r

# diff wrt x[i]
d_fwd_out = diff(fwd_out, x[i])

d_rms = diff(rms, x[i])

Ref:

RMSNorm Paper: https://arxiv.org/pdf/1910.07467 Pytorch API: https://pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html


Solution

  • In the paper it is Sum(a[i], (i, 1, n)). If you create an indexed variable you can differentiate with respect to it:

    from sympy import *
    from sympy.abc import n, gamma, epsilon
    a = IndexedBase('a')
    # Define the RMS normalization equation
    mean_square = Sum(a[i]**2, (i, 1, n)) / n
    rms = sqrt(mean_square + epsilon)
    fwd_out = x * gamma / rms
    
    >>> print(str(rms.diff(a[i])))
    Sum(2*a[i], (i, 1, n))/(2*n*sqrt(epsilon + Sum(a[i]**2, (i, 1, n))/n))
    

    cf here