pythonpytorchtorchtext

TypeError: 'Vocab' object is not callable


I'm following the tutorial on torchtext transformers which is published on 1.9 pytorch. However, because I'm working on a Tegra TX2, I am stuck to using torchtext 0.6.0, and not 0.10.0 (which is what I assume the tutorial uses).

Following the tutorial, the following throws an error:

data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

The error is:

TypeError: 'Vocab' object is not callable

I understand what the error means, what I don't know, is that is the expected return from Vocab in this case?

Looking at the documentation for TorchText 0.6.0 I see that it has:

Is the example expecting the vectors from Vocab?

EDIT:

I looked up the 0.10.0 documentation and it doesn't have a __call__.


Solution

  • Looking at the source for the implementation of Vocab in 0.10.0, apparently it is a subclass of torch.nn.Module, which means it inherits __call__ from there (calling it is roughly equivalent to calling its forward() method, but with some additional machinery for implementing hooks and such).

    We can also see that it wraps some underling VocabPyBind object (equivalent to the Vocab class in older versions), and its forward() method just calls its lookup_indices method.

    So in short, it seems the equivalent in older versions of the library would be to call vocab.lookup_indices(tokenizer(item)).

    Update: Apparently in 0.6.0 the Vocab class does not even have a lookup_indices method, but reading the source for that, this is just equivalent to:

    [vocab[token] for token in tokenizer]
    

    If you're ever able to upgrade, for the sake of forward-compatibility you could write a wrapper like:

    from torchtext.vocab import Vocab as _Vocab
    
    class Vocab(_Vocab):
        def lookup_indices(self, tokens):
            return [vocab[token] for token in tokens]
    
        def __call__(self, tokens):
            return self.lookup_indices(tokens)