I know that with Flux.jl I can do julia> Flux.params(model)
to get the parameters but the output does not tell me how many total parameters actually exist in the model itself. Is there a function to check this or a programatic way to calculate this?
As @mcabbott mentions in the comment, you can pass in the whole model to the params function to get the total count (sum(length, params(model))
) or loop through each layer as follows:
julia> model = Chain(
resnet[1:end-2],
Dense(2048, 1000),
Dense(1000, 256),
Dense(256, 2), # we get 2048 features out, and we have 2 classes
)
Chain(Chain(Conv((7, 7), 3=>64), MaxPool((3, 3), pad=1, stride=2), Metalhead.ResidualBlock((Conv((1, 1), 64=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), Chain(Conv((1, 1), 64=>256), BatchNorm(256))), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), Chain(Conv((1, 1), 256=>512), BatchNorm(512))), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), Chain(Conv((1, 1), 512=>1024), BatchNorm(1024))), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>256), Conv((3, 3), 256=>256), Conv((1, 1), 256=>1024)), (BatchNorm(256), BatchNorm(256), BatchNorm(1024)), identity), Metalhead.ResidualBlock((Conv((1, 1), 1024=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), Chain(Conv((1, 1), 1024=>2048), BatchNorm(2048))), Metalhead.ResidualBlock((Conv((1, 1), 2048=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), identity), Metalhead.ResidualBlock((Conv((1, 1), 2048=>512), Conv((3, 3), 512=>512), Conv((1, 1), 512=>2048)), (BatchNorm(512), BatchNorm(512), BatchNorm(2048)), identity), MeanPool((7, 7)), #103), Dense(2048, 1000), Dense(1000, 256), Dense(256, 2))
julia> paramCount = 0
0
julia> for layer in model
paramCount += sum(length, params(layer))
end
julia> paramCount
25840234
In this example, I am just incrementing the count but you could append the count from each layer into an array for example to keep track of each layer's count individually.