pythonlambdampmath

Plotting lambda functions in Python and mpmath.plot


I'm using the mpmath plot function (which simply uses pyplot, as far as I understood).

Consider the following code:

from math import cos, sin
import mpmath as mp

mp.plot([sin, cos], [0, 3]) # this is fine

l = [sin, cos]
mp.plot([lambda x: f(2*x) for f in l], [0, 3])
# this only plots sin(2x)!

Is there anything I'm missing here, or it's a bug in the plot function?


Solution

  • See the relevant documentation here. Here's a quick fix that does what you want.

    l = [sin, cos]
    mp.plot([lambda x, f=f: f(2*x) for f in l], [0, 3])
    

    So, what's going on here? The key is that each lambda x: f(2*x) is equivalent to something of the form

    def func(x):
        return f(2*x)
    

    Importantly, the f within each lambda function is NOT replaced by the corresponding function from l in the list comprehension, there is literally an f in the function definition. As such, whenever the lambda function is called, Python looks for something called f. f is not defined within the scope of the function, so it uses the value of f from the next level up, which is the scope of the list comprehension (interestingly, f is not a variable in the main scope).

    Because the list has already been constructed, f within the scope of the list comprehension refers to the last item of the list l, namely cos. For that reason, both functions within the list [lambda x: f(2*x) for f in l] yield cos(2x).

    The quick fix I provide puts the desired f into the scope of the function through an optional argument. If we write out the first lambda x, f=f: f(2*x) as a long form function definition, we have the following:

    def func(x, f=sin):
       return f(2*x)
    

    The key here is that unlike the f of the f(2*x), the f on the right hand side of the = in f=f DOES get replaced by the corresponding function from l in the list comprehension. So, there is no need for Python to go outside of the local scope of the function.


    Here's an alternative approach, using the "function factory" method suggested here.

    def make_func(f):
        def func(x):
            return f(2*x)
        return func
    
    l = [sin, cos]
    mp.plot([make_func(f) for f in l], [0,3])
    

    Or, sticking to lambdas,

    def make_func(f):
        return lambda x: f(2*x)
    
    l = [sin, cos]
    mp.plot([make_func(f) for f in l], [0,3])
    

    The key here is that the value of f (i.e. sin or cos) is kept within the scope of the make_func function call.

    A potential advantage of this method is that the resulting lambda functions do not have f as an optional parameter. In a sense, this amounts to "decorating lambdas", as you suggest in your comment.