from authlib.common.errors import ContinueIteration
from .authenticate_client import ClientAuthentication
from .requests import OAuth2Request, JsonRequest
from .errors import (
OAuth2Error,
InvalidScopeError,
UnsupportedResponseTypeError,
UnsupportedGrantTypeError,
)
from .util import scope_to_list
class AuthorizationServer:
"""Authorization server that handles Authorization Endpoint and Token
Endpoint.
:param scopes_supported: A list of supported scopes by this authorization server.
"""
def __init__(self, scopes_supported=None):
self.scopes_supported = scopes_supported
self._token_generators = {}
self._client_auth = None
self._authorization_grants = []
self._token_grants = []
self._endpoints = {}
def query_client(self, client_id):
"""Query OAuth client by client_id. The client model class MUST
implement the methods described by
:class:`~authlib.oauth2.rfc6749.ClientMixin`.
"""
raise NotImplementedError()
def save_token(self, token, request):
"""Define function to save the generated token into database."""
raise NotImplementedError()
def generate_token(self, grant_type, client, user=None, scope=None,
expires_in=None, include_refresh_token=True):
"""Generate the token dict.
:param grant_type: current requested grant_type.
:param client: the client that making the request.
:param user: current authorized user.
:param expires_in: if provided, use this value as expires_in.
:param scope: current requested scope.
:param include_refresh_token: should refresh_token be included.
:return: Token dict
"""
# generator for a specified grant type
func = self._token_generators.get(grant_type)
if not func:
# default generator for all grant types
func = self._token_generators.get('default')
if not func:
raise RuntimeError('No configured token generator')
return func(
grant_type=grant_type, client=client, user=user, scope=scope,
expires_in=expires_in, include_refresh_token=include_refresh_token)
def register_token_generator(self, grant_type, func):
"""Register a function as token generator for the given ``grant_type``.
Developers MUST register a default token generator with a special
``grant_type=default``::
def generate_bearer_token(grant_type, client, user=None, scope=None,
expires_in=None, include_refresh_token=True):
token = {'token_type': 'Bearer', 'access_token': ...}
if include_refresh_token:
token['refresh_token'] = ...
...
return token
authorization_server.register_token_generator('default', generate_bearer_token)
If you register a generator for a certain grant type, that generator will only works
for the given grant type::
authorization_server.register_token_generator('client_credentials', generate_bearer_token)
:param grant_type: string name of the grant type
:param func: a function to generate token
"""
self._token_generators[grant_type] = func
def authenticate_client(self, request, methods, endpoint='token'):
"""Authenticate client via HTTP request information with the given
methods, such as ``client_secret_basic``, ``client_secret_post``.
"""
if self._client_auth is None and self.query_client:
self._client_auth = ClientAuthentication(self.query_client)
return self._client_auth(request, methods, endpoint)
def register_client_auth_method(self, method, func):
"""Add more client auth method. The default methods are:
* none: The client is a public client and does not have a client secret
* client_secret_post: The client uses the HTTP POST parameters
* client_secret_basic: The client uses HTTP Basic
:param method: Name of the Auth method
:param func: Function to authenticate the client
The auth method accept two parameters: ``query_client`` and ``request``,
an example for this method::
def authenticate_client_via_custom(query_client, request):
client_id = request.headers['X-Client-Id']
client = query_client(client_id)
do_some_validation(client)
return client
authorization_server.register_client_auth_method(
'custom', authenticate_client_via_custom)
"""
if self._client_auth is None and self.query_client:
self._client_auth = ClientAuthentication(self.query_client)
self._client_auth.register(method, func)
def get_error_uri(self, request, error):
"""Return a URI for the given error, framework may implement this method."""
return None
def send_signal(self, name, *args, **kwargs):
"""Framework integration can re-implement this method to support
signal system.
"""
raise NotImplementedError()
def create_oauth2_request(self, request) -> OAuth2Request:
"""This method MUST be implemented in framework integrations. It is
used to create an OAuth2Request instance.
:param request: the "request" instance in framework
:return: OAuth2Request instance
"""
raise NotImplementedError()
def create_json_request(self, request) -> JsonRequest:
"""This method MUST be implemented in framework integrations. It is
used to create an HttpRequest instance.
:param request: the "request" instance in framework
:return: HttpRequest instance
"""
raise NotImplementedError()
def handle_response(self, status, body, headers):
"""Return HTTP response. Framework MUST implement this function."""
raise NotImplementedError()
def validate_requested_scope(self, scope, state=None):
"""Validate if requested scope is supported by Authorization Server.
Developers CAN re-write this method to meet your needs.
"""
if scope and self.scopes_supported:
scopes = set(scope_to_list(scope))
if not set(self.scopes_supported).issuperset(scopes):
raise InvalidScopeError(state=state)
def register_grant(self, grant_cls, extensions=None):
"""Register a grant class into the endpoint registry. Developers
can implement the grants in ``authlib.oauth2.rfc6749.grants`` and
register with this method::
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
def authenticate_user(self, credential):
# ...
authorization_server.register_grant(AuthorizationCodeGrant)
:param grant_cls: a grant class.
:param extensions: extensions for the grant class.
"""
if hasattr(grant_cls, 'check_authorization_endpoint'):
self._authorization_grants.append((grant_cls, extensions))
if hasattr(grant_cls, 'check_token_endpoint'):
self._token_grants.append((grant_cls, extensions))
def register_endpoint(self, endpoint):
"""Add extra endpoint to authorization server. e.g.
RevocationEndpoint::
authorization_server.register_endpoint(RevocationEndpoint)
:param endpoint_cls: A endpoint class or instance.
"""
if isinstance(endpoint, type):
endpoint = endpoint(self)
else:
endpoint.server = self
endpoints = self._endpoints.setdefault(endpoint.ENDPOINT_NAME, [])
endpoints.append(endpoint)
def get_authorization_grant(self, request):
"""Find the authorization grant for current request.
:param request: OAuth2Request instance.
:return: grant instance
"""
for (grant_cls, extensions) in self._authorization_grants:
if grant_cls.check_authorization_endpoint(request):
return _create_grant(grant_cls, extensions, request, self)
raise UnsupportedResponseTypeError(request.response_type)
def get_consent_grant(self, request=None, end_user=None):
"""Validate current HTTP request for authorization page. This page
is designed for resource owner to grant or deny the authorization.
"""
request = self.create_oauth2_request(request)
request.user = end_user
grant = self.get_authorization_grant(request)
grant.validate_no_multiple_request_parameter(request)
grant.validate_consent_request()
return grant
def get_token_grant(self, request):
"""Find the token grant for current request.
:param request: OAuth2Request instance.
:return: grant instance
"""
for (grant_cls, extensions) in self._token_grants:
if grant_cls.check_token_endpoint(request):
return _create_grant(grant_cls, extensions, request, self)
raise UnsupportedGrantTypeError(request.grant_type)
def create_endpoint_response(self, name, request=None):
"""Validate endpoint request and create endpoint response.
:param name: Endpoint name
:param request: HTTP request instance.
:return: Response
"""
if name not in self._endpoints:
raise RuntimeError(f'There is no "{name}" endpoint.')
endpoints = self._endpoints[name]
for endpoint in endpoints:
request = endpoint.create_endpoint_request(request)
try:
return self.handle_response(*endpoint(request))
except ContinueIteration:
continue
except OAuth2Error as error:
return self.handle_error_response(request, error)
def create_authorization_response(self, request=None, grant_user=None):
"""Validate authorization request and create authorization response.
:param request: HTTP request instance.
:param grant_user: if granted, it is resource owner. If denied,
it is None.
:returns: Response
"""
if not isinstance(request, OAuth2Request):
request = self.create_oauth2_request(request)
try:
grant = self.get_authorization_grant(request)
except UnsupportedResponseTypeError as error:
return self.handle_error_response(request, error)
try:
redirect_uri = grant.validate_authorization_request()
args = grant.create_authorization_response(redirect_uri, grant_user)
return self.handle_response(*args)
except OAuth2Error as error:
return self.handle_error_response(request, error)
def create_token_response(self, request=None):
"""Validate token request and create token response.
:param request: HTTP request instance
"""
request = self.create_oauth2_request(request)
try:
grant = self.get_token_grant(request)
except UnsupportedGrantTypeError as error:
return self.handle_error_response(request, error)
try:
grant.validate_token_request()
args = grant.create_token_response()
return self.handle_response(*args)
except OAuth2Error as error:
return self.handle_error_response(request, error)
def handle_error_response(self, request, error):
return self.handle_response(*error(self.get_error_uri(request, error)))
def _create_grant(grant_cls, extensions, request, server):
grant = grant_cls(request, server)
if extensions:
for ext in extensions:
ext(grant)
return grant