I only want to allow websocket connections when they come from a specific website.
I have a CORS layer in my axum server setup, but it only seems to apply to HTTPS, not WSS. All websocket connections get through, regardless of "origin". This seems to be correct behavior for websockets.
So how do I do the origin filtering for websockets?
I figured that I can check the origin
header in my handler, but then every handler needs to have this check (I admit that there is only one at the moment). Is it possible to put this filter/block on the router?
Original CORS filtering:
let app = Router::new()
.route("/client/{session_id}", get(client_handler))
.layer(
CorsLayer::new()
.allow_origin("https://www.example.com".parse::<HeaderValue>().
.allow_methods([Method::GET, Method::POST])
.allow_headers([CONTENT_TYPE]),
);
Handler workaround, using axum-extra
with the typed-header
feature:
async fn client_handler(
ws: WebSocketUpgrade,
extract::Path(session_id): extract::Path<String>,
origin: TypedHeader<headers::Origin>,
) -> Result<impl IntoResponse, StatusCode> {
let origin_host = origin.hostname();
// check origin_host value
You're correct that web browsers do not check CORS when initiating WebSocket's, and that a workaround involves checking the origin on the server. You can add a middleware, like CorsLayer
but customized, to avoid duplicating the check for each WebSocket route.
use axum::{
Router,
body::Body,
extract::{Request, WebSocketUpgrade},
http::{self, HeaderValue, Method, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::get,
};
use tower_http::cors::CorsLayer;
const ORIGIN: &'static str = "https://www.example.com";
async fn ws_origin_filter(request: Request, next: Next) -> Response {
if request
.headers()
.get(http::header::UPGRADE)
.map(|u| u.as_bytes() == b"websocket")
.unwrap_or(false)
&& request
.headers()
.get(http::header::ORIGIN)
.map(|o| o.as_bytes() != ORIGIN.as_bytes())
.unwrap_or(true)
{
return Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::new("forbidden origin".to_owned()))
.unwrap();
}
next.run(request).await
}
#[tokio::main]
async fn main() {
let router = Router::new()
.route("/", get(|| async { "Hello world! " }))
.route("/ws", get(client_handler))
.layer(middleware::from_fn(ws_origin_filter))
.layer(
CorsLayer::new()
.allow_origin(HeaderValue::from_static(ORIGIN))
.allow_headers([http::header::CONTENT_TYPE])
.allow_methods([Method::GET, Method::POST]),
);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();
}
async fn client_handler(ws: WebSocketUpgrade) -> impl IntoResponse {
ws.on_upgrade(|_ws| async move {
// TODO
})
}
You should keep the CorsLayer
for HTTP requests, for which browsers don't always send the Origin
header. Just like CORS, the server-side Origin
check can only deter untrusted code running in legitimate web browsers that send accurate Origin headers.