I need to implement a U-Net in Julia using Flux.jl. The ultimate goal is to train a neural network for a scientific problem. As a first step, I decided to experiment with the KITTI benchmark dataset.
Since I have more experience with Python, I initially implemented the U-Net in PyTorch, which worked perfectly. Then, I tried to translate my code into Julia while learning Flux.jl. Unfortunately, my implementation hasn't worked as expected.
To simplify the problem, I scaled back and attempted to implement a minimal U-Net model in Julia with synthetic data. However, I keep running into the following error during training:
ERROR: MethodError: no method matching (::var"#17#19"{var"#loss#18"{…}, Array{…}, Array{…}})(::@NamedTuple{layers::Tuple{…}})
The function #17 exists, but no method is defined for this combination of argument types.
The numbers in #17 change depending on the code.
Created a minimal U-Net model with a few convolutional and transposed convolutional layers.
Used synthetic 64x64 input images and binary masks as the dataset.
Attempted to train the model with a basic training loop using Flux.jl's gradient function and the logitcrossentropy loss.
Experimenting with different U-Net implementations (e.g., this repo).
I suspect the issue might be with:
Here is the code that reproduces the issue: I´m using Julia Version 1.11.1 and Flux v0.16.0
using Flux
using Flux: Conv, ConvTranspose, relu, MaxPool, Dense, Chain, params
using Base.Iterators: partition
using Random
using Plots
# Summary:
# This code defines a U-Net architecture for image segmentation using the Flux library in Julia.
# It creates synthetic data, prepares batches, trains the U-Net model, and tests the trained model.
# The problem is to ensure the U-Net model is correctly implemented and trained on the synthetic dataset.
# Define the U-Net architecture
function unet(input_channels::Int, output_channels::Int)
encoder = Chain(
Conv((3, 3), input_channels => 64, pad=1), relu, MaxPool((2, 2), stride=(2, 2)),
Conv((3, 3), 64 => 128, pad=1), relu, MaxPool((2, 2), stride=(2, 2)),
Conv((3, 3), 128 => 256, pad=1), relu, MaxPool((2, 2), stride=(2, 2)),
Conv((3, 3), 256 => 512, pad=1), relu, MaxPool((2, 2), stride=(2, 2))
)
decoder = Chain(
ConvTranspose((3, 3), 512 => 256, stride=2, pad=1), relu,
ConvTranspose((3, 3), 256 => 128, stride=2, pad=1), relu,
ConvTranspose((3, 3), 128 => 64, stride=2, pad=1), relu,
ConvTranspose((3, 3), 64 => output_channels, stride=2, pad=1)
)
return Chain(encoder, decoder, x -> x[:, 1:64, 1:64, :])
end
# Create a synthetic dataset for image segmentation
function create_test_data(num_samples::Int)
data = []
for _ in 1:num_samples
image = rand(Float32, 64, 64, 1)
mask = rand(Bool, 64, 64, 1)
push!(data, (image, mask))
end
return data
end
# Split the synthetic dataset into batches
function prepare_batches(data, batch_size::Int)
batches = []
for batch in partition(data, batch_size)
input_batch = cat([x[1] for x in batch]..., dims=4)
mask_batch = cat([x[2] for x in batch]..., dims=4)
push!(batches, (input_batch, mask_batch))
end
return batches
end
# Implement a training loop for the U-Net model
function train_unet(model, train_data, num_epochs::Int, learning_rate::Float64)
opt = ADAM(learning_rate)
loss(x, y) = Flux.logitcrossentropy(model(x), float(y))
for epoch in 1:num_epochs
for (input_batch, mask_batch) in train_data
gs = gradient(() -> loss(input_batch, mask_batch), Flux.trainable(model))
Flux.Optimise.update!(opt, Flux.trainable(model), gs)
end
println("Epoch $epoch complete")
end
end
# Test a trained U-Net model
function test_unet(model, test_image)
prediction = model(test_image)
plot(plot(test_image[:, :, 1, 1], title="Input Image"),
plot(prediction[:, :, 1, 1], title="Predicted Mask"),
layout=(1, 2))
end
# Example usage
model = unet(1, 1)
data = create_test_data(100)
batches = prepare_batches(data, 8)
train_unet(model, batches, 10, 0.001)
test_image, _ = data[1]
test_unet(model, test_image)
Why does the above code result in the said error, and how can I fix it?
I've tried ensuring that the shapes of the model's output and target match, but I suspect the issue lies in either the loss function or the gradient call.
The first problem is that this model does not accept the data:
julia> batches[1][1] |> summary # batch of 8 images, 1 channel
"64×64×1×8 Array{Float32, 4}"
julia> model(batches[1][1])
ERROR: BoundsError: attempt to access 49×49×1×8 Array{Float32, 4} at index [1:49, 1:64, 1:64, 1:8]
I presume your x -> x[:, 1:64, 1:64, :]
intends to trim the image size, but acts on 1 image axis and the channel dim. x -> x[1:64, 1:64, :, :]
.
However, 49×49×1×8
is still too small for this. Perhaps the stride is wrong? Here's a version which runs:
julia> function unet(input_channels::Int, output_channels::Int)
encoder = Chain(
Conv((3, 3), input_channels => 64, pad=1, relu), MaxPool((2, 2), stride=(2, 2)),
Conv((3, 3), 64 => 128, relu, pad=1), MaxPool((2, 2), stride=(2, 2)),
Conv((3, 3), 128 => 256, relu, pad=1), MaxPool((2, 2), stride=(2, 2)),
Conv((3, 3), 256 => 512, relu, pad=1), MaxPool((2, 2), stride=(2, 2))
)
decoder = Chain(
ConvTranspose((3, 3), 512 => 256, relu; stride=2, pad=1), # relu inside!
ConvTranspose((3, 3), 256 => 128, relu; stride=2, pad=1),
ConvTranspose((3, 3), 128 => 64, relu, stride=2, pad=1),
ConvTranspose((3, 3), 64 => output_channels, stride=4, pad=1) # changed stride?
)
return Chain(encoder, decoder, x -> x[1:64, 1:64, :, :]) # select on image axes
end;
julia> model = unet(1, 1);
julia> model(batches[1][1]) |> summary
"64×64×1×8 Array{Float32, 4}"
Here's how training should look. Writing ADAM
means it is likely you are following some very old guide, and you should not follow the whole weird implicit gradient(() -> ..., params(model))
path. The recommended way looks like this:
julia> function train_unet(model, train_data, num_epochs::Int, learning_rate::Float64)
# Set up the optimiser for this model:
opt_state = Flux.setup(Adam(learning_rate), model)
# The loss is always an explicit function of the model:
loss(m, x, y) = Flux.logitcrossentropy(m(x), float(y))
for epoch in 1:num_epochs
for (input_batch, mask_batch) in train_data
# Gradient with respect to the model itself:
grads = Flux.gradient(m -> loss(m, input_batch, mask_batch), model)
Flux.update!(opt_state, model, grads[1])
end
@show epoch
end
end;
julia> train_unet(model, batches, 10, 0.001)
epoch = 1
epoch = 2
epoch = 3
epoch = 4
epoch = 5
epoch = 6
epoch = 7
epoch = 8
epoch = 9
epoch = 10
Now you can try it out... although this is all just noise of course:
julia> test_image, _ = data[1];
julia> test_image |> summary # lacks a batch dimension
"64×64×1 Array{Float32, 3}"
julia> test_unet(model, test_image)
ERROR: DimensionMismatch: layer Conv((3, 3), 1 => 64, relu, pad=1) expects ndims(input) == 4, but got 64×64×1 Array{Float32, 3}
julia> test_unet(model, reshape(test_image,64,64,1,1)) # `plot` is drawing lines
julia> function test_unet(model, test_image)
prediction = model(test_image)
plot(heatmap(test_image[:, :, 1, 1], title="Input Image"),
heatmap(prediction[:, :, 1, 1], title="Predicted Mask"),
layout=(1, 2))
end
# Example usage
test_unet (generic function with 1 method)
julia> test_unet(model, reshape(test_image,64,64,1,1))