tensorflowmachine-learningkerasneural-networktf.keras

How to use a Dense layer with an input that has a dynamically sized dimension?


I have a model with an input (batch of images w/ shape (height, width, time)) that has a dynamically sized dimension (time), which is only determined at runtime. However, the Dense layer requires fully defined spatial dimensions. Code snippet example:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Input

# Define an input with an undefined dimension (None)
input_tensor = Input(shape=(None, 256, 256, None, 13))

# Apply a Dense layer (which expects a fully defined shape)
x = Flatten()(input_tensor)
x = Dense(10)(x)

# Build the model
model = tf.keras.models.Model(inputs=input_tensor, outputs=x)

model.summary()

This raises the error:

ValueError: The last dimension of the inputs to a Dense layer should be defined. Found None.

How can I make it work using Flatten instead of alternatives like GlobalAveragePooling3D? Essentially, I’m looking for a way to create a 1D array with the original pixel values, but compatible with the Dense layer.


Solution

  • This is just not possible because a dense layer has a fixed number of weights. When you call a dense layer after flattening, it is effectively doing

    w_0 * x_0 + w_1 * x_1 + w_2 * x_2 + .... + w_n-1 * x_n-1 + bias where the ws are the weights and the xs are the flattened input feature values.

    So if due to your unknown dimension, n can't be known ahead of time, then it's just not possible for the network to be configured with the appropriate number of weights.

    Even if you knew the "max time" and want to preallocate the number of weights in the network to support that, it would likely suffer from two problems

    1. the network gets too large
    2. the network overfits because it is treating each pixel in each time step as a completely separate feature which is going to bloat up the dimensionality with no real benefits

    So the alternatives to capture the time axis would be to either make it a time-series network like LSTMs or recurrent neural networks, or a 3D convolution network which relies on pooling across time.