one option I tried is pickling vocab and saving with extrafiles arg
import torch
import pickle
class Vocab(object):
pass
vocab = Vocab()
pickle.dump(open('path/to/vocab.pkl','w'))
m = torch.jit.ScriptModule()
## I am not sure about the usage of this arg, the docs didn't help me
extra_files = torch._C.ExtraFilesMap()
extra_files['vocab.pkl'] = 'path/to/vocab.pkl'
# I also tried pickle.dumps(vocab), and directly vocab
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
## Load with extra files.
files = {'vocab.pkl': ''}
torch.jit.load('scriptmodule.pt', _extra_files = files)
this gives
TypeError: import_ir_module(): incompatible function arguments. The following argument types are supported:
1. (arg0: Callable[[List[str]], torch._C.ScriptModule], arg1: str, arg2: object, arg3: torch._C.ExtraFilesMap) -> None
other option is obviously to load the pickle separately, but I was looking for single file option.
it would be nice if one could just add vocab to to the torchscript ... it would also be nice to know if there is some reason for not doing this that I am obviously not aware of.
I believe that the documentation for torch.jit.load
is incorrect. You need to create an ExtraFilesmap() object to load the saved files.
The following is an example of how I got things to work: Step 1: Save model
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
traced_script_module.save(serialized_model_path, _extra_files=extra_files)
Step 2: Load model
files = torch._C.ExtraFilesMap()
files['foo.txt'] = ''
loaded_model = torch.jit.load(serialized_model_path, _extra_files=files)
print(files)