1use std::{
4 convert::Infallible,
5 fmt::Debug,
6 future::{poll_fn, Future, IntoFuture},
7 io,
8 marker::PhantomData,
9 sync::Arc,
10};
11
12use axum_core::{body::Body, extract::Request, response::Response};
13use futures_util::{pin_mut, FutureExt};
14use hyper::body::Incoming;
15use hyper_util::rt::{TokioExecutor, TokioIo};
16#[cfg(any(feature = "http1", feature = "http2"))]
17use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
18use tokio::sync::watch;
19use tower::ServiceExt as _;
20use tower_service::Service;
21
22mod listener;
23
24pub use self::listener::{Listener, ListenerExt, TapIo};
25
26#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
91pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
92where
93 L: Listener,
94 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
95 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
96 S::Future: Send,
97{
98 Serve {
99 listener,
100 make_service,
101 _marker: PhantomData,
102 }
103}
104
105#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
107#[must_use = "futures must be awaited or polled"]
108pub struct Serve<L, M, S> {
109 listener: L,
110 make_service: M,
111 _marker: PhantomData<S>,
112}
113
114#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
115impl<L, M, S> Serve<L, M, S>
116where
117 L: Listener,
118{
119 pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
141 where
142 F: Future<Output = ()> + Send + 'static,
143 {
144 WithGracefulShutdown {
145 listener: self.listener,
146 make_service: self.make_service,
147 signal,
148 _marker: PhantomData,
149 }
150 }
151
152 pub fn local_addr(&self) -> io::Result<L::Addr> {
154 self.listener.local_addr()
155 }
156}
157
158#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
159impl<L, M, S> Debug for Serve<L, M, S>
160where
161 L: Debug + 'static,
162 M: Debug,
163{
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 let Self {
166 listener,
167 make_service,
168 _marker: _,
169 } = self;
170
171 let mut s = f.debug_struct("Serve");
172 s.field("listener", listener)
173 .field("make_service", make_service);
174
175 s.finish()
176 }
177}
178
179#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
180impl<L, M, S> IntoFuture for Serve<L, M, S>
181where
182 L: Listener,
183 L::Addr: Debug,
184 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
185 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
186 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
187 S::Future: Send,
188{
189 type Output = io::Result<()>;
190 type IntoFuture = private::ServeFuture;
191
192 fn into_future(self) -> Self::IntoFuture {
193 self.with_graceful_shutdown(std::future::pending())
194 .into_future()
195 }
196}
197
198#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
200#[must_use = "futures must be awaited or polled"]
201pub struct WithGracefulShutdown<L, M, S, F> {
202 listener: L,
203 make_service: M,
204 signal: F,
205 _marker: PhantomData<S>,
206}
207
208#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
209impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
210where
211 L: Listener,
212{
213 pub fn local_addr(&self) -> io::Result<L::Addr> {
215 self.listener.local_addr()
216 }
217}
218
219#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
220impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
221where
222 L: Debug + 'static,
223 M: Debug,
224 S: Debug,
225 F: Debug,
226{
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 let Self {
229 listener,
230 make_service,
231 signal,
232 _marker: _,
233 } = self;
234
235 f.debug_struct("WithGracefulShutdown")
236 .field("listener", listener)
237 .field("make_service", make_service)
238 .field("signal", signal)
239 .finish()
240 }
241}
242
243#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
244impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
245where
246 L: Listener,
247 L::Addr: Debug,
248 M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
249 for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
250 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
251 S::Future: Send,
252 F: Future<Output = ()> + Send + 'static,
253{
254 type Output = io::Result<()>;
255 type IntoFuture = private::ServeFuture;
256
257 fn into_future(self) -> Self::IntoFuture {
258 let Self {
259 mut listener,
260 mut make_service,
261 signal,
262 _marker: _,
263 } = self;
264
265 private::ServeFuture(Box::pin(async move {
266 let (signal_tx, signal_rx) = watch::channel(());
267 let signal_tx = Arc::new(signal_tx);
268 tokio::spawn(async move {
269 signal.await;
270 trace!("received graceful shutdown signal. Telling tasks to shutdown");
271 drop(signal_rx);
272 });
273
274 let (close_tx, close_rx) = watch::channel(());
275
276 loop {
277 let (io, remote_addr) = tokio::select! {
278 conn = listener.accept() => conn,
279 _ = signal_tx.closed() => {
280 trace!("signal received, not accepting new connections");
281 break;
282 }
283 };
284
285 let io = TokioIo::new(io);
286
287 trace!("connection {remote_addr:?} accepted");
288
289 poll_fn(|cx| make_service.poll_ready(cx))
290 .await
291 .unwrap_or_else(|err| match err {});
292
293 let tower_service = make_service
294 .call(IncomingStream {
295 io: &io,
296 remote_addr,
297 })
298 .await
299 .unwrap_or_else(|err| match err {})
300 .map_request(|req: Request<Incoming>| req.map(Body::new));
301
302 let hyper_service = TowerToHyperService::new(tower_service);
303
304 let signal_tx = Arc::clone(&signal_tx);
305
306 let close_rx = close_rx.clone();
307
308 tokio::spawn(async move {
309 #[allow(unused_mut)]
310 let mut builder = Builder::new(TokioExecutor::new());
311 #[cfg(feature = "http2")]
313 builder.http2().enable_connect_protocol();
314 let conn = builder.serve_connection_with_upgrades(io, hyper_service);
315 pin_mut!(conn);
316
317 let signal_closed = signal_tx.closed().fuse();
318 pin_mut!(signal_closed);
319
320 loop {
321 tokio::select! {
322 result = conn.as_mut() => {
323 if let Err(_err) = result {
324 trace!("failed to serve connection: {_err:#}");
325 }
326 break;
327 }
328 _ = &mut signal_closed => {
329 trace!("signal received in task, starting graceful shutdown");
330 conn.as_mut().graceful_shutdown();
331 }
332 }
333 }
334
335 drop(close_rx);
336 });
337 }
338
339 drop(close_rx);
340 drop(listener);
341
342 trace!(
343 "waiting for {} task(s) to finish",
344 close_tx.receiver_count()
345 );
346 close_tx.closed().await;
347
348 Ok(())
349 }))
350 }
351}
352
353#[derive(Debug)]
359pub struct IncomingStream<'a, L>
360where
361 L: Listener,
362{
363 io: &'a TokioIo<L::Io>,
364 remote_addr: L::Addr,
365}
366
367impl<L> IncomingStream<'_, L>
368where
369 L: Listener,
370{
371 pub fn io(&self) -> &L::Io {
373 self.io.inner()
374 }
375
376 pub fn remote_addr(&self) -> &L::Addr {
378 &self.remote_addr
379 }
380}
381
382mod private {
383 use std::{
384 future::Future,
385 io,
386 pin::Pin,
387 task::{Context, Poll},
388 };
389
390 pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
391
392 impl Future for ServeFuture {
393 type Output = io::Result<()>;
394
395 #[inline]
396 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
397 self.0.as_mut().poll(cx)
398 }
399 }
400
401 impl std::fmt::Debug for ServeFuture {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 f.debug_struct("ServeFuture").finish_non_exhaustive()
404 }
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use std::{
411 future::{pending, IntoFuture as _},
412 net::{IpAddr, Ipv4Addr},
413 };
414
415 use axum_core::{body::Body, extract::Request};
416 use http::StatusCode;
417 use hyper_util::rt::TokioIo;
418 #[cfg(unix)]
419 use tokio::net::UnixListener;
420 use tokio::{
421 io::{self, AsyncRead, AsyncWrite},
422 net::TcpListener,
423 };
424
425 #[cfg(unix)]
426 use super::IncomingStream;
427 use super::{serve, Listener};
428 #[cfg(unix)]
429 use crate::extract::connect_info::Connected;
430 use crate::{
431 body::to_bytes,
432 handler::{Handler, HandlerWithoutStateExt},
433 routing::get,
434 serve::ListenerExt,
435 Router,
436 };
437
438 #[allow(dead_code, unused_must_use)]
439 async fn if_it_compiles_it_works() {
440 #[derive(Clone, Debug)]
441 struct UdsConnectInfo;
442
443 #[cfg(unix)]
444 impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
445 fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
446 Self
447 }
448 }
449
450 let router: Router = Router::new();
451
452 let addr = "0.0.0.0:0";
453
454 let tcp_nodelay_listener = || async {
455 TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
456 if let Err(err) = tcp_stream.set_nodelay(true) {
457 eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
458 }
459 })
460 };
461
462 serve(TcpListener::bind(addr).await.unwrap(), router.clone());
464 serve(tcp_nodelay_listener().await, router.clone())
465 .await
466 .unwrap();
467 #[cfg(unix)]
468 serve(UnixListener::bind("").unwrap(), router.clone());
469
470 serve(
471 TcpListener::bind(addr).await.unwrap(),
472 router.clone().into_make_service(),
473 );
474 serve(
475 tcp_nodelay_listener().await,
476 router.clone().into_make_service(),
477 );
478 #[cfg(unix)]
479 serve(
480 UnixListener::bind("").unwrap(),
481 router.clone().into_make_service(),
482 );
483
484 serve(
485 TcpListener::bind(addr).await.unwrap(),
486 router
487 .clone()
488 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
489 );
490 serve(
491 tcp_nodelay_listener().await,
492 router
493 .clone()
494 .into_make_service_with_connect_info::<std::net::SocketAddr>(),
495 );
496 #[cfg(unix)]
497 serve(
498 UnixListener::bind("").unwrap(),
499 router.into_make_service_with_connect_info::<UdsConnectInfo>(),
500 );
501
502 serve(TcpListener::bind(addr).await.unwrap(), get(handler));
504 serve(tcp_nodelay_listener().await, get(handler));
505 #[cfg(unix)]
506 serve(UnixListener::bind("").unwrap(), get(handler));
507
508 serve(
509 TcpListener::bind(addr).await.unwrap(),
510 get(handler).into_make_service(),
511 );
512 serve(
513 tcp_nodelay_listener().await,
514 get(handler).into_make_service(),
515 );
516 #[cfg(unix)]
517 serve(
518 UnixListener::bind("").unwrap(),
519 get(handler).into_make_service(),
520 );
521
522 serve(
523 TcpListener::bind(addr).await.unwrap(),
524 get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
525 );
526 serve(
527 tcp_nodelay_listener().await,
528 get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
529 );
530 #[cfg(unix)]
531 serve(
532 UnixListener::bind("").unwrap(),
533 get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
534 );
535
536 serve(
538 TcpListener::bind(addr).await.unwrap(),
539 handler.into_service(),
540 );
541 serve(tcp_nodelay_listener().await, handler.into_service());
542 #[cfg(unix)]
543 serve(UnixListener::bind("").unwrap(), handler.into_service());
544
545 serve(
546 TcpListener::bind(addr).await.unwrap(),
547 handler.with_state(()),
548 );
549 serve(tcp_nodelay_listener().await, handler.with_state(()));
550 #[cfg(unix)]
551 serve(UnixListener::bind("").unwrap(), handler.with_state(()));
552
553 serve(
554 TcpListener::bind(addr).await.unwrap(),
555 handler.into_make_service(),
556 );
557 serve(tcp_nodelay_listener().await, handler.into_make_service());
558 #[cfg(unix)]
559 serve(UnixListener::bind("").unwrap(), handler.into_make_service());
560
561 serve(
562 TcpListener::bind(addr).await.unwrap(),
563 handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
564 );
565 serve(
566 tcp_nodelay_listener().await,
567 handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
568 );
569 #[cfg(unix)]
570 serve(
571 UnixListener::bind("").unwrap(),
572 handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
573 );
574 }
575
576 async fn handler() {}
577
578 #[crate::test]
579 async fn test_serve_local_addr() {
580 let router: Router = Router::new();
581 let addr = "0.0.0.0:0";
582
583 let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
584 let address = server.local_addr().unwrap();
585
586 assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
587 assert_ne!(address.port(), 0);
588 }
589
590 #[crate::test]
591 async fn test_with_graceful_shutdown_local_addr() {
592 let router: Router = Router::new();
593 let addr = "0.0.0.0:0";
594
595 let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
596 .with_graceful_shutdown(pending());
597 let address = server.local_addr().unwrap();
598
599 assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
600 assert_ne!(address.port(), 0);
601 }
602
603 #[test]
604 fn into_future_outside_tokio() {
605 let router: Router = Router::new();
606 let addr = "0.0.0.0:0";
607
608 let rt = tokio::runtime::Builder::new_multi_thread()
609 .enable_all()
610 .build()
611 .unwrap();
612
613 let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap();
614
615 _ = serve(listener, router).into_future();
617 }
618
619 #[crate::test]
620 async fn serving_on_custom_io_type() {
621 struct ReadyListener<T>(Option<T>);
622
623 impl<T> Listener for ReadyListener<T>
624 where
625 T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
626 {
627 type Io = T;
628 type Addr = ();
629
630 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
631 match self.0.take() {
632 Some(server) => (server, ()),
633 None => std::future::pending().await,
634 }
635 }
636
637 fn local_addr(&self) -> io::Result<Self::Addr> {
638 Ok(())
639 }
640 }
641
642 let (client, server) = io::duplex(1024);
643 let listener = ReadyListener(Some(server));
644
645 let app = Router::new().route("/", get(|| async { "Hello, World!" }));
646
647 tokio::spawn(serve(listener, app).into_future());
648
649 let stream = TokioIo::new(client);
650 let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
651 tokio::spawn(conn);
652
653 let request = Request::builder().body(Body::empty()).unwrap();
654
655 let response = sender.send_request(request).await.unwrap();
656 assert_eq!(response.status(), StatusCode::OK);
657
658 let body = Body::new(response.into_body());
659 let body = to_bytes(body, usize::MAX).await.unwrap();
660 let body = String::from_utf8(body.to_vec()).unwrap();
661 assert_eq!(body, "Hello, World!");
662 }
663}