optimizationjuliaiterationode

What can I do about if the output of optimization function is not a scalar?


I am using this Optimisation script:

using DifferentialEquations, Optimization, OptimizationPolyalgorithms,  OptimizationOptimJL, Plots, SciMLSensitivity, Zygote

parameters = []

function lotka_volterra!(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = α*x - β*x*y
    du[2] = -δ*y + γ*x*y
end

u0 = [1.0, 1.0]
tspan = (0.0, 40.0)
tsteps = 0.0:0.1:40.0
p = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem(lotka_volterra!, u0, tspan, p)

function loss(p)
    sol = solve(prob, Tsit5(), p=p, saveat = tsteps)
    return sum(abs2, sol .- 1), sol
end

function callback(p, l, pred)
    push!(parameters, p.u)
    display(l)
    plt = plot(pred, ylim = (0, 6))
    display(plt)
    return l <= 0.0135
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_ode = Optimization.solve(optprob, PolyOpt(), callback = callback, maxiters = 600)

And it works, until it gets to a bit under l = 0.013. If you run it, you can see it working until this point. I put an 'if' statement into the callback function so it would stop when it reaches 0.0135, but I still am getting this error message:

ERROR: Output should be scalar; gradients are not defined for output

How can I fix this?

I tried removing "sol" from the loss function but this gave me unintelligible errors (sol is needed to plot and is used in the last few lines) .

I tried the 'if' statement, which should tell the iterations to stop after a certain point, but even when they do stop (I changed the stopping value around and it stopped at whatever value it was) the error still comes up afterwards. How can I fix this?


Solution

  • Your loss function, loss(p) should return a scalar.

    The error you are getting has nothing to do with the loss value going below 0.0135. The error occurs because your loss and callback functions have the incorrect signatures. The use of PolyOpt() also has an interesting role to play in masking the real error.

    Fix loss and callback

    If you take a look at the documentation for OptimizationFunction, you will see that the constructor takes an argument f, which is the function you want to minimize. Taken from the docs:

    f(u,p): the function to optimize [...] This should return a scalar, the loss value, as the return output.

    Therefore you should change your function loss(p) to return a scalar loss value, not a Tuple. E.g.

    function loss(p)
        sol = solve(prob, Tsit5(), p=p, saveat = tsteps)
        return sum(abs2, sol .- 1)
    end
    

    Now, if you look at the documentation for the Optimization solver options, you'll read the following about callbacks:

    The callback function callback is a function which is called after every optimizer step. Its signature is:

    callback = (state, loss_val) -> false

    So, you should follow this signature in your own callback function. Of course, this means you cannot pass in the ODE solution, sol, that was computed in your loss function. But never mind, you can access the current iterations parameters by using state.u and just solve the ODE again. E.g.

    function callback(state, l)
        display(l)
        push!(parameters, state.u)
        sol = solve(prob, Tsit5(), p=state.u, saveat = tsteps)
        display(plot(sol, ylim=(0, 6)))
        return false
    end
    

    If you use these new functions in your optimization you'll be just fine.

    Note

    If you are worried about having to solve the ODEs twice as often, don't be. The system here is small enough you would never notice. However, if you had a really big system, you could do something similar to what you did with parameters, as in cache each ODE solve with a global variable that you reference in callback.

    Why does it work at all in the first place?

    Why does your code seem to work for a few iterations and then suddenly error?

    Because of what PolyOpt() is doing. From the docs:

    PolyOpt: Runs Adam followed by BFGS for an equal number of iterations.

    What is happening is that your loss and callback functions work fine for Adam but not for BFGS. So when PolyOpt() switches from Adam to BFGS then the error 'suddenly' appears.

    You can see this for yourself by changing PolyOpt(), first to Optimisers.Adam (this is the version of ADAM used by OptimizationPolyalgorithms), and then trying the same thing with BFGS. E.g.

    using Optimisers
    adtype = Optimization.AutoZygote()
    optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
    optprob = Optimization.OptimizationProblem(optf, p)
    result_ode = Optimization.solve(optprob, Optimisers.Adam(), callback = callback, maxiters = 600)
    

    Works with your original loss and callback, but

    adtype = Optimization.AutoZygote()
    optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
    optprob = Optimization.OptimizationProblem(optf, p)
    result_ode = Optimization.solve(optprob, BFGS(), callback = callback, maxiters = 600)
    

    Throws you an error immediately.

    Now, why Optimisers.Adam somehow has a different interface that BFGS and other optimization algorithms have, I cannot say for sure. I don't know enough about the packages involved and their design. But, I would guess that Optimisers.Adam at some point added this as a new interface to fit nicely with the rest of the Optimization ecosystem, and perhaps kept the old one so as not to break older code that relied on it. In any case, I would stick to the interface in the Optimization.jl docs.