First of all, sorry for the vague title
As I was interested in learning more about TensorFlow and Image segmentation, I was following their tutorial (https://www.tensorflow.org/tutorials/images/segmentation). However, I noticed something that I could not quite grasp, also not after some Googling around.
In this section:
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
What is the reason for first creating a new axis to the pred_mask vector, only right after that to pick only the first element? Why is not like I expected, as seen below:
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
return pred_mask
Calling tf.argmax
on with axis=-1
make the tensor loose the last channel. This is added back as a singleton channel via tf.newaxis
.
Then you return the first element of the batch. In short:
(batch_size, height, width, channels) # original tensor shape
(batch_size, height, width) # after argmax
(batch_size, height, width, 1) # after unsqueeze
(height, width, 1) # this is what you are returning