multithreadingrusttcp

Do I have to synchronize TcpStream::write_all calls to avoid interleaving?


I have a TcpStream (from the standard library) and would like to write to it from multiple threads. TcpStream allows this without additional synchronization due to impl Write for &TcpStream. The payloads are packaged such that I make a single .write_all() call per payload.

use std::io::Write;
use std::net::TcpStream;

pub struct Publisher {
    stream: TcpStream,
}

impl Publisher {
    pub fn send(&self, payload: &[u8]) {
        // ignore errors for now
        let _ = (&self.stream).write_all(payload);
    }
}

But does this really work?

My worry is that .write_all() may involve multiple .write() calls to send the full payload, and thus concurrent calls may end up interleaving writes from the different threads. I don't see any special handling for TcpStream::write_all and thus it just uses the default trait implementation.

Is my concern well-founded? Is there a "clever" way to avoid the problem? Or do I simply need to wrap it in a Mutex regardless?


Solution

  • Yes, you're responsible to synchronize writing to a TcpStream so your concern is totally well-founded. The fact that it implements Write for a shared reference is more indicative of the underlying implementations which commonly use just an integer to refer to open TcpStreams and thus don't need any references or Rust struct mutability to write to them.

    In fact, there is nothing in the documentation that suggests write_alls to a TcpStream are synchronized. Reviewing the code doesn't reveal any internal synchronization either. And indeed, you can observe interleaving with the following test program with a sufficiently large N:

    use std::collections::HashMap;
    use std::io::{Read, Write};
    use std::net::{TcpListener, TcpStream};
    use std::thread;
    const N: usize = 50_000_000;
    
    fn main() -> Result<(), std::io::Error> {
        let listener = TcpListener::bind(("127.0.0.1", 0))?;
        let address = listener.local_addr()?;
        let handle = thread::spawn(read(listener));
    
        let stream = TcpStream::connect(address)?;
        thread::scope(|s| {
            s.spawn(write(b'a', &stream));
            s.spawn(write(b'b', &stream));
        });
        handle.join().unwrap();
        Ok(())
    }
    
    fn read(listener: TcpListener) -> impl FnOnce() {
        move || {
            while let Ok((mut s, _)) = listener.accept() {
                let mut chars = HashMap::new();
                let mut buf = [0u8; 1024];
                while let Ok(n) = s.read(&mut buf) {
                    for &c in &buf[..n] {
                        chars.entry(c).and_modify(|v| *v += 1).or_insert(1);
                    }
                    if ![Some(N), None].contains(&chars.get(&b'a').copied())
                        && ![Some(N), None].contains(&chars.get(&b'b').copied())
                    {
                        if chars[&b'a'] > chars[&b'b'] {
                            eprintln!("received 'b' before done receiving 'a' {chars:?}");
                        } else {
                            eprintln!("received 'a' before done receiving 'b' {chars:?}");
                        }
                        return;
                    }
                }
            }
        }
    }
    
    fn write(c: u8, mut s: &TcpStream) -> impl FnOnce() + '_ {
        let data = vec![c; N];
        move || {
            _ = s.write_all(&data);
        }
    }
    

    Playground

    I would use a Mutex to add the required synchronization.

    The same applies to the other types which have an impl Write for &T in the standard library: