rustasync-awaitrust-tokio

try_join multiple vecs of Futures of different types


We have a finite number of vecs (unknown length) of futures. Each vec of futures has a different Result type, so we cannot just flatten them. We would like to do something like try_join where we wait until one returns an error, or all of them return successfully.

The place where this join is being done is already within an async context. The futures being joined must all be active simultaneously, so we cannot just try_join the vecs one at a time.

What is a good way to implement this in rust async?


Solution

  • You can add an intermediate mapping step that converts each Future's result into a consistent type, before passing them to try_join_all:

    use futures::future::FutureExt;
    use std::future::Future;
    use std::pin::Pin;
    
    #[tokio::main]
    async fn main() {
        let int_futures: Vec<Pin<Box<dyn Future<Output = Result<String, u8>> + Send>>> = vec![
            async { Ok("ok".to_owned()) }.boxed(),
            async { Err(0) }.boxed(),
        ];
    
        let bool_futures: Vec<Pin<Box<dyn Future<Output = Result<String, bool>> + Send>>> = vec![
            async { Ok("ok".to_owned()) }.boxed(),
            async { Err(false) }.boxed(),
        ];
    
        #[derive(Debug)]
        #[allow(unused)]
        enum EitherError {
            Int(u8),
            Bool(bool),
        }
    
        let result: Result<Vec<String>, EitherError> = futures_util::future::try_join_all(
            int_futures
                .into_iter()
                .map(|f| -> Pin<Box<dyn Future<Output = _>>> {
                    f.map(|r| r.map_err(EitherError::Int)).boxed()
                })
                .chain(
                    bool_futures
                        .into_iter()
                        .map(|f| -> Pin<Box<dyn Future<Output = _>>> {
                            f.map(|r| r.map_err(EitherError::Bool)).boxed()
                        }),
                ),
        )
        .await;
    
        // e.g. Err(Int(0))
        println!("{result:?}");
    }