axum_core/extract/
request_parts.rs1use super::{rejection::*, FromRequest, FromRequestParts, Request};
2use crate::{body::Body, RequestExt};
3use bytes::{BufMut, Bytes, BytesMut};
4use http::{request::Parts, Extensions, HeaderMap, Method, Uri, Version};
5use http_body_util::BodyExt;
6use std::convert::Infallible;
7
8impl<S> FromRequest<S> for Request
9where
10 S: Send + Sync,
11{
12 type Rejection = Infallible;
13
14 async fn from_request(req: Request, _: &S) -> Result<Self, Self::Rejection> {
15 Ok(req)
16 }
17}
18
19impl<S> FromRequestParts<S> for Method
20where
21 S: Send + Sync,
22{
23 type Rejection = Infallible;
24
25 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
26 Ok(parts.method.clone())
27 }
28}
29
30impl<S> FromRequestParts<S> for Uri
31where
32 S: Send + Sync,
33{
34 type Rejection = Infallible;
35
36 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
37 Ok(parts.uri.clone())
38 }
39}
40
41impl<S> FromRequestParts<S> for Version
42where
43 S: Send + Sync,
44{
45 type Rejection = Infallible;
46
47 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
48 Ok(parts.version)
49 }
50}
51
52impl<S> FromRequestParts<S> for HeaderMap
58where
59 S: Send + Sync,
60{
61 type Rejection = Infallible;
62
63 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
64 Ok(parts.headers.clone())
65 }
66}
67
68impl<S> FromRequest<S> for BytesMut
69where
70 S: Send + Sync,
71{
72 type Rejection = BytesRejection;
73
74 async fn from_request(req: Request, _: &S) -> Result<Self, Self::Rejection> {
75 let mut body = req.into_limited_body();
76 let mut bytes = BytesMut::new();
77 body_to_bytes_mut(&mut body, &mut bytes).await?;
78 Ok(bytes)
79 }
80}
81
82async fn body_to_bytes_mut(body: &mut Body, bytes: &mut BytesMut) -> Result<(), BytesRejection> {
83 while let Some(frame) = body
84 .frame()
85 .await
86 .transpose()
87 .map_err(FailedToBufferBody::from_err)?
88 {
89 let Ok(data) = frame.into_data() else {
90 return Ok(());
91 };
92 bytes.put(data);
93 }
94
95 Ok(())
96}
97
98impl<S> FromRequest<S> for Bytes
99where
100 S: Send + Sync,
101{
102 type Rejection = BytesRejection;
103
104 async fn from_request(req: Request, _: &S) -> Result<Self, Self::Rejection> {
105 let bytes = req
106 .into_limited_body()
107 .collect()
108 .await
109 .map_err(FailedToBufferBody::from_err)?
110 .to_bytes();
111
112 Ok(bytes)
113 }
114}
115
116impl<S> FromRequest<S> for String
117where
118 S: Send + Sync,
119{
120 type Rejection = StringRejection;
121
122 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
123 let bytes = Bytes::from_request(req, state)
124 .await
125 .map_err(|err| match err {
126 BytesRejection::FailedToBufferBody(inner) => {
127 StringRejection::FailedToBufferBody(inner)
128 }
129 })?;
130
131 let string = String::from_utf8(bytes.into()).map_err(InvalidUtf8::from_err)?;
132
133 Ok(string)
134 }
135}
136
137impl<S> FromRequestParts<S> for Parts
138where
139 S: Send + Sync,
140{
141 type Rejection = Infallible;
142
143 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
144 Ok(parts.clone())
145 }
146}
147
148impl<S> FromRequestParts<S> for Extensions
149where
150 S: Send + Sync,
151{
152 type Rejection = Infallible;
153
154 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
155 Ok(parts.extensions.clone())
156 }
157}
158
159impl<S> FromRequest<S> for Body
160where
161 S: Send + Sync,
162{
163 type Rejection = Infallible;
164
165 async fn from_request(req: Request, _: &S) -> Result<Self, Self::Rejection> {
166 Ok(req.into_body())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use axum::{extract::Extension, routing::get, test_helpers::*, Router};
173 use http::{Method, StatusCode};
174
175 #[crate::test]
176 async fn extract_request_parts() {
177 #[derive(Clone)]
178 struct Ext;
179
180 async fn handler(parts: http::request::Parts) {
181 assert_eq!(parts.method, Method::GET);
182 assert_eq!(parts.uri, "/");
183 assert_eq!(parts.version, http::Version::HTTP_11);
184 assert_eq!(parts.headers["x-foo"], "123");
185 parts.extensions.get::<Ext>().unwrap();
186 }
187
188 let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
189
190 let res = client.get("/").header("x-foo", "123").await;
191 assert_eq!(res.status(), StatusCode::OK);
192 }
193}