pythonfastapiword-embedding

How can I build an embedding encoder with FastAPI


I just want to use a pre-trained open source embedding model from SentenceTransformer for encoding plain text.

The goal is to use swagger as GUI - put in a sentence and get out embeddings.

from fastapi import Depends, FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("./assets/BAAI/bge-small-en")

app = FastAPI()

class EmbeddingRequest(BaseModel):
    text: str
    
class EmbeddingResponse(BaseModel):
    embeddings: float

@app.post("/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest, model: embedding_model):
    embeddings_result = model.encode(request.text)
    return EmbeddingResponse(embeddings=embeddings_result)

Solution

  • Note that I don't have access to your "./assets/BAAI/bge-small-en model, so I used all-mpnet-base-v2 instead.

    That said, there are two issues with your implementation:

    1. You're trying to use the model as an input parameter which is not necessary. Just use the global embedding_model directly.
    2. The return type for your embeddings is wrong. (Unless your model really outputs a single float). The output of embedding_model.encodeis np.ndarray, which you can convert to a list using embeddings_result.tolist().

    The following works for me:

    from typing import List
    
    from fastapi import FastAPI
    from pydantic import BaseModel
    from sentence_transformers import SentenceTransformer
    
    embedding_model = SentenceTransformer("all-mpnet-base-v2")
    
    app = FastAPI()
    
    
    class EmbeddingRequest(BaseModel):
        text: str
    
    
    class EmbeddingResponse(BaseModel):
        embeddings: List[float]
    
    
    @app.post("/embeddings", response_model=EmbeddingResponse)
    async def get_embeddings(request: EmbeddingRequest):
        embeddings_result = embedding_model.encode(request.text)
        return EmbeddingResponse(embeddings=embeddings_result.tolist())