axum/
serve.rs

1//! Serve services.
2
3use std::{
4    convert::Infallible,
5    fmt::Debug,
6    future::{poll_fn, Future, IntoFuture},
7    io,
8    marker::PhantomData,
9    sync::Arc,
10};
11
12use axum_core::{body::Body, extract::Request, response::Response};
13use futures_util::{pin_mut, FutureExt};
14use hyper::body::Incoming;
15use hyper_util::rt::{TokioExecutor, TokioIo};
16#[cfg(any(feature = "http1", feature = "http2"))]
17use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
18use tokio::sync::watch;
19use tower::ServiceExt as _;
20use tower_service::Service;
21
22mod listener;
23
24pub use self::listener::{Listener, ListenerExt, TapIo};
25
26/// Serve the service with the supplied listener.
27///
28/// This method of running a service is intentionally simple and doesn't support any configuration.
29/// Use hyper or hyper-util if you need configuration.
30///
31/// It supports both HTTP/1 as well as HTTP/2.
32///
33/// # Examples
34///
35/// Serving a [`Router`]:
36///
37/// ```
38/// use axum::{Router, routing::get};
39///
40/// # async {
41/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
42///
43/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
44/// axum::serve(listener, router).await.unwrap();
45/// # };
46/// ```
47///
48/// See also [`Router::into_make_service_with_connect_info`].
49///
50/// Serving a [`MethodRouter`]:
51///
52/// ```
53/// use axum::routing::get;
54///
55/// # async {
56/// let router = get(|| async { "Hello, World!" });
57///
58/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
59/// axum::serve(listener, router).await.unwrap();
60/// # };
61/// ```
62///
63/// See also [`MethodRouter::into_make_service_with_connect_info`].
64///
65/// Serving a [`Handler`]:
66///
67/// ```
68/// use axum::handler::HandlerWithoutStateExt;
69///
70/// # async {
71/// async fn handler() -> &'static str {
72///     "Hello, World!"
73/// }
74///
75/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
76/// axum::serve(listener, handler.into_make_service()).await.unwrap();
77/// # };
78/// ```
79///
80/// See also [`HandlerWithoutStateExt::into_make_service_with_connect_info`] and
81/// [`HandlerService::into_make_service_with_connect_info`].
82///
83/// [`Router`]: crate::Router
84/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
85/// [`MethodRouter`]: crate::routing::MethodRouter
86/// [`MethodRouter::into_make_service_with_connect_info`]: crate::routing::MethodRouter::into_make_service_with_connect_info
87/// [`Handler`]: crate::handler::Handler
88/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
89/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
90#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
91pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
92where
93    L: Listener,
94    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
95    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
96    S::Future: Send,
97{
98    Serve {
99        listener,
100        make_service,
101        _marker: PhantomData,
102    }
103}
104
105/// Future returned by [`serve`].
106#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
107#[must_use = "futures must be awaited or polled"]
108pub struct Serve<L, M, S> {
109    listener: L,
110    make_service: M,
111    _marker: PhantomData<S>,
112}
113
114#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
115impl<L, M, S> Serve<L, M, S>
116where
117    L: Listener,
118{
119    /// Prepares a server to handle graceful shutdown when the provided future completes.
120    ///
121    /// # Example
122    ///
123    /// ```
124    /// use axum::{Router, routing::get};
125    ///
126    /// # async {
127    /// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
128    ///
129    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
130    /// axum::serve(listener, router)
131    ///     .with_graceful_shutdown(shutdown_signal())
132    ///     .await
133    ///     .unwrap();
134    /// # };
135    ///
136    /// async fn shutdown_signal() {
137    ///     // ...
138    /// }
139    /// ```
140    pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
141    where
142        F: Future<Output = ()> + Send + 'static,
143    {
144        WithGracefulShutdown {
145            listener: self.listener,
146            make_service: self.make_service,
147            signal,
148            _marker: PhantomData,
149        }
150    }
151
152    /// Returns the local address this server is bound to.
153    pub fn local_addr(&self) -> io::Result<L::Addr> {
154        self.listener.local_addr()
155    }
156}
157
158#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
159impl<L, M, S> Debug for Serve<L, M, S>
160where
161    L: Debug + 'static,
162    M: Debug,
163{
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        let Self {
166            listener,
167            make_service,
168            _marker: _,
169        } = self;
170
171        let mut s = f.debug_struct("Serve");
172        s.field("listener", listener)
173            .field("make_service", make_service);
174
175        s.finish()
176    }
177}
178
179#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
180impl<L, M, S> IntoFuture for Serve<L, M, S>
181where
182    L: Listener,
183    L::Addr: Debug,
184    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
185    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
186    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
187    S::Future: Send,
188{
189    type Output = io::Result<()>;
190    type IntoFuture = private::ServeFuture;
191
192    fn into_future(self) -> Self::IntoFuture {
193        self.with_graceful_shutdown(std::future::pending())
194            .into_future()
195    }
196}
197
198/// Serve future with graceful shutdown enabled.
199#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
200#[must_use = "futures must be awaited or polled"]
201pub struct WithGracefulShutdown<L, M, S, F> {
202    listener: L,
203    make_service: M,
204    signal: F,
205    _marker: PhantomData<S>,
206}
207
208#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
209impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
210where
211    L: Listener,
212{
213    /// Returns the local address this server is bound to.
214    pub fn local_addr(&self) -> io::Result<L::Addr> {
215        self.listener.local_addr()
216    }
217}
218
219#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
220impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
221where
222    L: Debug + 'static,
223    M: Debug,
224    S: Debug,
225    F: Debug,
226{
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        let Self {
229            listener,
230            make_service,
231            signal,
232            _marker: _,
233        } = self;
234
235        f.debug_struct("WithGracefulShutdown")
236            .field("listener", listener)
237            .field("make_service", make_service)
238            .field("signal", signal)
239            .finish()
240    }
241}
242
243#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
244impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
245where
246    L: Listener,
247    L::Addr: Debug,
248    M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
249    for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
250    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
251    S::Future: Send,
252    F: Future<Output = ()> + Send + 'static,
253{
254    type Output = io::Result<()>;
255    type IntoFuture = private::ServeFuture;
256
257    fn into_future(self) -> Self::IntoFuture {
258        let Self {
259            mut listener,
260            mut make_service,
261            signal,
262            _marker: _,
263        } = self;
264
265        private::ServeFuture(Box::pin(async move {
266            let (signal_tx, signal_rx) = watch::channel(());
267            let signal_tx = Arc::new(signal_tx);
268            tokio::spawn(async move {
269                signal.await;
270                trace!("received graceful shutdown signal. Telling tasks to shutdown");
271                drop(signal_rx);
272            });
273
274            let (close_tx, close_rx) = watch::channel(());
275
276            loop {
277                let (io, remote_addr) = tokio::select! {
278                    conn = listener.accept() => conn,
279                    _ = signal_tx.closed() => {
280                        trace!("signal received, not accepting new connections");
281                        break;
282                    }
283                };
284
285                let io = TokioIo::new(io);
286
287                trace!("connection {remote_addr:?} accepted");
288
289                poll_fn(|cx| make_service.poll_ready(cx))
290                    .await
291                    .unwrap_or_else(|err| match err {});
292
293                let tower_service = make_service
294                    .call(IncomingStream {
295                        io: &io,
296                        remote_addr,
297                    })
298                    .await
299                    .unwrap_or_else(|err| match err {})
300                    .map_request(|req: Request<Incoming>| req.map(Body::new));
301
302                let hyper_service = TowerToHyperService::new(tower_service);
303
304                let signal_tx = Arc::clone(&signal_tx);
305
306                let close_rx = close_rx.clone();
307
308                tokio::spawn(async move {
309                    #[allow(unused_mut)]
310                    let mut builder = Builder::new(TokioExecutor::new());
311                    // CONNECT protocol needed for HTTP/2 websockets
312                    #[cfg(feature = "http2")]
313                    builder.http2().enable_connect_protocol();
314                    let conn = builder.serve_connection_with_upgrades(io, hyper_service);
315                    pin_mut!(conn);
316
317                    let signal_closed = signal_tx.closed().fuse();
318                    pin_mut!(signal_closed);
319
320                    loop {
321                        tokio::select! {
322                            result = conn.as_mut() => {
323                                if let Err(_err) = result {
324                                    trace!("failed to serve connection: {_err:#}");
325                                }
326                                break;
327                            }
328                            _ = &mut signal_closed => {
329                                trace!("signal received in task, starting graceful shutdown");
330                                conn.as_mut().graceful_shutdown();
331                            }
332                        }
333                    }
334
335                    drop(close_rx);
336                });
337            }
338
339            drop(close_rx);
340            drop(listener);
341
342            trace!(
343                "waiting for {} task(s) to finish",
344                close_tx.receiver_count()
345            );
346            close_tx.closed().await;
347
348            Ok(())
349        }))
350    }
351}
352
353/// An incoming stream.
354///
355/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
356///
357/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
358#[derive(Debug)]
359pub struct IncomingStream<'a, L>
360where
361    L: Listener,
362{
363    io: &'a TokioIo<L::Io>,
364    remote_addr: L::Addr,
365}
366
367impl<L> IncomingStream<'_, L>
368where
369    L: Listener,
370{
371    /// Get a reference to the inner IO type.
372    pub fn io(&self) -> &L::Io {
373        self.io.inner()
374    }
375
376    /// Returns the remote address that this stream is bound to.
377    pub fn remote_addr(&self) -> &L::Addr {
378        &self.remote_addr
379    }
380}
381
382mod private {
383    use std::{
384        future::Future,
385        io,
386        pin::Pin,
387        task::{Context, Poll},
388    };
389
390    pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
391
392    impl Future for ServeFuture {
393        type Output = io::Result<()>;
394
395        #[inline]
396        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
397            self.0.as_mut().poll(cx)
398        }
399    }
400
401    impl std::fmt::Debug for ServeFuture {
402        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403            f.debug_struct("ServeFuture").finish_non_exhaustive()
404        }
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use std::{
411        future::{pending, IntoFuture as _},
412        net::{IpAddr, Ipv4Addr},
413    };
414
415    use axum_core::{body::Body, extract::Request};
416    use http::StatusCode;
417    use hyper_util::rt::TokioIo;
418    #[cfg(unix)]
419    use tokio::net::UnixListener;
420    use tokio::{
421        io::{self, AsyncRead, AsyncWrite},
422        net::TcpListener,
423    };
424
425    #[cfg(unix)]
426    use super::IncomingStream;
427    use super::{serve, Listener};
428    #[cfg(unix)]
429    use crate::extract::connect_info::Connected;
430    use crate::{
431        body::to_bytes,
432        handler::{Handler, HandlerWithoutStateExt},
433        routing::get,
434        serve::ListenerExt,
435        Router,
436    };
437
438    #[allow(dead_code, unused_must_use)]
439    async fn if_it_compiles_it_works() {
440        #[derive(Clone, Debug)]
441        struct UdsConnectInfo;
442
443        #[cfg(unix)]
444        impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
445            fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
446                Self
447            }
448        }
449
450        let router: Router = Router::new();
451
452        let addr = "0.0.0.0:0";
453
454        let tcp_nodelay_listener = || async {
455            TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
456                if let Err(err) = tcp_stream.set_nodelay(true) {
457                    eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
458                }
459            })
460        };
461
462        // router
463        serve(TcpListener::bind(addr).await.unwrap(), router.clone());
464        serve(tcp_nodelay_listener().await, router.clone())
465            .await
466            .unwrap();
467        #[cfg(unix)]
468        serve(UnixListener::bind("").unwrap(), router.clone());
469
470        serve(
471            TcpListener::bind(addr).await.unwrap(),
472            router.clone().into_make_service(),
473        );
474        serve(
475            tcp_nodelay_listener().await,
476            router.clone().into_make_service(),
477        );
478        #[cfg(unix)]
479        serve(
480            UnixListener::bind("").unwrap(),
481            router.clone().into_make_service(),
482        );
483
484        serve(
485            TcpListener::bind(addr).await.unwrap(),
486            router
487                .clone()
488                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
489        );
490        serve(
491            tcp_nodelay_listener().await,
492            router
493                .clone()
494                .into_make_service_with_connect_info::<std::net::SocketAddr>(),
495        );
496        #[cfg(unix)]
497        serve(
498            UnixListener::bind("").unwrap(),
499            router.into_make_service_with_connect_info::<UdsConnectInfo>(),
500        );
501
502        // method router
503        serve(TcpListener::bind(addr).await.unwrap(), get(handler));
504        serve(tcp_nodelay_listener().await, get(handler));
505        #[cfg(unix)]
506        serve(UnixListener::bind("").unwrap(), get(handler));
507
508        serve(
509            TcpListener::bind(addr).await.unwrap(),
510            get(handler).into_make_service(),
511        );
512        serve(
513            tcp_nodelay_listener().await,
514            get(handler).into_make_service(),
515        );
516        #[cfg(unix)]
517        serve(
518            UnixListener::bind("").unwrap(),
519            get(handler).into_make_service(),
520        );
521
522        serve(
523            TcpListener::bind(addr).await.unwrap(),
524            get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
525        );
526        serve(
527            tcp_nodelay_listener().await,
528            get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
529        );
530        #[cfg(unix)]
531        serve(
532            UnixListener::bind("").unwrap(),
533            get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
534        );
535
536        // handler
537        serve(
538            TcpListener::bind(addr).await.unwrap(),
539            handler.into_service(),
540        );
541        serve(tcp_nodelay_listener().await, handler.into_service());
542        #[cfg(unix)]
543        serve(UnixListener::bind("").unwrap(), handler.into_service());
544
545        serve(
546            TcpListener::bind(addr).await.unwrap(),
547            handler.with_state(()),
548        );
549        serve(tcp_nodelay_listener().await, handler.with_state(()));
550        #[cfg(unix)]
551        serve(UnixListener::bind("").unwrap(), handler.with_state(()));
552
553        serve(
554            TcpListener::bind(addr).await.unwrap(),
555            handler.into_make_service(),
556        );
557        serve(tcp_nodelay_listener().await, handler.into_make_service());
558        #[cfg(unix)]
559        serve(UnixListener::bind("").unwrap(), handler.into_make_service());
560
561        serve(
562            TcpListener::bind(addr).await.unwrap(),
563            handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
564        );
565        serve(
566            tcp_nodelay_listener().await,
567            handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
568        );
569        #[cfg(unix)]
570        serve(
571            UnixListener::bind("").unwrap(),
572            handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
573        );
574    }
575
576    async fn handler() {}
577
578    #[crate::test]
579    async fn test_serve_local_addr() {
580        let router: Router = Router::new();
581        let addr = "0.0.0.0:0";
582
583        let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
584        let address = server.local_addr().unwrap();
585
586        assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
587        assert_ne!(address.port(), 0);
588    }
589
590    #[crate::test]
591    async fn test_with_graceful_shutdown_local_addr() {
592        let router: Router = Router::new();
593        let addr = "0.0.0.0:0";
594
595        let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
596            .with_graceful_shutdown(pending());
597        let address = server.local_addr().unwrap();
598
599        assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
600        assert_ne!(address.port(), 0);
601    }
602
603    #[test]
604    fn into_future_outside_tokio() {
605        let router: Router = Router::new();
606        let addr = "0.0.0.0:0";
607
608        let rt = tokio::runtime::Builder::new_multi_thread()
609            .enable_all()
610            .build()
611            .unwrap();
612
613        let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap();
614
615        // Call Serve::into_future outside of a tokio context. This used to panic.
616        _ = serve(listener, router).into_future();
617    }
618
619    #[crate::test]
620    async fn serving_on_custom_io_type() {
621        struct ReadyListener<T>(Option<T>);
622
623        impl<T> Listener for ReadyListener<T>
624        where
625            T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
626        {
627            type Io = T;
628            type Addr = ();
629
630            async fn accept(&mut self) -> (Self::Io, Self::Addr) {
631                match self.0.take() {
632                    Some(server) => (server, ()),
633                    None => std::future::pending().await,
634                }
635            }
636
637            fn local_addr(&self) -> io::Result<Self::Addr> {
638                Ok(())
639            }
640        }
641
642        let (client, server) = io::duplex(1024);
643        let listener = ReadyListener(Some(server));
644
645        let app = Router::new().route("/", get(|| async { "Hello, World!" }));
646
647        tokio::spawn(serve(listener, app).into_future());
648
649        let stream = TokioIo::new(client);
650        let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
651        tokio::spawn(conn);
652
653        let request = Request::builder().body(Body::empty()).unwrap();
654
655        let response = sender.send_request(request).await.unwrap();
656        assert_eq!(response.status(), StatusCode::OK);
657
658        let body = Body::new(response.into_body());
659        let body = to_bytes(body, usize::MAX).await.unwrap();
660        let body = String::from_utf8(body.to_vec()).unwrap();
661        assert_eq!(body, "Hello, World!");
662    }
663}