pythontensorflowdeep-learningconv-neural-networktensorflow2.0

Using tensorflow, how to operate on top_k and then create a new tensor with original tensor and modified top_k


I've been trying to do something pretty simple, but with no success. I have a tensor (say X of shape (None, 128) containing some scores, in other words each batch has 128 scores. Now I apply Y = tf.math.top_k(X, k=a).indices here a indicates the top a scores. Let us consider for simplicity, a = 95. Then the shape of tensor Y will be (None, 95). Till here it is fine.

Now my original data tensor is of shape (None, 3969, 128). I wanted to do some operation on the datas having top_k scores. So I extracted the datas using:

ti = tf.reshape(Y, [Y.shape[-1], -1])   # Here ti is of shape (95, None)
fs = tf.gather(X, ti[:, 0], axis=-1)    # Here fs is of shape (None, 3969, 95)

and then did my operation by say Z = fs * 0.7 # Here Z is of shape (None, 3969, 95). This was also fine.

Now I want to create a new tensor F such that, firstly F is of shape (None, 3969, 128), containing all the unchanged datas (datas whose scores do not fall in top_k) and modified datas (datas whose scores falls under top_k and have been modified in Z) but, the order of these datas will be same as in original datas i.e., modified datas should still be in their original position. Here is where I am stuck.

I am relatively new with TensorFlow, so apologies if I'm missing anything simple or being unclear. Have been stuck with it for a few days now.

Thanks!


Solution

  • One way to do so is to use tf.tensor_scatter_nd_update. You'll have to convert the indices from the topk function to something that works with your data though. For that, you can use a combination of tf.tile and tf.unravel_index to convert from the X shape to your data shape.

    If we assume 3D data like you have, you could use something similar to this:

    # getting the dimensions of the data tensor
    # assuming that the shape of X is (B,N) and the shape of data is (B,D,N)
    B, D, N = tf.unstack(tf.shape(data))
    topk = tf.math.top_k(X, k=k)
    # to get the absolute indices in the original tensor, we need to:
    # - tile to get indices from (B,N) to (B,D,N)
    # - do index arithmetics to get indices on the flattened tensor
    topk_idx_tiled = tf.tile(topk.indices[:,None,:], [1,D,1])
    flattened_indices = tf.reshape(tf.reshape(tf.range(B*D)*N,(B,D))[...,None] + topk_idx_tiled, -1)
    # unraveling to get indices with batch dimensions so that we have compatibility with scatter_nd
    sc_idx = tf.transpose(tf.unravel_index(flattened_indices, tf.shape(data)))
    # scattering the updates to update the original data 
    updates = tf.reshape(tf.tile(topk.values[:,None,:],[1,D,1]),-1)*0.7
    F = tf.tensor_scatter_nd_update(data, sc_idx, updates)