juliaflux.jl

What is this syntax with a function from an empty tuple in Julia Flux gradient?


I can't find a reference documentation for the gradient function of Julia Flux, there are only several tutorial examples.

I understand how gradient is used to compute gradients of functions, e.g. the syntax

f(x, y) = x^2 + y^2
df(x, y) = gradient(f, x, y)

will essentially yield df(x, y) = (2x, 2y). However later the tutorial uses the following syntax without any explanation:

gs = gradient(() -> loss(x, y), Flux.params(W, b))

I think there are ways to interpret a gradient of () -> loss(x, y) from the math point of view, but I am not sure that this is what is going on here. So what is this anonymous function and why was gradient designed that way? A link to the full documentation of gradient would be appreciated.


Solution

  • First, the link to gradient docs: https://fluxml.ai/Zygote.jl/dev/#Zygote.gradient

    gradient knows to return the derivative with respect to variables referenced within a function but not passed as arguments. This is the case in the example in the question. It is called implicit style in the manual. The parameters to differentiate are passed in a Params typed value. In the example the Params is created by helper function Flux.params. When using implicit style, the function is passed as a zero-argument function (see manual).

    Regarding the syntax itself (irrespective of use with gradient):

    () -> loss(x, y) is an anonymous function (sometimes called a lambda function in functional programming context). Essentially like a regular function, but for one-time use with no need to fret over choosing a name.

    The Julia manual link regarding these is https://docs.julialang.org/en/v1/manual/functions/#man-anonymous-functions

    Here are some more examples:

    With 1 parameter: map(x -> x^2 + 2x - 1, [1, 3, -1])

    With 2 parameters: (x,y) -> 2*x + y

    With no parameters: () -> time()