axum/extract/
query.rs

1use super::{rejection::*, FromRequestParts};
2use http::{request::Parts, Uri};
3use serde::de::DeserializeOwned;
4
5/// Extractor that deserializes query strings into some type.
6///
7/// `T` is expected to implement [`serde::Deserialize`].
8///
9/// # Examples
10///
11/// ```rust,no_run
12/// use axum::{
13///     extract::Query,
14///     routing::get,
15///     Router,
16/// };
17/// use serde::Deserialize;
18///
19/// #[derive(Deserialize)]
20/// struct Pagination {
21///     page: usize,
22///     per_page: usize,
23/// }
24///
25/// // This will parse query strings like `?page=2&per_page=30` into `Pagination`
26/// // structs.
27/// async fn list_things(pagination: Query<Pagination>) {
28///     let pagination: Pagination = pagination.0;
29///
30///     // ...
31/// }
32///
33/// let app = Router::new().route("/list_things", get(list_things));
34/// # let _: Router = app;
35/// ```
36///
37/// If the query string cannot be parsed it will reject the request with a `400
38/// Bad Request` response.
39///
40/// For handling values being empty vs missing see the [query-params-with-empty-strings][example]
41/// example.
42///
43/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
44///
45/// For handling multiple values for the same query parameter, in a `?foo=1&foo=2&foo=3`
46/// fashion, use [`axum_extra::extract::Query`] instead.
47///
48/// [`axum_extra::extract::Query`]: https://docs.rs/axum-extra/latest/axum_extra/extract/struct.Query.html
49#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
50#[derive(Debug, Clone, Copy, Default)]
51pub struct Query<T>(pub T);
52
53impl<T, S> FromRequestParts<S> for Query<T>
54where
55    T: DeserializeOwned,
56    S: Send + Sync,
57{
58    type Rejection = QueryRejection;
59
60    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61        Self::try_from_uri(&parts.uri)
62    }
63}
64
65impl<T> Query<T>
66where
67    T: DeserializeOwned,
68{
69    /// Attempts to construct a [`Query`] from a reference to a [`Uri`].
70    ///
71    /// # Example
72    /// ```
73    /// use axum::extract::Query;
74    /// use http::Uri;
75    /// use serde::Deserialize;
76    ///
77    /// #[derive(Deserialize)]
78    /// struct ExampleParams {
79    ///     foo: String,
80    ///     bar: u32,
81    /// }
82    ///
83    /// let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
84    /// let result: Query<ExampleParams> = Query::try_from_uri(&uri).unwrap();
85    /// assert_eq!(result.foo, String::from("hello"));
86    /// assert_eq!(result.bar, 42);
87    /// ```
88    pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
89        let query = value.query().unwrap_or_default();
90        let deserializer =
91            serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
92        let params = serde_path_to_error::deserialize(deserializer)
93            .map_err(FailedToDeserializeQueryString::from_err)?;
94        Ok(Query(params))
95    }
96}
97
98axum_core::__impl_deref!(Query);
99
100#[cfg(test)]
101mod tests {
102    use crate::{routing::get, test_helpers::TestClient, Router};
103
104    use super::*;
105    use axum_core::{body::Body, extract::FromRequest};
106    use http::{Request, StatusCode};
107    use serde::Deserialize;
108    use std::fmt::Debug;
109
110    async fn check<T>(uri: impl AsRef<str>, value: T)
111    where
112        T: DeserializeOwned + PartialEq + Debug,
113    {
114        let req = Request::builder()
115            .uri(uri.as_ref())
116            .body(Body::empty())
117            .unwrap();
118        assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
119    }
120
121    #[crate::test]
122    async fn test_query() {
123        #[derive(Debug, PartialEq, Deserialize)]
124        struct Pagination {
125            size: Option<u64>,
126            page: Option<u64>,
127        }
128
129        check(
130            "http://example.com/test",
131            Pagination {
132                size: None,
133                page: None,
134            },
135        )
136        .await;
137
138        check(
139            "http://example.com/test?size=10",
140            Pagination {
141                size: Some(10),
142                page: None,
143            },
144        )
145        .await;
146
147        check(
148            "http://example.com/test?size=10&page=20",
149            Pagination {
150                size: Some(10),
151                page: Some(20),
152            },
153        )
154        .await;
155    }
156
157    #[crate::test]
158    async fn correct_rejection_status_code() {
159        #[derive(Deserialize)]
160        #[allow(dead_code)]
161        struct Params {
162            n: i32,
163        }
164
165        async fn handler(_: Query<Params>) {}
166
167        let app = Router::new().route("/", get(handler));
168        let client = TestClient::new(app);
169
170        let res = client.get("/?n=hi").await;
171        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
172        assert_eq!(
173            res.text().await,
174            "Failed to deserialize query string: n: invalid digit found in string"
175        );
176    }
177
178    #[test]
179    fn test_try_from_uri() {
180        #[derive(Deserialize)]
181        struct TestQueryParams {
182            foo: String,
183            bar: u32,
184        }
185        let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
186        let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
187        assert_eq!(result.foo, String::from("hello"));
188        assert_eq!(result.bar, 42);
189    }
190
191    #[test]
192    fn test_try_from_uri_with_invalid_query() {
193        #[derive(Deserialize)]
194        struct TestQueryParams {
195            _foo: String,
196            _bar: u32,
197        }
198        let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
199            .parse()
200            .unwrap();
201        let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
202
203        assert!(result.is_err());
204    }
205}