rustmiddlewarerust-warp

Execute middleware before and after request in Rust warp


I would like to track in-flight connections in warp such that a metrics counter is incremented before the request is handled and decremented after it was processed.

I attempted to solve this by using a "no-op" filter in the start of the chain and a custom logging filter in the end of the chain; something like that:

/// Increment the request count metric before the requests starts.
fn with_start_call_metrics() -> impl Filter<Extract = (), Error = Infallible> + Clone {
    warp::any()
        .and(path::full())
        .map(|path: FullPath| {
            HttpMetrics::inc_in_flight(path.as_str());
        })
        .untuple_one()
}

/// Decrement the request count metric after the request ended.
fn with_end_call_metrics() -> Log<fn(Info<'_>)> {
    warp::log::custom(|info| {
        HttpMetrics::dec_in_flight(info.path());
        // ... track more metrics, e.g. info.elapsed() ...
    })
}

The problem arises when a long-running request (/slow in the code below) is started and the connection is dropped before the request could be processed completely (e.g. CTRL-C on curl).

In this case, the slow route is simply aborted by warp and the with_end_call_metrics filter below is never reached:

#[tokio::main]
async fn main() {
    let hello = warp::path!("hello" / String).and_then(hello);
    let slow = warp::path!("slow").and_then(slow);

    warp::serve(
        with_start_call_metrics()
            .and(
                hello.or(slow), // ... and more ...
            )
            // If the call (e.g. of `slow`) is cancelled, this is never reached.
            .with(with_end_call_metrics()),
    )
    .run(([127, 0, 0, 1], 8080))
    .await;
}

async fn hello(name: String) -> Result<impl warp::Reply, warp::Rejection> {
    Ok(format!("Hello, {}!", name))
}

async fn slow() -> Result<impl warp::Reply, warp::Rejection> {
    tokio::time::sleep(Duration::from_secs(5)).await;
    Ok(format!("That was slow."))
}

I understand this is normal behavior and the recommended way is to rely on the Drop implementation of a type in the request, as that would always be called, so something like:

async fn in_theory<F, T, E>(filter: F) -> Result<T, E>
where
    F: Filter<Extract = T, Error = E>
{
    let guard = TrackingGuard::new();
    filter.await
}

But that doesn't work. I tried using wrap_fn like so:

pub fn in_theory<F>(filter: F) -> Result<F::Extract, F::Error>
where
    F: Filter + Clone,
{
    warp::any()
        .and(filter)
        .wrap_fn(|f| async { 
             // ... magic here ...
             f.await 
        })
}

but regardless of what I try, it always ends up with an error like this:

error[E0277]: the trait bound `<F as warp::filter::FilterBase>::Error: reject::sealed::CombineRejection<Infallible>` is not satisfied
   --> src/metrics.rs:255:25
    |
255 |         warp::any().and(filter).wrap_fn(|f| async { f.await })
    |                     --- ^^^^^^ the trait `reject::sealed::CombineRejection<Infallible>` is not implemented for `<F as warp::filter::FilterBase>::Error`
    |                     |
    |                     required by a bound introduced by this call

And that cannot be specified, because reject::sealed is not a public module. Any help is appreciated!


Solution

  • As was suggested in the comments, moving away from warp and using Tower for building the middleware helped. I had to rewrite the code for hosting the server to use hyper::Server directly but this was only a mild inconvenience.


    I started off with an HttpCallMetrics service wrapping an inner service S. Since I am tracking HTTP responses, I need that service to ultimately produce a hyper::Response, which is indicated here by type argument O.

    The phantom data is here such that I can indicate O on the struct; not adding O here would prevent the Service implementation to fail due to missing trait bounds.

    #[derive(Clone)]
    pub struct HttpCallMetrics<S, O> {
        inner: T,
        _phantom: PhantomData<O>,
    }
    
    impl<T, O> HttpCallMetrics<S, O> {
        pub fn new(inner: S) -> Self {
            Self {
                inner,
                _phantom: PhantomData::default(),
            }
        }
    }
    

    Because it is about HTTP metrics, the service also specifically deals with HTTP requests and hence implements Service<Request<B>> for any body type B. Likewise, the wrapped service needs to be the same and its output needs to be convertible to a Response<O>.

    The HttpCallMetrics service will produce a custom future HttpCallMetricsFuture that takes care of the metrics tracking; this is to avoid boxing here. Apart from that, since metrics never block, it forwards its poll_ready call to the wrapped inner service.

    When called, a HttpCallMetricTracker instance is created from the request. This is a struct that holds basic request information (HTTP method, version, path, start time instance) and implements Drop - when dropped, it will register that the request terminated. This will work regardless of cancellation or finishing a request successfully.

    impl<S, B, O> Service<Request<B>> for HttpCallMetrics<S, O>
    where
        S: Service<Request<B>>,
        S::Response: Into<hyper::Response<O>>,
    {
        type Response = hyper::Response<O>;
        type Error = S::Error;
        type Future = HttpCallMetricsFuture<S::Future, O, Self::Error>;
    
        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
            self.inner.poll_ready(cx)
        }
    
        fn call(&mut self, request: Request<B>) -> Self::Future {
            let tracker = HttpCallMetricTracker::start(&request);
            HttpCallMetricsFuture::new(self.inner.call(request), tracker)
        }
    }
    

    The implemented future again requires a phantom data hack for keeping track of the success variant O and error variant E of the service's future.

    #[pin_project]
    pub struct HttpCallMetricsFuture<F, O, E> {
        #[pin]
        future: F,
        tracker: HttpCallMetricTracker,
        _phantom: PhantomData<(O, E)>,
    }
    
    impl<F, O, E> HttpCallMetricsFuture<F, O, E> {
        fn new(future: F, tracker: HttpCallMetricTracker) -> Self {
            Self {
                future,
                tracker,
                _phantom: PhantomData::default(),
            }
        }
    }
    

    The implementation is then comparatively simple: In essence, the poll call is forwarded to the wrapped inner future, and the method exits if that future is still Poll::Pending.

    The moment the future returns Poll::Ready it will be inspected for its result variant and if it is an Ok(result) the result is converted into a hyper::Response. Metrics are then updated and the response is returned.

    In case of an error variant, the error is essentially returned as is.

    impl<F, R, O, E> Future for HttpCallMetricsFuture<F, O, E>
    where
        F: Future<Output = Result<R, E>>,
        R: Into<hyper::Response<O>>,
    {
        type Output = Result<hyper::Response<O>, E>;
    
        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
            let this = self.project();
            let response = match this.future.poll(cx) {
                Poll::Pending => return Poll::Pending,
                Poll::Ready(reply) => reply,
            };
    
            let result = match response {
                Ok(reply) => {
                    let response = reply.into();
                    this.tracker
                        .set_state(ResultState::Result(response.status(), response.version()));
                    Ok(response)
                }
                Err(e) => {
                    this.tracker.set_state(ResultState::Failed);
                    Err(e)
                }
            };
            Poll::Ready(result)
        }
    }
    

    The HttpCallMetricTracker is more or less trivial, it increments call metrics when constructed and decrements call metrics when dropped.

    The only interesting point here would be the state: Cell<ResultState> field. This allows the Drop implementation to infer whether something should be logged or not. It's not strictly required here

    struct HttpCallMetricTracker {
        version: Version,
        method: hyper::Method,
        path: String,
        start: Instant,
        state: Cell<ResultState>,
    }
    
    pub enum ResultState {
        /// The result was already processed.
        None,
        /// Request was started.
        Started,
        /// The result failed with an error.
        Failed,
        /// The result is an actual HTTP response.
        Result(StatusCode, Version),
    }
    
    impl HttpCallMetricTracker {
        fn start<B>(request: &Request<B>) -> Self {
            // increase "requests in flight" metric
            Self {
                // ...
                state: Cell::new(ResultState::None),
            }
        }
    
        fn set_state(&self, state: ResultState) {
            self.state.set(state)
        }
    
        fn duration(&self) -> Duration {
            Instant::now() - self.start
        }
    }
    
    impl Drop for HttpCallMetricTracker {
        fn drop(&mut self) {
            match self.state.replace(ResultState::None) {
                ResultState::None => {
                    // This was already handled; don't decrement metrics again.
                    return;
                }
                ResultState::Started => {
                    // no request was actually performed.
                }
                ResultState::Failed => {
                    // handle "fail" state
                }
                ResultState::Result(status, version) => {
                    // handle "meaningful result" state
                }
            }
    
            // decrease "requests in flight" metric
        }
    }
    

    As far as hosting goes, the code now looks something like that:

    let make_svc = make_service_fn(|_conn| {
        let tx = shutdown_tx.clone();
    
        async move {
            // Convert the warp filter into a Tower service.
            let svc = warp::service(
                hello
                    .or(slow)
                    .or(filters::metrics_endpoint())
                    .or(filters::health_endpoints())
                    .or(filters::shutdown_endpoint(tx)),
            );
    
            // Wrap it into the metrics service.
            let svc = services::HttpCallMetrics::new(svc);
    
            Ok::<_, Infallible>(svc)
        }
    });
    
    let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
    let listener = TcpListener::bind(addr).unwrap();
    
    // Using a ServiceBuilder is not strictly required.
    let builder = ServiceBuilder::new().service(make_svc);
    
    Server::from_tcp(listener)
        .unwrap()
        .serve(builder)
        .with_graceful_shutdown(async move {
            shutdown_rx.recv().await.ok();
        })
        .await?;
    

    That said, there also exists tower_http::trace which indeed seems to support all of the above. I will likely migrate to that later on, but this exercise helped me tremendously in understanding Tower in the first place.