1use crate::{
2 extract::FromRequestParts,
3 response::{IntoResponse, Response},
4};
5use futures_util::{future::BoxFuture, ready};
6use http::Request;
7use pin_project_lite::pin_project;
8use std::{
9 fmt,
10 future::Future,
11 marker::PhantomData,
12 pin::Pin,
13 task::{Context, Poll},
14};
15use tower_layer::Layer;
16use tower_service::Service;
17
18pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
90 from_extractor_with_state(())
91}
92
93pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
97 FromExtractorLayer {
98 state,
99 _marker: PhantomData,
100 }
101}
102
103#[must_use]
110pub struct FromExtractorLayer<E, S> {
111 state: S,
112 _marker: PhantomData<fn() -> E>,
113}
114
115impl<E, S> Clone for FromExtractorLayer<E, S>
116where
117 S: Clone,
118{
119 fn clone(&self) -> Self {
120 Self {
121 state: self.state.clone(),
122 _marker: PhantomData,
123 }
124 }
125}
126
127impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
128where
129 S: fmt::Debug,
130{
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("FromExtractorLayer")
133 .field("state", &self.state)
134 .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
135 .finish()
136 }
137}
138
139impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
140where
141 S: Clone,
142{
143 type Service = FromExtractor<T, E, S>;
144
145 fn layer(&self, inner: T) -> Self::Service {
146 FromExtractor {
147 inner,
148 state: self.state.clone(),
149 _extractor: PhantomData,
150 }
151 }
152}
153
154pub struct FromExtractor<T, E, S> {
158 inner: T,
159 state: S,
160 _extractor: PhantomData<fn() -> E>,
161}
162
163#[test]
164fn traits() {
165 use crate::test_helpers::*;
166 assert_send::<FromExtractor<(), NotSendSync, ()>>();
167 assert_sync::<FromExtractor<(), NotSendSync, ()>>();
168}
169
170impl<T, E, S> Clone for FromExtractor<T, E, S>
171where
172 T: Clone,
173 S: Clone,
174{
175 fn clone(&self) -> Self {
176 Self {
177 inner: self.inner.clone(),
178 state: self.state.clone(),
179 _extractor: PhantomData,
180 }
181 }
182}
183
184impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
185where
186 T: fmt::Debug,
187 S: fmt::Debug,
188{
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 f.debug_struct("FromExtractor")
191 .field("inner", &self.inner)
192 .field("state", &self.state)
193 .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
194 .finish()
195 }
196}
197
198impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
199where
200 E: FromRequestParts<S> + 'static,
201 B: Send + 'static,
202 T: Service<Request<B>> + Clone,
203 T::Response: IntoResponse,
204 S: Clone + Send + Sync + 'static,
205{
206 type Response = Response;
207 type Error = T::Error;
208 type Future = ResponseFuture<B, T, E, S>;
209
210 #[inline]
211 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212 self.inner.poll_ready(cx)
213 }
214
215 fn call(&mut self, req: Request<B>) -> Self::Future {
216 let state = self.state.clone();
217 let extract_future = Box::pin(async move {
218 let (mut parts, body) = req.into_parts();
219 let extracted = E::from_request_parts(&mut parts, &state).await;
220 let req = Request::from_parts(parts, body);
221 (req, extracted)
222 });
223
224 ResponseFuture {
225 state: State::Extracting {
226 future: extract_future,
227 },
228 svc: Some(self.inner.clone()),
229 }
230 }
231}
232
233pin_project! {
234 #[allow(missing_debug_implementations)]
236 pub struct ResponseFuture<B, T, E, S>
237 where
238 E: FromRequestParts<S>,
239 T: Service<Request<B>>,
240 {
241 #[pin]
242 state: State<B, T, E, S>,
243 svc: Option<T>,
244 }
245}
246
247pin_project! {
248 #[project = StateProj]
249 enum State<B, T, E, S>
250 where
251 E: FromRequestParts<S>,
252 T: Service<Request<B>>,
253 {
254 Extracting {
255 future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
256 },
257 Call { #[pin] future: T::Future },
258 }
259}
260
261impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
262where
263 E: FromRequestParts<S>,
264 T: Service<Request<B>>,
265 T::Response: IntoResponse,
266{
267 type Output = Result<Response, T::Error>;
268
269 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
270 loop {
271 let mut this = self.as_mut().project();
272
273 let new_state = match this.state.as_mut().project() {
274 StateProj::Extracting { future } => {
275 let (req, extracted) = ready!(future.as_mut().poll(cx));
276
277 match extracted {
278 Ok(_) => {
279 let mut svc = this.svc.take().expect("future polled after completion");
280 let future = svc.call(req);
281 State::Call { future }
282 }
283 Err(err) => {
284 let res = err.into_response();
285 return Poll::Ready(Ok(res));
286 }
287 }
288 }
289 StateProj::Call { future } => {
290 return future
291 .poll(cx)
292 .map(|result| result.map(IntoResponse::into_response));
293 }
294 };
295
296 this.state.set(new_state);
297 }
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::{handler::Handler, routing::get, test_helpers::*, Router};
305 use axum_core::extract::FromRef;
306 use http::{header, request::Parts, StatusCode};
307 use tower_http::limit::RequestBodyLimitLayer;
308
309 #[crate::test]
310 async fn test_from_extractor() {
311 #[derive(Clone)]
312 struct Secret(&'static str);
313
314 struct RequireAuth;
315
316 impl<S> FromRequestParts<S> for RequireAuth
317 where
318 S: Send + Sync,
319 Secret: FromRef<S>,
320 {
321 type Rejection = StatusCode;
322
323 async fn from_request_parts(
324 parts: &mut Parts,
325 state: &S,
326 ) -> Result<Self, Self::Rejection> {
327 let Secret(secret) = Secret::from_ref(state);
328 if let Some(auth) = parts
329 .headers
330 .get(header::AUTHORIZATION)
331 .and_then(|v| v.to_str().ok())
332 {
333 if auth == secret {
334 return Ok(Self);
335 }
336 }
337
338 Err(StatusCode::UNAUTHORIZED)
339 }
340 }
341
342 async fn handler() {}
343
344 let state = Secret("secret");
345 let app = Router::new().route(
346 "/",
347 get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
348 );
349
350 let client = TestClient::new(app);
351
352 let res = client.get("/").await;
353 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
354
355 let res = client
356 .get("/")
357 .header(http::header::AUTHORIZATION, "secret")
358 .await;
359 assert_eq!(res.status(), StatusCode::OK);
360 }
361
362 #[allow(dead_code)]
364 fn works_with_request_body_limit() {
365 struct MyExtractor;
366
367 impl<S> FromRequestParts<S> for MyExtractor
368 where
369 S: Send + Sync,
370 {
371 type Rejection = std::convert::Infallible;
372
373 async fn from_request_parts(
374 _parts: &mut Parts,
375 _state: &S,
376 ) -> Result<Self, Self::Rejection> {
377 unimplemented!()
378 }
379 }
380
381 let _: Router = Router::new()
382 .layer(from_extractor::<MyExtractor>())
383 .layer(RequestBodyLimitLayer::new(1));
384 }
385}