juliaflux.jl

Julia Flux withgradient operation


I am a newbie to Julia and Flux with some experience in Tensorflow Keras and python. I tried to use the Flux.withgradient command to write a user-defined training function with more flexibility. Here is the training part of my code:

    loss, grad = Flux.withgradient(modelDQN.evalParameters) do
        qEval = modelDQN.evalModel(evalInput)
        Flux.mse(qEval, qTarget)
    end
    Flux.update!(modelDQN.optimizer, modelDQN.evalParameters, grad)

This code works just fine. But if I put the command qEval = modelDQN.evalModel(evalInput) outside the do end loop, as follows:

    qEval = modelDQN.evalModel(evalInput)
    loss, grad = Flux.withgradient(modelDQN.evalParameters) do
        Flux.mse(qEval, qTarget)
    end
    Flux.update!(modelDQN.optimizer, modelDQN.evalParameters, grad)

The model parameters will not be updated. As far as I know, the do end loop works as an anonymous function that takes 0 arguments. Then why do we need the command qEval = modelDQN.evalModel(evalInput) inside the loop to get the model updated?


Solution

  • The short answer is that anything to be differentiated has to happen inside the (anonymous) function which you pass to gradient (or withgradient), because this is very much not a standard function call -- Zygote (Flux's auto-differentiation library) traces its execution to compute the derivative, and can't transform what it can't see.

    Longer, this is Zygote's "implicit" mode, which relies on global references to arrays. The simplest use is something like this:

    julia> using Zygote
    
    julia> x = [2.0, 3.0];
    
    julia> g = gradient(() -> sum(x .^ 2), Params([x]))
    Grads(...)
    
    julia> g[x]  # lookup by objectid(x)
    2-element Vector{Float64}:
     4.0
     6.0
    

    If you move some of that calculation outside, then you make a new array y with a new objectid. Julia has no memory of where this came from, it is completely unrelated to x. They are ordinary arrays, not a special tracked type.

    So if you refer to y in the gradient, Zygote cannot infer how this depends on x:

    julia> y = x .^ 2  # calculate this outside of gradient
    2-element Vector{Float64}:
     4.0
     9.0
    
    julia> g2 = gradient(() -> sum(y), Params([x]))
    Grads(...)
    
    julia> g2[x] === nothing  # represents zero
    true
    

    Zygote doesn't have to be used in this way. It also has an "explicit" mode which does not rely on global references. This is perhaps less confusing:

    julia> gradient(x1 -> sum(x1 .^ 2), x)  # x1 is a local variable
    ([4.0, 6.0],)
    
    julia> gradient(x1 -> sum(y), x)  # sum(y) is obviously indep. x1
    (nothing,)
    
    julia> gradient((x1, y1) -> sum(y1), x, y)
    (nothing, Fill(1.0, 2))
    

    Flux is in the process of changing to use this second form. On v0.13.9 or later, something like this ought to work:

    opt_state = Flux.setup(modelDQN.optimizer, modelDQN)  # do this once
    
    loss, grads = Flux.withgradient(modelDQN.model) do m
            qEval = m(evalInput)  # local variable m
            Flux.mse(qEval, qTarget)
        end
    
    Flux.update!(opt_state, modelDQN.model, grads[1])