tensorflowimage-segmentationargmax

Why is a new axis created, when only the first element is needed?


First of all, sorry for the vague title

As I was interested in learning more about TensorFlow and Image segmentation, I was following their tutorial (https://www.tensorflow.org/tutorials/images/segmentation). However, I noticed something that I could not quite grasp, also not after some Googling around.

In this section:

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

What is the reason for first creating a new axis to the pred_mask vector, only right after that to pick only the first element? Why is not like I expected, as seen below:

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    return pred_mask

Solution

  • Calling tf.argmax on with axis=-1 make the tensor loose the last channel. This is added back as a singleton channel via tf.newaxis.

    Then you return the first element of the batch. In short:

    (batch_size, height, width, channels)  # original tensor shape
    (batch_size, height, width)            # after argmax
    (batch_size, height, width, 1)         # after unsqueeze
    (height, width, 1)                     # this is what you are returning