From 297fe19ef4f32132b8571db57de0692b952d6bd4 Mon Sep 17 00:00:00 2001 From: Max Dymond Date: Wed, 9 Sep 2020 03:09:41 +0100 Subject: [PATCH] Add initial support for Device Code Flow as per RFC 8628 (#113) * Add Client methods to set up the device authorization url and details * Add Client methods to exchange for device codes, and to exchange codes for a token. * Add an example that authorizes using Device Code Flow against Google. --- Cargo.toml | 2 + examples/google_devicecode.rs | 81 ++++++ src/basic.rs | 2 +- src/devicecode.rs | 239 +++++++++++++++ src/lib.rs | 532 ++++++++++++++++++++++++++++++++-- src/tests.rs | 531 ++++++++++++++++++++++++++++++++- src/types.rs | 35 +++ 7 files changed, 1390 insertions(+), 32 deletions(-) create mode 100644 examples/google_devicecode.rs create mode 100644 src/devicecode.rs diff --git a/Cargo.toml b/Cargo.toml index bf9713ed..516d84a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sha2 = "0.9" url = { version = "2.1", features = ["serde"] } +chrono = "0.4" [dev-dependencies] hex = "0.4" @@ -36,3 +37,4 @@ hmac = "0.8" uuid = { version = "0.8", features = ["v4"] } anyhow="1.0" tokio = { version = "0.2", features = ["full"] } +async-std = "1.6.3" diff --git a/examples/google_devicecode.rs b/examples/google_devicecode.rs new file mode 100644 index 00000000..5d192701 --- /dev/null +++ b/examples/google_devicecode.rs @@ -0,0 +1,81 @@ +//! +//! This example showcases the Google OAuth2 process for requesting access to the Google Calendar features +//! and the user's profile. +//! +//! Before running it, you'll need to generate your own Google OAuth2 credentials. +//! +//! In order to run the example call: +//! +//! ```sh +//! GOOGLE_CLIENT_ID=xxx GOOGLE_CLIENT_SECRET=yyy cargo run --example google +//! ``` +//! +//! ...and follow the instructions. +//! + +use oauth2::basic::BasicClient; +// Alternatively, this can be oauth2::curl::http_client or a custom. +use oauth2::devicecode::{DeviceAuthorizationResponse, ExtraDeviceAuthorizationFields}; +use oauth2::reqwest::http_client; +use oauth2::{AuthType, AuthUrl, ClientId, ClientSecret, DeviceAuthorizationUrl, Scope, TokenUrl}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::env; + +#[derive(Debug, Serialize, Deserialize)] +struct StoringFields(HashMap); + +impl ExtraDeviceAuthorizationFields for StoringFields {} +type StoringDeviceAuthorizationResponse = DeviceAuthorizationResponse; + +fn main() { + let google_client_id = ClientId::new( + env::var("GOOGLE_CLIENT_ID").expect("Missing the GOOGLE_CLIENT_ID environment variable."), + ); + let google_client_secret = ClientSecret::new( + env::var("GOOGLE_CLIENT_SECRET") + .expect("Missing the GOOGLE_CLIENT_SECRET environment variable."), + ); + let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string()) + .expect("Invalid authorization endpoint URL"); + let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string()) + .expect("Invalid token endpoint URL"); + let device_auth_url = + DeviceAuthorizationUrl::new("https://oauth2.googleapis.com/device/code".to_string()) + .expect("Invalid device authorization endpoint URL"); + + // Set up the config for the Google OAuth2 process. + // + // Google's OAuth endpoint expects the client_id to be in the request body, + // so ensure that option is set. + let device_client = BasicClient::new( + google_client_id, + Some(google_client_secret), + auth_url, + Some(token_url), + ) + .set_device_authorization_url(device_auth_url) + .set_auth_type(AuthType::RequestBody); + + // Request the set of codes from the Device Authorization endpoint. + let details: StoringDeviceAuthorizationResponse = device_client + .exchange_device_code() + .add_scope(Scope::new("profile".to_string())) + .request(http_client) + .expect("Failed to request codes from device auth endpoint"); + + // Display the URL and user-code. + println!( + "Open this URL in your browser:\n{}\nand enter the code: {}", + details.verification_uri().to_string(), + details.user_code().secret().to_string() + ); + + // Now poll for the token + let token = device_client + .exchange_device_access_token(&details) + .request(http_client, std::thread::sleep, None) + .expect("Failed to get token"); + + println!("Google returned the following token:\n{:?}\n", token); +} diff --git a/src/basic.rs b/src/basic.rs index af0b190d..f74f437a 100644 --- a/src/basic.rs +++ b/src/basic.rs @@ -117,7 +117,7 @@ pub enum BasicErrorResponseType { Extension(String), } impl BasicErrorResponseType { - fn from_str(s: &str) -> Self { + pub(crate) fn from_str(s: &str) -> Self { match s { "invalid_client" => BasicErrorResponseType::InvalidClient, "invalid_grant" => BasicErrorResponseType::InvalidGrant, diff --git a/src/devicecode.rs b/src/devicecode.rs new file mode 100644 index 00000000..f3d01dc2 --- /dev/null +++ b/src/devicecode.rs @@ -0,0 +1,239 @@ +use std::error::Error; +use std::fmt::Error as FormatterError; +use std::fmt::{Debug, Display, Formatter}; +use std::marker::PhantomData; +use std::time::Duration; + +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +use super::{ + DeviceCode, EndUserVerificationUrl, ErrorResponse, ErrorResponseType, RequestTokenError, + StandardErrorResponse, TokenResponse, TokenType, UserCode, +}; +use crate::basic::BasicErrorResponseType; +use crate::types::VerificationUriComplete; + +/// The minimum amount of time in seconds that the client SHOULD wait +/// between polling requests to the token endpoint. If no value is +/// provided, clients MUST use 5 as the default. +fn default_devicecode_interval() -> u64 { + 5 +} + +/// +/// Trait for adding extra fields to the `DeviceAuthorizationResponse`. +/// +pub trait ExtraDeviceAuthorizationFields: DeserializeOwned + Debug + Serialize {} + +#[derive(Clone, Debug, Deserialize, Serialize)] +/// +/// Empty (default) extra token fields. +/// +pub struct EmptyExtraDeviceAuthorizationFields {} +impl ExtraDeviceAuthorizationFields for EmptyExtraDeviceAuthorizationFields {} + +/// +/// Standard OAuth2 device authorization response. +/// +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct DeviceAuthorizationResponse +where + EF: ExtraDeviceAuthorizationFields, +{ + /// The device verification code. + device_code: DeviceCode, + + /// The end-user verification code. + user_code: UserCode, + + /// The end-user verification URI on the authorization The URI should be + /// short and easy to remember as end users will be asked to manually type + /// it into their user agent. + /// + /// The `verification_url` alias here is a deviation from the RFC, as + /// implementations of device code flow predate RFC 8628. + #[serde(alias = "verification_url")] + verification_uri: EndUserVerificationUrl, + + /// A verification URI that includes the "user_code" (or other information + /// with the same function as the "user_code"), which is designed for + /// non-textual transmission. + #[serde(skip_serializing_if = "Option::is_none")] + verification_uri_complete: Option, + + /// The lifetime in seconds of the "device_code" and "user_code". + expires_in: u64, + + /// The minimum amount of time in seconds that the client SHOULD wait + /// between polling requests to the token endpoint. If no value is + /// provided, clients MUST use 5 as the default. + #[serde(default = "default_devicecode_interval")] + interval: u64, + + #[serde(bound = "EF: ExtraDeviceAuthorizationFields", flatten)] + extra_fields: EF, +} + +impl DeviceAuthorizationResponse +where + EF: ExtraDeviceAuthorizationFields, +{ + /// The device verification code. + pub fn device_code(&self) -> &DeviceCode { + &self.device_code + } + + /// The end-user verification code. + pub fn user_code(&self) -> &UserCode { + &self.user_code + } + + /// The end-user verification URI on the authorization The URI should be + /// short and easy to remember as end users will be asked to manually type + /// it into their user agent. + pub fn verification_uri(&self) -> &EndUserVerificationUrl { + &self.verification_uri + } + + /// A verification URI that includes the "user_code" (or other information + /// with the same function as the "user_code"), which is designed for + /// non-textual transmission. + pub fn verification_uri_complete(&self) -> Option<&VerificationUriComplete> { + self.verification_uri_complete.as_ref() + } + + /// The lifetime in seconds of the "device_code" and "user_code". + pub fn expires_in(&self) -> Duration { + Duration::from_secs(self.expires_in) + } + + /// The minimum amount of time in seconds that the client SHOULD wait + /// between polling requests to the token endpoint. If no value is + /// provided, clients MUST use 5 as the default. + pub fn interval(&self) -> Duration { + Duration::from_secs(self.interval) + } + + /// Any extra fields returned on the response. + pub fn extra_fields(&self) -> &EF { + &self.extra_fields + } +} + +/// +/// Standard implementation of DeviceAuthorizationResponse which throws away +/// extra received response fields. +/// +pub type StandardDeviceAuthorizationResponse = + DeviceAuthorizationResponse; + +/// +/// Basic access token error types. +/// +/// These error types are defined in +/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2) and +/// [Section 3.5 of RFC 6749](https://tools.ietf.org/html/rfc8628#section-3.5) +/// +#[derive(Clone, PartialEq)] +pub enum DeviceCodeErrorResponseType { + /// + /// The authorization request is still pending as the end user hasn't + /// yet completed the user-interaction steps. The client SHOULD repeat the + /// access token request to the token endpoint. Before each new request, + /// the client MUST wait at least the number of seconds specified by the + /// "interval" parameter of the device authorization response, or 5 seconds + /// if none was provided, and respect any increase in the polling interval + /// required by the "slow_down" error. + /// + AuthorizationPending, + /// + /// A variant of "authorization_pending", the authorization request is + /// still pending and polling should continue, but the interval MUST be + /// increased by 5 seconds for this and all subsequent requests. + SlowDown, + /// + /// The authorization request was denied. + /// + AccessDenied, + /// + /// The "device_code" has expired, and the device authorization session has + /// concluded. The client MAY commence a new device authorization request + /// but SHOULD wait for user interaction before restarting to avoid + /// unnecessary polling. + ExpiredToken, + /// + /// A Basic response type + /// + Basic(BasicErrorResponseType), +} +impl DeviceCodeErrorResponseType { + fn from_str(s: &str) -> Self { + match BasicErrorResponseType::from_str(s) { + BasicErrorResponseType::Extension(ext) => match ext.as_str() { + "authorization_pending" => DeviceCodeErrorResponseType::AuthorizationPending, + "slow_down" => DeviceCodeErrorResponseType::SlowDown, + "access_denied" => DeviceCodeErrorResponseType::AccessDenied, + "expired_token" => DeviceCodeErrorResponseType::ExpiredToken, + _ => DeviceCodeErrorResponseType::Basic(BasicErrorResponseType::Extension(ext)), + }, + basic => DeviceCodeErrorResponseType::Basic(basic), + } + } +} +impl AsRef for DeviceCodeErrorResponseType { + fn as_ref(&self) -> &str { + match self { + DeviceCodeErrorResponseType::AuthorizationPending => "authorization_pending", + DeviceCodeErrorResponseType::SlowDown => "slow_down", + DeviceCodeErrorResponseType::AccessDenied => "access_denied", + DeviceCodeErrorResponseType::ExpiredToken => "expired_token", + DeviceCodeErrorResponseType::Basic(basic) => basic.as_ref(), + } + } +} +impl<'de> serde::Deserialize<'de> for DeviceCodeErrorResponseType { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + let variant_str = String::deserialize(deserializer)?; + Ok(Self::from_str(&variant_str)) + } +} +impl serde::ser::Serialize for DeviceCodeErrorResponseType { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + serializer.serialize_str(self.as_ref()) + } +} +impl ErrorResponseType for DeviceCodeErrorResponseType {} +impl Debug for DeviceCodeErrorResponseType { + fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> { + Display::fmt(self, f) + } +} + +impl Display for DeviceCodeErrorResponseType { + fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> { + write!(f, "{}", self.as_ref()) + } +} + +/// +/// Error response specialization for device code OAuth2 implementation. +/// +pub type DeviceCodeErrorResponse = StandardErrorResponse; + +pub(crate) enum DeviceAccessTokenPollResult +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, + RE: Error + 'static, +{ + ContinueWithNewPollInterval(Duration), + Done(Result>, PhantomData), +} diff --git a/src/lib.rs b/src/lib.rs index 63c0be2e..f37eb79d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ //! * [Implicit Grant](#implicit-grant) //! * [Resource Owner Password Credentials Grant](#resource-owner-password-credentials-grant) //! * [Client Credentials Grant](#client-credentials-grant) +//! * [Device Code Flow](#device-code-flow) //! * [Other examples](#other-examples) //! * [Contributed Examples](#contributed-examples) //! @@ -55,13 +56,13 @@ //! ``` //! //! Synchronous HTTP clients should implement the following trait: -//! ```ignore +//! ```rust,ignore //! FnOnce(HttpRequest) -> Result //! where RE: std::error::Error + 'static //! ``` //! //! Async/await HTTP clients should implement the following trait: -//! ```ignore +//! ```rust,ignore //! FnOnce(HttpRequest) -> F //! where //! F: Future>, @@ -344,6 +345,63 @@ //! # } //! ``` //! +//! # Device Code Flow +//! +//! Device Code Flow allows users to sign in on browserless or input-constrained +//! devices. This is a two-stage process; first a user-code and verification +//! URL are obtained by using the `Client::exchange_client_credentials` +//! method. Those are displayed to the user, then are used in a second client +//! to poll the token endpoint for a token. +//! +//! ## Example +//! +//! ```rust,no_run +//! use anyhow; +//! use oauth2::{ +//! AuthUrl, +//! ClientId, +//! ClientSecret, +//! DeviceAuthorizationUrl, +//! Scope, +//! TokenResponse, +//! TokenUrl +//! }; +//! use oauth2::basic::BasicClient; +//! use oauth2::devicecode::StandardDeviceAuthorizationResponse; +//! use oauth2::reqwest::http_client; +//! use url::Url; +//! +//! # fn err_wrapper() -> Result<(), anyhow::Error> { +//! let device_auth_url = DeviceAuthorizationUrl::new("http://deviceauth".to_string())?; +//! let client = +//! BasicClient::new( +//! ClientId::new("client_id".to_string()), +//! Some(ClientSecret::new("client_secret".to_string())), +//! AuthUrl::new("http://authorize".to_string())?, +//! Some(TokenUrl::new("http://token".to_string())?), +//! ) +//! .set_device_authorization_url(device_auth_url); +//! +//! let details: StandardDeviceAuthorizationResponse = client +//! .exchange_device_code() +//! .add_scope(Scope::new("read".to_string())) +//! .request(http_client)?; +//! +//! println!( +//! "Open this URL in your browser:\n{}\nand enter the code: {}", +//! details.verification_uri().to_string(), +//! details.user_code().secret().to_string() +//! ); +//! +//! let token_result = +//! client +//! .exchange_device_access_token(&details) +//! .request(http_client, std::thread::sleep, None)?; +//! +//! # Ok(()) +//! # } +//! ``` +//! //! # Other examples //! //! More specific implementations are available as part of the examples: @@ -357,12 +415,14 @@ //! //! - [`actix-web-oauth2`](https://github.com/pka/actix-web-oauth2) (version 2.x of this crate) //! +use chrono::{DateTime, Utc}; use std::borrow::Cow; use std::error::Error; use std::fmt::Error as FormatterError; use std::fmt::{Debug, Display, Formatter}; use std::future::Future; use std::marker::PhantomData; +use std::sync::Arc; use std::time::Duration; use http::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE}; @@ -384,6 +444,16 @@ pub mod basic; #[cfg(feature = "curl")] pub mod curl; +/// +/// Device Code Flow OAuth2 implementation +/// ([RFC 8628](https://tools.ietf.org/html/rfc8628)). +/// +pub mod devicecode; +use devicecode::{ + DeviceAccessTokenPollResult, DeviceAuthorizationResponse, DeviceCodeErrorResponse, + DeviceCodeErrorResponseType, ExtraDeviceAuthorizationFields, +}; + /// /// Helper methods used by OAuth2 implementations/extensions. /// @@ -408,9 +478,10 @@ pub use http; pub use url; pub use types::{ - AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, + AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, + DeviceAuthorizationUrl, DeviceCode, EndUserVerificationUrl, PkceCodeChallenge, PkceCodeChallengeMethod, PkceCodeVerifier, RedirectUrl, RefreshToken, ResourceOwnerPassword, - ResourceOwnerUsername, ResponseType, Scope, TokenUrl, + ResourceOwnerUsername, ResponseType, Scope, TokenUrl, UserCode, }; const CONTENT_TYPE_JSON: &str = "application/json"; @@ -447,6 +518,7 @@ where auth_type: AuthType, token_url: Option, redirect_url: Option, + device_authorization_url: Option, phantom_te: PhantomData, phantom_tr: PhantomData, phantom_tt: PhantomData, @@ -491,6 +563,7 @@ where auth_type: AuthType::BasicAuth, token_url, redirect_url: None, + device_authorization_url: None, phantom_te: PhantomData, phantom_tr: PhantomData, phantom_tt: PhantomData, @@ -519,6 +592,19 @@ where self } + /// + /// Sets the the device authorization URL used by the device authorization endpoint. + /// Used for Device Code Flow, as per [RFC 8628](https://tools.ietf.org/html/rfc8628). + /// + pub fn set_device_authorization_url( + mut self, + device_authorization_url: DeviceAuthorizationUrl, + ) -> Self { + self.device_authorization_url = Some(device_authorization_url); + + self + } + /// /// Generates an authorization URL for a new authorization request. /// @@ -641,6 +727,46 @@ where _phantom: PhantomData, } } + + /// + /// Perform a device authorization request as per + /// https://tools.ietf.org/html/rfc8628#section-3.1 + /// + pub fn exchange_device_code(&self) -> DeviceAuthorizationRequest { + DeviceAuthorizationRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + extra_params: Vec::new(), + scopes: Vec::new(), + device_authorization_url: self.device_authorization_url.as_ref(), + _phantom: PhantomData, + } + } + + /// + /// Perform a device access token request as per + /// https://tools.ietf.org/html/rfc8628#section-3.4 + /// + pub fn exchange_device_access_token<'a, 'b, 'c, EF>( + &'a self, + auth_response: &'b DeviceAuthorizationResponse, + ) -> DeviceAccessTokenRequest<'b, 'c, TR, TT, EF> + where + 'a: 'b, + EF: ExtraDeviceAuthorizationFields, + { + DeviceAccessTokenRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + extra_params: Vec::new(), + token_url: self.token_url.as_ref(), + dev_auth_resp: auth_response, + time_fn: Arc::new(Utc::now), + _phantom: PhantomData, + } + } } /// @@ -879,7 +1005,7 @@ where params.push(("code_verifier", pkce_verifier.secret())); } - Ok(token_request( + Ok(endpoint_request( self.auth_type, self.client_id, self.client_secret, @@ -887,7 +1013,8 @@ where self.redirect_url, None, self.token_url - .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))?, + .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))? + .url(), params, )) } @@ -902,7 +1029,7 @@ where { http_client(self.prepare_request()?) .map_err(RequestTokenError::Request) - .and_then(token_response) + .and_then(endpoint_response) } /// @@ -921,7 +1048,7 @@ where let http_response = http_client(http_request) .await .map_err(RequestTokenError::Request)?; - token_response(http_response) + endpoint_response(http_response) } } @@ -994,7 +1121,7 @@ where { http_client(self.prepare_request()?) .map_err(RequestTokenError::Request) - .and_then(token_response) + .and_then(endpoint_response) } /// /// Asynchronously sends the request to the authorization server and awaits a response. @@ -1012,14 +1139,14 @@ where let http_response = http_client(http_request) .await .map_err(RequestTokenError::Request)?; - token_response(http_response) + endpoint_response(http_response) } fn prepare_request(&self) -> Result> where RE: Error + 'static, { - Ok(token_request( + Ok(endpoint_request( self.auth_type, self.client_id, self.client_secret, @@ -1027,7 +1154,8 @@ where None, Some(&self.scopes), self.token_url - .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))?, + .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))? + .url(), vec![ ("grant_type", "refresh_token"), ("refresh_token", self.refresh_token.secret()), @@ -1106,7 +1234,7 @@ where { http_client(self.prepare_request()?) .map_err(RequestTokenError::Request) - .and_then(token_response) + .and_then(endpoint_response) } /// @@ -1125,14 +1253,14 @@ where let http_response = http_client(http_request) .await .map_err(RequestTokenError::Request)?; - token_response(http_response) + endpoint_response(http_response) } fn prepare_request(&self) -> Result> where RE: Error + 'static, { - Ok(token_request( + Ok(endpoint_request( self.auth_type, self.client_id, self.client_secret, @@ -1140,7 +1268,8 @@ where None, Some(&self.scopes), self.token_url - .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))?, + .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))? + .url(), vec![ ("grant_type", "password"), ("username", self.username), @@ -1218,7 +1347,7 @@ where { http_client(self.prepare_request()?) .map_err(RequestTokenError::Request) - .and_then(token_response) + .and_then(endpoint_response) } /// @@ -1237,14 +1366,14 @@ where let http_response = http_client(http_request) .await .map_err(RequestTokenError::Request)?; - token_response(http_response) + endpoint_response(http_response) } fn prepare_request(&self) -> Result> where RE: Error + 'static, { - Ok(token_request( + Ok(endpoint_request( self.auth_type, self.client_id, self.client_secret, @@ -1252,21 +1381,22 @@ where None, Some(&self.scopes), self.token_url - .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))?, + .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))? + .url(), vec![("grant_type", "client_credentials")], )) } } #[allow(clippy::too_many_arguments)] -fn token_request<'a>( +fn endpoint_request<'a>( auth_type: &'a AuthType, client_id: &'a ClientId, client_secret: Option<&'a ClientSecret>, extra_params: &'a [(Cow<'a, str>, Cow<'a, str>)], redirect_url: Option<&'a RedirectUrl>, scopes: Option<&'a Vec>>, - token_url: &'a TokenUrl, + url: &'a Url, params: Vec<(&'a str, &'a str)>, ) -> HttpRequest { let mut headers = HeaderMap::new(); @@ -1343,21 +1473,20 @@ fn token_request<'a>( .into_bytes(); HttpRequest { - url: token_url.url().to_owned(), + url: url.to_owned(), method: http::method::Method::POST, headers, body, } } -fn token_response( +fn endpoint_response( http_response: HttpResponse, -) -> Result> +) -> Result> where RE: Error + 'static, TE: ErrorResponse, - TR: TokenResponse, - TT: TokenType, + DO: DeserializeOwned, { if http_response.status_code != StatusCode::OK { let reason = http_response.body.as_slice(); @@ -1408,6 +1537,355 @@ where } } +/// +/// The request for a set of verification codes from the authorization server. +/// +/// See https://tools.ietf.org/html/rfc8628#section-3.1. +/// +#[derive(Debug)] +pub struct DeviceAuthorizationRequest<'a, TE> +where + TE: ErrorResponse, +{ + auth_type: &'a AuthType, + client_id: &'a ClientId, + client_secret: Option<&'a ClientSecret>, + extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + scopes: Vec>, + device_authorization_url: Option<&'a DeviceAuthorizationUrl>, + _phantom: PhantomData, +} + +impl<'a, TE> DeviceAuthorizationRequest<'a, TE> +where + TE: ErrorResponse + 'static, +{ + /// + /// Appends an extra param to the token request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + /// + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// + /// Appends a new scope to the token request. + /// + pub fn add_scope(mut self, scope: Scope) -> Self { + self.scopes.push(Cow::Owned(scope)); + self + } + + fn prepare_request(self) -> Result> + where + RE: Error + 'static, + { + Ok(endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + Some(&self.scopes), + self.device_authorization_url + .ok_or_else(|| { + RequestTokenError::Other("no device authorization_url provided".to_string()) + })? + .url(), + vec![], + )) + } + + /// + /// Synchronously sends the request to the authorization server and awaits a response. + /// + pub fn request( + self, + http_client: F, + ) -> Result, RequestTokenError> + where + F: FnOnce(HttpRequest) -> Result, + RE: Error + 'static, + EF: ExtraDeviceAuthorizationFields, + { + http_client(self.prepare_request()?) + .map_err(RequestTokenError::Request) + .and_then(endpoint_response) + } + + /// + /// Asynchronously sends the request to the authorization server and returns a Future. + /// + pub async fn request_async( + self, + http_client: C, + ) -> Result, RequestTokenError> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + 'static, + EF: ExtraDeviceAuthorizationFields, + { + let http_request = self.prepare_request()?; + let http_response = http_client(http_request) + .await + .map_err(RequestTokenError::Request)?; + endpoint_response(http_response) + } +} + +/// +/// The request for an device access token from the authorization server. +/// +/// See https://tools.ietf.org/html/rfc8628#section-3.4. +/// +#[derive(Clone)] +pub struct DeviceAccessTokenRequest<'a, 'b, TR, TT, EF> +where + TR: TokenResponse, + TT: TokenType, + EF: ExtraDeviceAuthorizationFields, +{ + auth_type: &'a AuthType, + client_id: &'a ClientId, + client_secret: Option<&'a ClientSecret>, + extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + token_url: Option<&'a TokenUrl>, + dev_auth_resp: &'a DeviceAuthorizationResponse, + time_fn: Arc DateTime + 'b + Send + Sync>, + _phantom: PhantomData<(TR, TT, EF)>, +} + +impl<'a, 'b, TR, TT, EF> DeviceAccessTokenRequest<'a, 'b, TR, TT, EF> +where + TR: TokenResponse, + TT: TokenType, + EF: ExtraDeviceAuthorizationFields, +{ + /// + /// Appends an extra param to the token request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + /// + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// + /// Specifies a function for returning the current time. + /// + /// This function is used while polling the authorization server. + /// + pub fn set_time_fn(mut self, time_fn: T) -> Self + where + T: Fn() -> DateTime + 'b + Send + Sync, + { + self.time_fn = Arc::new(time_fn); + self + } + + /// + /// Synchronously polls the authorization server for a response, waiting + /// using a user defined sleep function. + /// + pub fn request( + self, + http_client: F, + sleep_fn: S, + timeout: Option, + ) -> Result> + where + F: Fn(HttpRequest) -> Result, + S: Fn(Duration), + RE: Error + 'static, + { + // Get the request timeout and starting interval + let timeout_dt = self.compute_timeout(timeout)?; + let mut interval = self.dev_auth_resp.interval(); + + // Loop while requesting a token. + loop { + let now = (*self.time_fn)(); + if now > timeout_dt { + break Err(RequestTokenError::Other("Device code expired".to_string())); + } + + match self.process_response(http_client(self.prepare_request()?), interval) { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { + interval = new_interval + } + DeviceAccessTokenPollResult::Done(res, _) => break res, + } + + // Sleep here using the provided sleep function. + sleep_fn(interval); + } + } + + /// + /// Asynchronously sends the request to the authorization server and awaits a response. + /// + pub async fn request_async( + self, + http_client: C, + sleep_fn: S, + timeout: Option, + ) -> Result> + where + C: Fn(HttpRequest) -> F, + F: Future>, + S: Fn(Duration) -> SF, + SF: Future, + RE: Error + 'static, + { + // Get the request timeout and starting interval + let timeout_dt = self.compute_timeout(timeout)?; + let mut interval = self.dev_auth_resp.interval(); + + // Loop while requesting a token. + loop { + let now = (*self.time_fn)(); + if now > timeout_dt { + break Err(RequestTokenError::Other("Device code expired".to_string())); + } + + match self.process_response(http_client(self.prepare_request()?).await, interval) { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { + interval = new_interval + } + DeviceAccessTokenPollResult::Done(res, _) => break res, + } + + // Sleep here using the provided sleep function. + sleep_fn(interval); + } + } + + fn prepare_request( + &self, + ) -> Result> + where + RE: Error + 'static, + { + Ok(endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + None, + self.token_url + .ok_or_else(|| RequestTokenError::Other("no token_url provided".to_string()))? + .url(), + vec![ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("device_code", self.dev_auth_resp.device_code().secret()), + ], + )) + } + + fn process_response( + &self, + res: Result, + current_interval: Duration, + ) -> DeviceAccessTokenPollResult + where + RE: Error + 'static, + { + let http_response = match res { + Ok(inner) => inner, + Err(_) => { + // Try and double the current interval. If that fails, just use the current one. + let new_interval = current_interval.checked_mul(2).unwrap_or(current_interval); + return DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval); + } + }; + + // Explicitly process the response with a DeviceCodeErrorResponse + let res = endpoint_response::(http_response); + match res { + // On a ServerResponse error, the error needs inspecting as a DeviceCodeErrorResponse + // to work out whether a retry needs to happen. + Err(RequestTokenError::ServerResponse(dcer)) => { + match dcer.error() { + // On AuthorizationPending, a retry needs to happen with the same poll interval. + DeviceCodeErrorResponseType::AuthorizationPending => { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval(current_interval) + } + // On SlowDown, a retry needs to happen with a larger poll interval. + DeviceCodeErrorResponseType::SlowDown => { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval( + current_interval + Duration::from_secs(5), + ) + } + + // On any other error, just return the error. + _ => DeviceAccessTokenPollResult::Done( + Err(RequestTokenError::ServerResponse(dcer)), + PhantomData, + ), + } + } + + // On any other success or failure, return the failure. + res => DeviceAccessTokenPollResult::Done(res, PhantomData), + } + } + + fn compute_timeout( + &self, + timeout: Option, + ) -> Result, RequestTokenError> + where + RE: Error + 'static, + { + // Calculate the request timeout - if the user specified a timeout, + // use that, otherwise use the value given by the device authorization + // response. + let timeout_dur = timeout.unwrap_or_else(|| self.dev_auth_resp.expires_in()); + let chrono_timeout = chrono::Duration::from_std(timeout_dur) + .map_err(|_| RequestTokenError::Other("Failed to convert duration".to_string()))?; + + // Calculate the DateTime at which the request times out. + let timeout_dt = (*self.time_fn)() + .checked_add_signed(chrono_timeout) + .ok_or_else(|| RequestTokenError::Other("Failed to calculate timeout".to_string()))?; + + Ok(timeout_dt) + } +} + /// /// Trait for OAuth2 access tokens. /// diff --git a/src/tests.rs b/src/tests.rs index acad62b0..4f612d67 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -5,6 +5,7 @@ use url::form_urlencoded::byte_serialize; use url::Url; use super::basic::*; +use super::devicecode::*; use super::*; fn new_client() -> BasicClient { @@ -19,12 +20,15 @@ fn new_client() -> BasicClient { fn mock_http_client( request_headers: Vec<(HeaderName, &'static str)>, request_body: &'static str, + request_url: Option, response: HttpResponse, -) -> impl FnOnce(HttpRequest) -> Result { +) -> impl Fn(HttpRequest) -> Result { move |request: HttpRequest| { assert_eq!( - request.url, - Url::parse("https://example.com/token").unwrap() + &request.url, + request_url + .as_ref() + .unwrap_or(&Url::parse("https://example.com/token").unwrap()) ); assert_eq!( request.headers, @@ -35,7 +39,7 @@ fn mock_http_client( ); assert_eq!(&String::from_utf8(request.body).unwrap(), request_body); - Ok(response) + Ok(response.clone()) } } @@ -293,6 +297,7 @@ fn test_exchange_code_successful_with_minimal_json_response() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::OK, headers: HeaderMap::new(), @@ -330,6 +335,7 @@ fn test_exchange_code_successful_with_complete_json_response() { (CONTENT_TYPE, "application/x-www-form-urlencoded"), ], "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -394,6 +400,7 @@ fn test_exchange_client_credentials_with_basic_auth() { (AUTHORIZATION, "Basic YWFhJTJGJTNCJTI2OmJiYiUyRiUzQiUyNg=="), ], "grant_type=client_credentials", + None, HttpResponse { status_code: StatusCode::OK, headers: HeaderMap::new(), @@ -434,6 +441,7 @@ fn test_exchange_client_credentials_with_body_auth_and_scope() { (CONTENT_TYPE, "application/x-www-form-urlencoded"), ], "grant_type=client_credentials&scope=read+write&client_id=aaa&client_secret=bbb", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -478,6 +486,7 @@ fn test_exchange_refresh_token_with_basic_auth() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=refresh_token&refresh_token=ccc", + None, HttpResponse { status_code: StatusCode::OK, headers: HeaderMap::new(), @@ -516,6 +525,7 @@ fn test_exchange_refresh_token_with_json_response() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=refresh_token&refresh_token=ccc", + None, HttpResponse { status_code: StatusCode::OK, headers: HeaderMap::new(), @@ -560,6 +570,7 @@ fn test_exchange_password_with_json_response() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=password&username=user&password=pass&scope=read+write", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -607,6 +618,7 @@ fn test_exchange_code_successful_with_redirect_url() { ], "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb&\ redirect_uri=https%3A%2F%2Fredirect%2Fhere", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -654,6 +666,7 @@ fn test_exchange_code_successful_with_basic_auth() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc&redirect_uri=https%3A%2F%2Fredirect%2Fhere", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -709,6 +722,7 @@ fn test_exchange_code_successful_with_pkce_and_extension() { &code_verifier=dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk\ &redirect_uri=https%3A%2F%2Fredirect%2Fhere\ &foo=bar", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -757,6 +771,7 @@ fn test_exchange_refresh_token_successful_with_extension() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=refresh_token&refresh_token=ccc&foo=bar", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -801,6 +816,7 @@ fn test_exchange_code_with_simple_json_error() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::BAD_REQUEST, headers: vec![( @@ -887,6 +903,7 @@ fn test_exchange_code_with_json_parse_error() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -923,6 +940,7 @@ fn test_exchange_code_with_unexpected_content_type() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![(CONTENT_TYPE, HeaderValue::from_str("text/plain").unwrap())] @@ -963,6 +981,7 @@ fn test_exchange_code_with_invalid_token_type() { (AUTHORIZATION, "Basic YWFhOg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -1001,6 +1020,7 @@ fn test_exchange_code_with_400_status_code() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::BAD_REQUEST, headers: vec![( @@ -1145,6 +1165,7 @@ fn test_extension_successful_with_minimal_json_response() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -1197,6 +1218,7 @@ fn test_extension_successful_with_complete_json_response() { (CONTENT_TYPE, "application/x-www-form-urlencoded"), ], "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb", + None, HttpResponse { status_code: StatusCode::OK, headers: vec![( @@ -1266,6 +1288,7 @@ fn test_extension_with_simple_json_error() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::BAD_REQUEST, headers: vec![( @@ -1372,6 +1395,7 @@ fn test_extension_with_custom_json_error() { (AUTHORIZATION, "Basic YWFhOmJiYg=="), ], "grant_type=authorization_code&code=ccc", + None, HttpResponse { status_code: StatusCode::BAD_REQUEST, headers: vec![( @@ -1456,6 +1480,489 @@ fn test_secret_redaction() { assert_eq!("ClientSecret([redacted])", format!("{:?}", secret)); } +fn new_device_auth_details(expires_in: u32) -> StandardDeviceAuthorizationResponse { + let body = format!( + "{{\ + \"device_code\": \"12345\", \ + \"verification_uri\": \"https://verify/here\", \ + \"user_code\": \"abcde\", \ + \"verification_uri_complete\": \"https://verify/here?abcde\", \ + \"expires_in\": {}, \ + \"interval\": 1 \ + }}", + expires_in + ); + + let device_auth_url = + DeviceAuthorizationUrl::new("https://deviceauth/here".to_string()).unwrap(); + + let client = new_client().set_device_authorization_url(device_auth_url.clone()); + client + .exchange_device_code() + .add_extra_param("foo", "bar") + .add_scope(Scope::new("openid".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "scope=openid&foo=bar", + Some(device_auth_url.url().to_owned()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: body.into_bytes(), + }, + )) + .unwrap() +} + +struct IncreasingTime { + times: std::ops::RangeFrom, +} + +impl IncreasingTime { + fn new() -> Self { + Self { times: (0..) } + } + fn next(&mut self) -> DateTime { + let next_value = self.times.next().unwrap(); + let naive = chrono::NaiveDateTime::from_timestamp(next_value, 0); + DateTime::::from_utc(naive, chrono::Utc) + } +} + +/// Creates a time function that increments by one second each time. +fn mock_time_fn() -> impl Fn() -> DateTime + Send + Sync { + let timer = std::sync::Mutex::new(IncreasingTime::new()); + move || timer.lock().unwrap().next() +} + +/// Mock sleep function that doesn't actually sleep. +fn mock_sleep_fn(_: Duration) {} + +#[test] +fn test_exchange_device_code_and_token() { + let details = new_device_auth_details(3600); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(3600), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"openid\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![Scope::new("openid".to_string()),]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_device_token_authorization_timeout() { + let details = new_device_auth_details(2); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(2), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"authorization_pending\", \ + \"error_description\": \"Still waiting for user\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .err() + .unwrap(); + match token { + RequestTokenError::Other(msg) => assert_eq!(msg, "Device code expired"), + _ => unreachable!("Error should be an expiry"), + } +} + +#[test] +fn test_device_token_access_denied() { + let details = new_device_auth_details(2); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(2), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"access_denied\", \ + \"error_description\": \"Access Denied\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .err() + .unwrap(); + match token { + RequestTokenError::ServerResponse(msg) => { + assert_eq!(msg.error(), &DeviceCodeErrorResponseType::AccessDenied) + } + _ => unreachable!("Error should be Access Denied"), + } +} + +#[test] +fn test_device_token_expired() { + let details = new_device_auth_details(2); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(2), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"expired_token\", \ + \"error_description\": \"Token has expired\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .err() + .unwrap(); + match token { + RequestTokenError::ServerResponse(msg) => { + assert_eq!(msg.error(), &DeviceCodeErrorResponseType::ExpiredToken) + } + _ => unreachable!("Error should be ExpiredToken"), + } +} + +fn mock_http_client_success_fail( + request_url: Option, + request_headers: Vec<(HeaderName, &'static str)>, + request_body: &'static str, + failure_response: HttpResponse, + num_failures: usize, + success_response: HttpResponse, +) -> impl Fn(HttpRequest) -> Result { + let responses: Vec = std::iter::repeat(failure_response) + .take(num_failures) + .chain(std::iter::once(success_response)) + .collect(); + let sync_responses = std::sync::Mutex::new(responses); + + move |request: HttpRequest| { + assert_eq!( + &request.url, + request_url + .as_ref() + .unwrap_or(&Url::parse("https://example.com/token").unwrap()) + ); + assert_eq!( + request.headers, + request_headers + .iter() + .map(|(name, value)| (name.clone(), HeaderValue::from_str(value).unwrap())) + .collect(), + ); + assert_eq!(&String::from_utf8(request.body).unwrap(), request_body); + + { + let mut rsp_vec = sync_responses.lock().unwrap(); + if rsp_vec.len() == 0 { + Err(FakeError::Err) + } else { + Ok(rsp_vec.remove(0)) + } + } + } +} + +#[test] +fn test_device_token_pending_then_success() { + let details = new_device_auth_details(20); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(20), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client_success_fail( + None, + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"authorization_pending\", \ + \"error_description\": \"Still waiting for user\"\ + }" + .to_string() + .into_bytes(), + }, + 5, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"openid\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![Scope::new("openid".to_string()),]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_device_token_slowdown_then_success() { + let details = new_device_auth_details(3600); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(3600), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client_success_fail( + None, + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"slow_down\", \ + \"error_description\": \"Woah there partner\"\ + }" + .to_string() + .into_bytes(), + }, + 5, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"openid\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![Scope::new("openid".to_string()),]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + #[test] fn test_send_sync_impl() { fn is_sync_and_send() {}; @@ -1532,6 +2039,22 @@ fn test_send_sync_impl() { is_sync_and_send::>>( ); + is_sync_and_send::(); + is_sync_and_send::(); + is_sync_and_send::(); + is_sync_and_send::(); + is_sync_and_send::(); + is_sync_and_send::< + DeviceAccessTokenRequest< + StandardTokenResponse, + BasicTokenType, + EmptyExtraDeviceAuthorizationFields, + >, + >(); + is_sync_and_send::>>(); + is_sync_and_send::(); + is_sync_and_send::(); + #[cfg(feature = "curl")] is_sync_and_send::(); #[cfg(feature = "reqwest-010")] diff --git a/src/types.rs b/src/types.rs index 15eff314..75b3480c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -347,6 +347,18 @@ new_url_type![ /// RedirectUrl ]; +new_url_type![ + /// + /// URL of the client's device authorization endpoint. + /// + DeviceAuthorizationUrl +]; +new_url_type![ + /// + /// URL of the end-user verification URI on the authorization server. + /// + EndUserVerificationUrl +]; new_type![ /// /// Authorization endpoint response (grant) type defined in @@ -547,3 +559,26 @@ new_secret_type![ #[derive(Clone)] ResourceOwnerPassword(String) ]; +new_secret_type![ + /// + /// Device code returned by the device authorization endpoint and used to query the token endpoint. + /// + #[derive(Clone, Deserialize, Serialize)] + DeviceCode(String) +]; +new_secret_type![ + /// + /// Verification URI returned by the device authorization endpoint and visited by the user + /// to authorize. Contains the user code. + /// + #[derive(Clone, Deserialize, Serialize)] + VerificationUriComplete(String) +]; +new_secret_type![ + /// + /// User code returned by the device authorization endpoint and used by the user to authorize at + /// the verification URI. + /// + #[derive(Clone, Deserialize, Serialize)] + UserCode(String) +];