Skip to content

Commit

Permalink
Add initial support for Device Code Flow as per RFC 8628 (#113)
Browse files Browse the repository at this point in the history
* Add Client methods to set up the device authorization url and details
* Add Client methods to exchange for device codes, and to exchange codes
   for a token.
* Add an example that authorizes using Device Code Flow against Google.
  • Loading branch information
maxdymond authored Sep 9, 2020
1 parent b50328e commit 297fe19
Show file tree
Hide file tree
Showing 7 changed files with 1,390 additions and 32 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sha2 = "0.9"
url = { version = "2.1", features = ["serde"] }
chrono = "0.4"

[dev-dependencies]
hex = "0.4"
hmac = "0.8"
uuid = { version = "0.8", features = ["v4"] }
anyhow="1.0"
tokio = { version = "0.2", features = ["full"] }
async-std = "1.6.3"
81 changes: 81 additions & 0 deletions examples/google_devicecode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//!
//! This example showcases the Google OAuth2 process for requesting access to the Google Calendar features
//! and the user's profile.
//!
//! Before running it, you'll need to generate your own Google OAuth2 credentials.
//!
//! In order to run the example call:
//!
//! ```sh
//! GOOGLE_CLIENT_ID=xxx GOOGLE_CLIENT_SECRET=yyy cargo run --example google
//! ```
//!
//! ...and follow the instructions.
//!

use oauth2::basic::BasicClient;
// Alternatively, this can be oauth2::curl::http_client or a custom.
use oauth2::devicecode::{DeviceAuthorizationResponse, ExtraDeviceAuthorizationFields};
use oauth2::reqwest::http_client;
use oauth2::{AuthType, AuthUrl, ClientId, ClientSecret, DeviceAuthorizationUrl, Scope, TokenUrl};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;

#[derive(Debug, Serialize, Deserialize)]
struct StoringFields(HashMap<String, serde_json::Value>);

impl ExtraDeviceAuthorizationFields for StoringFields {}
type StoringDeviceAuthorizationResponse = DeviceAuthorizationResponse<StoringFields>;

fn main() {
let google_client_id = ClientId::new(
env::var("GOOGLE_CLIENT_ID").expect("Missing the GOOGLE_CLIENT_ID environment variable."),
);
let google_client_secret = ClientSecret::new(
env::var("GOOGLE_CLIENT_SECRET")
.expect("Missing the GOOGLE_CLIENT_SECRET environment variable."),
);
let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string())
.expect("Invalid token endpoint URL");
let device_auth_url =
DeviceAuthorizationUrl::new("https://oauth2.googleapis.com/device/code".to_string())
.expect("Invalid device authorization endpoint URL");

// Set up the config for the Google OAuth2 process.
//
// Google's OAuth endpoint expects the client_id to be in the request body,
// so ensure that option is set.
let device_client = BasicClient::new(
google_client_id,
Some(google_client_secret),
auth_url,
Some(token_url),
)
.set_device_authorization_url(device_auth_url)
.set_auth_type(AuthType::RequestBody);

// Request the set of codes from the Device Authorization endpoint.
let details: StoringDeviceAuthorizationResponse = device_client
.exchange_device_code()
.add_scope(Scope::new("profile".to_string()))
.request(http_client)
.expect("Failed to request codes from device auth endpoint");

// Display the URL and user-code.
println!(
"Open this URL in your browser:\n{}\nand enter the code: {}",
details.verification_uri().to_string(),
details.user_code().secret().to_string()
);

// Now poll for the token
let token = device_client
.exchange_device_access_token(&details)
.request(http_client, std::thread::sleep, None)
.expect("Failed to get token");

println!("Google returned the following token:\n{:?}\n", token);
}
2 changes: 1 addition & 1 deletion src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ pub enum BasicErrorResponseType {
Extension(String),
}
impl BasicErrorResponseType {
fn from_str(s: &str) -> Self {
pub(crate) fn from_str(s: &str) -> Self {
match s {
"invalid_client" => BasicErrorResponseType::InvalidClient,
"invalid_grant" => BasicErrorResponseType::InvalidGrant,
Expand Down
239 changes: 239 additions & 0 deletions src/devicecode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
use std::error::Error;
use std::fmt::Error as FormatterError;
use std::fmt::{Debug, Display, Formatter};
use std::marker::PhantomData;
use std::time::Duration;

use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

use super::{
DeviceCode, EndUserVerificationUrl, ErrorResponse, ErrorResponseType, RequestTokenError,
StandardErrorResponse, TokenResponse, TokenType, UserCode,
};
use crate::basic::BasicErrorResponseType;
use crate::types::VerificationUriComplete;

/// The minimum amount of time in seconds that the client SHOULD wait
/// between polling requests to the token endpoint. If no value is
/// provided, clients MUST use 5 as the default.
fn default_devicecode_interval() -> u64 {
5
}

///
/// Trait for adding extra fields to the `DeviceAuthorizationResponse`.
///
pub trait ExtraDeviceAuthorizationFields: DeserializeOwned + Debug + Serialize {}

#[derive(Clone, Debug, Deserialize, Serialize)]
///
/// Empty (default) extra token fields.
///
pub struct EmptyExtraDeviceAuthorizationFields {}
impl ExtraDeviceAuthorizationFields for EmptyExtraDeviceAuthorizationFields {}

///
/// Standard OAuth2 device authorization response.
///
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct DeviceAuthorizationResponse<EF>
where
EF: ExtraDeviceAuthorizationFields,
{
/// The device verification code.
device_code: DeviceCode,

/// The end-user verification code.
user_code: UserCode,

/// The end-user verification URI on the authorization The URI should be
/// short and easy to remember as end users will be asked to manually type
/// it into their user agent.
///
/// The `verification_url` alias here is a deviation from the RFC, as
/// implementations of device code flow predate RFC 8628.
#[serde(alias = "verification_url")]
verification_uri: EndUserVerificationUrl,

/// A verification URI that includes the "user_code" (or other information
/// with the same function as the "user_code"), which is designed for
/// non-textual transmission.
#[serde(skip_serializing_if = "Option::is_none")]
verification_uri_complete: Option<VerificationUriComplete>,

/// The lifetime in seconds of the "device_code" and "user_code".
expires_in: u64,

/// The minimum amount of time in seconds that the client SHOULD wait
/// between polling requests to the token endpoint. If no value is
/// provided, clients MUST use 5 as the default.
#[serde(default = "default_devicecode_interval")]
interval: u64,

#[serde(bound = "EF: ExtraDeviceAuthorizationFields", flatten)]
extra_fields: EF,
}

impl<EF> DeviceAuthorizationResponse<EF>
where
EF: ExtraDeviceAuthorizationFields,
{
/// The device verification code.
pub fn device_code(&self) -> &DeviceCode {
&self.device_code
}

/// The end-user verification code.
pub fn user_code(&self) -> &UserCode {
&self.user_code
}

/// The end-user verification URI on the authorization The URI should be
/// short and easy to remember as end users will be asked to manually type
/// it into their user agent.
pub fn verification_uri(&self) -> &EndUserVerificationUrl {
&self.verification_uri
}

/// A verification URI that includes the "user_code" (or other information
/// with the same function as the "user_code"), which is designed for
/// non-textual transmission.
pub fn verification_uri_complete(&self) -> Option<&VerificationUriComplete> {
self.verification_uri_complete.as_ref()
}

/// The lifetime in seconds of the "device_code" and "user_code".
pub fn expires_in(&self) -> Duration {
Duration::from_secs(self.expires_in)
}

/// The minimum amount of time in seconds that the client SHOULD wait
/// between polling requests to the token endpoint. If no value is
/// provided, clients MUST use 5 as the default.
pub fn interval(&self) -> Duration {
Duration::from_secs(self.interval)
}

/// Any extra fields returned on the response.
pub fn extra_fields(&self) -> &EF {
&self.extra_fields
}
}

///
/// Standard implementation of DeviceAuthorizationResponse which throws away
/// extra received response fields.
///
pub type StandardDeviceAuthorizationResponse =
DeviceAuthorizationResponse<EmptyExtraDeviceAuthorizationFields>;

///
/// Basic access token error types.
///
/// These error types are defined in
/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2) and
/// [Section 3.5 of RFC 6749](https://tools.ietf.org/html/rfc8628#section-3.5)
///
#[derive(Clone, PartialEq)]
pub enum DeviceCodeErrorResponseType {
///
/// The authorization request is still pending as the end user hasn't
/// yet completed the user-interaction steps. The client SHOULD repeat the
/// access token request to the token endpoint. Before each new request,
/// the client MUST wait at least the number of seconds specified by the
/// "interval" parameter of the device authorization response, or 5 seconds
/// if none was provided, and respect any increase in the polling interval
/// required by the "slow_down" error.
///
AuthorizationPending,
///
/// A variant of "authorization_pending", the authorization request is
/// still pending and polling should continue, but the interval MUST be
/// increased by 5 seconds for this and all subsequent requests.
SlowDown,
///
/// The authorization request was denied.
///
AccessDenied,
///
/// The "device_code" has expired, and the device authorization session has
/// concluded. The client MAY commence a new device authorization request
/// but SHOULD wait for user interaction before restarting to avoid
/// unnecessary polling.
ExpiredToken,
///
/// A Basic response type
///
Basic(BasicErrorResponseType),
}
impl DeviceCodeErrorResponseType {
fn from_str(s: &str) -> Self {
match BasicErrorResponseType::from_str(s) {
BasicErrorResponseType::Extension(ext) => match ext.as_str() {
"authorization_pending" => DeviceCodeErrorResponseType::AuthorizationPending,
"slow_down" => DeviceCodeErrorResponseType::SlowDown,
"access_denied" => DeviceCodeErrorResponseType::AccessDenied,
"expired_token" => DeviceCodeErrorResponseType::ExpiredToken,
_ => DeviceCodeErrorResponseType::Basic(BasicErrorResponseType::Extension(ext)),
},
basic => DeviceCodeErrorResponseType::Basic(basic),
}
}
}
impl AsRef<str> for DeviceCodeErrorResponseType {
fn as_ref(&self) -> &str {
match self {
DeviceCodeErrorResponseType::AuthorizationPending => "authorization_pending",
DeviceCodeErrorResponseType::SlowDown => "slow_down",
DeviceCodeErrorResponseType::AccessDenied => "access_denied",
DeviceCodeErrorResponseType::ExpiredToken => "expired_token",
DeviceCodeErrorResponseType::Basic(basic) => basic.as_ref(),
}
}
}
impl<'de> serde::Deserialize<'de> for DeviceCodeErrorResponseType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let variant_str = String::deserialize(deserializer)?;
Ok(Self::from_str(&variant_str))
}
}
impl serde::ser::Serialize for DeviceCodeErrorResponseType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
serializer.serialize_str(self.as_ref())
}
}
impl ErrorResponseType for DeviceCodeErrorResponseType {}
impl Debug for DeviceCodeErrorResponseType {
fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
Display::fmt(self, f)
}
}

impl Display for DeviceCodeErrorResponseType {
fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
write!(f, "{}", self.as_ref())
}
}

///
/// Error response specialization for device code OAuth2 implementation.
///
pub type DeviceCodeErrorResponse = StandardErrorResponse<DeviceCodeErrorResponseType>;

pub(crate) enum DeviceAccessTokenPollResult<TR, RE, TE, TT>
where
TE: ErrorResponse + 'static,
TR: TokenResponse<TT>,
TT: TokenType,
RE: Error + 'static,
{
ContinueWithNewPollInterval(Duration),
Done(Result<TR, RequestTokenError<RE, TE>>, PhantomData<TT>),
}
Loading

0 comments on commit 297fe19

Please sign in to comment.