I'm using FastAPI with WebSockets to "push" SVGs to the client. The problem is: If iterations run continuously, they block the async event loop and the socket therefore can't listen to other messages.
Running the loop as a background task is not suitable because each iteration is CPU heavy and the data must be returned to the client.
Is there a different approach, or will I need to trigger each step from the client? I thought multiprocessing could work but not sure how this would work with asynchronous code like await websocket.send_text()
.
@app.websocket("/ws")
async def read_websocket(websocket: WebSocket) -> None:
await websocket.accept()
while True:
data = await websocket.receive_text()
async def run_continuous_iterations():
#needed to run the steps until the user sends "stop"
while True:
svg_string = get_step_data()
await websocket.send_text(svg_string)
if data == "status":
await run_continuous_iterations()
#this code can't run if the event loop is blocked by run_continuous_iterations
if data == "stop":
is_running = False
print("Stopping process")
"...each iteration is CPU heavy and the data must be returned to the client".
As described in this answer, a "coroutine suspends its execution only when it explicitly requests to be suspended", for example, if there is an await
call to an asynchronous operation/function; normally, to non-blocking I/O-bound
tasks such as the ones described here (Note: FastAPI/Starlette runs I/O-bound
methods such as reading File
contents in an external threadpool, using the async
run_in_threadpool()
function, and await
s them; hence, calling such File
operations from your async def
endpoint, e.g., await file.read()
won't block the event loop—have a look at the linked answer above for more details). This, however, does not apply to blocking I/O-bound
or CPU-bound
operations, such as the ones mentioned here. Running such operations inside an async def
endpoint will block the event loop; and hence, any further client requests will get blocked until the blocking operation is completed.
Additionally, from the code snippet your provided, it seems that you would like to be sending data back to the client, while at the same time listening for new messages (in order to check if the client sent a "stop" msg, in order to stop the process). Thus, await
ing for an operation to be completed is not the way to go, but rather executing that task in a separate thread or process (if this is a CPU-bound task)—as demonstrated in this answer, but without await
ing it—should be a more suitable way (Note: processes have their own memory, and hence, sharing websocket connections among multiple processes would not be natively feasible—have a look here and here for available options on that). Solutions on using a separate thread are given below.
asyncio
's loop.run_in_executor()
with the default ThreadPoolExecutor
Passing None
as the executor argument in loop.run_in_executor()
, the default executor will be used, which is a ThreadPoolExecutor
. Please have a look at this answer for more details and the difference between using the default executor and a custom ThreadPoolExecutor
(as shown later on in this answer).
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosed
import asyncio
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
is_running = True
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
async def run_continuous_iterations():
while is_running:
svg_string = get_step_data() # synchronous/blocking function
await websocket.send_text(svg_string)
if data == "status":
is_running = True
loop = asyncio.get_running_loop()
loop.run_in_executor(None, lambda: asyncio.run(run_continuous_iterations()))
if data == "stop":
is_running = False
print("Stopping process")
except (WebSocketDisconnect, ConnectionClosed):
is_running = False
print("Client disconnected")
asyncio
's loop.run_in_executor()
with a custom ThreadPoolExecutor
import concurrent.futures
#... rest of the code is the same as above
@app.on_event("startup")
def startup_event():
# instantiate the ThreadPool
app.state.pool = concurrent.futures.ThreadPoolExecutor()
@app.on_event("shutdown")
def shutdown_event():
# terminate the ThreadPool
app.state.pool.shutdown()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
#... rest of the code is the same as above
try:
while True:
#... rest of the code is the same as above
if data == "status":
is_running = True
loop = asyncio.get_running_loop()
loop.run_in_executor(app.state.pool, lambda: asyncio.run(run_continuous_iterations()))
#... rest of the code is the same as above
except (WebSocketDisconnect, ConnectionClosed):
#... rest of the code is the same as above
Note that startup
and shutdown
event handlers have been deprecated and might be completely removed in the future. Thus, one should rather use the recently introduced lifespan
event handler to instantiate the ThreadPoolExecutor
at application startup and save it to request.state
, as demonstrated in this answer and the linked answer above. The request.state
, in the case of websockets
, can be accessed using webscoket.state
, as shown below. You can adjust the number of worker threads as required (see the linked answer above for more details on ThreadPoolExecutor
and the maximum/optimal number of threads).
from contextlib import asynccontextmanager
import concurrent.futures
@asynccontextmanager
async def lifespan(app: FastAPI):
pool = concurrent.futures.ThreadPoolExecutor(max_workers=20)
yield {'pool': pool}
pool.shutdown()
app = FastAPI(lifespan=lifespan)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
#... rest of the code is the same as above
try:
while True:
#... rest of the code is the same as above
if data == "status":
is_running = True
loop = asyncio.get_running_loop()
loop.run_in_executor(websocket.state.pool, lambda: asyncio.run(run_continuous_iterations()))
#... rest of the code is the same as above
except (WebSocketDisconnect, ConnectionClosed):
#... rest of the code is the same as above
threading
's Thread
import threading
#... rest of the code is the same as above
if data == "status":
is_running = True
thread = threading.Thread(target=lambda: asyncio.run(run_continuous_iterations()))
thread.start()
#... rest of the code is the same as above