pythonhuggingface-transformerstext-classificationlarge-language-modelzeroshot-classification

How does Huggingface's zero-shot classification work in production/webapp, do I need to train the model first?


I have already used huggingface's zero-shot classification: I used "facebook/bart-large-mnli" model as reported here (https://huggingface.co/tasks/zero-shot-classification). The accuracy is quite good for my task.

The latter scenario would be preferable. But I am not sure whether loading the model from scratch would produce the same output as loadingthe pickle file with the saved facebook/bart-large-mnli" model.

Thank you in advance.


Solution

  • Q: How does zero-shot classification work? Do I need train/tune the model to use in production?

    Options:

    A (human): (ii) You can load up the model with pipeline("zero-shot-classification", model="facebook/bart-large-mnli") once when the server start, then reuse the pipeline without re-initializing it for each request.

    When you use the model off-the-shelf, it'll be zero-shot but if you fine-tune a model with limited training data, people commonly refer to that as "few-shot"; take a look at https://github.com/huggingface/setfit for few-shot learning.


    The proof is in the pudding, see if the model you pick fits the task you want. Also, there's more than one way to wield the shiny hammer =)

    Disclaimer: Your Miles May Vary...

    Zero shot classification

    TL;DR: I don't want to train anything, I don't have labeled data, do something with some labels that I come up with.

    from transformers import pipeline
    
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    
    text = "Catan (Base Game) | Ages 10+ | for 3 to 4 Players | Average Playtime 60 Minutes | Made by Catan Studio | TRADE, BUILD AND SETTLE: Embark on a quest to settle the isle of Catan! Guide your settlers to victory by clever trading and cunning development. But beware! Someone might cut off your road or buy a monopoly. And you never know when the wily robber might steal some of your precious games!"
    
    candidate_labels = ['Beauty & Wellness', 'Electronics', 'Toys & Games']
    
    classifier(text, candidate_labels)
    

    [out]:

    {'sequence': 'Catan (Base Game) | Ages 10+ | for 3 to 4 Players | Average Playtime 60 Minutes | Made by Catan Studio | TRADE, BUILD AND SETTLE: Embark on a quest to settle the isle of Catan! Guide your settlers to victory by clever trading and cunning development. But beware! Someone might cut off your road or buy a monopoly. And you never know when the wily robber might steal some of your precious games!',
     'labels': ['Toys & Games', 'Electronics', 'Beauty & Wellness'],
     'scores': [0.511284351348877, 0.38416239619255066, 0.10455326735973358]}
    

    Don't classify, translate (or seq2seq)

    Inspiration: https://arxiv.org/abs/1812.05774

    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    
    model_name = "google/flan-t5-large"
    
    tokenizer= AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    
    text = "Catan (Base Game) | Ages 10+ | for 3 to 4 Players | Average Playtime 60 Minutes | Made by Catan Studio | TRADE, BUILD AND SETTLE: Embark on a quest to settle the isle of Catan! Guide your settlers to victory by clever trading and cunning development. But beware! Someone might cut off your road or buy a monopoly. And you never know when the wily robber might steal some of your precious games!"
    
    
    prompt=f"""Which category is this product?
    QUERY:{text}
    OPTIONS:
     - Beauty & Wellness
     - Electronics
     - Toys & Games
    """
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    
    tokenizer.decode(model.generate(input_ids)[0], skip_special_tokens=True)
    

    [out]:

    Toys & Games
    

    And for the fun of it =)

    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    
    model_name = "google/flan-t5-large"
    
    tokenizer= AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    
    prompt=f"""How does zero-shot classification work? 
    QUERY: Do I need tune/modify the model to use in production?
    OPTIONS:
     - (i) train the "facebook/bart-large-mnli" model first, secondly save the model in a pickle file, and then predict a new (unseen) sentence using the pickle file
     - (ii) can I simply import the "facebook/bart-large-mnli" library and compute the prediction for the production/webapp code
    """
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    
    print(tokenizer.decode(model.generate(input_ids)[0], skip_special_tokens=True))
    

    [out]:

    (ii)
    

    Q: What if both methods above don't work?

    A: Try more models from https://huggingface.co/models or try different tasks and be creative in how to use what's available to fit your data to solve the problem

    Q: What if none of the models/tasks works?

    A: Then it's time to think about what data you can/need to collect to train the model you need. But before collecting the data, it'll be prudent to first decide how you want to evaluate/measure the success of the model, e.g. F1-score, accuracy, etc.

    This is how I'll personally solve NLP problems that fits the frame "X problem, Y approach" solutions, https://hackernoon.com/what-kind-of-scientist-are-you (shameless plug)

    Q: How do I deploy a model after I found the model+task I want?

    There're several ways but it'll be out-of-scope of this question, since it's asking about how zero-shot works and more pertinently "Can I use zero-shot classification models off-the-shelf without training?".

    To deploy a model, take a look at: