axum/
extension.rs

1use crate::{extract::rejection::*, response::IntoResponseParts};
2use axum_core::{
3    extract::FromRequestParts,
4    response::{IntoResponse, Response, ResponseParts},
5};
6use http::{request::Parts, Request};
7use std::{
8    convert::Infallible,
9    task::{Context, Poll},
10};
11use tower_service::Service;
12
13/// Extractor and response for extensions.
14///
15/// # As extractor
16///
17/// This is commonly used to share state across handlers.
18///
19/// ```rust,no_run
20/// use axum::{
21///     Router,
22///     Extension,
23///     routing::get,
24/// };
25/// use std::sync::Arc;
26///
27/// // Some shared state used throughout our application
28/// struct State {
29///     // ...
30/// }
31///
32/// async fn handler(state: Extension<Arc<State>>) {
33///     // ...
34/// }
35///
36/// let state = Arc::new(State { /* ... */ });
37///
38/// let app = Router::new().route("/", get(handler))
39///     // Add middleware that inserts the state into all incoming request's
40///     // extensions.
41///     .layer(Extension(state));
42/// # let _: Router = app;
43/// ```
44///
45/// If the extension is missing it will reject the request with a `500 Internal
46/// Server Error` response.
47///
48/// # As response
49///
50/// Response extensions can be used to share state with middleware.
51///
52/// ```rust
53/// use axum::{
54///     Extension,
55///     response::IntoResponse,
56/// };
57///
58/// async fn handler() -> (Extension<Foo>, &'static str) {
59///     (
60///         Extension(Foo("foo")),
61///         "Hello, World!"
62///     )
63/// }
64///
65/// #[derive(Clone)]
66/// struct Foo(&'static str);
67/// ```
68#[derive(Debug, Clone, Copy, Default)]
69#[must_use]
70pub struct Extension<T>(pub T);
71
72impl<T, S> FromRequestParts<S> for Extension<T>
73where
74    T: Clone + Send + Sync + 'static,
75    S: Send + Sync,
76{
77    type Rejection = ExtensionRejection;
78
79    async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
80        let value = req
81            .extensions
82            .get::<T>()
83            .ok_or_else(|| {
84                MissingExtension::from_err(format!(
85                    "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
86                    std::any::type_name::<T>()
87                ))
88            }).cloned()?;
89
90        Ok(Extension(value))
91    }
92}
93
94axum_core::__impl_deref!(Extension);
95
96impl<T> IntoResponseParts for Extension<T>
97where
98    T: Clone + Send + Sync + 'static,
99{
100    type Error = Infallible;
101
102    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
103        res.extensions_mut().insert(self.0);
104        Ok(res)
105    }
106}
107
108impl<T> IntoResponse for Extension<T>
109where
110    T: Clone + Send + Sync + 'static,
111{
112    fn into_response(self) -> Response {
113        let mut res = ().into_response();
114        res.extensions_mut().insert(self.0);
115        res
116    }
117}
118
119impl<S, T> tower_layer::Layer<S> for Extension<T>
120where
121    T: Clone + Send + Sync + 'static,
122{
123    type Service = AddExtension<S, T>;
124
125    fn layer(&self, inner: S) -> Self::Service {
126        AddExtension {
127            inner,
128            value: self.0.clone(),
129        }
130    }
131}
132
133/// Middleware for adding some shareable value to [request extensions].
134///
135/// See [Passing state from middleware to handlers](index.html#passing-state-from-middleware-to-handlers)
136/// for more details.
137///
138/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
139#[derive(Clone, Copy, Debug)]
140pub struct AddExtension<S, T> {
141    pub(crate) inner: S,
142    pub(crate) value: T,
143}
144
145impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T>
146where
147    S: Service<Request<ResBody>>,
148    T: Clone + Send + Sync + 'static,
149{
150    type Response = S::Response;
151    type Error = S::Error;
152    type Future = S::Future;
153
154    #[inline]
155    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        self.inner.poll_ready(cx)
157    }
158
159    fn call(&mut self, mut req: Request<ResBody>) -> Self::Future {
160        req.extensions_mut().insert(self.value.clone());
161        self.inner.call(req)
162    }
163}