pythonjavanlppytorchdjl

How do I use Pytorch models in Deep Java Library(DJL)?


I would like to run EasyNMT in Java.
However I don't know how to load and run the model.

I loaded the model as follows:

URI uri = new URI("file:////Users/.../prior.pth");
Path modelDir = Paths.get(uri);
Model model = Model.newInstance("model.pth", Device.cpu(), "PyTorch");
model.load(modelDir);

However, I do not know what to do after this.
EasyNMT performs the following:

model.translate("Dies ist ein Satz in Deutsch.", target_lang='en', max_new_tokens=1000)

How does DJL perform translations?


Solution

  • You need create your own Translator to do pre-processing and post-processing. You can find this jupyter notebook that explains how Translator works in DJL.

    For NMT model, you can find this example in DJL: https://github.com/deepjavalibrary/djl/blob/master/examples/docs/neural_machine_translation.md