My goal is to evaluate a basic symbolic equation such as ad(b + c)
with my own custom implementaions of multiply and addition.
I'm trying to use lambdify
to translate the two core SymPy functions (Add
and Mul
) with my own functions, but I cant get them recognised.
At this stage I'm just trying to get Add
working. The code I have is below.
from sympy import *
import numpy as np
x, y = symbols('x y')
A = [1,1]
B = [2,2]
def addVectors(inA, inB):
print("running addVectors")
return np.add(inA, inB)
# Test vector addition
print(addVectors(A,B))
# Now using lambdify
f = lambdify([x, y], x + y, {"add":addVectors})
print(f(A, B)) # <------- expect [3,3] and addVectors to be run a second time
# but I get the same as this
print(A + B)
which yields
running addVectors
[3 3]
[1, 1, 2, 2]
[1, 1, 2, 2]
I was expecting the +
operator in the expression to be evaluated using the custom addVectors
function. Which would mean the results looks like this.
running addVectors
[3 3]
running addVectors
[3 3]
[1, 1, 2, 2]
I tried several different configurations of the lambdify
line and these all give the same original result.
f = lambdify([x, y], x + y, {"add":addVectors})
f = lambdify([x, y], x + y, {"Add":addVectors})
f = lambdify([x, y], x + y, {"+":addVectors})
f = lambdify([x, y], Add(x,y), {"Add":addVectors})
f = lambdify([x, y], x + y)
To confirm I have the syntax correct I used an example closer to the documentation and replaced the symbolic cos
function with a sin implementation.
from sympy import *
import numpy as np
x = symbols('x')
def mysin(x):
print('taking the sin of', x)
return np.sin(x)
print(mysin(1))
f = lambdify(x, cos(x), {'cos': mysin})
f(1)
which works as expected and yields
taking the sin of 1
0.8414709848078965
taking the sin of 1
0.8414709848078965
Is it even possible to implement my own Add
and Mul
functions using lambdify?
I suspect my trouble is Add
(and Mul
) are not SymPy 'functions'. The documentation refers to them as an 'expression' and that somehow means they dont get recognised for substitution in the lambdify
process.
Some links that I've been reading: SymPy cos SymPy Add SymPy Lambdify
Any pointers would be appreciated. Thanks for reading this far.
EDIT: Got a more general case working
This uses a combination of the lambdify
and replace
functions to replace Add
and Mul
. This example then evaluates an expression in the form ad(b + c)
, which was the goal.
from sympy import *
import numpy as np
w, x, y, z = symbols('w x y z')
A = [3,3]
B = [2,2]
C = [1,1]
D = [4,4]
def addVectors(*args):
result = args[0]
for arg in args[1:]:
result = np.add(result, arg)
return result
def mulVectors(*args):
result = args[0]
for arg in args[1:]:
result = np.multiply(result, arg)
return result
expr = w*z*(x + y)
print(expr)
expr = expr.replace(Add, lambda *args: lerchphi(*args))
expr = expr.replace(Mul, lambda *args: Max(*args))
print(expr)
f = lambdify([w, x, y, z], expr, {"lerchphi":addVectors, "Max":mulVectors})
print(f(A, B, C, D))
print(mulVectors(A,D,addVectors(B,C)))
which yields
w*z*(x + y)
Max(w, z, lerchphi(x, y))
[36 36]
[36 36]
A few things to note with this solution:
replace
function you can replace a type with a function (type -> func). See the docs.*args
as an input. These were Min
, Max
and lerchphi
.Min
and Max
functions since Max(x, Min(x, y)) = x
. That meant I couldn't use Min
and Max
together. So I used lerchphi
and Max
. These functions are arbitary as I'll be translating their implementation to a custom function in the next step. However, this means I can only replace two.lerchphi
and Max
to the custom functions.With sympy, addition is an operation. Hence, I'm not sure if it's possible to achieve your goal by passing in custom modules
...
However, at the heart of lambdify
there is the printing module. Essentially, lambdify
uses some printer to generate a string representation of the expression to be evaluated. If you look at lambdify
's signature, you'll see that it's possible to pass a custom printer.
Given a printer class, the addition with +
is performed by the _print_Add
method. One way to achieve your goal is to modify this method of the NumPyPrinter
.
from sympy.printing.lambdarepr import NumPyPrinter
import inspect
class MyNumPyPrinter(NumPyPrinter):
def _print_Add(self, expr, **kwargs):
str_args = [self.doprint(t) for t in expr.args]
return "add(*[%s])" % ", ".join(str_args)
f = lambdify([x, y], x + y, printer=MyNumPyPrinter)
print(inspect.getsource(f))
# def _lambdifygenerated(x, y):
# return add(*[x, y])
print(f(A, B))
# [3 3]
Note that I've no idea what implication this might creates. That's for you to find out...