diff --git a/CHANGELOG.md b/CHANGELOG.md index 776225a..42d0fbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.1.1] - 2026-03-?? + +- Modernize type annotations: replace `Union[X, Y]` with `X | Y`, `Optional[X]` with `X | None`, and built-in generics (`dict`, `list`, `set`) instead of their `typing` counterparts. Requires Python 3.10+. + ## [1.1.0] - 2026-03-10 - Add support for ES* algorithms (`ES256`, `ES384`, `ES512`) for EC keys in diff --git a/guardpost/abc.py b/guardpost/abc.py index 2303ab4..b2beb30 100644 --- a/guardpost/abc.py +++ b/guardpost/abc.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Iterable, List, Optional, Type, TypeVar, Union +from typing import Any, Iterable, Type, TypeVar from rodi import ContainerProtocol @@ -19,7 +19,7 @@ def __init__(self) -> None: class BaseStrategy(ABC): - def __init__(self, container: Optional[ContainerProtocol] = None) -> None: + def __init__(self, container: ContainerProtocol | None = None) -> None: super().__init__() self._container = container @@ -39,7 +39,7 @@ def _get_di_scope(self, scope: Any): except AttributeError: return None - def _get_instances(self, items: List[Union[T, Type[T]]], scope: Any) -> Iterable[T]: + def _get_instances(self, items: list[T | Type[T]], scope: Any) -> Iterable[T]: """ Yields instances of types, optionally activated through dependency injection. diff --git a/guardpost/authentication.py b/guardpost/authentication.py index 21e8688..24fc346 100644 --- a/guardpost/authentication.py +++ b/guardpost/authentication.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from functools import lru_cache from logging import Logger -from typing import Any, List, Optional, Sequence, Type, Union +from typing import Any, Sequence, Type from rodi import ContainerProtocol @@ -19,20 +19,20 @@ class Identity: def __init__( self, - claims: Optional[dict] = None, - authentication_mode: Optional[str] = None, + claims: dict | None = None, + authentication_mode: str | None = None, ): self.claims = claims or {} self.authentication_mode = authentication_mode - self.access_token: Optional[str] = None - self.refresh_token: Optional[str] = None + self.access_token: str | None = None + self.refresh_token: str | None = None @property - def sub(self) -> Optional[str]: + def sub(self) -> str | None: return self.get("sub") @property - def roles(self) -> Optional[str]: + def roles(self) -> str | None: return self.get("roles") def is_authenticated(self) -> bool: @@ -58,15 +58,15 @@ def has_role(self, name: str) -> bool: class User(Identity): @property - def id(self) -> Optional[str]: + def id(self) -> str | None: return self.get("id") or self.sub @property - def name(self) -> Optional[str]: + def name(self) -> str | None: return self.get("name") @property - def email(self) -> Optional[str]: + def email(self) -> str | None: return self.get("email") @@ -79,7 +79,7 @@ def scheme(self) -> str: return self.__class__.__name__ @abstractmethod - def authenticate(self, context: Any) -> Optional[Identity]: + def authenticate(self, context: Any) -> Identity | None: """Obtains an identity from a context.""" @@ -90,9 +90,7 @@ def _is_async_handler(handler_type: Type[AuthenticationHandler]) -> bool: return inspect.iscoroutinefunction(handler_type.authenticate) -AuthenticationHandlerConfType = Union[ - AuthenticationHandler, Type[AuthenticationHandler] -] +AuthenticationHandlerConfType = AuthenticationHandler | Type[AuthenticationHandler] class AuthenticationSchemesNotFound(ValueError): @@ -110,9 +108,9 @@ class AuthenticationStrategy(BaseStrategy): def __init__( self, *handlers: AuthenticationHandlerConfType, - container: Optional[ContainerProtocol] = None, - rate_limiter: Optional[RateLimiter] = None, - logger: Optional[Logger] = None, + container: ContainerProtocol | None = None, + rate_limiter: RateLimiter | None = None, + logger: Logger | None = None, ): """ Initializes an AuthenticationStrategy instance. @@ -144,9 +142,9 @@ def __iadd__( def _get_handlers_by_schemes( self, - authentication_schemes: Optional[Sequence[str]] = None, + authentication_schemes: Sequence[str] | None = None, context: Any = None, - ) -> List[AuthenticationHandler]: + ) -> list[AuthenticationHandler]: if not authentication_schemes: return list(self._get_instances(self.handlers, context)) @@ -168,8 +166,8 @@ def _get_handlers_by_schemes( return handlers async def authenticate( - self, context: Any, authentication_schemes: Optional[Sequence[str]] = None - ) -> Optional[Identity]: + self, context: Any, authentication_schemes: Sequence[str] | None = None + ) -> Identity | None: """ Tries to obtain the user for a context, applying authentication rules and optional rate limiting. diff --git a/guardpost/authorization.py b/guardpost/authorization.py index e5538a4..5a0f6fc 100644 --- a/guardpost/authorization.py +++ b/guardpost/authorization.py @@ -1,7 +1,7 @@ import inspect from abc import ABC, abstractmethod from functools import lru_cache, wraps -from typing import Any, Callable, Iterable, List, Optional, Sequence, Set, Type, Union +from typing import Any, Callable, Iterable, Sequence, Type from rodi import ContainerProtocol @@ -46,7 +46,7 @@ class RolesRequirement(Requirement): __slots__ = ("_roles",) - def __init__(self, roles: Optional[Sequence[str]] = None): + def __init__(self, roles: Sequence[str] | None = None): self._roles = list(roles) if roles else None def handle(self, context: "AuthorizationContext"): @@ -61,7 +61,7 @@ def handle(self, context: "AuthorizationContext"): context.succeed(self) -RequirementConfType = Union[Requirement, Type[Requirement]] +RequirementConfType = Requirement | Type[Requirement] @lru_cache(maxsize=None) @@ -79,11 +79,11 @@ class UnauthorizedError(AuthorizationError): def __init__( self, - forced_failure: Optional[str], + forced_failure: str | None, failed_requirements: Sequence[Requirement], - scheme: Optional[str] = None, - error: Optional[str] = None, - error_description: Optional[str] = None, + scheme: str | None = None, + error: str | None = None, + error_description: str | None = None, ): """ Creates a new instance of UnauthorizedError, with details. @@ -132,11 +132,11 @@ class AuthorizationContext: def __init__(self, identity: Identity, requirements: Sequence[Requirement]): self.identity = identity self.requirements = requirements - self._succeeded: Set[Requirement] = set() - self._failed_forced: Optional[str] = None + self._succeeded: set[Requirement] = set() + self._failed_forced: str | None = None @property - def pending_requirements(self) -> List[Requirement]: + def pending_requirements(self) -> list[Requirement]: return [item for item in self.requirements if item not in self._succeeded] @property @@ -146,7 +146,7 @@ def has_succeeded(self) -> bool: return all(requirement in self._succeeded for requirement in self.requirements) @property - def forced_failure(self) -> Optional[str]: + def forced_failure(self) -> str | None: return None if self._failed_forced is None else str(self._failed_forced) def fail(self, reason: str): @@ -208,16 +208,16 @@ class AuthorizationStrategy(BaseStrategy): def __init__( self, *policies: Policy, - container: Optional[ContainerProtocol] = None, - default_policy: Optional[Policy] = None, - identity_getter: Optional[Callable[..., Identity]] = None, + container: ContainerProtocol | None = None, + default_policy: Policy | None = None, + identity_getter: Callable[..., Identity] | None = None, ): super().__init__(container) self.policies = list(policies) self.default_policy = default_policy self.identity_getter = identity_getter - def get_policy(self, name: str) -> Optional[Policy]: + def get_policy(self, name: str) -> Policy | None: for policy in self.policies: if policy.name == name: return policy @@ -237,10 +237,10 @@ def with_default_policy(self, policy: Policy) -> "AuthorizationStrategy": async def authorize( self, - policy_name: Optional[str], + policy_name: str | None, identity: Identity, scope: Any = None, - roles: Optional[Sequence[str]] = None, + roles: Sequence[str] | None = None, ): if policy_name: policy = self.get_policy(policy_name) @@ -268,7 +268,7 @@ async def authorize( raise UnauthorizedError("The resource requires authentication", []) def _get_requirements( - self, policy: Policy, scope: Any, roles: Optional[Sequence[str]] = None + self, policy: Policy, scope: Any, roles: Sequence[str] | None = None ) -> Iterable[Requirement]: if roles: yield RolesRequirement(roles=roles) @@ -279,7 +279,7 @@ async def _handle_with_policy( policy: Policy, identity: Identity, scope: Any, - roles: Optional[Sequence[str]] = None, + roles: Sequence[str] | None = None, ): with AuthorizationContext( identity, list(self._get_requirements(policy, scope, roles)) @@ -287,7 +287,7 @@ async def _handle_with_policy( await self._handle_context(identity, context) async def _handle_with_roles( - self, identity: Identity, roles: Optional[Sequence[str]] = None + self, identity: Identity, roles: Sequence[str] | None = None ): # This method is to be used only when the user specified roles without a policy with AuthorizationContext(identity, [RolesRequirement(roles=roles)]) as context: @@ -310,13 +310,13 @@ async def _handle_context(self, identity: Identity, context: AuthorizationContex ) async def _handle_with_identity_getter( - self, policy_name: Optional[str], *args, **kwargs + self, policy_name: str | None, *args, **kwargs ): if self.identity_getter is None: raise TypeError("Missing identity getter function.") await self.authorize(policy_name, self.identity_getter(*args, **kwargs)) - def __call__(self, policy: Optional[str] = None): + def __call__(self, policy: str | None = None): """ Decorates a function to apply authorization logic on each call. """ diff --git a/guardpost/common.py b/guardpost/common.py index e60abf0..bc2b316 100644 --- a/guardpost/common.py +++ b/guardpost/common.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from typing import Mapping as MappingType -from typing import Sequence, Union +from typing import Sequence from .authorization import AuthorizationContext, Policy, Requirement @@ -35,7 +35,7 @@ def handle(self, context: AuthorizationContext): context.succeed(self) -RequiredClaimsType = Union[MappingType[str, str], Sequence[str], str] +RequiredClaimsType = MappingType[str, str] | Sequence[str] | str class ClaimsRequirement(Requirement): diff --git a/guardpost/jwks/__init__.py b/guardpost/jwks/__init__.py index ee861f4..ba588a9 100644 --- a/guardpost/jwks/__init__.py +++ b/guardpost/jwks/__init__.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Type +from typing import Type from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -24,7 +24,7 @@ def _raise_if_missing(value: dict, *keys: str) -> None: raise ValueError(f"Missing {key}") -_EC_CURVES: Dict[str, Type[EllipticCurve]] = { +_EC_CURVES: dict[str, Type[EllipticCurve]] = { "P-256": SECP256R1, "P-384": SECP384R1, "P-521": SECP521R1, @@ -62,14 +62,14 @@ class JWK: kty: KeyType pem: bytes - kid: Optional[str] = None + kid: str | None = None # RSA parameters - n: Optional[str] = None - e: Optional[str] = None + n: str | None = None + e: str | None = None # EC parameters - crv: Optional[str] = None - x: Optional[str] = None - y: Optional[str] = None + crv: str | None = None + x: str | None = None + y: str | None = None @classmethod def from_dict(cls, value) -> "JWK": @@ -104,7 +104,7 @@ def from_dict(cls, value) -> "JWK": @dataclass class JWKS: - keys: List[JWK] + keys: list[JWK] def update(self, new_set: "JWKS"): self.keys = list({key.kid: key for key in self.keys + new_set.keys}.values()) diff --git a/guardpost/jwks/caching.py b/guardpost/jwks/caching.py index f0e769f..79ccf85 100644 --- a/guardpost/jwks/caching.py +++ b/guardpost/jwks/caching.py @@ -1,5 +1,4 @@ import time -from typing import Optional from . import JWK, JWKS, KeysProvider @@ -24,7 +23,7 @@ def __init__( if not keys_provider: raise TypeError("Missing KeysProvider") - self._keys: Optional[JWKS] = None + self._keys: JWKS | None = None self._cache_time = cache_time self._refresh_time = refresh_time self._last_fetch_time: float = 0 @@ -57,7 +56,7 @@ async def get_keys(self) -> JWKS: return self._keys return await self._fetch_keys() - async def get_key(self, kid: str) -> Optional[JWK]: + async def get_key(self, kid: str) -> JWK | None: """ Tries to get a JWK by kid. If the JWK is not found and the last time the keys were fetched is older than `refresh_time` (default 120 seconds), it fetches diff --git a/guardpost/jwts/__init__.py b/guardpost/jwts/__init__.py index 5499760..cb5123d 100644 --- a/guardpost/jwts/__init__.py +++ b/guardpost/jwts/__init__.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Protocol, Sequence, Union +from typing import Any, Protocol, Sequence import jwt from essentials.secrets import Secret @@ -36,7 +36,7 @@ def __init__(self): super().__init__("Token expired.") -def get_kid(token: str) -> Optional[str]: +def get_kid(token: str) -> str | None: """ Extracts a kid (key id) from a JWT. """ @@ -49,7 +49,7 @@ def get_kid(token: str) -> Optional[str]: class JWTValidatorProtocol(Protocol): """Protocol defining the interface for JWT validators""" - async def validate_jwt(self, access_token: str) -> Dict[str, Any]: + async def validate_jwt(self, access_token: str) -> dict[str, Any]: """ Validates an access token and returns its claims. """ @@ -72,7 +72,7 @@ def __init__( self.logger = get_logger() @abstractmethod - async def validate_jwt(self, access_token: str) -> Dict[str, Any]: + async def validate_jwt(self, access_token: str) -> dict[str, Any]: """ Validates an access token and returns its claims. """ @@ -90,11 +90,11 @@ def __init__( *, valid_issuers: Sequence[str], valid_audiences: Sequence[str], - authority: Optional[str] = None, + authority: str | None = None, algorithms: Sequence[str] = ["RS256"], require_kid: bool = True, - keys_provider: Optional[KeysProvider] = None, - keys_url: Optional[str] = None, + keys_provider: KeysProvider | None = None, + keys_url: str | None = None, cache_time: float = 10800, refresh_time: float = 120, ) -> None: @@ -109,7 +109,7 @@ def __init__( Sequence of acceptable issuers (iss). valid_audiences : Sequence[str] Sequence of acceptable audiences (aud). - authority : Optional[str], optional + authority : str | None, optional If provided, keys are obtained from a standard well-known endpoint. This parameter is ignored if `keys_provider` is given. algorithms : Sequence[str], optional @@ -119,10 +119,10 @@ def __init__( this parameter lets control whether access tokens missing `kid` in their headers should be handled or rejected. By default True, thus only JWTs having `kid` header are accepted. - keys_provider : Optional[KeysProvider], optional + keys_provider : KeysProvider | None, optional If provided, the exact `KeysProvider` to be used when fetching keys. By default None - keys_url : Optional[str], optional + keys_url : str | None, optional If provided, keys are obtained from the given URL through HTTP GET. This parameter is ignored if `keys_provider` is given. cache_time : float @@ -169,7 +169,7 @@ async def get_jwk(self, kid: str) -> JWK: raise InvalidAccessToken("kid not recognized") return key - def _validate_jwt_by_key(self, access_token: str, jwk: JWK) -> Dict[str, Any]: + def _validate_jwt_by_key(self, access_token: str, jwk: JWK) -> dict[str, Any]: try: return jwt.decode( access_token, @@ -185,7 +185,7 @@ def _validate_jwt_by_key(self, access_token: str, jwk: JWK) -> Dict[str, Any]: self.logger.debug("Invalid access token: ", exc_info=exc) raise InvalidAccessToken() from exc - async def validate_jwt(self, access_token: str) -> Dict[str, Any]: + async def validate_jwt(self, access_token: str) -> dict[str, Any]: """ Validates the given JWT and returns its payload. This method throws exception if the JWT is not valid (i.e. its signature cannot be verified, for example @@ -229,7 +229,7 @@ def __init__( *, valid_issuers: Sequence[str], valid_audiences: Sequence[str], - secret_key: Union[str, bytes, Secret], + secret_key: str | bytes | Secret, algorithms: Sequence[str] = ["HS256"], ) -> None: """ @@ -242,7 +242,7 @@ def __init__( Sequence of acceptable issuers (iss). valid_audiences : Sequence[str] Sequence of acceptable audiences (aud). - secret_key : Union[str, bytes, Secret] + secret_key : str | bytes | Secret The secret key used for symmetric validation. algorithms : Sequence[str], optional Sequence of acceptable algorithms, by default ["HS256"]. @@ -270,7 +270,7 @@ def __init__( raise TypeError("secret_key must be a str, bytes, or Secret instance.") self._secret_key = secret_key - async def validate_jwt(self, access_token: str) -> Dict[str, Any]: + async def validate_jwt(self, access_token: str) -> dict[str, Any]: """ Validates the given JWT using symmetric key and returns its payload. This method throws exception if the JWT is not valid. @@ -292,20 +292,20 @@ async def validate_jwt(self, access_token: str) -> Dict[str, Any]: class CompositeJWTValidator(BaseJWTValidator): - def __init__(self, validators: List[JWTValidatorProtocol]) -> None: + def __init__(self, validators: list[JWTValidatorProtocol]) -> None: """ Creates a composite validator that tries multiple validation strategies. Useful when you need to support both symmetric and asymmetric validation. Parameters ---------- - validators : List[JWTValidatorProtocol] + validators : list[JWTValidatorProtocol] List of validators to try in sequence """ self._validators = validators self.logger = get_logger() - async def validate_jwt(self, access_token: str) -> Dict[str, Any]: + async def validate_jwt(self, access_token: str) -> dict[str, Any]: """ Attempts to validate the JWT using each validator in sequence. Returns the first successful validation result or raises InvalidAccessToken diff --git a/guardpost/protection.py b/guardpost/protection.py index aba2f5f..5a63cbd 100644 --- a/guardpost/protection.py +++ b/guardpost/protection.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone from logging import Logger -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Sequence from guardpost.errors import InvalidCredentialsError, RateLimitExceededError @@ -75,7 +75,7 @@ class AuthenticationAttemptsStore(ABC): @abstractmethod async def get_failed_attempts( self, key: str - ) -> Optional[FailedAuthenticationAttempts]: + ) -> FailedAuthenticationAttempts | None: """ Returns the record tracking the number of failed authentication attempts for a given context key (e.g. client IP), or none if no failed attempt exists for the @@ -115,7 +115,7 @@ def __init__(self) -> None: async def get_failed_attempts( self, key: str - ) -> Optional[FailedAuthenticationAttempts]: + ) -> FailedAuthenticationAttempts | None: try: return self._attempts[key] except KeyError: @@ -140,9 +140,9 @@ def __init__( key_getter: Callable[[Any], str], threshold: int = 20, block_time: int = 60, - store: Optional[AuthenticationAttemptsStore] = None, - trusted_keys: Optional[Sequence[str]] = None, - logger: Optional[Logger] = None, + store: AuthenticationAttemptsStore | None = None, + trusted_keys: Sequence[str] | None = None, + logger: Logger | None = None, ) -> None: """ Initialize a RateLimiter instance for brute-force protection. @@ -283,7 +283,7 @@ def __init__(self, cleanup_interval: int = 300, max_entry_age: int = 3600) -> No async def get_failed_attempts( self, key: str - ) -> Optional[FailedAuthenticationAttempts]: + ) -> FailedAuthenticationAttempts | None: await self._cleanup_if_needed() return await super().get_failed_attempts(key) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 3ba08d1..d307f61 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from uuid import uuid4 import pytest @@ -122,7 +122,7 @@ class MockHandler(AuthenticationHandler): def __init__(self, identity): self.identity = identity - async def authenticate(self, context: Any) -> Optional[Identity]: + async def authenticate(self, context: Any) -> Identity | None: context.user = self.identity return context.user @@ -175,11 +175,11 @@ async def test_authentication_strategy_by_scheme_throws_for_missing_scheme(): def test_default_authentication_scheme_name_matches_class_name(): class Basic(AuthenticationHandler): - async def authenticate(self, context: Any) -> Optional[Identity]: + async def authenticate(self, context: Any) -> Identity | None: pass class Foo(AuthenticationHandler): - async def authenticate(self, context: Any) -> Optional[Identity]: + async def authenticate(self, context: Any) -> Identity | None: pass assert Basic().scheme == "Basic" @@ -193,7 +193,7 @@ class Foo: class InjectedAuthenticationHandler(AuthenticationHandler): service: Foo - def authenticate(self, context) -> Optional[Identity]: + def authenticate(self, context) -> Identity | None: return None @@ -220,7 +220,7 @@ async def test_authenticate_set_identity_context_attribute_error_handling(): container = Container() class TestHandler(AuthenticationHandler): - def authenticate(self, context: Any) -> Optional[Identity]: + def authenticate(self, context: Any) -> Identity | None: return Identity({"sub": test_id}) container.register(TestHandler) diff --git a/tests/test_common.py b/tests/test_common.py index 6ad42fb..575bb62 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import pytest from pytest import raises @@ -69,11 +69,11 @@ def test_authentication_strategy_iadd_method(): strategy = AuthenticationStrategy() class ExampleOne(AuthenticationHandler): - def authenticate(self, context: Any) -> Optional[Identity]: + def authenticate(self, context: Any) -> Identity | None: pass class ExampleTwo(AuthenticationHandler): - def authenticate(self, context: Any) -> Optional[Identity]: + def authenticate(self, context: Any) -> Identity | None: pass one = ExampleOne() @@ -93,11 +93,11 @@ def test_authentication_strategy_add_method(): strategy = AuthenticationStrategy() class ExampleOne(AuthenticationHandler): - def authenticate(self, context: Any) -> Optional[Identity]: + def authenticate(self, context: Any) -> Identity | None: pass class ExampleTwo(AuthenticationHandler): - def authenticate(self, context: Any) -> Optional[Identity]: + def authenticate(self, context: Any) -> Identity | None: pass one = ExampleOne()