juliaflux.jl

How to define a custom loss function in Flux.jl?


Looking at the Flux.jl docs, I see there a ton of built in loss functions: https://fluxml.ai/Flux.jl/stable/models/losses/. My question is how can I define and use my own loss function in Flux if I want something more esoteric?


Solution

  • You can use any differentiable function which returns a single float value as your loss, as stated in the comment above, the prepared functions are just for your convenience. You can pass anything e.g.

    using Flux
    yourcustomloss(ŷ, y) = sum(.- sum(y .* logsoftmax(ŷ), dims = 1))
    

    and calculate the gradient of it or pass it to train! function.