tensorflowkerastensorflow2.0

using tensorflow gather for 3D tensors


I have a 3D tensor called V, with shape (P,Q,R) and a 2D tensor called W with shape (P,S). I want to create a tensor Z of shape (P,Q,R) where Z[i,j,k] = W[i,V[i,j,k]]. In other words, each row of W corresponds to a row/slice of V. I have tried using gather and gather_nd, but the output is either 2D or 4D. I've tried variations on this:

import tensorflow as tf
V=tf.random.uniform((2,3,4),minval=0,maxval=2,dtype=tf.int32)
W=tf.random.uniform((2,20), minval=0, maxval=4, dtype=tf.int32)
Z=tf.gather(params=W, indices=V,axis=1,batch_dims=0)
print(Z.shape)

I have tensorflow 2.15. Can someone please help?


Solution

  • I realized I should set the number of batch dimensions to 1. This now works:

    import tensorflow as tf
    V=tf.random.uniform((2,3,4),minval=0,maxval=2,dtype=tf.int32)
    print(V[0,:])
    W=tf.random.uniform((2,5), minval=0, maxval=4, dtype=tf.int32)
    print('W=',W)
    Z=tf.gather(params=W, indices=V, axis=1, batch_dims=1)
    print(Z.shape)
    

    The shape is (2,3,4), as desired. Simple in the end, but a pain to figure out!