I am trying to build a simple server that provides and decodes JWT tokens and I am missing a big part. Here is the code:
pub struct Server {
pub host: String,
pub port: String,
pub public_key: String,
pub private_key: String
}
impl Server {
pub async fn start(&self) {
let routes = Router::new()
.route("/", get(check))
.route("/auth", post(auth));
let hostport = format!("{}:{}", self.host, self.port);
let addr: SocketAddr = hostport.parse().expect("invalid host:port pair");
axum::Server::bind(&addr)
.serve(routes.into_make_service())
.await
.unwrap();
}
}
async fn auth(Json(payload): Json<LoginInput>) -> impl IntoResponse {
let token = RS384PublicKey::from_pem("id_like_to_put_Server::public_key_here")
.sign(Claims::create(Duration::from_hours(1)))?;
let lo = LoginOutput { token };
(StatusCode::OK, Json(lo))
}
As you can see Server
holds routing logic and applies configuration. Among configuration there is a public key I'd like to use in order to sign the JWT token (I am using jwt_simple
to achieve that). Since public key is a Server
's attribute, I want to pass that value to the auth
handler but I can't figure out how to do that. How can I pass a parameter to an Axum
handler and sign the token is generated inside?
Although you can use either Extension
and State
for this, State
is preferred (axum docs):
You should prefer using
State
if possible since it’s more type safe. The downside is that its less dynamic than request extensions.
In Cargo.toml:
[dependencies]
axum = "0.6.0-rc.2"
serde = { version = "1.0.147", features = ["derive"] }
tokio = { version = "1.21.2", features = ["macros", "rt-multi-thread"] }
State
:You would store your Server
in the Router
using with_state
which you can then retrieve in your handler using the State
extractor:
use std::net::SocketAddr;
use axum::{
extract::State,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use serde::Deserialize;
#[derive(Clone)]
pub struct ServerConfig {
pub host: String,
pub port: String,
pub public_key: String,
pub private_key: String,
}
#[derive(Deserialize)]
pub struct LoginInput {
username: String,
password: String,
}
#[tokio::main]
async fn main() {
let server_config = ServerConfig {
host: "0.0.0.0".into(),
port: "8080".into(),
public_key: "public_key".into(),
private_key: "private_key".into(),
};
let addr: SocketAddr = format!("{}:{}", server_config.host, server_config.port)
.parse()
.unwrap();
let routes = Router::with_state(server_config) // state will be available to all the routes
.route("/", get(check))
.route("/auth", post(auth));
axum::Server::bind(&addr)
.serve(routes.into_make_service())
.await
.unwrap();
}
async fn check() -> &'static str {
"check"
}
async fn auth(
State(server_config): State<ServerConfig>, // extract state in this handler
// `Json` supports any type that implements `serde::Deserialize`
Json(payload): Json<LoginInput>,
) -> impl IntoResponse {
// use server_config and payload to run the `auth` logic
println!("host: {}", server_config.host);
"jwt"
}
Extension
:You would insert an Extension
holding your Server
as a .layer()
when building your Router
. Then you would get it in your handler via the Extension
extrator.
use std::net::SocketAddr;
use axum::{
response::IntoResponse,
routing::{get, post},
Extension, Json, Router,
};
use serde::Deserialize;
#[derive(Clone)]
pub struct ServerConfig {
pub host: String,
pub port: String,
pub public_key: String,
pub private_key: String,
}
#[derive(Deserialize)]
pub struct LoginInput {
username: String,
password: String,
}
#[tokio::main]
async fn main() {
let server_config = ServerConfig {
host: "0.0.0.0".into(),
port: "8080".into(),
public_key: "public_key".into(),
private_key: "private_key".into(),
};
let addr: SocketAddr = format!("{}:{}", server_config.host, server_config.port)
.parse()
.unwrap();
let routes = Router::new()
.route("/", get(check))
.route("/auth", post(auth))
.layer(Extension(server_config));
axum::Server::bind(&addr)
.serve(routes.into_make_service())
.await
.unwrap();
}
async fn check() -> &'static str {
"check"
}
async fn auth(
Extension(server_config): Extension<ServerConfig>,
// `Json` supports any type that implements `serde::Deserialize`
Json(payload): Json<LoginInput>,
) -> impl IntoResponse {
// use server_config and payload to run the `auth` logic
println!("host: {}", server_config.host);
"jwt"
}