pythontensorflowshapesmismatchlogits

Shape mismatch, 2D Input & 2D Labels


I want to create a neural network, that -easy speaking- creates an image out of an image (greyscale) I have successfully created a dataset of 3200 examples of input and output (label) images. (I know the dataset should be larger but that is not the problem right now)

The input [Xin] has the size (3200, 50, 30), since it is 50*30 pixels The output [yout] has the size of (3200, 30, 20) since it is 30*20 pixels

I want to try out a fully connected network (later on a CNN) The built of the fully connected model looks like that:

# 5 Create Model
model = tf.keras.models.Sequential()                                
model.add(tf.keras.layers.Flatten())                                
model.add(tf.keras.layers.Dense(256, activation=tf.nn.relu))        
model.add(tf.keras.layers.Dense(30*20, activation=tf.nn.relu))    


#compile the model
model.compile(optimizer='adam',                                    
              loss='sparse_categorical_crossentropy',               
              metrics=['accuracy'])                                 

# 6 Train the model
model.fit(Xin, yout, epochs=1)                                      #train the model

After that I get the following error:

ValueError: Shape mismatch: The shape of labels (received (19200,)) should equal the shape of logits except for the last dimension (received (32, 600)).

I already tried to flatten yout:

youtflat = yout.transpose(1,0,2).reshape(-1,yout.shape[1]*yout.shape[2])

but this resulted in the same error


Solution

  • It appears you're flattening your labels (yout) completely, i.e., you're losing batch dimension. If your original yout has a shape of (3200, 30, 20) you should reshape it to have a shape of (3200, 30*20) which equals (3200, 600):

    yout = numpy.reshape((3200, 600))
    

    Then it should work

    NOTE The suggested fix however only removes the error. I see many problems with your method though. For the task you're trying to perform (getting an image as output), you cannot use sparse_categorical_crossentropy as loss and accuracy as metrics. You should use 'mse' or 'mae' instead.