pythonmachine-learningfastapipredictionpydantic

Can we use Pydantic models (BaseModel) directly inside model.predict() using FastAPI, and if not ,why?


I'm using a Pydantic model (Basemodel) with FastAPI and converting the input into a dictionary, and then converting it into a Pandas DataFrame, in order to pass it into model.predict() function for Machine Learning predictions, as shown below:

from fastapi import FastAPI
import uvicorn
from pydantic import BaseModel
import pandas as pd
from typing import List

class Inputs(BaseModel):
    f1: float,
    f2: float,
    f3: str

@app.post('/predict')
def predict(features: List[Inputs]):
    output = []

    # loop the list of input features
    for data in features:
         result = {}

         # Convert data into dict() and then into a DataFrame
            data = data.dict()
            df = pd.DataFrame([data])

         # get predictions
            prediction = classifier.predict(df)[0]

         # get probability
            probability = classifier.predict_proba(df).max()

         # assign to dictionary 
            result["prediction"] = prediction
            result["probability"] = probability

         # append dictionary to list (many outputs)
            output.append(result)

    return output

It works fine, I'm just not quite sure if it's optimized or the right way to do it, since I convert the input two times to get the predictions. Also, I'm not sure if it is going to work fast in the case of having a huge number of inputs. Any improvements on this? If there's a way (even other than using Pydantic models), where I can work directly and avoid going through conversions and the loop.


Solution

  • First, you should use more descriptive names for your variables/objects. For example:

    @app.post('/predict')
    def predict(inputs: List[Inputs]):
        for i in inputs:
        # ...
    

    You cannot pass the Pydantic model directly to the predict() function, as it accepts a data array, not a Pydantic model. Available options are listed below.

    Option 1

    You could use the following (The i below represents an item from the inputs list):

    # Getting prediction
    prediction = model.predict([[i.f1, i.f2, i.f3]])[0]
    
    # Getting probability
    probability = model.predict_proba([[i.f1, i.f2, i.f3]])
    

    Option 2

    You could use the __dict__ method to get the values of all attributes in the model and convert them into a list:

    # Getting prediction
    prediction = model.predict([list(i.__dict__.values())])[0]
    
    # Getting probability
    probability = model.predict_proba([list(i.__dict__.values())])
    

    or, preferably, use the Pydantic's dict() method (Note: In Pydantic V2 dict() has been replaced by model_dump()):

    # Getting prediction
    prediction = model.predict([list(i.dict().values())])[0]
    
    # Getting probability
    probability = model.predict_proba([list(i.dict().values())])
    

    Option 3

    Use a Pandas DataFrame as follows (again, in Pydantic V2 dict() has been replaced by model_dump()):

    import pandas as pd
    
    # Converting input data into a Pandas DataFrame
    df = pd.DataFrame([i.dict()])
    
    # Getting prediction
    prediction = model.predict(df)[0]
    
    # Getting probability
    probability = model.predict_proba(df)
    

    Option 4

    You could avoid looping over individual items and calling the predict() function multiple times, by using, instead, the below (once again, in Pydantic V2, replace dict() with model_dump()):

    import pandas as pd
    
    df = pd.DataFrame([i.dict() for i in inputs])
    prediction = model.predict(df)
    probability = model.predict_proba(df)
    return {'prediction': prediction.tolist(), 'probability': probability.tolist()}
    

    or (in case you wouldn't like using Pandas DataFrame):

    inputs_list = [list(i.dict().values()) for i in inputs]
    prediction = model.predict(inputs_list)
    probability = model.predict_proba(inputs_list)
    return {'prediction': prediction.tolist(), 'probability': probability.tolist()}