pythonpytorchtorchtext

Saving vocabulary object from pytorch's torchtext library


Building a text classification model using pytorch's torchtext . The vocabulary object is in the data.field :

def create_tabularDataset_object(self,csv_path):
   self.TEXT = data.Field(tokenize=self.tokenizer,batch_first=True,include_lengths=True)
   self.LABEL = data.LabelField(dtype = torch.float,batch_first=True)
def get_vocab_with_glov(self,data):
   # initialize glove embeddings
   self.TEXT.build_vocab(data,min_freq=100,vectors = "glove.6B.100d")

After training , when serving the model in production how do i hold the TEXT object ? at prediction time i need it for indexing the words tokens

[TEXT.vocab.stoi[t] for t in tokenizedׁ_sentence]

am i missing something and it is not necessary to hold that object ? Do i need any other file other then the model weights ?


Solution

  • I've found that i can save it as a pkl: Saving the TEXT.vocab as a pkl worked :

    def save_vocab(vocab, path):
        import pickle
        output = open(path, 'wb')
        pickle.dump(vocab, output)
        output.close()
    

    Where

    vocab = TEXT.vocab 
    

    and reading it as usual.