From c5c8a41aa606c9fb4d42815dcd49104fedf9c8e2 Mon Sep 17 00:00:00 2001 From: Martin Bartlett Date: Mon, 23 Oct 2023 11:45:27 +0200 Subject: [PATCH 1/3] Implementation --- tower-http/Cargo.toml | 2 + tower-http/src/builder.rs | 17 ++ tower-http/src/compression/future.rs | 1 + tower-http/src/lib.rs | 5 +- tower-http/src/propagate_extension.rs | 215 ++++++++++++++++++++++++++ 5 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 tower-http/src/propagate_extension.rs diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index b88ddeb3..c9de6366 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -68,6 +68,7 @@ full = [ "map-response-body", "metrics", "normalize-path", + "propagate-extension", "propagate-header", "redirect", "request-id", @@ -91,6 +92,7 @@ map-request-body = [] map-response-body = [] metrics = ["tokio/time"] normalize-path = [] +propagate-extension = [] propagate-header = [] redirect = [] request-id = ["uuid"] diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index 2cb4f94a..5784274a 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -54,6 +54,16 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { header: HeaderName, ) -> ServiceBuilder>; + /// Propagate an extension from the request to the response. + /// + /// See [`tower_http::propagate_extension`] for more details. + /// + /// [`tower_http::propagate_extension`]: crate::propagate_extension + #[cfg(feature = "propagate-extension")] + fn propagate_extension( + self + ) -> ServiceBuilder, L>>; + /// Add some shareable value to [request extensions]. /// /// See [`tower_http::add_extension`] for more details. @@ -380,6 +390,13 @@ impl ServiceBuilderExt for ServiceBuilder { self.layer(crate::propagate_header::PropagateHeaderLayer::new(header)) } + #[cfg(feature = "propagate-extension")] + fn propagate_extension( + self, + ) -> ServiceBuilder, L>> { + self.layer(crate::propagate_extension::PropagateExtensionLayer::::new()) + } + #[cfg(feature = "add-extension")] fn add_extension( self, diff --git a/tower-http/src/compression/future.rs b/tower-http/src/compression/future.rs index 426bb161..bfccff52 100644 --- a/tower-http/src/compression/future.rs +++ b/tower-http/src/compression/future.rs @@ -73,6 +73,7 @@ where CompressionBody::new(BodyInner::zstd(WrapBody::new(body, self.quality))) } #[cfg(feature = "fs")] + #[allow(unreachable_patterns)] (true, _) => { // This should never happen because the `AcceptEncoding` struct which is used to determine // `self.encoding` will only enable the different compression algorithms if the diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 6719ddbd..c10b8189 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -231,6 +231,9 @@ pub mod auth; #[cfg(feature = "set-header")] pub mod set_header; +#[cfg(any(test,feature = "propagate-extension"))] +pub mod propagate_extension; + #[cfg(feature = "propagate-header")] pub mod propagate_header; @@ -242,7 +245,7 @@ pub mod propagate_header; ))] pub mod compression; -#[cfg(feature = "add-extension")] +#[cfg(any(test,feature = "add-extension"))] pub mod add_extension; #[cfg(feature = "sensitive-headers")] diff --git a/tower-http/src/propagate_extension.rs b/tower-http/src/propagate_extension.rs new file mode 100644 index 00000000..eb3f9e06 --- /dev/null +++ b/tower-http/src/propagate_extension.rs @@ -0,0 +1,215 @@ +//! Propagate an extension from the request to the response. +//! +//! This middleware is intended to wrap a Request->Response service handler that is _unaware_ of the +//! extension. Consequently it _removes_ the extension from the request before forwarding the request, and then +//! inserts it into the response when the response is ready. As a usage example, if you have pre-service mappers +//! that need to share state with post-service mappers, you can store the state in the Request extensions, +//! and this middleware will ensure that it is available to the post service mappers via the Response extensions. +//! +//! # Example +//! +//! ```rust +//! use http::{Request, Response}; +//! use std::convert::Infallible; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use tower_http::add_extension::AddExtensionLayer; +//! use tower_http::propagate_extension::PropagateExtensionLayer; +//! use hyper::Body; +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! async fn handle(req: Request) -> Result, Infallible> { +//! // ... +//! # Ok(Response::new(Body::empty())) +//! } +//! +//! // +//! // Note that while the state object must _implement_ Clone, it should never actually +//! // _be_ cloned due to the manner in which it is used within the middleware. +//! // +//! #[derive(Clone)] +//! struct MyState { +//! state_message: String +//! }; +//! +//! let my_state = MyState { state_message: "propagated state".to_string() }; +//! +//! let mut svc = ServiceBuilder::new() +//! .layer(AddExtensionLayer::new(my_state)) // any other way of adding the extension to the request is OK too +//! .layer(PropagateExtensionLayer::::new()) +//! .service_fn(handle); +//! +//! // Call the service. +//! let request = Request::builder() +//! .body(Body::empty())?; +//! +//! let response = svc.ready().await?.call(request).await?; +//! +//! assert_eq!(response.extensions().get::().unwrap().state_message, "propagated state"); +//! # +//! # Ok(()) +//! # } +//! ``` + +use futures_util::ready; +use http::{Request, Response}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::{ + pin::Pin, + task::{Context, Poll}, + marker::PhantomData, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Layer that applies [`PropagateExtension`] which propagates an extension from the request to the response. +/// +/// This middleware is intended to wrap a Request->Response service handler that is _unaware_ of the +/// extension. Consequently it _removes_ the extension from the request before forwarding the request, and then +/// inserts it into the response when the response is ready. As a usage example, if you have pre-service mappers +/// that need to share state with post-service mappers, you can store the state in the Request extensions, +/// and this middleware will ensure that it is available to the post service mappers via the Response extensions. +/// +/// See the [module docs](crate::propagate_extension) for more details. +#[derive(Clone, Debug)] +pub struct PropagateExtensionLayer { + _phantom: PhantomData +} + +impl PropagateExtensionLayer { + /// Create a new [`PropagateExtensionLayer`]. + pub fn new() -> Self { + Self { _phantom: PhantomData } + } +} + +impl Layer for PropagateExtensionLayer { + type Service = PropagateExtension; + + fn layer(&self, inner: S) -> Self::Service { + PropagateExtension:: { + inner, + _phantom: PhantomData + } + } +} + +/// Middleware that propagates extensions from requests to responses. +/// +/// If the extension is present on the request it'll be removed from the request and +/// inserted into the response. +/// +/// See the [module docs](crate::propagate_extension) for more details. +#[derive(Clone,Debug)] +pub struct PropagateExtension { + inner: S, + _phantom: PhantomData +} + +impl PropagateExtension { + /// Create a new [`PropagateExtension`] that propagates the given extension type. + pub fn new(inner: S) -> Self { + Self { inner, _phantom: PhantomData } + } + + define_inner_service_accessors!(); + + /// Returns a new [`Layer`] that wraps services with a `PropagateExtension` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer() -> PropagateExtensionLayer { + PropagateExtensionLayer::::new() + } +} + +impl Service> for PropagateExtension +where + X: Sync + Send + 'static, + S: Service, Response = Response>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let extension: Option = req.extensions_mut().remove(); + + ResponseFuture { + future: self.inner.call(req), + extension, + } + } +} + +pin_project! { + /// Response future for [`PropagateExtension`]. + #[derive(Debug)] + pub struct ResponseFuture { + #[pin] + future: F, + extension: Option, + } +} + +impl Future for ResponseFuture +where + X: Sync + Send + 'static, + F: Future, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let mut res = ready!(this.future.poll(cx)?); + + if let Some(extension) = this.extension.take() { + res.extensions_mut().insert(extension); + } + + Poll::Ready(Ok(res)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use http::{Request, Response}; + use std::convert::Infallible; + use tower::{Service, ServiceExt, ServiceBuilder}; + use crate::add_extension::AddExtensionLayer; + use hyper::Body; + + async fn handle(_req: Request) -> Result, Infallible> { + Ok(Response::new(Body::empty())) + } + + #[derive(Clone)] + struct MyState { + state_message: String + } + + #[test] + fn basic_test() { + + let my_state = MyState { state_message: "propagated state".to_string() }; + + let mut svc = ServiceBuilder::new() + .layer(AddExtensionLayer::new(my_state)) // any other way of adding the extension to the request is OK too + .layer(PropagateExtensionLayer::::new()) + .service_fn(handle); + + let request = Request::builder().body(Body::empty()).expect("Expected an empty body"); + + // Call the service. + let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); + let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); + assert_eq!(response.extensions().get::().unwrap().state_message, "propagated state"); + } +} From b38eda1aa97bb52651731776424736cec6e62492 Mon Sep 17 00:00:00 2001 From: Martin Bartlett Date: Mon, 23 Oct 2023 15:50:42 +0200 Subject: [PATCH 2/3] Discovering --all-features on cargo test! --- tower-http/src/lib.rs | 4 ++-- tower-http/src/propagate_extension.rs | 24 ++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index c10b8189..a5777b24 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -231,7 +231,7 @@ pub mod auth; #[cfg(feature = "set-header")] pub mod set_header; -#[cfg(any(test,feature = "propagate-extension"))] +#[cfg(feature = "propagate-extension")] pub mod propagate_extension; #[cfg(feature = "propagate-header")] @@ -245,7 +245,7 @@ pub mod propagate_header; ))] pub mod compression; -#[cfg(any(test,feature = "add-extension"))] +#[cfg(feature = "add-extension")] pub mod add_extension; #[cfg(feature = "sensitive-headers")] diff --git a/tower-http/src/propagate_extension.rs b/tower-http/src/propagate_extension.rs index eb3f9e06..2d6b0eb7 100644 --- a/tower-http/src/propagate_extension.rs +++ b/tower-http/src/propagate_extension.rs @@ -14,6 +14,7 @@ //! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; //! use tower_http::add_extension::AddExtensionLayer; //! use tower_http::propagate_extension::PropagateExtensionLayer; +//! use tower_http::ServiceBuilderExt; //! use hyper::Body; //! //! # #[tokio::main] @@ -35,8 +36,8 @@ //! let my_state = MyState { state_message: "propagated state".to_string() }; //! //! let mut svc = ServiceBuilder::new() -//! .layer(AddExtensionLayer::new(my_state)) // any other way of adding the extension to the request is OK too -//! .layer(PropagateExtensionLayer::::new()) +//! .add_extension(my_state) // any other way of adding the extension to the request is OK too +//! .propagate_extension::() //! .service_fn(handle); //! //! // Call the service. @@ -184,6 +185,7 @@ mod tests { use std::convert::Infallible; use tower::{Service, ServiceExt, ServiceBuilder}; use crate::add_extension::AddExtensionLayer; + use crate::builder::ServiceBuilderExt; use hyper::Body; async fn handle(_req: Request) -> Result, Infallible> { @@ -207,6 +209,24 @@ mod tests { let request = Request::builder().body(Body::empty()).expect("Expected an empty body"); + // Call the service. + let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); + let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); + assert_eq!(response.extensions().get::().unwrap().state_message, "propagated state"); + } + + #[test] + fn test_server_builder_ext() { + + let my_state = MyState { state_message: "propagated state".to_string() }; + + let mut svc = ServiceBuilder::new() + .add_extension(my_state) // any other way of adding the extension to the request is OK too + .propagate_extension::() + .service_fn(handle); + + let request = Request::builder().body(Body::empty()).expect("Expected an empty body"); + // Call the service. let ready = futures::executor::block_on(svc.ready()).expect("Expected the service to be ready"); let response = futures::executor::block_on(ready.call(request)).expect("Expected the service to be successful"); From 6b5e9a181052900063a8a2cffa5b594542d2c5f5 Mon Sep 17 00:00:00 2001 From: Martin Bartlett Date: Mon, 23 Oct 2023 18:11:57 +0200 Subject: [PATCH 3/3] Add debug/tracing --- tower-http/src/propagate_extension.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tower-http/src/propagate_extension.rs b/tower-http/src/propagate_extension.rs index 2d6b0eb7..3438cf2d 100644 --- a/tower-http/src/propagate_extension.rs +++ b/tower-http/src/propagate_extension.rs @@ -64,6 +64,15 @@ use std::{ use tower_layer::Layer; use tower_service::Service; +#[allow(unused_imports)] +use tracing::{ + trace, + debug, + info, + warn, + error, +}; + /// Layer that applies [`PropagateExtension`] which propagates an extension from the request to the response. /// /// This middleware is intended to wrap a Request->Response service handler that is _unaware_ of the @@ -140,6 +149,7 @@ where fn call(&mut self, mut req: Request) -> Self::Future { let extension: Option = req.extensions_mut().remove(); + debug!("Removed state from request extensions. is_some? {}", extension.is_some()); ResponseFuture { future: self.inner.call(req), @@ -170,7 +180,10 @@ where let mut res = ready!(this.future.poll(cx)?); if let Some(extension) = this.extension.take() { + debug!("Inserting state into response extensions"); res.extensions_mut().insert(extension); + } else { + debug!("No state to insert into response"); } Poll::Ready(Ok(res))