pythonmultithreadingsessionfastapi

How to assign a FastAPI Uvicorn thread to a session id?


I am integrating an AI model into a web app (this model needs to have some context in order to maintain a fluid conversation with the user) on a local deployment. The problem is the intrinsic structure of the thread.

I know how a pool thread works. And the problem is that, when doing multiple POST request (for chatting with the bot), there is a posibility that another thread that has not been used answers to that request. Then, the context will be saved in the memory of that thread, not in the one that we have been previously using.

Main problem: context is being saved in different threads with different memory each.

Firstly, I want to mention that none solution should not be either implemeting cookies or creating a file for saving the context. The idea is to assign a thread per session_token.

I have tried the following:

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from dotenv import load_dotenv
from openai import AzureOpenAI
import os
from typing import Dict, List
from uuid import UUID, uuid4

load_dotenv()
router = APIRouter()

class Message(BaseModel):
    """Class representing the message of a conversation"""
    role: str  # either 'user' or 'assistant'
    content: str

class Prompt(BaseModel):
    """Class for sending or receiving messages"""
    session_id: UUID
    prompt: str

class NewSessionResponse(BaseModel):
    session_id: UUID
    
class ResetRequest(BaseModel):
    session_id: UUID

# Dictionary to store conversations
conversations: Dict[UUID, List[Message]] = {}

@router.post("/ai/chat")
async def chat(prompt: Prompt):
    """In charge of executing and obtaining the connection with the model"""
    api_key = os.getenv("key")
    api_url = os.getenv("endpoint url")
    
    if not api_key or not api_url:
        raise HTTPException(status_code=500, detail='API key or API endpoint not found! Try again')
    
    client = AzureOpenAI("here goes some parameters")
    
    # Retrieve the conversation history for the session
    session_id = prompt.session_id
    if session_id not in conversations:
        conversations[session_id] = []
    
    # Add the user's prompt to the conversation history
    conversations[session_id].append(Message(role="user", content=prompt.prompt))
    
    # Create the context for the API request
    context = [{'role': msg.role, 'content': msg.content} for msg in conversations[session_id]]
    
    # Request completion from the model
    response = client.chat.completions.create(
        model="gpt-35-turbo-4k-0613",
        messages=context
    )
    
    # Extract the model's response
    model_response = response.choices[0].message.content
    
    # Add the assistant's response to the conversation history
    conversations[session_id].append(Message(role='assistant', content=model_response))
    print(conversations)
    return {"response": model_response}

@router.post("/ai/new_session", response_model=NewSessionResponse)
async def newSession():
    """Creates a new conversation session"""
    session_id = uuid4()
    conversations[session_id] = []
    return NewSessionResponse(session_id=session_id)

I also try to implement threading when managing requests, but it did not work. I suppose that this happens because the Uvicorn threads and the threads created here are not the same.


Solution

  • I implemented a Redis database for saving the messages depending on the sessionID that the endpoint receives.

    Notice that I had to modify the docker-compose.yaml file for connecting into the DB.

    Here you have the code:

    from fastapi import APIRouter, HTTPException
    from pydantic import BaseModel
    from dotenv import load_dotenv
    from openai import AzureOpenAI
    import os
    from typing import List
    from uuid import UUID, uuid4
    import json
    from app.core.redis_config import redis_db
    
    load_dotenv()
    router = APIRouter()
    
    class Message(BaseModel):
        role: str
        content: str
    
    class Prompt(BaseModel):
        session_id: UUID
        prompt: str
        ai_model: str
    
    class NewSessionResponse(BaseModel):
        session_id: UUID
    
    class ResetRequest(BaseModel):
        session_id: UUID
    
    def get_conversation(session_id: UUID) -> List[Message]:
        """In charge of loading the conversation from the ddbb"""
        data = redis_db.get(str(session_id))
        if data:
            return [Message(**msg) for msg in json.loads(data)]
        return []
    
    def save_conversation(session_id: UUID, messages: List[Message]):
        """In charge of inserting the conversation into the ddbb"""
        redis_db.set(str(session_id), json.dumps([msg.dict() for msg in messages])) # TODO: works for the moment with this although is deprecated
    
    @router.post("/ai/chat")
    async def chat(prompt: Prompt):
        api_key = os.getenv("OPENAI_API_KEY", "")
        api_url = os.getenv("OPENAI_API_BASE", "")
        
        ai_model = prompt.ai_model
        
        if not api_key or not api_url:
            raise HTTPException(status_code=500, detail='API key or API endpoint not found! Try again')
        
        client = AzureOpenAI("some parameters")
        
        session_id = prompt.session_id
        conversation = get_conversation(session_id)
        
        conversation.append(Message(role="user", content=prompt.prompt))
        
        context = [{'role': msg.role, 'content': msg.content} for msg in conversation]
        
        response = client.chat.completions.create(
            model=ai_model,
            messages=context
        )
        
        model_response = response.choices[0].message.content
        
        conversation.append(Message(role='assistant', content=model_response))
        save_conversation(session_id, conversation)
        
        return {"response": model_response}
    
    @router.post("/ai/reset")
    async def resetConversation(reset_request: ResetRequest):
        session_id = reset_request.session_id
        if redis_db.exists(str(session_id)):
            redis_db.delete(str(session_id))
        else:
            raise HTTPException(status_code=400, detail='Invalid session ID')
        return {"response": True}
    
    @router.post("/ai/new_session", response_model=NewSessionResponse)
    async def newSession():
        session_id = uuid4()
        save_conversation(session_id, [])
        return NewSessionResponse(session_id=session_id)