javamxnetdjl

How to call a custom mxnet operator in DJL (Deep Java Library)?


How do I call a custom mxnet operator from DJL? E.g. the my_gemm operator from the examples.


Solution

  • It is possible by manually calling the JnaUtils in the same way as the built-in mxnet engine does, just with your custom lib. For the my_gemm example, this looks like this:

    import ai.djl.Device;
    import ai.djl.mxnet.jna.FunctionInfo;
    import ai.djl.mxnet.jna.JnaUtils;
    import ai.djl.ndarray.NDArray;
    import ai.djl.ndarray.NDManager;
    import ai.djl.util.PairList;
    
    import java.util.Map;
    
    // Load the external mxnet operator library
    JnaUtils.loadLib("path/to/incubator-mxnet/example/extensions/lib_custom_op/libgemm_lib.so", 1);
    // get a handle to the loaded operator
    Map<String, FunctionInfo> allFunctionsAfterLoading = JnaUtils.getNdArrayFunctions();
    FunctionInfo myGemmFunction = allFunctionsAfterLoading.get("my_gemm");
    // create a manager to execute the example with
    try (NDManager ndManager = NDManager.newBaseManager().newSubManager(Device.cpu())) {
        // create input for the gemm call
        NDArray a = ndManager.create(new float[][]{new float[]{1, 2, 3}, new float[]{4, 5, 6}});
        NDArray b = ndManager.create(new float[][]{new float[]{7}, new float[]{8}, new float[]{9}});
        // call the function manually (NDManager.invoke will not work, as it caches the mxnet
        // engine operators and ignores external ones)
        PairList<String, Object> params = new PairList<>();
        NDArray result = myGemmFunction.invoke(ndManager, new NDArray[]{a, b}, params)[0];
        // prints
        // ND: (2, 1) cpu() float32
        //[[ 50.],
        // [122.],
        //]
        // (same as the python example)
        System.out.println(result);
    }