How do I call a custom mxnet operator from DJL? E.g. the my_gemm
operator from the examples.
It is possible by manually calling the JnaUtils in the same way as the built-in mxnet engine does, just with your custom lib. For the my_gemm
example, this looks like this:
import ai.djl.Device;
import ai.djl.mxnet.jna.FunctionInfo;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.util.PairList;
import java.util.Map;
// Load the external mxnet operator library
JnaUtils.loadLib("path/to/incubator-mxnet/example/extensions/lib_custom_op/libgemm_lib.so", 1);
// get a handle to the loaded operator
Map<String, FunctionInfo> allFunctionsAfterLoading = JnaUtils.getNdArrayFunctions();
FunctionInfo myGemmFunction = allFunctionsAfterLoading.get("my_gemm");
// create a manager to execute the example with
try (NDManager ndManager = NDManager.newBaseManager().newSubManager(Device.cpu())) {
// create input for the gemm call
NDArray a = ndManager.create(new float[][]{new float[]{1, 2, 3}, new float[]{4, 5, 6}});
NDArray b = ndManager.create(new float[][]{new float[]{7}, new float[]{8}, new float[]{9}});
// call the function manually (NDManager.invoke will not work, as it caches the mxnet
// engine operators and ignores external ones)
PairList<String, Object> params = new PairList<>();
NDArray result = myGemmFunction.invoke(ndManager, new NDArray[]{a, b}, params)[0];
// prints
// ND: (2, 1) cpu() float32
//[[ 50.],
// [122.],
//]
// (same as the python example)
System.out.println(result);
}