axum/extract/state.rs
1use axum_core::extract::{FromRef, FromRequestParts};
2use http::request::Parts;
3use std::{
4 convert::Infallible,
5 ops::{Deref, DerefMut},
6};
7
8/// Extractor for state.
9///
10/// See ["Accessing state in middleware"][state-from-middleware] for how to
11/// access state in middleware.
12///
13/// State is global and used in every request a router with state receives.
14/// For accessing data derived from requests, such as authorization data, see [`Extension`].
15///
16/// [state-from-middleware]: crate::middleware#accessing-state-in-middleware
17/// [`Extension`]: crate::Extension
18///
19/// # With `Router`
20///
21/// ```
22/// use axum::{Router, routing::get, extract::State};
23///
24/// // the application state
25/// //
26/// // here you can put configuration, database connection pools, or whatever
27/// // state you need
28/// #[derive(Clone)]
29/// struct AppState {}
30///
31/// let state = AppState {};
32///
33/// // create a `Router` that holds our state
34/// let app = Router::new()
35/// .route("/", get(handler))
36/// // provide the state so the router can access it
37/// .with_state(state);
38///
39/// async fn handler(
40/// // access the state via the `State` extractor
41/// // extracting a state of the wrong type results in a compile error
42/// State(state): State<AppState>,
43/// ) {
44/// // use `state`...
45/// }
46/// # let _: axum::Router = app;
47/// ```
48///
49/// Note that `State` is an extractor, so be sure to put it before any body
50/// extractors, see ["the order of extractors"][order-of-extractors].
51///
52/// [order-of-extractors]: crate::extract#the-order-of-extractors
53///
54/// ## Combining stateful routers
55///
56/// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`]
57/// When combining [`Router`]s with one of these methods, the [`Router`]s must have
58/// the same state type. Generally, this can be inferred automatically:
59///
60/// ```
61/// use axum::{Router, routing::get, extract::State};
62///
63/// #[derive(Clone)]
64/// struct AppState {}
65///
66/// let state = AppState {};
67///
68/// // create a `Router` that will be nested within another
69/// let api = Router::new()
70/// .route("/posts", get(posts_handler));
71///
72/// let app = Router::new()
73/// .nest("/api", api)
74/// .with_state(state);
75///
76/// async fn posts_handler(State(state): State<AppState>) {
77/// // use `state`...
78/// }
79/// # let _: axum::Router = app;
80/// ```
81///
82/// However, if you are composing [`Router`]s that are defined in separate scopes,
83/// you may need to annotate the [`State`] type explicitly:
84///
85/// ```
86/// use axum::{Router, routing::get, extract::State};
87///
88/// #[derive(Clone)]
89/// struct AppState {}
90///
91/// fn make_app() -> Router {
92/// let state = AppState {};
93///
94/// Router::new()
95/// .nest("/api", make_api())
96/// .with_state(state) // the outer Router's state is inferred
97/// }
98///
99/// // the inner Router must specify its state type to compose with the
100/// // outer router
101/// fn make_api() -> Router<AppState> {
102/// Router::new()
103/// .route("/posts", get(posts_handler))
104/// }
105///
106/// async fn posts_handler(State(state): State<AppState>) {
107/// // use `state`...
108/// }
109/// # let _: axum::Router = make_app();
110/// ```
111///
112/// In short, a [`Router`]'s generic state type defaults to `()`
113/// (no state) unless [`Router::with_state`] is called or the value
114/// of the generic type is given explicitly.
115///
116/// [`Router`]: crate::Router
117/// [`Router::merge`]: crate::Router::merge
118/// [`Router::nest`]: crate::Router::nest
119/// [`Router::with_state`]: crate::Router::with_state
120///
121/// # With `MethodRouter`
122///
123/// ```
124/// use axum::{routing::get, extract::State};
125///
126/// #[derive(Clone)]
127/// struct AppState {}
128///
129/// let state = AppState {};
130///
131/// let method_router_with_state = get(handler)
132/// // provide the state so the handler can access it
133/// .with_state(state);
134/// # let _: axum::routing::MethodRouter = method_router_with_state;
135///
136/// async fn handler(State(state): State<AppState>) {
137/// // use `state`...
138/// }
139/// ```
140///
141/// # With `Handler`
142///
143/// ```
144/// use axum::{routing::get, handler::Handler, extract::State};
145///
146/// #[derive(Clone)]
147/// struct AppState {}
148///
149/// let state = AppState {};
150///
151/// async fn handler(State(state): State<AppState>) {
152/// // use `state`...
153/// }
154///
155/// // provide the state so the handler can access it
156/// let handler_with_state = handler.with_state(state);
157///
158/// # async {
159/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
160/// axum::serve(listener, handler_with_state.into_make_service()).await.unwrap();
161/// # };
162/// ```
163///
164/// # Substates
165///
166/// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates":
167///
168/// ```
169/// use axum::{Router, routing::get, extract::{State, FromRef}};
170///
171/// // the application state
172/// #[derive(Clone)]
173/// struct AppState {
174/// // that holds some api specific state
175/// api_state: ApiState,
176/// }
177///
178/// // the api specific state
179/// #[derive(Clone)]
180/// struct ApiState {}
181///
182/// // support converting an `AppState` in an `ApiState`
183/// impl FromRef<AppState> for ApiState {
184/// fn from_ref(app_state: &AppState) -> ApiState {
185/// app_state.api_state.clone()
186/// }
187/// }
188///
189/// let state = AppState {
190/// api_state: ApiState {},
191/// };
192///
193/// let app = Router::new()
194/// .route("/", get(handler))
195/// .route("/api/users", get(api_users))
196/// .with_state(state);
197///
198/// async fn api_users(
199/// // access the api specific state
200/// State(api_state): State<ApiState>,
201/// ) {
202/// }
203///
204/// async fn handler(
205/// // we can still access to top level state
206/// State(state): State<AppState>,
207/// ) {
208/// }
209/// # let _: axum::Router = app;
210/// ```
211///
212/// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`.
213///
214/// # For library authors
215///
216/// If you're writing a library that has an extractor that needs state, this is the recommended way
217/// to do it:
218///
219/// ```rust
220/// use axum_core::extract::{FromRequestParts, FromRef};
221/// use http::request::Parts;
222/// use std::convert::Infallible;
223///
224/// // the extractor your library provides
225/// struct MyLibraryExtractor;
226///
227/// impl<S> FromRequestParts<S> for MyLibraryExtractor
228/// where
229/// // keep `S` generic but require that it can produce a `MyLibraryState`
230/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
231/// MyLibraryState: FromRef<S>,
232/// S: Send + Sync,
233/// {
234/// type Rejection = Infallible;
235///
236/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
237/// // get a `MyLibraryState` from a reference to the state
238/// let state = MyLibraryState::from_ref(state);
239///
240/// // ...
241/// # todo!()
242/// }
243/// }
244///
245/// // the state your library needs
246/// struct MyLibraryState {
247/// // ...
248/// }
249/// ```
250///
251/// # Shared mutable state
252///
253/// [As state is global within a `Router`][global] you can't directly get a mutable reference to
254/// the state.
255///
256/// The most basic solution is to use an `Arc<Mutex<_>>`. Which kind of mutex you need depends on
257/// your use case. See [the tokio docs] for more details.
258///
259/// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send`
260/// futures which are incompatible with axum. If you need to hold a mutex across `.await` points,
261/// consider using a `tokio::sync::Mutex` instead.
262///
263/// ## Example
264///
265/// ```
266/// use axum::{Router, routing::get, extract::State};
267/// use std::sync::{Arc, Mutex};
268///
269/// #[derive(Clone)]
270/// struct AppState {
271/// data: Arc<Mutex<String>>,
272/// }
273///
274/// async fn handler(State(state): State<AppState>) {
275/// {
276/// let mut data = state.data.lock().expect("mutex was poisoned");
277/// *data = "updated foo".to_owned();
278/// }
279///
280/// // ...
281/// }
282///
283/// let state = AppState {
284/// data: Arc::new(Mutex::new("foo".to_owned())),
285/// };
286///
287/// let app = Router::new()
288/// .route("/", get(handler))
289/// .with_state(state);
290/// # let _: Router = app;
291/// ```
292///
293/// [global]: crate::Router::with_state
294/// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
295#[derive(Debug, Default, Clone, Copy)]
296pub struct State<S>(pub S);
297
298impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
299where
300 InnerState: FromRef<OuterState>,
301 OuterState: Send + Sync,
302{
303 type Rejection = Infallible;
304
305 async fn from_request_parts(
306 _parts: &mut Parts,
307 state: &OuterState,
308 ) -> Result<Self, Self::Rejection> {
309 let inner_state = InnerState::from_ref(state);
310 Ok(Self(inner_state))
311 }
312}
313
314impl<S> Deref for State<S> {
315 type Target = S;
316
317 fn deref(&self) -> &Self::Target {
318 &self.0
319 }
320}
321
322impl<S> DerefMut for State<S> {
323 fn deref_mut(&mut self) -> &mut Self::Target {
324 &mut self.0
325 }
326}