axum/routing/
route.rs

1use crate::{
2    body::{Body, HttpBody},
3    response::Response,
4};
5use axum_core::{extract::Request, response::IntoResponse};
6use bytes::Bytes;
7use http::{
8    header::{self, CONTENT_LENGTH},
9    HeaderMap, HeaderValue, Method,
10};
11use pin_project_lite::pin_project;
12use std::{
13    convert::Infallible,
14    fmt,
15    future::Future,
16    pin::Pin,
17    task::{ready, Context, Poll},
18};
19use tower::{
20    util::{BoxCloneSyncService, MapErrLayer, MapResponseLayer, Oneshot},
21    ServiceExt,
22};
23use tower_layer::Layer;
24use tower_service::Service;
25
26/// How routes are stored inside a [`Router`](super::Router).
27///
28/// You normally shouldn't need to care about this type. It's used in
29/// [`Router::layer`](super::Router::layer).
30pub struct Route<E = Infallible>(BoxCloneSyncService<Request, Response, E>);
31
32impl<E> Route<E> {
33    pub(crate) fn new<T>(svc: T) -> Self
34    where
35        T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
36        T::Response: IntoResponse + 'static,
37        T::Future: Send + 'static,
38    {
39        Self(BoxCloneSyncService::new(
40            svc.map_response(IntoResponse::into_response),
41        ))
42    }
43
44    /// Variant of [`Route::call`] that takes ownership of the route to avoid cloning.
45    pub(crate) fn call_owned(self, req: Request<Body>) -> RouteFuture<E> {
46        let req = req.map(Body::new);
47        self.oneshot_inner_owned(req).not_top_level()
48    }
49
50    pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture<E> {
51        let method = req.method().clone();
52        RouteFuture::new(method, self.0.clone().oneshot(req))
53    }
54
55    /// Variant of [`Route::oneshot_inner`] that takes ownership of the route to avoid cloning.
56    pub(crate) fn oneshot_inner_owned(self, req: Request) -> RouteFuture<E> {
57        let method = req.method().clone();
58        RouteFuture::new(method, self.0.oneshot(req))
59    }
60
61    pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
62    where
63        L: Layer<Route<E>> + Clone + Send + 'static,
64        L::Service: Service<Request> + Clone + Send + Sync + 'static,
65        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
66        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
67        <L::Service as Service<Request>>::Future: Send + 'static,
68        NewError: 'static,
69    {
70        let layer = (
71            MapErrLayer::new(Into::into),
72            MapResponseLayer::new(IntoResponse::into_response),
73            layer,
74        );
75
76        Route::new(layer.layer(self))
77    }
78}
79
80impl<E> Clone for Route<E> {
81    #[track_caller]
82    fn clone(&self) -> Self {
83        Self(self.0.clone())
84    }
85}
86
87impl<E> fmt::Debug for Route<E> {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.debug_struct("Route").finish()
90    }
91}
92
93impl<B, E> Service<Request<B>> for Route<E>
94where
95    B: HttpBody<Data = bytes::Bytes> + Send + 'static,
96    B::Error: Into<axum_core::BoxError>,
97{
98    type Response = Response;
99    type Error = E;
100    type Future = RouteFuture<E>;
101
102    #[inline]
103    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104        Poll::Ready(Ok(()))
105    }
106
107    #[inline]
108    fn call(&mut self, req: Request<B>) -> Self::Future {
109        self.oneshot_inner(req.map(Body::new)).not_top_level()
110    }
111}
112
113pin_project! {
114    /// Response future for [`Route`].
115    pub struct RouteFuture<E> {
116        #[pin]
117        inner: Oneshot<BoxCloneSyncService<Request, Response, E>, Request>,
118        method: Method,
119        allow_header: Option<Bytes>,
120        top_level: bool,
121    }
122}
123
124impl<E> RouteFuture<E> {
125    fn new(
126        method: Method,
127        inner: Oneshot<BoxCloneSyncService<Request, Response, E>, Request>,
128    ) -> Self {
129        Self {
130            inner,
131            method,
132            allow_header: None,
133            top_level: true,
134        }
135    }
136
137    pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
138        self.allow_header = Some(allow_header);
139        self
140    }
141
142    pub(crate) fn not_top_level(mut self) -> Self {
143        self.top_level = false;
144        self
145    }
146}
147
148impl<E> Future for RouteFuture<E> {
149    type Output = Result<Response, E>;
150
151    #[inline]
152    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153        let this = self.project();
154        let mut res = ready!(this.inner.poll(cx))?;
155
156        if *this.method == Method::CONNECT && res.status().is_success() {
157            // From https://httpwg.org/specs/rfc9110.html#CONNECT:
158            // > A server MUST NOT send any Transfer-Encoding or
159            // > Content-Length header fields in a 2xx (Successful)
160            // > response to CONNECT.
161            if res.headers().contains_key(&CONTENT_LENGTH)
162                || res.headers().contains_key(&header::TRANSFER_ENCODING)
163                || res.size_hint().lower() != 0
164            {
165                error!("response to CONNECT with nonempty body");
166                res = res.map(|_| Body::empty());
167            }
168        } else if *this.top_level {
169            set_allow_header(res.headers_mut(), this.allow_header);
170
171            // make sure to set content-length before removing the body
172            set_content_length(res.size_hint(), res.headers_mut());
173
174            if *this.method == Method::HEAD {
175                *res.body_mut() = Body::empty();
176            }
177        }
178
179        Poll::Ready(Ok(res))
180    }
181}
182
183fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
184    match allow_header.take() {
185        Some(allow_header) if !headers.contains_key(header::ALLOW) => {
186            headers.insert(
187                header::ALLOW,
188                HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
189            );
190        }
191        _ => {}
192    }
193}
194
195fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
196    if headers.contains_key(CONTENT_LENGTH) {
197        return;
198    }
199
200    if let Some(size) = size_hint.exact() {
201        let header_value = if size == 0 {
202            #[allow(clippy::declare_interior_mutable_const)]
203            const ZERO: HeaderValue = HeaderValue::from_static("0");
204
205            ZERO
206        } else {
207            let mut buffer = itoa::Buffer::new();
208            HeaderValue::from_str(buffer.format(size)).unwrap()
209        };
210
211        headers.insert(CONTENT_LENGTH, header_value);
212    }
213}
214
215pin_project! {
216    /// A [`RouteFuture`] that always yields a [`Response`].
217    pub struct InfallibleRouteFuture {
218        #[pin]
219        future: RouteFuture<Infallible>,
220    }
221}
222
223impl InfallibleRouteFuture {
224    pub(crate) fn new(future: RouteFuture<Infallible>) -> Self {
225        Self { future }
226    }
227}
228
229impl Future for InfallibleRouteFuture {
230    type Output = Response;
231
232    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233        match futures_util::ready!(self.project().future.poll(cx)) {
234            Ok(response) => Poll::Ready(response),
235            Err(err) => match err {},
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn traits() {
246        use crate::test_helpers::*;
247        assert_send::<Route<()>>();
248    }
249}