1use super::{rejection::*, FromRequestParts};
2use http::{request::Parts, Uri};
3use serde::de::DeserializeOwned;
4
5#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
50#[derive(Debug, Clone, Copy, Default)]
51pub struct Query<T>(pub T);
52
53impl<T, S> FromRequestParts<S> for Query<T>
54where
55 T: DeserializeOwned,
56 S: Send + Sync,
57{
58 type Rejection = QueryRejection;
59
60 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61 Self::try_from_uri(&parts.uri)
62 }
63}
64
65impl<T> Query<T>
66where
67 T: DeserializeOwned,
68{
69 pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
89 let query = value.query().unwrap_or_default();
90 let deserializer =
91 serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
92 let params = serde_path_to_error::deserialize(deserializer)
93 .map_err(FailedToDeserializeQueryString::from_err)?;
94 Ok(Query(params))
95 }
96}
97
98axum_core::__impl_deref!(Query);
99
100#[cfg(test)]
101mod tests {
102 use crate::{routing::get, test_helpers::TestClient, Router};
103
104 use super::*;
105 use axum_core::{body::Body, extract::FromRequest};
106 use http::{Request, StatusCode};
107 use serde::Deserialize;
108 use std::fmt::Debug;
109
110 async fn check<T>(uri: impl AsRef<str>, value: T)
111 where
112 T: DeserializeOwned + PartialEq + Debug,
113 {
114 let req = Request::builder()
115 .uri(uri.as_ref())
116 .body(Body::empty())
117 .unwrap();
118 assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
119 }
120
121 #[crate::test]
122 async fn test_query() {
123 #[derive(Debug, PartialEq, Deserialize)]
124 struct Pagination {
125 size: Option<u64>,
126 page: Option<u64>,
127 }
128
129 check(
130 "http://example.com/test",
131 Pagination {
132 size: None,
133 page: None,
134 },
135 )
136 .await;
137
138 check(
139 "http://example.com/test?size=10",
140 Pagination {
141 size: Some(10),
142 page: None,
143 },
144 )
145 .await;
146
147 check(
148 "http://example.com/test?size=10&page=20",
149 Pagination {
150 size: Some(10),
151 page: Some(20),
152 },
153 )
154 .await;
155 }
156
157 #[crate::test]
158 async fn correct_rejection_status_code() {
159 #[derive(Deserialize)]
160 #[allow(dead_code)]
161 struct Params {
162 n: i32,
163 }
164
165 async fn handler(_: Query<Params>) {}
166
167 let app = Router::new().route("/", get(handler));
168 let client = TestClient::new(app);
169
170 let res = client.get("/?n=hi").await;
171 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
172 assert_eq!(
173 res.text().await,
174 "Failed to deserialize query string: n: invalid digit found in string"
175 );
176 }
177
178 #[test]
179 fn test_try_from_uri() {
180 #[derive(Deserialize)]
181 struct TestQueryParams {
182 foo: String,
183 bar: u32,
184 }
185 let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
186 let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
187 assert_eq!(result.foo, String::from("hello"));
188 assert_eq!(result.bar, 42);
189 }
190
191 #[test]
192 fn test_try_from_uri_with_invalid_query() {
193 #[derive(Deserialize)]
194 struct TestQueryParams {
195 _foo: String,
196 _bar: u32,
197 }
198 let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
199 .parse()
200 .unwrap();
201 let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
202
203 assert!(result.is_err());
204 }
205}