rustrust-tokiorust-warp

Accessing main-scoped awaitable hashmap from within a task without using channels


I am checking out WebSocket frameworks for Rust and ended up watching a tutorial video on Warp (https://www.youtube.com/watch?v=fuiFycJpCBw), recreated that project and then compared it with Warp's own example implementation of a chat server.

I started to pick pieces from each approach and ended up modifying Warp's own example to my liking, then started to include some errors on purpose to see what effect it has on the code.

Specifically I was attempting to understand when which error handling branch would get executed.

These examples contain a main-scoped hashmap consisting of a mapping between the user id and their corresponding transmit channel, so that iterating this hashmap will allow to send a message to each connected user.

Each new connection will insert a new mapping via users.write().await.insert(my_id, tx); and upon disconnection remove it via users.write().await.remove(&my_id);.

What I'm doing in order to create a send error is to not remove the user mapping upon the client disconnecting. When then a new message comes in and this hashmap is iterated, it still contains the obsolete entry, trying to send a message through it, which correctly branches into the error branch for the send() attempt.

The issue is that this error branch is within a tokio::spawn block, and from within there I would like to issue this users.write().await.remove(&my_id); call which I removed from the normal flow.

I might be mistaken, but I believe that this is not possible, since I don't see a way for this task to access and modify this hashmap. If I understood the problem correctly, I am supposed to create an additional channel which this task can use to send a message back to the main scope in order for it to remove the entry from the hashmap.

For this I'm using an additional mpsc::unbounded_channel() on which I call the send method from the error handling branch in order to send the removal request message.

But this makes me also need to await a next() on the receiving end of the channel, which causes a problem, since that branch is already blocking in a while let Some(result) = user_rx.next().await loop block in order to wait for the next() incoming WebSocket message.

So what I tried to do was to add a tokio::select! block where I would listen for new WebSocket messages as well as those removal messages which are sent from the task when it encounters an error. This works, I can receive WebSocket messages as well as those from the new "control" channel.

Yet this creates a new problem: When the client disconnects, I would expect the tokio::select! block to trigger an error or something on ws_rx.next() (the WebSocket receiving socket), which is one of the branches in the tokio::select! block. This would allow me to treat that connection as disconnected and remove the client from the hashmap.

Previously, without the tokio::select! block, the while let Some(result) = ws_rx.next().await would terminate immediately as soon as a client disconnects, without raising an error.

What I also tried was instead of using an additional channel in order to send a request message back, to call drop(ws_tx), which didn't work. The core of the problem is that I want to be able to manipulate the hashmap from within that task.

I'm now adding the code, which can be copy pasted into a new project. It contains the two variants, one with the tokio::select! block and one with the while let Some(result) = user_rx.next().await block, they can be selected by setting the boolean from if true { /*select*/ } else { /*while let*/ }.

Two problems you want to inspect:

  1. when using the while let block, comment out the very last line users.write().await.remove(&current_id); to trigger the send error.
  2. when using the tokio::select! block, observe that select doesn't trigger a disconnection on the ws_rx.next() branch and therefore not reaching the bottom users.write().await.remove(&current_id);.

What I would like to do is to not use tokio::select! with a channel, but leave it at the simpler while let-variant, and modifying the users hashmap from within the tokio::task::spawn code.

Apparently I can use the hashmap there, but then I can't continue using it in the main scope.

This is the code which contains the problems, main.rs:


//###########################################################################
use std::collections::HashMap;
use std::sync::{atomic::{AtomicUsize, Ordering}, Arc};
use env_logger::Env;
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, RwLock};
use tokio_stream::wrappers::UnboundedReceiverStream;
use warp::ws::{Message, WebSocket};
use warp::Filter;
use colored::Colorize;
//###########################################################################


//###########################################################################
static NEXT_USER_ID: AtomicUsize = AtomicUsize::new(1);
type Users = Arc<RwLock<HashMap<usize, mpsc::UnboundedSender<Message>>>>;
//###########################################################################


//###########################################################################
#[tokio::main]
async fn main() {    
  env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();
  let users = Users::default();
  let users = warp::any().map(move || users.clone());
  let websocket = warp::path("ws")
    .and(warp::ws())
    .and(users)
    .map(|ws: warp::ws::Ws, users| {
        ws.on_upgrade(move |socket| connect(socket, users))
    });
  let files = warp::fs::dir("./static");
  let port = 8186;
  println!("running server at 0.0.0.0:{}", port.to_string().yellow());
  warp::serve(files.or(websocket)).run(([0, 0, 0, 0], port)).await;
}
//###########################################################################


//###########################################################################
async fn connect(ws: WebSocket, users: Users) {
  let current_id = NEXT_USER_ID.fetch_add(1, Ordering::Relaxed);
  println!("user {} connected", current_id.to_string().green());
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  let (mut ws_tx, mut ws_rx) = ws.split();
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  let (mpsc_tx, mpsc_rx) = mpsc::unbounded_channel(); // For passing WS messages between tasks
  let mut mpsc_stream_rx = UnboundedReceiverStream::new(mpsc_rx);
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  let (mpsc_tx_2, mpsc_rx_2) = mpsc::unbounded_channel(); // For sending `remove-request` messages
  let mut mpsc_stream_rx_2: UnboundedReceiverStream<(String, usize)> = UnboundedReceiverStream::new(mpsc_rx_2);
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  tokio::task::spawn(async move {
    while let Some(message) = mpsc_stream_rx.next().await {
      //----------------------------------------------------------------
      match ws_tx.send(message).await {
        Ok(_) => {
          // println!("websocket send success (current_id={})", current_id);
        },
        Err(e) => {
          eprintln!("=============================================================");
          eprintln!("websocket send error (current_id={}): {}", current_id, e);
          eprintln!("=============================================================");
          mpsc_tx_2.send(("remove-user".to_string(), current_id)).expect("unable to send remove-user message");
          break;
        }
      };
      //----------------------------------------------------------------
    };
    // NOTE: Problem here: cannot use "users"
    // users.write().await.remove(&current_id);
    // eprintln!("websocket send task ended (current_id={})", current_id);
    // eprintln!("=============================================================");
  });
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  users.write().await.insert(current_id, mpsc_tx);
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

  
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  if false { // <------------------ TOGGLE THIS
    loop {
      tokio::select! {
        Some(result) = ws_rx.next() => {
          //------------------------------------------------------------------
          let msg = match result {
            Ok(msg) => msg,
            Err(e) => {
              eprintln!("=============================================================");
              eprintln!("websocket receive error(current_id={}): {}", current_id, e);
              eprintln!("=============================================================");
              break;
            }
          };
          //------------------------------------------------------------------
          if let Ok(text) = msg.to_str() {
            //----------------------------------------------------------------
            println!("got message '{}' from user {}", text, current_id);
            let new_msg = Message::text(format!("user {}: {}", current_id, text));
            //----------------------------------------------------------------
            let mut remove = Vec::new();
            for (&uid, mpsc_tx) in users.read().await.iter() {
              if current_id != uid {
                println!(" -> forwarding message '{}' to channel of user {}", text, uid);
                if let Err(e) = mpsc_tx.send(new_msg.clone()) {
                  eprintln!("=============================================================");
                  eprintln!("websocket channel error (current_id={}, uid={}): {}", current_id, uid.clone(), e);
                  eprintln!("=============================================================");
                  remove.push(uid);
                }
              }
            }
            //----------------------------------------------------------------
            if remove.len() > 0 {
              for uid in remove {
                eprintln!("removing from users (uid={})", uid);
                eprintln!("=============================================================");
                users.write().await.remove(&uid);
              }
            }
            //----------------------------------------------------------------
          };
          //------------------------------------------------------------------
        }

        Some(result) = mpsc_stream_rx_2.next() => {
          let (command, uid) = result;
          if command == "remove-user" {
            eprintln!("=============================================================");
            eprintln!("removing user {}", uid);
            eprintln!("=============================================================");
            users.write().await.remove(&uid);
          }
          else {
            eprintln!("=============================================================");
            eprintln!("unknown command {}", command);
            eprintln!("=============================================================");
          }
          break;
        }
        else => break
      }
    }
  }
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  else {
    while let Some(result) = ws_rx.next().await {
      //------------------------------------------------------------------
      let msg = match result {
        Ok(msg) => msg,
        Err(e) => {
          eprintln!("=============================================================");
          eprintln!("websocket receive error(current_id={}): {}", current_id, e);
          eprintln!("=============================================================");
          break;
        }
      };
      //------------------------------------------------------------------
      if let Ok(text) = msg.to_str() {
        //----------------------------------------------------------------
        println!("got message '{}' from user {}", text, current_id);
        let new_msg = Message::text(format!("user {}: {}", current_id, text));
        //----------------------------------------------------------------
        let mut remove = Vec::new();
        for (&uid, mpsc_tx) in users.read().await.iter() {
          if current_id != uid {
            println!(" -> forwarding message '{}' to channel of user {}", text, uid);
            if let Err(e) = mpsc_tx.send(new_msg.clone()) {
              eprintln!("=============================================================");
              eprintln!("websocket channel error (current_id={}, uid={}): {}", current_id, uid.clone(), e);
              eprintln!("=============================================================");
              remove.push(uid);
            }
          }
        }
        //----------------------------------------------------------------
        if remove.len() > 0 {
          for uid in remove {
            eprintln!("removing from users (uid={})", uid);
            eprintln!("=============================================================");
            users.write().await.remove(&uid);
          }
        }
        //----------------------------------------------------------------
      };
      //------------------------------------------------------------------
    }
  }
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  println!("user {} disconnected", current_id.to_string().red());
  users.write().await.remove(&current_id);
  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
}
//###########################################################################


The source of this code is primarily based on these files:

https://github.com/seanmonstar/warp/blob/master/examples/websockets_chat.rs

https://github.com/ddprrt/warp-websockets-example/blob/main/src/main.rs

This is the Cargo.toml file content:

[package]
name = "websocket-3"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tokio = { version = "1", features = ["full"] }
warp = "0.3.3"
tokio-stream = "0.1.10"
futures = "0.3.24"
env_logger = "0.9.1"
colored = "2"

And this is the index.html file residing in the static/ directory:

<!DOCTYPE html>
<html lang="en">
    <head>
        <style>
          html, body {
            color: rgba(128, 128, 128);
            background-color: rgb(32, 32, 32);
          }
        </style>
        <title>Warp Websocket 3 8186 Chat</title>
    </head>
    <body>
        <h1>Warp Websocket 3 8186 Chat</h1>
        <div id="chat">
            <p><em>Connecting...</em></p>
        </div>
        <input type="text" id="text" />
        <button type="button" id="send">Send</button>
        <script type="text/javascript">
        const chat = document.getElementById('chat');
        const text = document.getElementById('text');
        const uri = 'ws://' + location.host + '/ws';
        const ws = new WebSocket(uri);
        function message(data) {
            const line = document.createElement('p');
            line.innerText = data;
            chat.appendChild(line);
        }
        ws.onopen = function() {
            chat.innerHTML = '<p><em>Connected!</em></p>';
        };
        ws.onmessage = function(msg) {
            message(msg.data);
        };
        ws.onclose = function() {
            chat.getElementsByTagName('em')[0].innerText = 'Disconnected!';
        };
        send.onclick = function() {
            const msg = text.value;
            ws.send(msg);
            text.value = '';
            message('you: ' + msg);
        };
        </script>
    </body>
</html>```

Solution

  • To be honest, I didn't read your entire question. It's a little too long.

    Either way. I flew over it and stumbled across this:

    Apparently I can use the hashmap there, but then I can't continue using it in the main scope.

    This is incorrect. It is only true if you move the HashMap itself into the closure.

    Arcs work a little different with move || closures: You have to clone them and then move the clone in:

    async fn connect(ws: WebSocket, users: Users) {
        // .. some code ..
        tokio::task::spawn({
            let users = Arc::clone(&users);
            async move {
                // `users` in here is the cloned one,
                // the original one still exists
            }
        });
        // `users` can still be used here
    }