rustrust-tokiorust-tonic

Abort tokio task when channel is dropped


I have a gRPC service implemented using tonic which returns a stream of values. This stream is created inside a tokio task and send over to client using a tokio mpsc channel.

Problem is that the spawned task which is sending the partial results is not aborted after client disconnected and the Receiver is dropped causing errors when sending to the channel.

Simplified code:

#[tonic::async_trait]
impl ServiceTrait for MyService {
    type MyStream = BoxStream<'static, Result<MyResponse, tonic::Status>>;

    async fn get_stream(
        &self,
        _request: tonic::Request<()>,
    ) -> Result<tonic::Response<Self::MyStream>, tonic::Status> {
        let (tx, rx) = mpsc::channel::<Result<MyResponse, tonic::Status>>(1);

        // I need this task to be aborted when rx is dropped
        let producer_task_handle = tokio::spawn({
                // spawn many parallel tasks with ratelimiting
                ...
                // each task sends its result to tx
                tx.send(response).await.unwrap() // panics when rx is dropped after client disconnects
        });


        Ok(tonic::Response::new(ReceiverStream::new(rx).boxed()))
    }
}

How can I abort the producer task when the channel is closed? Or is there a better way to do this? I had a working version which returned streams, but that is no longer an option.


Solution

  • After some consideration I decided to use CancellationToken with DropGuard

    Wrapping the receiver stream in a structure with the DropGuard embedded ensures that once the stream is dropped, the cancellation token is cancelled and the task can be aborted

    #[derive(Debug)]
    #[pin_project]
    pub struct StreamWithData<T: Stream, D> {
        #[pin]
        inner: T,
        data: D,
    }
    
    impl<T: Stream, D> StreamWithData<T, D> {
        pub fn new(inner: T, data: D) -> Self {
            Self { inner, data }
        }
    }
    
    impl<T: Stream, D> Stream for StreamWithData<T, D> {
        type Item = T::Item;
    
        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
            let this = self.project();
            this.inner.poll_next(cx)
        }
    }
    
    pub trait DataStream: Stream + Sized {
        fn attach_data<D>(self, data: D) -> StreamWithData<Self, D>
        where
            Self: Sized,
        {
            StreamWithData::new(self, data)
        }
    }
    
    impl<T: Stream> DataStream for T {}
    

    Usage:

            let cancellation_token = CancellationToken::new();
            let drop_guard = cancellation_token.clone().drop_guard();
            // output stream guards the rx drop and cancels the root token
            let rx = ReceiverStream::new(rx).attach_data(drop_guard);
    
    
    
            let producer_task_handle = task::spawn({
                    // do stuff
            })
    
            // abort task
            let abort_handle = producer_task_handle.abort_handle();
            task::spawn(async move {
                cancellation_token.cancelled().await;
                abort_handle.abort();
            });