I would just like to use the ForwardDiff.jl
functionality to define a function and plot its gradient (evaluated using ForwardDiff.gradient
). It seems not be working because the output of ForwardDiff.gradient
is this weird Dual
type thing, and it's not easily being converted to the desired type (in my case, a 1-D array of Float32s).
using Plots
using ForwardDiff
my_func(x::Array{Float32,1}) = 1f0. / (1f0 .+ exp(3f0 .* x)) # doesn't matter what this is, just a sigmoid function here
grad_f(x::Array{Float32,1}) = ForwardDiff.gradient(my_func, x)
x_values = collect(Float32,0:0.01:10)
plot(x_values,my_func(x_values)); # this works fine
plot!(x_values,grad_f(x_values)); # this throws an error
And this is the error I get:
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float64,12})
When I inspect the type of grad_f(x_values)
, I get this:
Array{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,12},1},1}
Why doesn't that happen in the example on the documentation for ForwardDiff, for instance? See here: https://github.com/JuliaDiff/ForwardDiff.jl
Thanks in advance.
EDIT: After Kristoffer Carlsson's comments: I tried this but it still doesn't work. I don't understand what is so different about what I tried here versus what he suggested:
function g(x::Float32)
return x / (1f0 + exp(10f0 * (x - 5f0)))
end
function ∂g∂x(x::Float32)
return ForwardDiff.derivative(g, x)
end
x_vals = collect(Float32,0:0.01:10)
plot(x_vals,g.(x_vals))
plot!(x_vals,∂g∂x.(x_vals))
With the error now being:
no method matching g(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,1})
And this error occurs just when I call ∂g∂x(x)
, whether or not I'm using the broadcasted version ∂g∂x.(x)
. I guess it's something to do with the function definition, but I don't see how the way I defined it differs than Kristoffer's version, other than that it's not defined in a single line...This is so confusing.
This should work because according to ForwardDiff
's documentation, you just need the types of the inputs to be a sub-type of Real
- and Float32
is a sub-type of Real.
EDIT: I realise that now, having read the comments from others, that you need to restrict your functions to be generic enough to accept any inputs of the abstract type Real
, which I didn't quite glean from the documentation. Apologies for the confusion.
You are defining functions on arrays instead of scalars and also restrict the input types too much. Also, for scalar functions you should use ForwardDiff.derivative
. Try something like:
using Plots
using ForwardDiff
my_func(x::Real) = 1f0 / (1f0 + exp(3f0 * x))
my_func_derivative(x::Real) = ForwardDiff.derivative(my_func, x)
plot(my_func, xlimits = (0, 10))
plot!(my_func_derivative)
giving: