1use crate::{
2 body::{Body, HttpBody},
3 response::Response,
4};
5use axum_core::{extract::Request, response::IntoResponse};
6use bytes::Bytes;
7use http::{
8 header::{self, CONTENT_LENGTH},
9 HeaderMap, HeaderValue, Method,
10};
11use pin_project_lite::pin_project;
12use std::{
13 convert::Infallible,
14 fmt,
15 future::Future,
16 pin::Pin,
17 task::{ready, Context, Poll},
18};
19use tower::{
20 util::{BoxCloneSyncService, MapErrLayer, MapResponseLayer, Oneshot},
21 ServiceExt,
22};
23use tower_layer::Layer;
24use tower_service::Service;
25
26pub struct Route<E = Infallible>(BoxCloneSyncService<Request, Response, E>);
31
32impl<E> Route<E> {
33 pub(crate) fn new<T>(svc: T) -> Self
34 where
35 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
36 T::Response: IntoResponse + 'static,
37 T::Future: Send + 'static,
38 {
39 Self(BoxCloneSyncService::new(
40 svc.map_response(IntoResponse::into_response),
41 ))
42 }
43
44 pub(crate) fn call_owned(self, req: Request<Body>) -> RouteFuture<E> {
46 let req = req.map(Body::new);
47 self.oneshot_inner_owned(req).not_top_level()
48 }
49
50 pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture<E> {
51 let method = req.method().clone();
52 RouteFuture::new(method, self.0.clone().oneshot(req))
53 }
54
55 pub(crate) fn oneshot_inner_owned(self, req: Request) -> RouteFuture<E> {
57 let method = req.method().clone();
58 RouteFuture::new(method, self.0.oneshot(req))
59 }
60
61 pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
62 where
63 L: Layer<Route<E>> + Clone + Send + 'static,
64 L::Service: Service<Request> + Clone + Send + Sync + 'static,
65 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
66 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
67 <L::Service as Service<Request>>::Future: Send + 'static,
68 NewError: 'static,
69 {
70 let layer = (
71 MapErrLayer::new(Into::into),
72 MapResponseLayer::new(IntoResponse::into_response),
73 layer,
74 );
75
76 Route::new(layer.layer(self))
77 }
78}
79
80impl<E> Clone for Route<E> {
81 #[track_caller]
82 fn clone(&self) -> Self {
83 Self(self.0.clone())
84 }
85}
86
87impl<E> fmt::Debug for Route<E> {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("Route").finish()
90 }
91}
92
93impl<B, E> Service<Request<B>> for Route<E>
94where
95 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
96 B::Error: Into<axum_core::BoxError>,
97{
98 type Response = Response;
99 type Error = E;
100 type Future = RouteFuture<E>;
101
102 #[inline]
103 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104 Poll::Ready(Ok(()))
105 }
106
107 #[inline]
108 fn call(&mut self, req: Request<B>) -> Self::Future {
109 self.oneshot_inner(req.map(Body::new)).not_top_level()
110 }
111}
112
113pin_project! {
114 pub struct RouteFuture<E> {
116 #[pin]
117 inner: Oneshot<BoxCloneSyncService<Request, Response, E>, Request>,
118 method: Method,
119 allow_header: Option<Bytes>,
120 top_level: bool,
121 }
122}
123
124impl<E> RouteFuture<E> {
125 fn new(
126 method: Method,
127 inner: Oneshot<BoxCloneSyncService<Request, Response, E>, Request>,
128 ) -> Self {
129 Self {
130 inner,
131 method,
132 allow_header: None,
133 top_level: true,
134 }
135 }
136
137 pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
138 self.allow_header = Some(allow_header);
139 self
140 }
141
142 pub(crate) fn not_top_level(mut self) -> Self {
143 self.top_level = false;
144 self
145 }
146}
147
148impl<E> Future for RouteFuture<E> {
149 type Output = Result<Response, E>;
150
151 #[inline]
152 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153 let this = self.project();
154 let mut res = ready!(this.inner.poll(cx))?;
155
156 if *this.method == Method::CONNECT && res.status().is_success() {
157 if res.headers().contains_key(&CONTENT_LENGTH)
162 || res.headers().contains_key(&header::TRANSFER_ENCODING)
163 || res.size_hint().lower() != 0
164 {
165 error!("response to CONNECT with nonempty body");
166 res = res.map(|_| Body::empty());
167 }
168 } else if *this.top_level {
169 set_allow_header(res.headers_mut(), this.allow_header);
170
171 set_content_length(res.size_hint(), res.headers_mut());
173
174 if *this.method == Method::HEAD {
175 *res.body_mut() = Body::empty();
176 }
177 }
178
179 Poll::Ready(Ok(res))
180 }
181}
182
183fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
184 match allow_header.take() {
185 Some(allow_header) if !headers.contains_key(header::ALLOW) => {
186 headers.insert(
187 header::ALLOW,
188 HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
189 );
190 }
191 _ => {}
192 }
193}
194
195fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
196 if headers.contains_key(CONTENT_LENGTH) {
197 return;
198 }
199
200 if let Some(size) = size_hint.exact() {
201 let header_value = if size == 0 {
202 #[allow(clippy::declare_interior_mutable_const)]
203 const ZERO: HeaderValue = HeaderValue::from_static("0");
204
205 ZERO
206 } else {
207 let mut buffer = itoa::Buffer::new();
208 HeaderValue::from_str(buffer.format(size)).unwrap()
209 };
210
211 headers.insert(CONTENT_LENGTH, header_value);
212 }
213}
214
215pin_project! {
216 pub struct InfallibleRouteFuture {
218 #[pin]
219 future: RouteFuture<Infallible>,
220 }
221}
222
223impl InfallibleRouteFuture {
224 pub(crate) fn new(future: RouteFuture<Infallible>) -> Self {
225 Self { future }
226 }
227}
228
229impl Future for InfallibleRouteFuture {
230 type Output = Response;
231
232 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233 match futures_util::ready!(self.project().future.poll(cx)) {
234 Ok(response) => Poll::Ready(response),
235 Err(err) => match err {},
236 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn traits() {
246 use crate::test_helpers::*;
247 assert_send::<Route<()>>();
248 }
249}