mxnet

How to extract typical rows in MXNet?


These are data(Batch size 2) and batch index

import mxnet as mx
data=mx.nd.array(range(24)).reshape(2,3,4)
index=mx.nd.array([[0,1],[1,2]])

enter image description here

How to get the selected data? I tried the Pick and take functions, but don't know how to do it.


Solution

  • It seems gather_nd works

    mx.nd.gather_nd(data,mx.nd.array([[0,0,1,1],[0,1,1,2]])).reshape(2,2,4)