javascriptpythonwebsocketfastapi

FastAPI - Websockets, Multiple connections for one client


Trying to implement websockets in my fastapi application, however, when I connect to the websocket from the javascript side, it opens 4 connections, I have implemented a workaround in the backend side that check if the certain customer is connected, however this would mean that a client couldn't connect to the websocket on mobile while it's open on the computer for example.

class ConnectionManager:
    def __init__(self):
        self.connections: dict[int, WebSocket] = {}

    async def connect(self, websocket, customer_id):
        existing_connection = self.connections.get(customer_id)
        if existing_connection:
            await websocket.close()
            raise ExistingConnectionError("Existing connection for customer_id")

        await websocket.accept()
        websocket.customer_id = customer_id
        self.connections[customer_id] = websocket
        print(list(self.connections.values()))

    async def disconnect(self, websocket):
        customer_id = getattr(websocket, 'customer_id', None)
        if customer_id in self.connections:
            del self.connections[customer_id]

    async def broadcast(self, message):
        for websocket in self.connections.values():
            await websocket.send_json(message)

    async def broadcast_to_customer(self, customer_id, message):
        matching_connection = self.connections.get(customer_id)
        if matching_connection:
            await matching_connection.send_json(message)


connection_manager = ConnectionManager()
@router.websocket("/sock")
async def websocket_endpoint(websocket: WebSocket, customer_id: int):
    try:
        await connection_manager.connect(websocket, customer_id)
        while True:
            data = await websocket.receive_json()
    except WebSocketDisconnect:
        await connection_manager.disconnect(websocket)
    except ExistingConnectionError:
        print("Existing connection detected, rejecting the new connection")

Javascript:

let socket;

    if (!socket || socket.readyState !== WebSocket.OPEN) {
        socket = new WebSocket(`ws://localhost:3002/sock?customer_id=${customer_id}`);

        socket.onopen = () => {
            console.log('Connected to WebSocket server');
        };

        let i = 0;

        socket.onmessage = (event) => {
            const data = JSON.parse(event.data);
            console.log('Incoming data:', data);
            console.log('i:', i);
            i = i + 1;
        };
    }

Backend logs:

INFO:     connection open
INFO:     ('127.0.0.1', 33366) - "WebSocket /sock?customer_id=185" 403
Existing connection detected, rejecting the new connection
INFO:     connection rejected (403 Forbidden)
INFO:     connection closed
INFO:     ('127.0.0.1', 33368) - "WebSocket /sock?customer_id=185" 403
Existing connection detected, rejecting the new connection
INFO:     connection rejected (403 Forbidden)
INFO:     connection closed
INFO:     ('127.0.0.1', 33384) - "WebSocket /sock?customer_id=185" 403
Existing connection detected, rejecting the new connection
INFO:     connection rejected (403 Forbidden)
INFO:     connection closed

Frontend logs:

Connected to WebSocket server
Firefox can’t establish a connection to the server at ws://localhost:3002/sock?customer_id=185.
Firefox can’t establish a connection to the server at ws://localhost:3002/sock?customer_id=185.
Firefox can’t establish a connection to the server at ws://localhost:3002/sock?customer_id=185.

Tried to implement websocket in FastAPI, opens multiple connections instead of one, for each client.


Solution

  • I think you need something like this:

    On the backend store connections for every device of every client:

    class ConnectionManager:
        def __init__(self):
            self.clients: dict[int, dict[str, WebSocket]] = {}
    
        async def connect(self, websocket, customer_id, device_hash):
            client_connections = self.clients.get(customer_id)
            if client_connections:
                client_device_connection = client_connections.get(device_hash)
                if client_device_connection:
                    await websocket.close()
                    raise ExistingConnectionError("Existing connection for customer_id")
            else:
                self.clients[customer_id] = {}
    
            await websocket.accept()
            websocket.customer_id = customer_id
            websocket.device_hash = device_hash
    
            self.clients[customer_id][device_hash] = websocket
            print(list(self.clients.values()))
    
        async def disconnect(self, websocket):
            customer_id = getattr(websocket, 'customer_id', None)
            device_hash = getattr(websocket, 'device_hash', None)
            client_connections = self.clients.get(customer_id)
            if client_connections:
                if client_connections.get(device_hash):
                    client_connections.pop(device_hash)
    
        async def broadcast(self, message):
            for client_connections in self.clients.values():
                for websocket in client_connections.values():
                    await websocket.send_json(message)
    
        async def broadcast_to_customer(self, customer_id, message):
            client_connections = self.client.get(customer_id)
            if client_connections:
                for connection in client_connections.values():
                    await connection.send_json(message)
    
    
    connection_manager = ConnectionManager()
    
    @app.websocket("/sock")
    async def websocket_endpoint(websocket: WebSocket, customer_id: int, device_hash: str):
    
        try:
            await connection_manager.connect(websocket, customer_id, device_hash)
            while True:
                data = await websocket.receive_json()
        except WebSocketDisconnect:
            await connection_manager.disconnect(websocket)
        except ExistingConnectionError:
            print("Existing connection detected, rejecting the new connection")
    

    When you connect to websocket from frontend, add additional query parameter device_hash, which should be unique for every user device. You can also generate this device_hash on the server side from request headers (e.g. you can use user-agent)