These are data(Batch size 2) and batch index
import mxnet as mx
data=mx.nd.array(range(24)).reshape(2,3,4)
index=mx.nd.array([[0,1],[1,2]])
How to get the selected data? I tried the Pick
and take
functions, but don't know how to do it.
It seems gather_nd
works
mx.nd.gather_nd(data,mx.nd.array([[0,0,1,1],[0,1,1,2]])).reshape(2,2,4)