rustrust-axum

Error in signature when trying to pass shared state to my authorization middleware in Axum


I'm trying to pass DB pool which is in State(app_state): State<AppState> to authorize fn inside my auth.rs from api.rs where I'm using this middleware as JWT to protect my urls like this.

api.rs

pub fn api_routes() -> Router<AppState> {
    let route = Router::new()
        .route("/api", get(main_page_get).post(main_page_post))
        .route("/api/{query}", get(id_page_get))
        .route("/demo.json", get(get_demo_json).put(put_demo_json)
        .layer(middleware::from_fn(auth::authorize))
    );

    return route;
}

auth.rs

pub async fn authorize(
    mut req: Request<Body>,
    State(app_state): State<AppState>,
    next: Next
) -> Result<Response<Body>, AuthError> {

I'm getting error in this line middleware::from_fn(auth::authorize) saying:

the trait bound `axum::middleware::FromFn<fn(http::Request<Body>, axum::extract::State<AppState>, Next) -> impl Future<Output = Result<Response<Body>, AuthError>> {authorize}, (), Route, _>: tower_service::Service<http::Request<Body>>` is not satisfied
the trait `tower_service::Service<http::Request<Body>>` is not implemented for `FromFn<fn(Request<Body>, ..., ...) -> ... {authorize}, ..., ..., ...>`
the following other types implement trait `tower_service::Service<Request>`:
  axum::middleware::FromFn<F, S, I, (T1, T2)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8)>
  axum::middleware::FromFn<F, S, I, (T1, T2, T3, T4, T5, T6, T7, T8, T9)>
and 8 othersrustcClick for full compiler diagnostic
api.rs(17, 10): required by a bound introduced by this call
method_routing.rs(967, 21): required by a bound in `MethodRouter::<S, E>::layer`

And I'm not sure why, because the authorize signature looks fine.

EDIT: My shared state is implemented like this

use std::sync::Arc;
use sqlx::postgres::PgPool;

#[derive(Clone)]
pub struct AppState {
    pub html_path: Arc<String>,
    pub db_pool: Arc<PgPool>,
}

impl AppState {
    pub fn new(html_path: String, db_pool: PgPool) -> Self {
        AppState {
            html_path: Arc::new(html_path),
            db_pool: Arc::new(db_pool),
        }
    }
}

Solution

  • This is relatively a simple error. This is how the code should be

    pub fn api_routes(app_state: AppState) -> Router {
        Router::new()
            .route("/api", get(main_page_get).post(main_page_post))
            .route("/api/{query}", get(id_page_get))
            .route("/demo.json", get(get_demo_json).put(put_demo_json)
            .layer(middleware::from_fn_with_state(app_state.clone() ,auth::authorize))
        )
    }
    

    Also, PgPool by default is cloneable, and String is cloneable as well. So placing an Arc on top of these 2 seems pretty useless. The criteria for app state to hold arc fields would be if the field needs modifications on runtime, then you should introduce Arc with a RwLock or Mutex. When you app state is not expected to change, using clone would suffice.

    #[derive(Debug, Clone)]
    pub struct AppState {
        html_path: String,
        db_pool: PgPool,
    }
    
    impl AppState {
        pub async fn new(config: &AppConfig) -> Result<Self> {
            let html_path = //Get it from your config
            let pool = //Initialize your pool here
            Ok(Self {
                html_path,
                pool
                })
        }
        
        pub fn html_path(&self) -> String {
            self.html_path.clone()
        }
    
        pub fn db_pool(&self) -> PgPool {
            self.pool.clone()
        }
    
    }
    

    Also for auth.rs. State should be the first parameter.then request and all other stuff.

    pub async fn authorize(
        State(app_state): State<AppState>,
        mut req: Request,
        next: Next
    ) -> Result<Response<Body>, AuthError> {
    

    edit:

    This is how I pass my app state to router function

    use axum::Router;
    
    mod branch;
    mod health;
    
    use crate::state::SharedAppState;
    
    /// Initializes the routes for the StaffHub application.
    ///
    /// This function sets up the routing for the application, including health check routes
    /// and API versioning.
    ///
    /// # Arguments
    ///
    /// * `app_state` - A shared application state used across routes.
    ///
    /// # Returns
    ///
    /// A `Router` instance with the configured routes.
    pub fn init_routes(app_state: SharedAppState) -> Router {
        // Initialize health check routes under the "/health" path.
        let health_routes_v1 =
            Router::new().nest("/health", health::init_health_routes(app_state.clone()));
        let branch_routes_v1 =
            Router::new().nest("/branch", branch::init_branch_routes(app_state.clone()));
    
        // Merge the health check routes.
        let merged_routes_v1 = Router::new()
            .merge(branch_routes_v1)
            .merge(health_routes_v1);
    
        // Nest the merged routes under the "/v1" version path.
        let v1 = Router::new().nest("/v1", merged_routes_v1);
    
        // Nest the versioned routes under the "/api" path.
        let api_routes_v1 = Router::new().nest("/api", v1);
    
        // Nest the API routes under the "/staff-hub" base path.
        Router::new().nest("/staff-hub", api_routes_v1)
    }
    

    state.rs

    mod jwks;
    mod keycloak;
    
    use crate::config::StaffHubConfig;
    use crate::state::jwks::JwksClient;
    use crate::state::keycloak::KeycloakClient;
    use anyhow::Context;
    use sqlx::{PgPool, postgres::PgPoolOptions};
    use std::{sync::Arc, time::Duration};
    
    /// Application state structure.
    ///
    /// This structure holds the state of the application, including the database connection pool,
    /// the Keycloak client, and the JWKS client.
    ///
    /// Fields:
    /// - `pool`: The PostgreSQL connection pool.
    /// - `keycloak`: The Keycloak client.
    /// - `jwks`: The JWKS client.
    pub struct AppState {
        pool: PgPool,
        keycloak: KeycloakClient,
        jwks: JwksClient,
    }
    
    /// Shared application state type.
    ///
    /// This type represents a shared reference to the application state.
    pub type SharedAppState = Arc<AppState>;
    
    impl AppState {
        /// Initializes the application state.
        ///
        /// This function creates a new `AppState` instance by configuring and initializing
        /// the PostgreSQL connection pool using the provided configuration.
        ///
        /// # Parameters
        ///
        /// - `config`: A reference to the `StaffHubConfig` containing the configuration details.
        ///
        /// # Returns
        ///
        /// An `anyhow::Result` containing the initialized `AppState` instance.
        pub async fn init_state(config: &StaffHubConfig) -> anyhow::Result<Self> {
            let pool = PgPoolOptions::new()
                .min_connections(config.database().min_connections())
                .max_connections(config.database().max_connections())
                .acquire_timeout(Duration::from_secs(config.database().connect_timeout()))
                .max_lifetime(Some(Duration::from_secs(config.database().max_lifetime())))
                .idle_timeout(Duration::from_secs(config.database().idle_timeout()))
                .connect_lazy(config.database().connection_string().as_str())?;
            let keycloak = KeycloakClient::new(config.keycloak());
            let jwks = JwksClient::new(config.keycloak().jwks_uri())
                .await
                .context("Jwks Client Setup")?;
    
            Ok(Self {
                pool,
                keycloak,
                jwks,
            })
        }
    
        /// Returns a shared reference to the application state.
        ///
        /// This function wraps the `AppState` instance in an `Arc` to create a shared reference.
        ///
        /// # Returns
        ///
        /// A `SharedAppState` representing the shared application state.
        pub fn get_shared_state(self) -> SharedAppState {
            Arc::new(self)
        }
    
        /// Returns the PostgreSQL connection pool.
        ///
        /// # Returns
        ///
        /// A `PgPool` representing the PostgreSQL connection pool.
        pub fn pool(&self) -> PgPool {
            self.pool.clone()
        }
    
        /// Returns a reference to the Keycloak client.
        ///
        /// # Returns
        ///
        /// A reference to the `KeycloakClient`.
        pub fn keycloak(&self) -> &KeycloakClient {
            &self.keycloak
        }
    
        /// Returns a reference to the JWKS client.
        ///
        /// # Returns
        ///
        /// A reference to the `JwksClient`.
        pub fn jwks(&self) -> &JwksClient {
            &self.jwks
        }
    }
    

    main.rs

    use crate::config::StaffHubConfig;
    use crate::state::AppState;
    use anyhow::Context;
    use tokio::net::TcpListener;
    
    mod config;
    mod dtos;
    mod entities;
    mod errors;
    mod handlers;
    mod middlewares;
    mod routes;
    mod state;
    mod utils;
    
    /// Initializes and starts the StaffHub service.
    ///
    /// This asynchronous function performs the following steps:
    /// 1. Loads the application configuration.
    /// 2. Initializes the application state.
    /// 3. Sets up the application routes.
    /// 4. Binds the server to the specified address and port.
    /// 5. Starts serving the application.
    ///
    /// # Returns
    ///
    /// An `anyhow::Result` which is `Ok` if the service starts successfully, or an error if any step fails.
    pub async fn init_service() -> anyhow::Result<()> {
        // Load the application configuration.
        let app_config = StaffHubConfig::load_config().context("Configuration Load")?;
    
        // Initialize the application state.
        let app_state = AppState::init_state(&app_config)
            .await
            .context("App State Initialization")?;
    
        let shared_state = app_state.get_shared_state();
    
        // Setting up background tasks
        tokio::spawn(utils::background_jwks_refresh(shared_state.clone()));
    
        // Set up the application routes.
        let app_routes = routes::init_routes(shared_state.clone());
    
        // Bind the server to the specified address and port.
        let listener = TcpListener::bind(app_config.server().addr())
            .await
            .context("Server Port Bind")?;
    
        // Start serving the application.
        axum::serve(listener, app_routes.into_make_service())
            .await
            .context("Application Serve")?;
    
        Ok(())
    }