axum/middleware/
from_extractor.rs

1use crate::{
2    extract::FromRequestParts,
3    response::{IntoResponse, Response},
4};
5use futures_util::{future::BoxFuture, ready};
6use http::Request;
7use pin_project_lite::pin_project;
8use std::{
9    fmt,
10    future::Future,
11    marker::PhantomData,
12    pin::Pin,
13    task::{Context, Poll},
14};
15use tower_layer::Layer;
16use tower_service::Service;
17
18/// Create a middleware from an extractor.
19///
20/// If the extractor succeeds the value will be discarded and the inner service
21/// will be called. If the extractor fails the rejection will be returned and
22/// the inner service will _not_ be called.
23///
24/// This can be used to perform validation of requests if the validation doesn't
25/// produce any useful output, and run the extractor for several handlers
26/// without repeating it in the function signature.
27///
28/// Note that if the extractor consumes the request body, as `String` or
29/// [`Bytes`] does, an empty body will be left in its place. Thus won't be
30/// accessible to subsequent extractors or handlers.
31///
32/// # Example
33///
34/// ```rust
35/// use axum::{
36///     extract::FromRequestParts,
37///     middleware::from_extractor,
38///     routing::{get, post},
39///     Router,
40///     http::{header, StatusCode, request::Parts},
41/// };
42///
43/// // An extractor that performs authorization.
44/// struct RequireAuth;
45///
46/// impl<S> FromRequestParts<S> for RequireAuth
47/// where
48///     S: Send + Sync,
49/// {
50///     type Rejection = StatusCode;
51///
52///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
53///         let auth_header = parts
54///             .headers
55///             .get(header::AUTHORIZATION)
56///             .and_then(|value| value.to_str().ok());
57///
58///         match auth_header {
59///             Some(auth_header) if token_is_valid(auth_header) => {
60///                 Ok(Self)
61///             }
62///             _ => Err(StatusCode::UNAUTHORIZED),
63///         }
64///     }
65/// }
66///
67/// fn token_is_valid(token: &str) -> bool {
68///     // ...
69///     # false
70/// }
71///
72/// async fn handler() {
73///     // If we get here the request has been authorized
74/// }
75///
76/// async fn other_handler() {
77///     // If we get here the request has been authorized
78/// }
79///
80/// let app = Router::new()
81///     .route("/", get(handler))
82///     .route("/foo", post(other_handler))
83///     // The extractor will run before all routes
84///     .route_layer(from_extractor::<RequireAuth>());
85/// # let _: Router = app;
86/// ```
87///
88/// [`Bytes`]: bytes::Bytes
89pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
90    from_extractor_with_state(())
91}
92
93/// Create a middleware from an extractor with the given state.
94///
95/// See [`State`](crate::extract::State) for more details about accessing state.
96pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
97    FromExtractorLayer {
98        state,
99        _marker: PhantomData,
100    }
101}
102
103/// [`Layer`] that applies [`FromExtractor`] that runs an extractor and
104/// discards the value.
105///
106/// See [`from_extractor`] for more details.
107///
108/// [`Layer`]: tower::Layer
109#[must_use]
110pub struct FromExtractorLayer<E, S> {
111    state: S,
112    _marker: PhantomData<fn() -> E>,
113}
114
115impl<E, S> Clone for FromExtractorLayer<E, S>
116where
117    S: Clone,
118{
119    fn clone(&self) -> Self {
120        Self {
121            state: self.state.clone(),
122            _marker: PhantomData,
123        }
124    }
125}
126
127impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
128where
129    S: fmt::Debug,
130{
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_struct("FromExtractorLayer")
133            .field("state", &self.state)
134            .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
135            .finish()
136    }
137}
138
139impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
140where
141    S: Clone,
142{
143    type Service = FromExtractor<T, E, S>;
144
145    fn layer(&self, inner: T) -> Self::Service {
146        FromExtractor {
147            inner,
148            state: self.state.clone(),
149            _extractor: PhantomData,
150        }
151    }
152}
153
154/// Middleware that runs an extractor and discards the value.
155///
156/// See [`from_extractor`] for more details.
157pub struct FromExtractor<T, E, S> {
158    inner: T,
159    state: S,
160    _extractor: PhantomData<fn() -> E>,
161}
162
163#[test]
164fn traits() {
165    use crate::test_helpers::*;
166    assert_send::<FromExtractor<(), NotSendSync, ()>>();
167    assert_sync::<FromExtractor<(), NotSendSync, ()>>();
168}
169
170impl<T, E, S> Clone for FromExtractor<T, E, S>
171where
172    T: Clone,
173    S: Clone,
174{
175    fn clone(&self) -> Self {
176        Self {
177            inner: self.inner.clone(),
178            state: self.state.clone(),
179            _extractor: PhantomData,
180        }
181    }
182}
183
184impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
185where
186    T: fmt::Debug,
187    S: fmt::Debug,
188{
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        f.debug_struct("FromExtractor")
191            .field("inner", &self.inner)
192            .field("state", &self.state)
193            .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
194            .finish()
195    }
196}
197
198impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
199where
200    E: FromRequestParts<S> + 'static,
201    B: Send + 'static,
202    T: Service<Request<B>> + Clone,
203    T::Response: IntoResponse,
204    S: Clone + Send + Sync + 'static,
205{
206    type Response = Response;
207    type Error = T::Error;
208    type Future = ResponseFuture<B, T, E, S>;
209
210    #[inline]
211    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212        self.inner.poll_ready(cx)
213    }
214
215    fn call(&mut self, req: Request<B>) -> Self::Future {
216        let state = self.state.clone();
217        let extract_future = Box::pin(async move {
218            let (mut parts, body) = req.into_parts();
219            let extracted = E::from_request_parts(&mut parts, &state).await;
220            let req = Request::from_parts(parts, body);
221            (req, extracted)
222        });
223
224        ResponseFuture {
225            state: State::Extracting {
226                future: extract_future,
227            },
228            svc: Some(self.inner.clone()),
229        }
230    }
231}
232
233pin_project! {
234    /// Response future for [`FromExtractor`].
235    #[allow(missing_debug_implementations)]
236    pub struct ResponseFuture<B, T, E, S>
237    where
238        E: FromRequestParts<S>,
239        T: Service<Request<B>>,
240    {
241        #[pin]
242        state: State<B, T, E, S>,
243        svc: Option<T>,
244    }
245}
246
247pin_project! {
248    #[project = StateProj]
249    enum State<B, T, E, S>
250    where
251        E: FromRequestParts<S>,
252        T: Service<Request<B>>,
253    {
254        Extracting {
255            future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
256        },
257        Call { #[pin] future: T::Future },
258    }
259}
260
261impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
262where
263    E: FromRequestParts<S>,
264    T: Service<Request<B>>,
265    T::Response: IntoResponse,
266{
267    type Output = Result<Response, T::Error>;
268
269    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
270        loop {
271            let mut this = self.as_mut().project();
272
273            let new_state = match this.state.as_mut().project() {
274                StateProj::Extracting { future } => {
275                    let (req, extracted) = ready!(future.as_mut().poll(cx));
276
277                    match extracted {
278                        Ok(_) => {
279                            let mut svc = this.svc.take().expect("future polled after completion");
280                            let future = svc.call(req);
281                            State::Call { future }
282                        }
283                        Err(err) => {
284                            let res = err.into_response();
285                            return Poll::Ready(Ok(res));
286                        }
287                    }
288                }
289                StateProj::Call { future } => {
290                    return future
291                        .poll(cx)
292                        .map(|result| result.map(IntoResponse::into_response));
293                }
294            };
295
296            this.state.set(new_state);
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::{handler::Handler, routing::get, test_helpers::*, Router};
305    use axum_core::extract::FromRef;
306    use http::{header, request::Parts, StatusCode};
307    use tower_http::limit::RequestBodyLimitLayer;
308
309    #[crate::test]
310    async fn test_from_extractor() {
311        #[derive(Clone)]
312        struct Secret(&'static str);
313
314        struct RequireAuth;
315
316        impl<S> FromRequestParts<S> for RequireAuth
317        where
318            S: Send + Sync,
319            Secret: FromRef<S>,
320        {
321            type Rejection = StatusCode;
322
323            async fn from_request_parts(
324                parts: &mut Parts,
325                state: &S,
326            ) -> Result<Self, Self::Rejection> {
327                let Secret(secret) = Secret::from_ref(state);
328                if let Some(auth) = parts
329                    .headers
330                    .get(header::AUTHORIZATION)
331                    .and_then(|v| v.to_str().ok())
332                {
333                    if auth == secret {
334                        return Ok(Self);
335                    }
336                }
337
338                Err(StatusCode::UNAUTHORIZED)
339            }
340        }
341
342        async fn handler() {}
343
344        let state = Secret("secret");
345        let app = Router::new().route(
346            "/",
347            get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
348        );
349
350        let client = TestClient::new(app);
351
352        let res = client.get("/").await;
353        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
354
355        let res = client
356            .get("/")
357            .header(http::header::AUTHORIZATION, "secret")
358            .await;
359        assert_eq!(res.status(), StatusCode::OK);
360    }
361
362    // just needs to compile
363    #[allow(dead_code)]
364    fn works_with_request_body_limit() {
365        struct MyExtractor;
366
367        impl<S> FromRequestParts<S> for MyExtractor
368        where
369            S: Send + Sync,
370        {
371            type Rejection = std::convert::Infallible;
372
373            async fn from_request_parts(
374                _parts: &mut Parts,
375                _state: &S,
376            ) -> Result<Self, Self::Rejection> {
377                unimplemented!()
378            }
379        }
380
381        let _: Router = Router::new()
382            .layer(from_extractor::<MyExtractor>())
383            .layer(RequestBodyLimitLayer::new(1));
384    }
385}