1use std::{fmt, future::Future, time::Duration};
2
3use tokio::{
4 io::{self, AsyncRead, AsyncWrite},
5 net::{TcpListener, TcpStream},
6};
7
8pub trait Listener: Send + 'static {
10 type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
12
13 type Addr: Send;
15
16 fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send;
21
22 fn local_addr(&self) -> io::Result<Self::Addr>;
24}
25
26impl Listener for TcpListener {
27 type Io = TcpStream;
28 type Addr = std::net::SocketAddr;
29
30 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
31 loop {
32 match Self::accept(self).await {
33 Ok(tup) => return tup,
34 Err(e) => handle_accept_error(e).await,
35 }
36 }
37 }
38
39 #[inline]
40 fn local_addr(&self) -> io::Result<Self::Addr> {
41 Self::local_addr(self)
42 }
43}
44
45#[cfg(unix)]
46impl Listener for tokio::net::UnixListener {
47 type Io = tokio::net::UnixStream;
48 type Addr = tokio::net::unix::SocketAddr;
49
50 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
51 loop {
52 match Self::accept(self).await {
53 Ok(tup) => return tup,
54 Err(e) => handle_accept_error(e).await,
55 }
56 }
57 }
58
59 #[inline]
60 fn local_addr(&self) -> io::Result<Self::Addr> {
61 Self::local_addr(self)
62 }
63}
64
65pub trait ListenerExt: Listener + Sized {
67 fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
90 where
91 F: FnMut(&mut Self::Io) + Send + 'static,
92 {
93 TapIo {
94 listener: self,
95 tap_fn,
96 }
97 }
98}
99
100impl<L: Listener> ListenerExt for L {}
101
102pub struct TapIo<L, F> {
106 listener: L,
107 tap_fn: F,
108}
109
110impl<L, F> fmt::Debug for TapIo<L, F>
111where
112 L: Listener + fmt::Debug,
113{
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 f.debug_struct("TapIo")
116 .field("listener", &self.listener)
117 .finish_non_exhaustive()
118 }
119}
120
121impl<L, F> Listener for TapIo<L, F>
122where
123 L: Listener,
124 F: FnMut(&mut L::Io) + Send + 'static,
125{
126 type Io = L::Io;
127 type Addr = L::Addr;
128
129 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
130 let (mut io, addr) = self.listener.accept().await;
131 (self.tap_fn)(&mut io);
132 (io, addr)
133 }
134
135 fn local_addr(&self) -> io::Result<Self::Addr> {
136 self.listener.local_addr()
137 }
138}
139
140async fn handle_accept_error(e: io::Error) {
141 if is_connection_error(&e) {
142 return;
143 }
144
145 error!("accept error: {e}");
157 tokio::time::sleep(Duration::from_secs(1)).await;
158}
159
160fn is_connection_error(e: &io::Error) -> bool {
161 matches!(
162 e.kind(),
163 io::ErrorKind::ConnectionRefused
164 | io::ErrorKind::ConnectionAborted
165 | io::ErrorKind::ConnectionReset
166 )
167}