rustrust-axumjwt-simple

How to pass parameters to axum handlers?


I am trying to build a simple server that provides and decodes JWT tokens and I am missing a big part. Here is the code:

pub struct Server {
    pub host: String,
    pub port: String,
    pub public_key: String,
    pub private_key: String
}

impl Server {
    pub async fn start(&self) {
        let routes = Router::new()
            .route("/", get(check))
            .route("/auth", post(auth));

        let hostport = format!("{}:{}", self.host, self.port);
        let addr: SocketAddr = hostport.parse().expect("invalid host:port pair");
        axum::Server::bind(&addr)
            .serve(routes.into_make_service())
            .await
            .unwrap();
    }
}

async fn auth(Json(payload): Json<LoginInput>) -> impl IntoResponse {
    let token = RS384PublicKey::from_pem("id_like_to_put_Server::public_key_here")
        .sign(Claims::create(Duration::from_hours(1)))?;

    let lo = LoginOutput { token };
    (StatusCode::OK, Json(lo))
}

As you can see Server holds routing logic and applies configuration. Among configuration there is a public key I'd like to use in order to sign the JWT token (I am using jwt_simple to achieve that). Since public key is a Server's attribute, I want to pass that value to the auth handler but I can't figure out how to do that. How can I pass a parameter to an Axum handler and sign the token is generated inside?


Solution

  • Although you can use either Extension and State for this, State is preferred (axum docs):

    You should prefer using State if possible since it’s more type safe. The downside is that its less dynamic than request extensions.

    In Cargo.toml:

    [dependencies]
    axum = "0.6.0-rc.2"
    serde = { version = "1.0.147", features = ["derive"] }
    tokio = { version = "1.21.2", features = ["macros", "rt-multi-thread"] }
    

    Using State:

    You would store your Server in the Router using with_state which you can then retrieve in your handler using the State extractor:

    use std::net::SocketAddr;
    
    use axum::{
        extract::State,
        response::IntoResponse,
        routing::{get, post},
        Json, Router,
    };
    use serde::Deserialize;
    
    #[derive(Clone)]
    pub struct ServerConfig {
        pub host: String,
        pub port: String,
        pub public_key: String,
        pub private_key: String,
    }
    
    #[derive(Deserialize)]
    pub struct LoginInput {
        username: String,
        password: String,
    }
    
    #[tokio::main]
    async fn main() {
        let server_config = ServerConfig {
            host: "0.0.0.0".into(),
            port: "8080".into(),
            public_key: "public_key".into(),
            private_key: "private_key".into(),
        };
    
        let addr: SocketAddr = format!("{}:{}", server_config.host, server_config.port)
            .parse()
            .unwrap();
    
        let routes = Router::with_state(server_config) // state will be available to all the routes
            .route("/", get(check))
            .route("/auth", post(auth));
    
        axum::Server::bind(&addr)
            .serve(routes.into_make_service())
            .await
            .unwrap();
    }
    
    async fn check() -> &'static str {
        "check"
    }
    
    async fn auth(
        State(server_config): State<ServerConfig>, // extract state in this handler
        // `Json` supports any type that implements `serde::Deserialize`
        Json(payload): Json<LoginInput>,
    ) -> impl IntoResponse {
        // use server_config and payload to run the `auth` logic
        println!("host: {}", server_config.host);
    
        "jwt"
    }
    
    
    

    Using Extension:

    You would insert an Extension holding your Server as a .layer() when building your Router. Then you would get it in your handler via the Extension extrator.

    use std::net::SocketAddr;
    
    use axum::{
        response::IntoResponse,
        routing::{get, post},
        Extension, Json, Router,
    };
    use serde::Deserialize;
    
    #[derive(Clone)]
    pub struct ServerConfig {
        pub host: String,
        pub port: String,
        pub public_key: String,
        pub private_key: String,
    }
    
    #[derive(Deserialize)]
    pub struct LoginInput {
        username: String,
        password: String,
    }
    
    #[tokio::main]
    async fn main() {
        let server_config = ServerConfig {
            host: "0.0.0.0".into(),
            port: "8080".into(),
            public_key: "public_key".into(),
            private_key: "private_key".into(),
        };
    
        let addr: SocketAddr = format!("{}:{}", server_config.host, server_config.port)
            .parse()
            .unwrap();
    
        let routes = Router::new()
            .route("/", get(check))
            .route("/auth", post(auth))
            .layer(Extension(server_config));
    
        axum::Server::bind(&addr)
            .serve(routes.into_make_service())
            .await
            .unwrap();
    }
    
    async fn check() -> &'static str {
        "check"
    }
    
    async fn auth(
        Extension(server_config): Extension<ServerConfig>,
        // `Json` supports any type that implements `serde::Deserialize`
        Json(payload): Json<LoginInput>,
    ) -> impl IntoResponse {
        // use server_config and payload to run the `auth` logic
        println!("host: {}", server_config.host);
    
        "jwt"
    }