I'm trying to implement an async read wrapper that will add read timeout functionality. The objective is that the API is plain AsyncRead
. In other words, I don't want to add io.read(buf).timeout(t) everywehere in the code. Instead, the read instance itself should return the appropriate io::ErrorKind::TimedOut
after the given timeout expires.
I can't poll the delay
to Ready though. It's always Pending. I've tried with async-std
, futures
, smol-timeout
- the same result. While the timeout does trigger when awaited, it just doesn't when polled. I know timeouts aren't easy. Something needs to wake it up. What am I doing wrong? How to pull this through?
use async_std::{
future::Future,
io,
pin::Pin,
task::{sleep, Context, Poll},
};
use std::time::Duration;
pub struct PrudentIo<IO> {
expired: Option<Pin<Box<dyn Future<Output = ()> + Sync + Send>>>,
timeout: Duration,
io: IO,
}
impl<IO> PrudentIo<IO> {
pub fn new(timeout: Duration, io: IO) -> Self {
PrudentIo {
expired: None,
timeout,
io,
}
}
}
fn delay(t: Duration) -> Option<Pin<Box<dyn Future<Output = ()> + Sync + Send + 'static>>> {
if t.is_zero() {
return None;
}
Some(Box::pin(sleep(t)))
}
impl<IO: io::Read + Unpin> io::Read for PrudentIo<IO> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if let Some(ref mut expired) = self.expired {
match expired.as_mut().poll(cx) {
Poll::Ready(_) => {
println!("expired ready");
// too much time passed since last read/write
return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
}
Poll::Pending => {
println!("expired pending");
// in good time
}
}
}
let res = Pin::new(&mut self.io).poll_read(cx, buf);
println!("read {:?}", res);
match res {
Poll::Pending => {
if self.expired.is_none() {
// No data, start checking for a timeout
self.expired = delay(self.timeout);
}
}
Poll::Ready(_) => self.expired = None,
}
res
}
}
impl<IO: io::Write + Unpin> io::Write for PrudentIo<IO> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_close(cx)
}
}
#[cfg(test)]
mod io_tests {
use super::*;
use async_std::io::ReadExt;
use async_std::prelude::FutureExt;
use async_std::{
io::{copy, Cursor},
net::TcpStream,
};
use std::time::Duration;
#[async_std::test]
async fn fail_read_after_timeout() -> io::Result<()> {
let mut output = b"______".to_vec();
let io = PendIo;
let mut io = PrudentIo::new(Duration::from_millis(5), io);
let mut io = Pin::new(&mut io);
insta::assert_debug_snapshot!(io.read(&mut output[..]).timeout(Duration::from_secs(1)).await,@"Ok(io::Err(timeou))");
Ok(())
}
#[async_std::test]
async fn timeout_expires() {
let later = delay(Duration::from_millis(1)).expect("some").await;
insta::assert_debug_snapshot!(later,@r"()");
}
/// Mock IO always pending
struct PendIo;
impl io::Read for PendIo {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<futures_io::Result<usize>> {
Poll::Pending
}
}
impl io::Write for PendIo {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<futures_io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
Poll::Pending
}
}
}
Async timeouts work as follows:
poll
into the timeout, it checks whether the timeout has expired.Ready
and done.cx.waker().wake()
, or similar.wake()
in the proper waker, which instructs the runtime to call poll
again.poll
will return Ready
. Done!The problem with your code is that you create the delay from inside the poll()
implementation: self.expired = delay(self.timeout);
. But then you return Pending
without polling the timeout even once. This way, there is no callback registered anywhere that would call the Waker
. No waker, no timeout.
I see several solutions:
A. Do not initialize PrudentIo::expired
to None
but create the timeout
directly in the constructor. That way the timeout will always be polled before the io
at least once, and it will be woken. But you will create a timeout always, even if it is not actually needed.
B. When creating the timeout
do a recursive poll:
Poll::Pending => {
if self.expired.is_none() {
// No data, start checking for a timeout
self.expired = delay(self.timeout);
return self.poll_read(cx, buf);
}
This will call the io
twice, unnecesarily, so it may not be optimal.
C. Add a call to poll after creating the timeout:
Poll::Pending => {
if self.expired.is_none() {
// No data, start checking for a timeout
self.expired = delay(self.timeout);
self.expired.as_mut().unwrap().as_mut().poll(cx);
}
Maybe you should match the output of poll in case it returns Ready
, but hey, it's a new timeout, it's probably pending yet, and it seems to work nicely.