1use crate::response::{IntoResponse, Response};
2use axum_core::extract::{FromRequest, FromRequestParts, Request};
3use futures_util::future::BoxFuture;
4use std::{
5 any::type_name,
6 convert::Infallible,
7 fmt,
8 future::Future,
9 marker::PhantomData,
10 pin::Pin,
11 task::{Context, Poll},
12};
13use tower::util::BoxCloneSyncService;
14use tower::ServiceBuilder;
15use tower_layer::Layer;
16use tower_service::Service;
17
18pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
112 from_fn_with_state((), f)
113}
114
115pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
162 FromFnLayer {
163 f,
164 state,
165 _extractor: PhantomData,
166 }
167}
168
169#[must_use]
175pub struct FromFnLayer<F, S, T> {
176 f: F,
177 state: S,
178 _extractor: PhantomData<fn() -> T>,
179}
180
181impl<F, S, T> Clone for FromFnLayer<F, S, T>
182where
183 F: Clone,
184 S: Clone,
185{
186 fn clone(&self) -> Self {
187 Self {
188 f: self.f.clone(),
189 state: self.state.clone(),
190 _extractor: self._extractor,
191 }
192 }
193}
194
195impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
196where
197 F: Clone,
198 S: Clone,
199{
200 type Service = FromFn<F, S, I, T>;
201
202 fn layer(&self, inner: I) -> Self::Service {
203 FromFn {
204 f: self.f.clone(),
205 state: self.state.clone(),
206 inner,
207 _extractor: PhantomData,
208 }
209 }
210}
211
212impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
213where
214 S: fmt::Debug,
215{
216 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
217 f.debug_struct("FromFnLayer")
218 .field("f", &format_args!("{}", type_name::<F>()))
220 .field("state", &self.state)
221 .finish()
222 }
223}
224
225pub struct FromFn<F, S, I, T> {
229 f: F,
230 inner: I,
231 state: S,
232 _extractor: PhantomData<fn() -> T>,
233}
234
235impl<F, S, I, T> Clone for FromFn<F, S, I, T>
236where
237 F: Clone,
238 I: Clone,
239 S: Clone,
240{
241 fn clone(&self) -> Self {
242 Self {
243 f: self.f.clone(),
244 inner: self.inner.clone(),
245 state: self.state.clone(),
246 _extractor: self._extractor,
247 }
248 }
249}
250
251macro_rules! impl_service {
252 (
253 [$($ty:ident),*], $last:ident
254 ) => {
255 #[allow(non_snake_case, unused_mut)]
256 impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
257 where
258 F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
259 $( $ty: FromRequestParts<S> + Send, )*
260 $last: FromRequest<S> + Send,
261 Fut: Future<Output = Out> + Send + 'static,
262 Out: IntoResponse + 'static,
263 I: Service<Request, Error = Infallible>
264 + Clone
265 + Send
266 + Sync
267 + 'static,
268 I::Response: IntoResponse,
269 I::Future: Send + 'static,
270 S: Clone + Send + Sync + 'static,
271 {
272 type Response = Response;
273 type Error = Infallible;
274 type Future = ResponseFuture;
275
276 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
277 self.inner.poll_ready(cx)
278 }
279
280 fn call(&mut self, req: Request) -> Self::Future {
281 let not_ready_inner = self.inner.clone();
282 let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
283
284 let mut f = self.f.clone();
285 let state = self.state.clone();
286
287 let future = Box::pin(async move {
288 let (mut parts, body) = req.into_parts();
289
290 $(
291 let $ty = match $ty::from_request_parts(&mut parts, &state).await {
292 Ok(value) => value,
293 Err(rejection) => return rejection.into_response(),
294 };
295 )*
296
297 let req = Request::from_parts(parts, body);
298
299 let $last = match $last::from_request(req, &state).await {
300 Ok(value) => value,
301 Err(rejection) => return rejection.into_response(),
302 };
303
304 let inner = ServiceBuilder::new()
305 .layer_fn(BoxCloneSyncService::new)
306 .map_response(IntoResponse::into_response)
307 .service(ready_inner);
308 let next = Next { inner };
309
310 f($($ty,)* $last, next).await.into_response()
311 });
312
313 ResponseFuture {
314 inner: future
315 }
316 }
317 }
318 };
319}
320
321all_the_tuples!(impl_service);
322
323impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
324where
325 S: fmt::Debug,
326 I: fmt::Debug,
327{
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 f.debug_struct("FromFnLayer")
330 .field("f", &format_args!("{}", type_name::<F>()))
331 .field("inner", &self.inner)
332 .field("state", &self.state)
333 .finish()
334 }
335}
336
337#[derive(Debug, Clone)]
339pub struct Next {
340 inner: BoxCloneSyncService<Request, Response, Infallible>,
341}
342
343impl Next {
344 pub async fn run(mut self, req: Request) -> Response {
346 match self.inner.call(req).await {
347 Ok(res) => res,
348 Err(err) => match err {},
349 }
350 }
351}
352
353impl Service<Request> for Next {
354 type Response = Response;
355 type Error = Infallible;
356 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
357
358 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
359 self.inner.poll_ready(cx)
360 }
361
362 fn call(&mut self, req: Request) -> Self::Future {
363 self.inner.call(req)
364 }
365}
366
367pub struct ResponseFuture {
369 inner: BoxFuture<'static, Response>,
370}
371
372impl Future for ResponseFuture {
373 type Output = Result<Response, Infallible>;
374
375 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
376 self.inner.as_mut().poll(cx).map(Ok)
377 }
378}
379
380impl fmt::Debug for ResponseFuture {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 f.debug_struct("ResponseFuture").finish()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::{body::Body, routing::get, Router};
390 use http::{HeaderMap, StatusCode};
391 use http_body_util::BodyExt;
392 use tower::ServiceExt;
393
394 #[crate::test]
395 async fn basic() {
396 async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
397 req.headers_mut()
398 .insert("x-axum-test", "ok".parse().unwrap());
399
400 next.run(req).await
401 }
402
403 async fn handle(headers: HeaderMap) -> String {
404 headers["x-axum-test"].to_str().unwrap().to_owned()
405 }
406
407 let app = Router::new()
408 .route("/", get(handle))
409 .layer(from_fn(insert_header));
410
411 let res = app
412 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
413 .await
414 .unwrap();
415 assert_eq!(res.status(), StatusCode::OK);
416 let body = res.collect().await.unwrap().to_bytes();
417 assert_eq!(&body[..], b"ok");
418 }
419}