I would like to classify incoming tcp streams by their first n
bytes and then pass to different handlers according to the classification.
I do not want to consume any of the bytes in the stream, otherwise I will be passing invalid streams to the handlers, that start with the n
th byte.
So poll_peek
looks almost like what I need, as it waits for data to be available before it peeks.
However I think what I would ideally need would be a poll_peek_exact
that does not return until the passed buffer is full.
This method does not seem to exist in TcpStream
, so I'm not sure what the correct way would be to peek the first n
bytes of a TcpStream
without consuming them.
I could do something like:
// Keep peeking until we have enough bytes to decide.
while let Ok(num_bytes) = poll_fn(|cx| {
tcp_stream.poll_peek(cx, &mut buf)
}).await? {
if num_bytes >= n {
return classify(&buf);
}
}
But I think that would be busy waiting, so it seems like a bad idea, right? I could of course add a sleep to the loop, but that also does not seem like good style to me.
So what's the right way to do that?
Here is my attempt:
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use std::error::Error;
#[pin_project]
struct HeaderExtractor<const S: usize> {
#[pin]
socket: TcpStream,
header: [u8; S],
num_forwarded: usize,
}
impl<const S: usize> HeaderExtractor<S> {
pub async fn read_header(socket: TcpStream) -> Result<Self, Box<dyn Error>> {
let mut this = Self {
socket,
header: [0; S],
num_forwarded: 0,
};
this.socket.read_exact(&mut this.header).await?;
Ok(this)
}
pub fn get_header(&mut self) -> &[u8; S] {
&self.header
}
}
impl<const S: usize> AsyncRead for HeaderExtractor<S> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
if *this.num_forwarded < this.header.len() {
let leftover = &this.header[*this.num_forwarded..];
let num_forward_now = leftover.len().min(buf.remaining());
let forward = &leftover[..num_forward_now];
buf.put_slice(forward);
*this.num_forwarded += num_forward_now;
std::task::Poll::Ready(Ok(()))
} else {
this.socket.poll_read(cx, buf)
}
}
}
impl<const S: usize> AsyncWrite for HeaderExtractor<S> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
this.socket.poll_write(cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
this.socket.poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
this.socket.poll_shutdown(cx)
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind("127.0.0.1:12345").await?;
loop {
// Asynchronously wait for an inbound socket.
let (socket, _) = listener.accept().await?;
let mut socket = HeaderExtractor::<3>::read_header(socket).await?;
let header = socket.get_header();
println!("Got header: {:?}", header);
tokio::spawn(async move {
let mut buf = vec![0; 1024];
// In a loop, read data from the socket and write the data back.
loop {
let n = socket
.read(&mut buf)
.await
.expect("failed to read data from socket");
if n == 0 {
println!("Connection closed.");
return;
}
println!("Received: {:?}", &buf[..n]);
}
});
}
}
When I run echo "123HelloWorld!" | nc -N l localhost 12345
on another console, I get:
Got header: [49, 50, 51]
Received: [49, 50, 51]
Received: [72, 101, 108, 108, 111, 87, 111, 114, 108, 100, 33, 10]
Connection closed.