Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various simplification of the code #24

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 66 additions & 82 deletions sanction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,50 @@
# vim: set ts=4 sw=)

from functools import wraps
from json import loads
import json
from datetime import datetime, timedelta
from time import mktime
try:
from urllib import urlencode
from urllib2 import Request, urlopen
from urlparse import urlsplit, urlunsplit, parse_qsl

# monkeypatch httpmessage
from httplib import HTTPMessage
def get_charset(self):
def get_content_charset(self, failobj=None):
try:
data = filter(lambda s: 'Content-Type' in s, self.headers)[0]
# Example: Content-Type: text/html; charset=ISO-8859-1
# https://tools.ietf.org/html/rfc7231#section-3.1.1.1
data = self.headers.getheader('Content-Type')
if 'charset' in data:
cs = data[data.index(';') + 1:-2].split('=')[1].lower()
return cs
return data.split(';')[1].split('=')[1].lower()
except IndexError:
pass

return 'utf-8'
HTTPMessage.get_content_charset = get_charset
return failobj
# monkeypatch HTTPMmessage
HTTPMessage.get_content_charset = get_content_charset
except ImportError: # pragma: no cover
from urllib.parse import urlencode, urlsplit, urlunsplit, parse_qsl
from urllib.request import Request, urlopen


def _request(url, data, method):
# This is not even really needed. It is in case somebody wants to use
# method other than GET / POST, which is a bit of a stretch.
try:
# Python 3.3+ (2012).
return Request(url, data=data, method=method)
except TypeError:
req = Request(url, data=data)
req.get_method = lambda: method
return req


class Client(object):
""" OAuth 2.0 client object
"""
"""OAuth 2.0 client object"""

def __init__(self, auth_endpoint=None, token_endpoint=None,
resource_endpoint=None, client_id=None, client_secret=None,
resource_endpoint="", client_id=None, client_secret=None,
token_transport=None):
""" Instantiates a `Client` to authorize and authenticate a user
"""Instantiates a `Client` to authorize and authenticate a user

:param auth_endpoint: The authorization endpoint as issued by the
provider. This is where the user should be
Expand All @@ -47,6 +57,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None,
:param client_id: The client ID as issued by the provider.
:param client_secret: The client secret as issued by the provider. This
must not be shared.
:param token_transport: (optional) Callable to construct requests.
"""
assert token_transport is None or hasattr(token_transport, '__call__')

Expand All @@ -60,10 +71,8 @@ def __init__(self, auth_endpoint=None, token_endpoint=None,
self.token_expires = -1
self.refresh_token = None

def auth_uri(self, redirect_uri=None, scope=None, scope_delim=None,
state=None, **kwargs):

""" Builds the auth URI for the authorization endpoint
def auth_uri(self, redirect_uri=None, scope=None, state=None, **kwargs):
"""Builds the auth URI for the authorization endpoint

:param scope: (optional) The `scope` parameter to pass for
authorization. The format should match that expected by
Expand All @@ -76,24 +85,23 @@ def auth_uri(self, redirect_uri=None, scope=None, scope_delim=None,
:param **kwargs: Any other querystring parameters to be passed to the
provider.
"""
kwargs.update({
'client_id': self.client_id,
'response_type': 'code',
})
kwargs['client_id'] = self.client_id,
kwargs['response_type'] = 'code'

if scope is not None:
if scope:
kwargs['scope'] = scope

if state is not None:
if state:
kwargs['state'] = state

if redirect_uri is not None:
if redirect_uri:
kwargs['redirect_uri'] = redirect_uri

return '%s?%s' % (self.auth_endpoint, urlencode(kwargs))

def request_token(self, parser=None, redirect_uri=None, **kwargs):
""" Request an access token from the token endpoint.
""" Requests an access token from the token endpoint.

This is largely a helper method and expects the client code to
understand what the server expects. Anything that's passed into
``**kwargs`` will be sent (``urlencode``d) to the endpoint. Client
Expand All @@ -111,31 +119,28 @@ def request_token(self, parser=None, redirect_uri=None, **kwargs):
'grant_type': 'refresh_token',
}

:param parser: Callback to deal with returned data. Not all providers
use JSON.
:param parser: Callback to deal with returned data. By default JSON.
"""
kwargs = kwargs and kwargs or {}

parser = parser or _default_parser
kwargs.update({
'client_id': self.client_id,
'client_secret': self.client_secret,
'grant_type': 'grant_type' in kwargs and kwargs['grant_type'] or \
'authorization_code'
})
if redirect_uri is not None:
kwargs.update({'redirect_uri': redirect_uri})

kwargs['client_id'] = self.client_id,
kwargs['client_secret'] = self.client_secret,

if 'grant_type' not in kwargs:
kwargs['grant_type'] = 'authorization_code'

if redirect_uri:
kwargs['redirect_uri'] = redirect_uri

# TODO: maybe raise an exception here if status code isn't 200?
msg = urlopen(self.token_endpoint, urlencode(kwargs).encode(
'utf-8'))
data = parser(msg.read().decode(msg.info().get_content_charset() or
'utf-8'))
data = parser(msg.read().decode(msg.info().get_content_charset()))

for key in data:
setattr(self, key, data[key])

# expires_in is RFC-compliant. if anything else is used by the
# expires_in is RFC-compliant. If anything else is used by the
# provider, token_expires must be set manually
if hasattr(self, 'expires_in'):
try:
Expand All @@ -157,70 +162,49 @@ def request(self, url, method=None, data=None, headers=None, parser=None):
in which case it defaults to ``POST``
:param data: Data to be POSTed to the resource endpoint
:param parser: Parser callback to deal with the returned data. Defaults
to ``json.loads`.`
to ``json.loads`, and dict(parse_qsl()) as fallback.`
"""
assert self.access_token is not None
parser = parser or loads
assert self.access_token
parser = parser or _default_parser

if not method:
method = 'GET' if not data else 'POST'

req = self.token_transport('{0}{1}'.format(self.resource_endpoint,
url), self.access_token, data=data, method=method, headers=headers)
full_url = '{0}{1}'.format(self.resource_endpoint, url)
req = self.token_transport(full_url, self.access_token,
data=data, method=method, headers=headers)

resp = urlopen(req)
data = resp.read()
try:
return parser(data.decode(resp.info().get_content_charset() or
'utf-8'))
# try to decode it first using either the content charset, falling
# back to utf-8

# Try to decode it first using either the content charset, falling
# back to UTF-8
return parser(data.decode(resp.info().get_content_charset(failobj="utf-8"))))
except UnicodeDecodeError:
# if we've gotten a decoder error, the calling code better know how
# to deal with it. some providers (i.e. stackexchange) like to gzip
# If we've gotten a decoder error, the calling code better know how
# to deal with it. Some providers (i.e. stackexchange) like to gzip
# their responses, so this allows the client code to handle it
# directly.
return parser(data)


def transport_headers(url, access_token, data=None, method=None, headers=None):
try:
req = Request(url, data=data, method=method)
except TypeError:
req = Request(url, data=data)
req.get_method = lambda: method

add_headers = {'Authorization': 'Bearer {0}'.format(access_token)}
if headers is not None:
add_headers.update(headers)

req.headers.update(add_headers)
def transport_headers(url, access_token, data=None, method=None, headers={}):
req = _request(url, data=data, method=method)
req.headers['Authorization'] = 'Bearer {0}'.format(access_token)
req.headers.update(headers)
return req


def transport_query(url, access_token, data=None, method=None, headers=None):
def transport_query(url, access_token, data=None, method=None, headers={}):
parts = urlsplit(url)
query = dict(parse_qsl(parts.query))
query.update({
'access_token': access_token
})
query['access_token'] = access_token
url = urlunsplit((parts.scheme, parts.netloc, parts.path,
urlencode(query), parts.fragment))
try:
req = Request(url, data=data, method=method)
except TypeError:
req = Request(url, data=data)
req.get_method = lambda: method

if headers is not None:
req.headers.update(headers)

req = _request(url, data=data, method=method)
req.headers.update(headers)
return req


def _default_parser(data):
try:
return loads(data)
return json.loads(data)
except ValueError:
return dict(parse_qsl(data))