diff --git a/sdk/storage/src/authorization/authorization_policy.rs b/sdk/storage/src/authorization/authorization_policy.rs index c5334ce079..aee7c1b049 100644 --- a/sdk/storage/src/authorization/authorization_policy.rs +++ b/sdk/storage/src/authorization/authorization_policy.rs @@ -52,6 +52,7 @@ impl Policy for AuthorizationPolicy { request } StorageCredentials::SASToken(query_pairs) => { + // Ensure the signature param is not already present if !request.url().query_pairs().any(|(k, _)| &*k == "sig") { request .url_mut() @@ -225,3 +226,77 @@ fn lexy_sort<'a>( values.sort_unstable(); values } + +#[cfg(test)] +mod tests { + use super::*; + use azure_core::{BytesStream, Response}; + + #[derive(Debug, Clone)] + struct AssertSigHeaderUniqueMockPolicy; + + #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] + #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] + impl Policy for AssertSigHeaderUniqueMockPolicy { + async fn send( + &self, + _ctx: &Context, + request: &mut Request, + _next: &[Arc], + ) -> PolicyResult { + let sig_header_count = request + .url() + .query_pairs() + .filter(|param| param.0 == "sig") + .count(); + assert_eq!(sig_header_count, 1); + + Ok(Response::new( + azure_core::StatusCode::Accepted, + Headers::new(), + Box::pin(BytesStream::new(vec![])), + )) + } + } + + const SAMPLE_SAS_TOKEN: &str = "sp=r&st=1970-01-01T00:00:00Z&se=1970-01-01T00:00:00Z&spr=https&sv=1970-01-01&sr=c&sig=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"; + + #[tokio::test] + async fn authorization_policy_applies_sas_token() { + let ctx = Context::default(); + let storage_credentials = StorageCredentials::sas_token(SAMPLE_SAS_TOKEN).unwrap(); + let auth_policy = AuthorizationPolicy::new(storage_credentials); + let mut request = Request::new(Url::parse("https://example.com").unwrap(), Method::Get); + + let assert_sig_header_unique_mock_policy = Arc::new(AssertSigHeaderUniqueMockPolicy); + + auth_policy + .send(&ctx, &mut request, &[assert_sig_header_unique_mock_policy]) + .await + .unwrap(); + } + + #[tokio::test] + async fn authorization_policy_with_sas_token_does_not_apply_twice() { + let ctx = Context::default(); + let storage_credentials = StorageCredentials::sas_token(SAMPLE_SAS_TOKEN).unwrap(); + let auth_policy = AuthorizationPolicy::new(storage_credentials); + let mut request = Request::new(Url::parse("https://example.com").unwrap(), Method::Get); + + let assert_sig_header_unique_mock_policy = Arc::new(AssertSigHeaderUniqueMockPolicy); + + // apply policy twice + auth_policy + .send( + &ctx, + &mut request, + &[assert_sig_header_unique_mock_policy.clone()], + ) + .await + .unwrap(); + auth_policy + .send(&ctx, &mut request, &[assert_sig_header_unique_mock_policy]) + .await + .unwrap(); + } +}