1use crate::extract::Request;
2use crate::extract::{rejection::*, FromRequest, RawForm};
3use axum_core::response::{IntoResponse, Response};
4use axum_core::RequestExt;
5use http::header::CONTENT_TYPE;
6use http::StatusCode;
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9
10#[cfg_attr(docsrs, doc(cfg(feature = "form")))]
70#[derive(Debug, Clone, Copy, Default)]
71#[must_use]
72pub struct Form<T>(pub T);
73
74impl<T, S> FromRequest<S> for Form<T>
75where
76 T: DeserializeOwned,
77 S: Send + Sync,
78{
79 type Rejection = FormRejection;
80
81 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
82 let is_get_or_head =
83 req.method() == http::Method::GET || req.method() == http::Method::HEAD;
84
85 match req.extract().await {
86 Ok(RawForm(bytes)) => {
87 let deserializer =
88 serde_urlencoded::Deserializer::new(form_urlencoded::parse(&bytes));
89 let value = serde_path_to_error::deserialize(deserializer).map_err(
90 |err| -> FormRejection {
91 if is_get_or_head {
92 FailedToDeserializeForm::from_err(err).into()
93 } else {
94 FailedToDeserializeFormBody::from_err(err).into()
95 }
96 },
97 )?;
98 Ok(Form(value))
99 }
100 Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
101 Err(RawFormRejection::InvalidFormContentType(r)) => {
102 Err(FormRejection::InvalidFormContentType(r))
103 }
104 }
105 }
106}
107
108impl<T> IntoResponse for Form<T>
109where
110 T: Serialize,
111{
112 fn into_response(self) -> Response {
113 match serde_urlencoded::to_string(&self.0) {
114 Ok(body) => (
115 [(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())],
116 body,
117 )
118 .into_response(),
119 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
120 }
121 }
122}
123
124axum_core::__impl_deref!(Form);
125
126#[cfg(test)]
127mod tests {
128 use crate::{
129 routing::{on, MethodFilter},
130 test_helpers::TestClient,
131 Router,
132 };
133
134 use super::*;
135 use axum_core::body::Body;
136 use http::{Method, Request};
137 use mime::APPLICATION_WWW_FORM_URLENCODED;
138 use serde::{Deserialize, Serialize};
139 use std::fmt::Debug;
140
141 #[derive(Debug, PartialEq, Serialize, Deserialize)]
142 struct Pagination {
143 size: Option<u64>,
144 page: Option<u64>,
145 }
146
147 async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
148 let req = Request::builder()
149 .uri(uri.as_ref())
150 .body(Body::empty())
151 .unwrap();
152 assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
153 }
154
155 async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
156 let req = Request::builder()
157 .uri("http://example.com/test")
158 .method(Method::POST)
159 .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
160 .body(Body::from(serde_urlencoded::to_string(&value).unwrap()))
161 .unwrap();
162 assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
163 }
164
165 #[crate::test]
166 async fn test_form_query() {
167 check_query(
168 "http://example.com/test",
169 Pagination {
170 size: None,
171 page: None,
172 },
173 )
174 .await;
175
176 check_query(
177 "http://example.com/test?size=10",
178 Pagination {
179 size: Some(10),
180 page: None,
181 },
182 )
183 .await;
184
185 check_query(
186 "http://example.com/test?size=10&page=20",
187 Pagination {
188 size: Some(10),
189 page: Some(20),
190 },
191 )
192 .await;
193 }
194
195 #[crate::test]
196 async fn test_form_body() {
197 check_body(Pagination {
198 size: None,
199 page: None,
200 })
201 .await;
202
203 check_body(Pagination {
204 size: Some(10),
205 page: None,
206 })
207 .await;
208
209 check_body(Pagination {
210 size: Some(10),
211 page: Some(20),
212 })
213 .await;
214 }
215
216 #[crate::test]
217 async fn test_incorrect_content_type() {
218 let req = Request::builder()
219 .uri("http://example.com/test")
220 .method(Method::POST)
221 .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
222 .body(Body::from(
223 serde_urlencoded::to_string(&Pagination {
224 size: Some(10),
225 page: None,
226 })
227 .unwrap(),
228 ))
229 .unwrap();
230 assert!(matches!(
231 Form::<Pagination>::from_request(req, &())
232 .await
233 .unwrap_err(),
234 FormRejection::InvalidFormContentType(InvalidFormContentType)
235 ));
236 }
237
238 #[tokio::test]
239 async fn deserialize_error_status_codes() {
240 #[allow(dead_code)]
241 #[derive(Deserialize)]
242 struct Payload {
243 a: i32,
244 }
245
246 let app = Router::new().route(
247 "/",
248 on(
249 MethodFilter::GET.or(MethodFilter::POST),
250 |_: Form<Payload>| async {},
251 ),
252 );
253
254 let client = TestClient::new(app);
255
256 let res = client.get("/?a=false").await;
257 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
258 assert_eq!(
259 res.text().await,
260 "Failed to deserialize form: a: invalid digit found in string"
261 );
262
263 let res = client
264 .post("/")
265 .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
266 .body("a=false")
267 .await;
268 assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
269 assert_eq!(
270 res.text().await,
271 "Failed to deserialize form body: a: invalid digit found in string"
272 );
273 }
274}