I am writing a simple websocket script which registers & unregisters clients and then broadcasts random messages to them in a interval of 5s. This is the code:
import asyncio, websockets, random, string
import websockets.asyncio.server
class web_socket(websockets.asyncio.server.ServerConnection):
pass
connections: set[web_socket] = set()
connections_lock = asyncio.Lock()
async def register(websocket: web_socket):
async with connections_lock:
connections.add(websocket)
async def unregister(websocket: web_socket):
async with connections_lock:
connections.discard(websocket)
async def cleanup_connections():
while True:
async with connections_lock:
closed_clients = [client for client in connections if client.close_code == 1000]
for client in closed_clients:
connections.discard(client)
await asyncio.sleep(1)
async def handler(websocket:web_socket):
await register(websocket)
try:
await websocket.wait_closed()
except Exception as e:
print(f"Error: {e}")
async def random_messages():
while True:
message = '-'.join(random.choices(string.ascii_letters + string.digits, k=10))
async with connections_lock:
send_tasks = [client.send(message) for client in connections if client.close_code != 1000]
if send_tasks:
await asyncio.gather(*send_tasks)
print(f"Broadcasted message: {message}")
await asyncio.sleep(5)
async def respond_to_messages():
while True:
async with connections_lock:
if connections:
for client in connections:
message = await client.recv()
if message:
print(message)
else: await asyncio.sleep(1)
async def main():
async with websockets.serve(handler, "0.0.0.0", 8765):
asyncio.create_task(random_messages())
asyncio.create_task(cleanup_connections())
asyncio.create_task(respond_to_messages())
await asyncio.Future()
asyncio.run(main())
As soon as someone connects, the broadcasting stops and the loop gets stuck at the message await. How do I make sure that listening to messages sent by clients do not affect the main loop?
I tried this:
message = await asyncio.wait(client.recv())
Now the client can connect without stopping the main loop but it throws an error:
Task exception was never retrieved
future: <Task finished name='Task-5' coro=<respond_to_messages() done, defined at /home/j/socket-test/server.py:44> exception=TypeError('expect a list of futures, not coroutine')>
Traceback (most recent call last):
File "/home/j/socket-test/server.py", line 49, in respond_to_messages
message = await asyncio.wait(client.recv())
File "/usr/lib/python3.10/asyncio/tasks.py", line 366, in wait
raise TypeError(f"expect a list of futures, not {type(fs).__name__}")
TypeError: expect a list of futures, not coroutine
/usr/lib/python3.10/asyncio/base_events.py:1910: RuntimeWarning: coroutine 'Connection.recv' was never awaited
handle = None # Needed to break cycles when an exception occurs.
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
It throws this error and keeps going. The client gets all the broadcast messages but whatever the client sends maybe isn't received but definitely doesn't get printed. I know that asycnio.wait()
expects a list of futures, so I tried explicitly passing it like this: asyncio.wait([client.recv()])
. Running this gave a deprecation warning: DeprecationWarning: The explicit passing of coroutine objects to asyncio.wait() is deprecated since Python 3.8, and scheduled for removal in Python 3.11. message = await asyncio.wait([client.recv()])
and the client must send a message every time to receive a single broadcast message.
If I do this: asyncio.wait([await client.recv()])
, then I get raise TypeError('An asyncio.Future, a coroutine or an awaitable ' TypeError: An asyncio.Future, a coroutine or an awaitable is required
error and the client must send a message to start receiving the broadcast and then it keeps coming.
I tried GPT but it just keeps telling me to use message = await client.recv()
which was the original problem.
I also tried to run the original code with coroutine threadsafe:
async def main():
async with websockets.serve(handler, "0.0.0.0", 8765):
asyncio.create_task(random_messages())
asyncio.create_task(cleanup_connections())
loop = asyncio.get_running_loop()
asyncio.run_coroutine_threadsafe(respond_to_messages(), loop=loop)
await asyncio.Future()
While this did not produce any errors, the client still needs to send the first message to start receiving broadcast. If this is something that is the way to go, that would be good to know. Another thing to note: Only the first client has to send a message, other ones do not. They start receiving broadcast upon connecting without issues.
Any help is appreciated. If you can point me to any article or something, that's fine too. I am a beginner so please keep that in mind.
In respond_to_messages
, you're basically holding the lock forever. Inside the top-level loop you grab the lock and then start iterating through the connections; on each connection you await client.recv()
, which blocks (holding the lock) until a new message arrives. You need to release the lock before you call .recv()
.
This raises the question of what exactly the lock is protecting. Let's say it's the connection list itself – you want to avoid overlapping connections.add()
calls – but each individual websocket object can probably be used concurrently.
You usually want to spend the absolute minimum amount of time holding the lock. In this case that means you'll want to make a copy of the connection list inside the lock, and then iterate through it, acknowledging that you might not have the absolute-most-current values.
async def respond_to_messages():
while True:
async with connections_lock:
my_connections = list(connections)
for client in my_connections:
message = await client.recv()
if message:
print(message)
The next problem you'll run into is conceptually similar: because you wait for messages serially, and client.recv()
is blocking, you'll wait for the first socket to send a message back before doing anything on the next one. You'd really like all of these receives to happen together.
You can write a function that handles all of the traffic on a single websocket:
async def respond_to_connection(client: ServerConnection) -> None:
"""Handle all of the messages on a client and return when the websocket closes."""
async for message in client:
print(message)
Now you need to create a new concurrent copy of that function every time you accept a new websocket. You're already using asyncio.create_task()
and that will work here; an asyncio.TaskGroup
is potentially a better tool. Note that you need to save a copy of the returned task, and I might keep that together with the websocket.
from asyncio import Lock, Task
from dataclasses import dataclass
@dataclass
class ActiveConnection:
"""A websocket and the task handling it."""
client: web_socket
task: Task
connections: dict[web_socket, ActiveConnection] = {}
connections_lock = Lock()
async def register(websocket: web_socket):
task = create_task(respond_to_connection(websocket))
active_connection = ActiveConnection(websocket, task)
async with connections_lock:
connections[websocket] = active_connection
async def unregister(websocket: web_socket):
async with connections_lock:
active_connection = connections.pop(websocket, None)
if active_connection:
active_connection.task.cancel()
So every time we get a new connection, we start a new task and save the task and connection together; and every time we shut down a connection, we cancel the task (assuming it hasn't completed on its own). Again notice that we're doing as much work as possible outside the async with connections_lock
critical block, and limiting that block to only inserting and deleting from the dictionary.