axum/middleware/
from_fn.rs

1use crate::response::{IntoResponse, Response};
2use axum_core::extract::{FromRequest, FromRequestParts, Request};
3use futures_util::future::BoxFuture;
4use std::{
5    any::type_name,
6    convert::Infallible,
7    fmt,
8    future::Future,
9    marker::PhantomData,
10    pin::Pin,
11    task::{Context, Poll},
12};
13use tower::util::BoxCloneSyncService;
14use tower::ServiceBuilder;
15use tower_layer::Layer;
16use tower_service::Service;
17
18/// Create a middleware from an async function.
19///
20/// `from_fn` requires the function given to
21///
22/// 1. Be an `async fn`.
23/// 2. Take zero or more [`FromRequestParts`] extractors.
24/// 3. Take exactly one [`FromRequest`] extractor as the second to last argument.
25/// 4. Take [`Next`](Next) as the last argument.
26/// 5. Return something that implements [`IntoResponse`].
27///
28/// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`].
29///
30/// # Example
31///
32/// ```rust
33/// use axum::{
34///     Router,
35///     http,
36///     routing::get,
37///     response::Response,
38///     middleware::{self, Next},
39///     extract::Request,
40/// };
41///
42/// async fn my_middleware(
43///     request: Request,
44///     next: Next,
45/// ) -> Response {
46///     // do something with `request`...
47///
48///     let response = next.run(request).await;
49///
50///     // do something with `response`...
51///
52///     response
53/// }
54///
55/// let app = Router::new()
56///     .route("/", get(|| async { /* ... */ }))
57///     .layer(middleware::from_fn(my_middleware));
58/// # let app: Router = app;
59/// ```
60///
61/// # Running extractors
62///
63/// ```rust
64/// use axum::{
65///     Router,
66///     extract::Request,
67///     http::{StatusCode, HeaderMap},
68///     middleware::{self, Next},
69///     response::Response,
70///     routing::get,
71/// };
72///
73/// async fn auth(
74///     // run the `HeaderMap` extractor
75///     headers: HeaderMap,
76///     // you can also add more extractors here but the last
77///     // extractor must implement `FromRequest` which
78///     // `Request` does
79///     request: Request,
80///     next: Next,
81/// ) -> Result<Response, StatusCode> {
82///     match get_token(&headers) {
83///         Some(token) if token_is_valid(token) => {
84///             let response = next.run(request).await;
85///             Ok(response)
86///         }
87///         _ => {
88///             Err(StatusCode::UNAUTHORIZED)
89///         }
90///     }
91/// }
92///
93/// fn get_token(headers: &HeaderMap) -> Option<&str> {
94///     // ...
95///     # None
96/// }
97///
98/// fn token_is_valid(token: &str) -> bool {
99///     // ...
100///     # false
101/// }
102///
103/// let app = Router::new()
104///     .route("/", get(|| async { /* ... */ }))
105///     .route_layer(middleware::from_fn(auth));
106/// # let app: Router = app;
107/// ```
108///
109/// [extractors]: crate::extract::FromRequest
110/// [`State`]: crate::extract::State
111pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
112    from_fn_with_state((), f)
113}
114
115/// Create a middleware from an async function with the given state.
116///
117/// For the requirements for the function supplied see [`from_fn`].
118///
119/// See [`State`](crate::extract::State) for more details about accessing state.
120///
121/// # Example
122///
123/// ```rust
124/// use axum::{
125///     Router,
126///     http::StatusCode,
127///     routing::get,
128///     response::{IntoResponse, Response},
129///     middleware::{self, Next},
130///     extract::{Request, State},
131/// };
132///
133/// #[derive(Clone)]
134/// struct AppState { /* ... */ }
135///
136/// async fn my_middleware(
137///     State(state): State<AppState>,
138///     // you can add more extractors here but the last
139///     // extractor must implement `FromRequest` which
140///     // `Request` does
141///     request: Request,
142///     next: Next,
143/// ) -> Response {
144///     // do something with `request`...
145///
146///     let response = next.run(request).await;
147///
148///     // do something with `response`...
149///
150///     response
151/// }
152///
153/// let state = AppState { /* ... */ };
154///
155/// let app = Router::new()
156///     .route("/", get(|| async { /* ... */ }))
157///     .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
158///     .with_state(state);
159/// # let _: axum::Router = app;
160/// ```
161pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
162    FromFnLayer {
163        f,
164        state,
165        _extractor: PhantomData,
166    }
167}
168
169/// A [`tower::Layer`] from an async function.
170///
171/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
172///
173/// Created with [`from_fn`] or [`from_fn_with_state`]. See those functions for more details.
174#[must_use]
175pub struct FromFnLayer<F, S, T> {
176    f: F,
177    state: S,
178    _extractor: PhantomData<fn() -> T>,
179}
180
181impl<F, S, T> Clone for FromFnLayer<F, S, T>
182where
183    F: Clone,
184    S: Clone,
185{
186    fn clone(&self) -> Self {
187        Self {
188            f: self.f.clone(),
189            state: self.state.clone(),
190            _extractor: self._extractor,
191        }
192    }
193}
194
195impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
196where
197    F: Clone,
198    S: Clone,
199{
200    type Service = FromFn<F, S, I, T>;
201
202    fn layer(&self, inner: I) -> Self::Service {
203        FromFn {
204            f: self.f.clone(),
205            state: self.state.clone(),
206            inner,
207            _extractor: PhantomData,
208        }
209    }
210}
211
212impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
213where
214    S: fmt::Debug,
215{
216    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217        f.debug_struct("FromFnLayer")
218            // Write out the type name, without quoting it as `&type_name::<F>()` would
219            .field("f", &format_args!("{}", type_name::<F>()))
220            .field("state", &self.state)
221            .finish()
222    }
223}
224
225/// A middleware created from an async function.
226///
227/// Created with [`from_fn`] or [`from_fn_with_state`]. See those functions for more details.
228pub struct FromFn<F, S, I, T> {
229    f: F,
230    inner: I,
231    state: S,
232    _extractor: PhantomData<fn() -> T>,
233}
234
235impl<F, S, I, T> Clone for FromFn<F, S, I, T>
236where
237    F: Clone,
238    I: Clone,
239    S: Clone,
240{
241    fn clone(&self) -> Self {
242        Self {
243            f: self.f.clone(),
244            inner: self.inner.clone(),
245            state: self.state.clone(),
246            _extractor: self._extractor,
247        }
248    }
249}
250
251macro_rules! impl_service {
252    (
253        [$($ty:ident),*], $last:ident
254    ) => {
255        #[allow(non_snake_case, unused_mut)]
256        impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
257        where
258            F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
259            $( $ty: FromRequestParts<S> + Send, )*
260            $last: FromRequest<S> + Send,
261            Fut: Future<Output = Out> + Send + 'static,
262            Out: IntoResponse + 'static,
263            I: Service<Request, Error = Infallible>
264                + Clone
265                + Send
266                + Sync
267                + 'static,
268            I::Response: IntoResponse,
269            I::Future: Send + 'static,
270            S: Clone + Send + Sync + 'static,
271        {
272            type Response = Response;
273            type Error = Infallible;
274            type Future = ResponseFuture;
275
276            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
277                self.inner.poll_ready(cx)
278            }
279
280            fn call(&mut self, req: Request) -> Self::Future {
281                let not_ready_inner = self.inner.clone();
282                let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
283
284                let mut f = self.f.clone();
285                let state = self.state.clone();
286
287                let future = Box::pin(async move {
288                    let (mut parts, body) = req.into_parts();
289
290                    $(
291                        let $ty = match $ty::from_request_parts(&mut parts, &state).await {
292                            Ok(value) => value,
293                            Err(rejection) => return rejection.into_response(),
294                        };
295                    )*
296
297                    let req = Request::from_parts(parts, body);
298
299                    let $last = match $last::from_request(req, &state).await {
300                        Ok(value) => value,
301                        Err(rejection) => return rejection.into_response(),
302                    };
303
304                    let inner = ServiceBuilder::new()
305                        .layer_fn(BoxCloneSyncService::new)
306                        .map_response(IntoResponse::into_response)
307                        .service(ready_inner);
308                    let next = Next { inner };
309
310                    f($($ty,)* $last, next).await.into_response()
311                });
312
313                ResponseFuture {
314                    inner: future
315                }
316            }
317        }
318    };
319}
320
321all_the_tuples!(impl_service);
322
323impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
324where
325    S: fmt::Debug,
326    I: fmt::Debug,
327{
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        f.debug_struct("FromFnLayer")
330            .field("f", &format_args!("{}", type_name::<F>()))
331            .field("inner", &self.inner)
332            .field("state", &self.state)
333            .finish()
334    }
335}
336
337/// The remainder of a middleware stack, including the handler.
338#[derive(Debug, Clone)]
339pub struct Next {
340    inner: BoxCloneSyncService<Request, Response, Infallible>,
341}
342
343impl Next {
344    /// Execute the remaining middleware stack.
345    pub async fn run(mut self, req: Request) -> Response {
346        match self.inner.call(req).await {
347            Ok(res) => res,
348            Err(err) => match err {},
349        }
350    }
351}
352
353impl Service<Request> for Next {
354    type Response = Response;
355    type Error = Infallible;
356    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
357
358    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
359        self.inner.poll_ready(cx)
360    }
361
362    fn call(&mut self, req: Request) -> Self::Future {
363        self.inner.call(req)
364    }
365}
366
367/// Response future for [`FromFn`].
368pub struct ResponseFuture {
369    inner: BoxFuture<'static, Response>,
370}
371
372impl Future for ResponseFuture {
373    type Output = Result<Response, Infallible>;
374
375    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
376        self.inner.as_mut().poll(cx).map(Ok)
377    }
378}
379
380impl fmt::Debug for ResponseFuture {
381    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382        f.debug_struct("ResponseFuture").finish()
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use crate::{body::Body, routing::get, Router};
390    use http::{HeaderMap, StatusCode};
391    use http_body_util::BodyExt;
392    use tower::ServiceExt;
393
394    #[crate::test]
395    async fn basic() {
396        async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
397            req.headers_mut()
398                .insert("x-axum-test", "ok".parse().unwrap());
399
400            next.run(req).await
401        }
402
403        async fn handle(headers: HeaderMap) -> String {
404            headers["x-axum-test"].to_str().unwrap().to_owned()
405        }
406
407        let app = Router::new()
408            .route("/", get(handle))
409            .layer(from_fn(insert_header));
410
411        let res = app
412            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
413            .await
414            .unwrap();
415        assert_eq!(res.status(), StatusCode::OK);
416        let body = res.collect().await.unwrap().to_bytes();
417        assert_eq!(&body[..], b"ok");
418    }
419}