Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.1.1] - 2026-03-??

- Modernize type annotations: replace `Union[X, Y]` with `X | Y`, `Optional[X]` with `X | None`, and built-in generics (`dict`, `list`, `set`) instead of their `typing` counterparts. Requires Python 3.10+.

## [1.1.0] - 2026-03-10

- Add support for ES* algorithms (`ES256`, `ES384`, `ES512`) for EC keys in
Expand Down
6 changes: 3 additions & 3 deletions guardpost/abc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Any, Iterable, List, Optional, Type, TypeVar, Union
from typing import Any, Iterable, Type, TypeVar

from rodi import ContainerProtocol

Expand All @@ -19,7 +19,7 @@ def __init__(self) -> None:


class BaseStrategy(ABC):
def __init__(self, container: Optional[ContainerProtocol] = None) -> None:
def __init__(self, container: ContainerProtocol | None = None) -> None:
super().__init__()
self._container = container

Expand All @@ -39,7 +39,7 @@ def _get_di_scope(self, scope: Any):
except AttributeError:
return None

def _get_instances(self, items: List[Union[T, Type[T]]], scope: Any) -> Iterable[T]:
def _get_instances(self, items: list[T | Type[T]], scope: Any) -> Iterable[T]:
"""
Yields instances of types, optionally activated through dependency injection.

Expand Down
40 changes: 19 additions & 21 deletions guardpost/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from functools import lru_cache
from logging import Logger
from typing import Any, List, Optional, Sequence, Type, Union
from typing import Any, Sequence, Type

from rodi import ContainerProtocol

Expand All @@ -19,20 +19,20 @@ class Identity:

def __init__(
self,
claims: Optional[dict] = None,
authentication_mode: Optional[str] = None,
claims: dict | None = None,
authentication_mode: str | None = None,
):
self.claims = claims or {}
self.authentication_mode = authentication_mode
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self.access_token: str | None = None
self.refresh_token: str | None = None

@property
def sub(self) -> Optional[str]:
def sub(self) -> str | None:
return self.get("sub")

@property
def roles(self) -> Optional[str]:
def roles(self) -> str | None:
return self.get("roles")

def is_authenticated(self) -> bool:
Expand All @@ -58,15 +58,15 @@ def has_role(self, name: str) -> bool:

class User(Identity):
@property
def id(self) -> Optional[str]:
def id(self) -> str | None:
return self.get("id") or self.sub

@property
def name(self) -> Optional[str]:
def name(self) -> str | None:
return self.get("name")

@property
def email(self) -> Optional[str]:
def email(self) -> str | None:
return self.get("email")


Expand All @@ -79,7 +79,7 @@ def scheme(self) -> str:
return self.__class__.__name__

@abstractmethod
def authenticate(self, context: Any) -> Optional[Identity]:
def authenticate(self, context: Any) -> Identity | None:
"""Obtains an identity from a context."""


Expand All @@ -90,9 +90,7 @@ def _is_async_handler(handler_type: Type[AuthenticationHandler]) -> bool:
return inspect.iscoroutinefunction(handler_type.authenticate)


AuthenticationHandlerConfType = Union[
AuthenticationHandler, Type[AuthenticationHandler]
]
AuthenticationHandlerConfType = AuthenticationHandler | Type[AuthenticationHandler]


class AuthenticationSchemesNotFound(ValueError):
Expand All @@ -110,9 +108,9 @@ class AuthenticationStrategy(BaseStrategy):
def __init__(
self,
*handlers: AuthenticationHandlerConfType,
container: Optional[ContainerProtocol] = None,
rate_limiter: Optional[RateLimiter] = None,
logger: Optional[Logger] = None,
container: ContainerProtocol | None = None,
rate_limiter: RateLimiter | None = None,
logger: Logger | None = None,
):
"""
Initializes an AuthenticationStrategy instance.
Expand Down Expand Up @@ -144,9 +142,9 @@ def __iadd__(

def _get_handlers_by_schemes(
self,
authentication_schemes: Optional[Sequence[str]] = None,
authentication_schemes: Sequence[str] | None = None,
context: Any = None,
) -> List[AuthenticationHandler]:
) -> list[AuthenticationHandler]:
if not authentication_schemes:
return list(self._get_instances(self.handlers, context))

Expand All @@ -168,8 +166,8 @@ def _get_handlers_by_schemes(
return handlers

async def authenticate(
self, context: Any, authentication_schemes: Optional[Sequence[str]] = None
) -> Optional[Identity]:
self, context: Any, authentication_schemes: Sequence[str] | None = None
) -> Identity | None:
"""
Tries to obtain the user for a context, applying authentication rules and
optional rate limiting.
Expand Down
44 changes: 22 additions & 22 deletions guardpost/authorization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from abc import ABC, abstractmethod
from functools import lru_cache, wraps
from typing import Any, Callable, Iterable, List, Optional, Sequence, Set, Type, Union
from typing import Any, Callable, Iterable, Sequence, Type

from rodi import ContainerProtocol

Expand Down Expand Up @@ -46,7 +46,7 @@ class RolesRequirement(Requirement):

__slots__ = ("_roles",)

def __init__(self, roles: Optional[Sequence[str]] = None):
def __init__(self, roles: Sequence[str] | None = None):
self._roles = list(roles) if roles else None

def handle(self, context: "AuthorizationContext"):
Expand All @@ -61,7 +61,7 @@ def handle(self, context: "AuthorizationContext"):
context.succeed(self)


RequirementConfType = Union[Requirement, Type[Requirement]]
RequirementConfType = Requirement | Type[Requirement]


@lru_cache(maxsize=None)
Expand All @@ -79,11 +79,11 @@ class UnauthorizedError(AuthorizationError):

def __init__(
self,
forced_failure: Optional[str],
forced_failure: str | None,
failed_requirements: Sequence[Requirement],
scheme: Optional[str] = None,
error: Optional[str] = None,
error_description: Optional[str] = None,
scheme: str | None = None,
error: str | None = None,
error_description: str | None = None,
):
"""
Creates a new instance of UnauthorizedError, with details.
Expand Down Expand Up @@ -132,11 +132,11 @@ class AuthorizationContext:
def __init__(self, identity: Identity, requirements: Sequence[Requirement]):
self.identity = identity
self.requirements = requirements
self._succeeded: Set[Requirement] = set()
self._failed_forced: Optional[str] = None
self._succeeded: set[Requirement] = set()
self._failed_forced: str | None = None

@property
def pending_requirements(self) -> List[Requirement]:
def pending_requirements(self) -> list[Requirement]:
return [item for item in self.requirements if item not in self._succeeded]

@property
Expand All @@ -146,7 +146,7 @@ def has_succeeded(self) -> bool:
return all(requirement in self._succeeded for requirement in self.requirements)

@property
def forced_failure(self) -> Optional[str]:
def forced_failure(self) -> str | None:
return None if self._failed_forced is None else str(self._failed_forced)

def fail(self, reason: str):
Expand Down Expand Up @@ -208,16 +208,16 @@ class AuthorizationStrategy(BaseStrategy):
def __init__(
self,
*policies: Policy,
container: Optional[ContainerProtocol] = None,
default_policy: Optional[Policy] = None,
identity_getter: Optional[Callable[..., Identity]] = None,
container: ContainerProtocol | None = None,
default_policy: Policy | None = None,
identity_getter: Callable[..., Identity] | None = None,
):
super().__init__(container)
self.policies = list(policies)
self.default_policy = default_policy
self.identity_getter = identity_getter

def get_policy(self, name: str) -> Optional[Policy]:
def get_policy(self, name: str) -> Policy | None:
for policy in self.policies:
if policy.name == name:
return policy
Expand All @@ -237,10 +237,10 @@ def with_default_policy(self, policy: Policy) -> "AuthorizationStrategy":

async def authorize(
self,
policy_name: Optional[str],
policy_name: str | None,
identity: Identity,
scope: Any = None,
roles: Optional[Sequence[str]] = None,
roles: Sequence[str] | None = None,
):
if policy_name:
policy = self.get_policy(policy_name)
Expand Down Expand Up @@ -268,7 +268,7 @@ async def authorize(
raise UnauthorizedError("The resource requires authentication", [])

def _get_requirements(
self, policy: Policy, scope: Any, roles: Optional[Sequence[str]] = None
self, policy: Policy, scope: Any, roles: Sequence[str] | None = None
) -> Iterable[Requirement]:
if roles:
yield RolesRequirement(roles=roles)
Expand All @@ -279,15 +279,15 @@ async def _handle_with_policy(
policy: Policy,
identity: Identity,
scope: Any,
roles: Optional[Sequence[str]] = None,
roles: Sequence[str] | None = None,
):
with AuthorizationContext(
identity, list(self._get_requirements(policy, scope, roles))
) as context:
await self._handle_context(identity, context)

async def _handle_with_roles(
self, identity: Identity, roles: Optional[Sequence[str]] = None
self, identity: Identity, roles: Sequence[str] | None = None
):
# This method is to be used only when the user specified roles without a policy
with AuthorizationContext(identity, [RolesRequirement(roles=roles)]) as context:
Expand All @@ -310,13 +310,13 @@ async def _handle_context(self, identity: Identity, context: AuthorizationContex
)

async def _handle_with_identity_getter(
self, policy_name: Optional[str], *args, **kwargs
self, policy_name: str | None, *args, **kwargs
):
if self.identity_getter is None:
raise TypeError("Missing identity getter function.")
await self.authorize(policy_name, self.identity_getter(*args, **kwargs))

def __call__(self, policy: Optional[str] = None):
def __call__(self, policy: str | None = None):
"""
Decorates a function to apply authorization logic on each call.
"""
Expand Down
4 changes: 2 additions & 2 deletions guardpost/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Mapping
from typing import Mapping as MappingType
from typing import Sequence, Union
from typing import Sequence

from .authorization import AuthorizationContext, Policy, Requirement

Expand Down Expand Up @@ -35,7 +35,7 @@ def handle(self, context: AuthorizationContext):
context.succeed(self)


RequiredClaimsType = Union[MappingType[str, str], Sequence[str], str]
RequiredClaimsType = MappingType[str, str] | Sequence[str] | str


class ClaimsRequirement(Requirement):
Expand Down
18 changes: 9 additions & 9 deletions guardpost/jwks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Type
from typing import Type

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand All @@ -24,7 +24,7 @@ def _raise_if_missing(value: dict, *keys: str) -> None:
raise ValueError(f"Missing {key}")


_EC_CURVES: Dict[str, Type[EllipticCurve]] = {
_EC_CURVES: dict[str, Type[EllipticCurve]] = {
"P-256": SECP256R1,
"P-384": SECP384R1,
"P-521": SECP521R1,
Expand Down Expand Up @@ -62,14 +62,14 @@ class JWK:

kty: KeyType
pem: bytes
kid: Optional[str] = None
kid: str | None = None
# RSA parameters
n: Optional[str] = None
e: Optional[str] = None
n: str | None = None
e: str | None = None
# EC parameters
crv: Optional[str] = None
x: Optional[str] = None
y: Optional[str] = None
crv: str | None = None
x: str | None = None
y: str | None = None

@classmethod
def from_dict(cls, value) -> "JWK":
Expand Down Expand Up @@ -104,7 +104,7 @@ def from_dict(cls, value) -> "JWK":

@dataclass
class JWKS:
keys: List[JWK]
keys: list[JWK]

def update(self, new_set: "JWKS"):
self.keys = list({key.kid: key for key in self.keys + new_set.keys}.values())
Expand Down
5 changes: 2 additions & 3 deletions guardpost/jwks/caching.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from typing import Optional

from . import JWK, JWKS, KeysProvider

Expand All @@ -24,7 +23,7 @@ def __init__(
if not keys_provider:
raise TypeError("Missing KeysProvider")

self._keys: Optional[JWKS] = None
self._keys: JWKS | None = None
self._cache_time = cache_time
self._refresh_time = refresh_time
self._last_fetch_time: float = 0
Expand Down Expand Up @@ -57,7 +56,7 @@ async def get_keys(self) -> JWKS:
return self._keys
return await self._fetch_keys()

async def get_key(self, kid: str) -> Optional[JWK]:
async def get_key(self, kid: str) -> JWK | None:
"""
Tries to get a JWK by kid. If the JWK is not found and the last time the keys
were fetched is older than `refresh_time` (default 120 seconds), it fetches
Expand Down
Loading