socketsnetwork-programmingrusttcpprotocols

How to read leb128 from TcpStream in Rust


In short, I need to read a string from a TCP connection sent by a C# client. The client uses the BinaryWriter, which prefixes the actual string with a length in leb128 format. I'm using tokio::net::TcpStream in Rust and have searched for a crate to help me retrieve that length prefix from the stream, but I couldn't find anything suitable. Most solutions require the source from which you are reading to implement the io::Read trait, but tokio::net::TcpStream does not implement it.

I managed to get it working with this ugly code, but I was suspicious of it from the beginning. I recently discovered that it sometimes leads to some kind of race condition. I'm not entirely sure, but I think that it gets blocked on the let file_name_len = leb128::read::unsigned(&mut socket)?;, which somehow causes my TCP listener to stop accepting new connections, which is even stranger.

let mut socket = socket.into_std()?;
socket.set_nonblocking(false)?;

let file_name_len = leb128::read::unsigned(&mut socket)?;

let mut socket = tokio::net::TcpStream::from_std(socket)?;

Does anyone know the right way to this?


Solution

  • I'm not entirely sure, but I think that it gets blocked on the let file_name_len = leb128::read::unsigned(&mut socket)?;, which somehow causes my TCP listener to stop accepting new connections, which is even stranger.

    The above code is blocking:

    1. You set the socket to blocking (set_nonblocking(false)).
    2. Then block in leb128::read::unsigned(&mut socket)?;.

    This will block the entire tokio thread.

    It shouldn't block the TCP listener if the TCP listener runs and in a separate task and you are using the (default) multi-threaded tokio runtime... unless of course you have multiple such LEB tasks blocking every single tokio thread.


    Unfortunately, there's no standard API for Async Reads, and the leb128 crate doesn't provide any tokio integration, so it's going to require a bit of work.

    Not too much, though, because &[u8] implements Read, and after the Read the slice will have been updated to point to the unread bytes.

    Since you are using TCP, I assume you already have some kind of buffer in place for the bytes you receive -- to pass them to the decoder -- so you should just use that buffer.

    //  Read until sufficient bytes are obtained to determine length.
    let length = loop {
        const UNEXPECTED_EOF: io::ErrorKind = io::ErrorKind::UnexpectedEof;
    
        let mut slice = buffer.readable();
    
        let result = leb128::read::unsigned(&mut slice);
    
        match result {
            Ok(length) => {
                let consumed = buffer.readable().len() - slice.len();
    
                buffer.advance(consumed);
    
                break length;
            },
            Err(leb128::read::Error::IoError(e)) if e.kind() == UNEXPECTED_EOF =>
                continue,
            Err(e) => return Err(e.into()),
        }
    
        socket.readable().await?;
    
        let length = socket.try_read(buffer.writable())?;
    
        todo!("Handle length, beware 0 means EOF");
    };
    
    //  Do something with buffered bytes, perhaps waiting for more
    //  (now that you know how many you need).
    todo!();
    

    I don't find the above code-structure very... nice, though. Mixing async I/O & decoding means you cannot test the decoding alone, painful, I really advise preferring Sans IO design when possible.

    Instead, I would encourage you to write a Framer or Decoder which will take care of part (or all) of the decoding logic, and just cleanly separate I/O from framing/decoding.

    The idea is relatively simple: push bytes into it, get framed bytes or decoded messages out.

    Since I don't have your decoder, I'll go for a framer instead, whose role is to isolate a single frame (encoded message) in the stream.

    Once you have a framer, it's actually relatively simple:

    socket.readable().await?;
    
    let mut buffer = [0; 16 * 1024];
    let length = socket.try_read(&mut buffer)?;
    
    if length == 0 {
        todo!("handle EOF");
    }
    
    framer.push(&buffer[..length]);
    
    //  May want to limit the number of frames handled at once, to avoid blocking
    //  other clients.
    
    while let Some(message) = framer.pop()? {
        todo!("handle message");
    }
    

    And quite importantly, it's very easy to test that the framer can handle all kinds of messages.

    The actual framer code is relatively straightforward:

    //  Disclaimer: I have not even _compiled_ this code, don't expect it to handle
    //              all the edge cases.
    
    #[derive(Clone, Debug, Default)]
    pub struct Framer {
        state: FramerState,
        consumed: usize,
        buffer: Vec<u8>,
    }
    
    impl Framer {
        /// Constructs a framer with a specific capacity.
        pub fn with_capacity(capacity: usize) -> Self {
            let state = FramerState::default();
            let consumed = 0;
            let buffer = Vec::with_capacity(capacity);
    
            Self { state, consumed, buffer }
        }
    
        /// Pushes bytes into the framer.
        pub fn push(&mut self, bytes: &[u8]) {
            //  Trick here: draining in push is easier, and avoids O(N²) pop.
            if self.consumed > 0 {
                self.buffer.drain(..self.consumed);
            }
    
            self.buffer.extend_from_slice(bytes);
        }
    
        /// Pops a frame, if possible.
        ///
        /// Call repeatedly to pop all buffered frames, when no complete frame is
        /// buffered any longer, returns `Ok(None)`.
        ///
        /// Returns an error if the underlying stream is faulty, for example has
        /// an overflowing length.
        pub fn pop(&mut self) -> Result<Option<&[u8]>, Error> {
            match self.state {
                FramerState::WaitingForLength => {
                    let length = self.pop_length()?;
    
                    let Some(length) = length else { return Ok(None) };
    
                    let Some(length) = NonZeroU64::new(length) else {
                        return Ok(Some(&[]));
                    };
    
                    //  FIXME: may want a maximum length here, regardless of
                    //         overflow, as otherwise a client accidentally
                    //         sending 2^63-1 LEB encoded will lead the server
                    //         to wait forever.
    
                    self.state = FramerState::WaitingForMessage(length);
    
                    self.pop_message(length)
                }
                FramerState::WaitingForMessage(length) => {
                    self.pop_message(length)
                }
            }
        }
    
        //  Pops the length.
        //
        //  # Panics
        //
        //  In Debug, if called when state is not WaitingForLength.
        fn pop_length(&mut self) -> Result<Option<u64>, Error> {
            const UNEXPECTED_EOF: io::ErrorKind = io::ErrorKind::UnexpectedEof;
    
            debug_assert_eq!(FramerState::WaitingForLength, self.state);
    
            let mut available = &self.buffer[self.consumed..];
    
            match leb128::read::unsigned(&mut available) {
                Ok(length) => {
                    let consumed = self.buffer.len() - self.consumed - available.len();
                    self.consumed += consumed;
    
                    Ok(Some(length))
                },
                Err(leb128::read::Error::IoError(e)) if e.kind() == UNEXPECTED_EOF => Ok(None),
                Err(e) => Err(e.into()),
            }
        }
    
        //  Pops the actual frame, according to the length.
        //
        //  # Panics
        //
        //  In Debug, if called when state is not WaitingForMessage(length).
        fn pop_message(&mut self, length: NonZeroU64) -> Result<Option<&[u8]>, Error> {
            debug_assert_eq!(FramerState::WaitingForMessage(length), self.state);
    
            let length = length.get().try_into().map_err(|_| Error::Overflow)?;
    
            let Some((frame, _)) = self.buffer[self.consumed..].split_at_checked(length) else {
                return Ok(None);
            };
    
            self.state = FramerState::WaitingForLength;
            self.consumed += frame.len();
    
            Ok(Some(frame))
        }
    }
    
    #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
    enum FramerState {
        #[default]
        WaitingForLength,
        WaitingForBytes(NonZeroU64),
    }