1use crate::{extract::rejection::*, response::IntoResponseParts};
2use axum_core::{
3 extract::FromRequestParts,
4 response::{IntoResponse, Response, ResponseParts},
5};
6use http::{request::Parts, Request};
7use std::{
8 convert::Infallible,
9 task::{Context, Poll},
10};
11use tower_service::Service;
12
13#[derive(Debug, Clone, Copy, Default)]
69#[must_use]
70pub struct Extension<T>(pub T);
71
72impl<T, S> FromRequestParts<S> for Extension<T>
73where
74 T: Clone + Send + Sync + 'static,
75 S: Send + Sync,
76{
77 type Rejection = ExtensionRejection;
78
79 async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
80 let value = req
81 .extensions
82 .get::<T>()
83 .ok_or_else(|| {
84 MissingExtension::from_err(format!(
85 "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
86 std::any::type_name::<T>()
87 ))
88 }).cloned()?;
89
90 Ok(Extension(value))
91 }
92}
93
94axum_core::__impl_deref!(Extension);
95
96impl<T> IntoResponseParts for Extension<T>
97where
98 T: Clone + Send + Sync + 'static,
99{
100 type Error = Infallible;
101
102 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
103 res.extensions_mut().insert(self.0);
104 Ok(res)
105 }
106}
107
108impl<T> IntoResponse for Extension<T>
109where
110 T: Clone + Send + Sync + 'static,
111{
112 fn into_response(self) -> Response {
113 let mut res = ().into_response();
114 res.extensions_mut().insert(self.0);
115 res
116 }
117}
118
119impl<S, T> tower_layer::Layer<S> for Extension<T>
120where
121 T: Clone + Send + Sync + 'static,
122{
123 type Service = AddExtension<S, T>;
124
125 fn layer(&self, inner: S) -> Self::Service {
126 AddExtension {
127 inner,
128 value: self.0.clone(),
129 }
130 }
131}
132
133#[derive(Clone, Copy, Debug)]
140pub struct AddExtension<S, T> {
141 pub(crate) inner: S,
142 pub(crate) value: T,
143}
144
145impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T>
146where
147 S: Service<Request<ResBody>>,
148 T: Clone + Send + Sync + 'static,
149{
150 type Response = S::Response;
151 type Error = S::Error;
152 type Future = S::Future;
153
154 #[inline]
155 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156 self.inner.poll_ready(cx)
157 }
158
159 fn call(&mut self, mut req: Request<ResBody>) -> Self::Future {
160 req.extensions_mut().insert(self.value.clone());
161 self.inner.call(req)
162 }
163}