machine-learningjuliasemantic-segmentationunet-neural-networkflux.jl

MethodError in Training Minimal U-Net


I need to implement a U-Net in Julia using Flux.jl. The ultimate goal is to train a neural network for a scientific problem. As a first step, I decided to experiment with the KITTI benchmark dataset.

Since I have more experience with Python, I initially implemented the U-Net in PyTorch, which worked perfectly. Then, I tried to translate my code into Julia while learning Flux.jl. Unfortunately, my implementation hasn't worked as expected.

To simplify the problem, I scaled back and attempted to implement a minimal U-Net model in Julia with synthetic data. However, I keep running into the following error during training:

ERROR: MethodError: no method matching (::var"#17#19"{var"#loss#18"{…}, Array{…}, Array{…}})(::@NamedTuple{layers::Tuple{…}})

The function #17 exists, but no method is defined for this combination of argument types.

The numbers in #17 change depending on the code.

What I Have Done:

I suspect the issue might be with:

Example

Here is the code that reproduces the issue: I´m using Julia Version 1.11.1 and Flux v0.16.0

using Flux 
using Flux: Conv, ConvTranspose, relu, MaxPool, Dense, Chain, params
using Base.Iterators: partition
using Random
using Plots 

# Summary:
# This code defines a U-Net architecture for image segmentation using the Flux library in Julia.
# It creates synthetic data, prepares batches, trains the U-Net model, and tests the trained model.
# The problem is to ensure the U-Net model is correctly implemented and trained on the synthetic dataset.

# Define the U-Net architecture
function unet(input_channels::Int, output_channels::Int)
    encoder = Chain(
        Conv((3, 3), input_channels => 64, pad=1), relu, MaxPool((2, 2), stride=(2, 2)),
        Conv((3, 3), 64 => 128, pad=1), relu, MaxPool((2, 2), stride=(2, 2)),
        Conv((3, 3), 128 => 256, pad=1), relu, MaxPool((2, 2), stride=(2, 2)),
        Conv((3, 3), 256 => 512, pad=1), relu, MaxPool((2, 2), stride=(2, 2))
    )
    
    decoder = Chain(
        ConvTranspose((3, 3), 512 => 256, stride=2, pad=1), relu,
        ConvTranspose((3, 3), 256 => 128, stride=2, pad=1), relu,
        ConvTranspose((3, 3), 128 => 64, stride=2, pad=1), relu,
        ConvTranspose((3, 3), 64 => output_channels, stride=2, pad=1)
    )
    
    return Chain(encoder, decoder, x -> x[:, 1:64, 1:64, :])
end

# Create a synthetic dataset for image segmentation
function create_test_data(num_samples::Int)
    data = []
    for _ in 1:num_samples
        image = rand(Float32, 64, 64, 1)
        mask = rand(Bool, 64, 64, 1)
        push!(data, (image, mask))
    end
    return data
end

# Split the synthetic dataset into batches
function prepare_batches(data, batch_size::Int)
    batches = []
    for batch in partition(data, batch_size)
        input_batch = cat([x[1] for x in batch]..., dims=4)
        mask_batch = cat([x[2] for x in batch]..., dims=4)
        push!(batches, (input_batch, mask_batch))
    end
    return batches
end

# Implement a training loop for the U-Net model
function train_unet(model, train_data, num_epochs::Int, learning_rate::Float64)
    opt = ADAM(learning_rate)
    loss(x, y) = Flux.logitcrossentropy(model(x), float(y))
    
    for epoch in 1:num_epochs
        for (input_batch, mask_batch) in train_data
            gs = gradient(() -> loss(input_batch, mask_batch), Flux.trainable(model))
            Flux.Optimise.update!(opt, Flux.trainable(model), gs)
        end
        println("Epoch $epoch complete")
    end
end

# Test a trained U-Net model
function test_unet(model, test_image)
    prediction = model(test_image)
    plot(plot(test_image[:, :, 1, 1], title="Input Image"),
         plot(prediction[:, :, 1, 1], title="Predicted Mask"),
         layout=(1, 2))
end

# Example usage
model = unet(1, 1)
data = create_test_data(100)
batches = prepare_batches(data, 8)
train_unet(model, batches, 10, 0.001)
test_image, _ = data[1]
test_unet(model, test_image)

Why does the above code result in the said error, and how can I fix it?

I've tried ensuring that the shapes of the model's output and target match, but I suspect the issue lies in either the loss function or the gradient call.


Solution

  • The first problem is that this model does not accept the data:

    julia> batches[1][1] |> summary  # batch of 8 images, 1 channel
    "64×64×1×8 Array{Float32, 4}"
    
    julia> model(batches[1][1])
    ERROR: BoundsError: attempt to access 49×49×1×8 Array{Float32, 4} at index [1:49, 1:64, 1:64, 1:8]
    

    I presume your x -> x[:, 1:64, 1:64, :] intends to trim the image size, but acts on 1 image axis and the channel dim. x -> x[1:64, 1:64, :, :].

    However, 49×49×1×8 is still too small for this. Perhaps the stride is wrong? Here's a version which runs:

    julia> function unet(input_channels::Int, output_channels::Int)
               encoder = Chain(
                   Conv((3, 3), input_channels => 64, pad=1, relu), MaxPool((2, 2), stride=(2, 2)),
                   Conv((3, 3), 64 => 128, relu, pad=1), MaxPool((2, 2), stride=(2, 2)),
                   Conv((3, 3), 128 => 256, relu, pad=1), MaxPool((2, 2), stride=(2, 2)),
                   Conv((3, 3), 256 => 512, relu, pad=1), MaxPool((2, 2), stride=(2, 2))
               )
               decoder = Chain(
                   ConvTranspose((3, 3), 512 => 256, relu; stride=2, pad=1),  # relu inside!
                   ConvTranspose((3, 3), 256 => 128, relu; stride=2, pad=1),
                   ConvTranspose((3, 3), 128 => 64, relu, stride=2, pad=1),
                   ConvTranspose((3, 3), 64 => output_channels, stride=4, pad=1)  # changed stride?
               )
               return Chain(encoder, decoder, x -> x[1:64, 1:64, :, :])  # select on image axes
           end;
    
    julia> model = unet(1, 1);
    
    julia> model(batches[1][1]) |> summary
    "64×64×1×8 Array{Float32, 4}"
    

    Here's how training should look. Writing ADAM means it is likely you are following some very old guide, and you should not follow the whole weird implicit gradient(() -> ..., params(model)) path. The recommended way looks like this:

    julia> function train_unet(model, train_data, num_epochs::Int, learning_rate::Float64)
               # Set up the optimiser for this model:
               opt_state = Flux.setup(Adam(learning_rate), model)
               
               # The loss is always an explicit function of the model:
               loss(m, x, y) = Flux.logitcrossentropy(m(x), float(y))
               
               for epoch in 1:num_epochs
                   for (input_batch, mask_batch) in train_data
                       # Gradient with respect to the model itself:
                       grads = Flux.gradient(m -> loss(m, input_batch, mask_batch), model)
                       Flux.update!(opt_state, model, grads[1])
                   end
                   @show epoch
               end
           end;
    
    julia> train_unet(model, batches, 10, 0.001)
    epoch = 1
    epoch = 2
    epoch = 3
    epoch = 4
    epoch = 5
    epoch = 6
    epoch = 7
    epoch = 8
    epoch = 9
    epoch = 10
    

    Now you can try it out... although this is all just noise of course:

    julia> test_image, _ = data[1];
    
    julia> test_image |> summary  # lacks a batch dimension
    "64×64×1 Array{Float32, 3}"
    
    julia> test_unet(model, test_image)
    ERROR: DimensionMismatch: layer Conv((3, 3), 1 => 64, relu, pad=1) expects ndims(input) == 4, but got 64×64×1 Array{Float32, 3}
    
    julia> test_unet(model, reshape(test_image,64,64,1,1))  # `plot` is drawing lines
    
    julia> function test_unet(model, test_image)
               prediction = model(test_image)
               plot(heatmap(test_image[:, :, 1, 1], title="Input Image"),
                    heatmap(prediction[:, :, 1, 1], title="Predicted Mask"),
                    layout=(1, 2))
           end
    
           # Example usage
    test_unet (generic function with 1 method)
    
    julia> test_unet(model, reshape(test_image,64,64,1,1))