diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c005545e..3bb48a2d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,12 +15,13 @@ on: workflow_dispatch: jobs: - check: + tests: runs-on: ubuntu-22.04 + name: tests strategy: matrix: - python-version: ['3.10', '3.11'] + python-version: ['3.11'] steps: @@ -60,59 +61,44 @@ jobs: - name: Build docker images run: docker compose -f test-docker-compose.yaml build --no-cache - - name: Run pycodestyle - run: | - pycodestyle --max-line-length=100 api/*.py - - name: Run API containers run: | docker compose -f test-docker-compose.yaml up -d test - - name: Run pylint - run: | - docker compose -f test-docker-compose.yaml exec -T test pylint --extension-pkg-whitelist=pydantic api/ - docker compose -f test-docker-compose.yaml exec -T test pylint tests/unit_tests - docker compose -f test-docker-compose.yaml exec -T test pylint tests/e2e_tests - - name: Stop docker containers if: always() run: | docker compose -f test-docker-compose.yaml down lint: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest + container: + image: ubuntu:24.04 name: Lint steps: + - name: Install base tooling + run: | + apt-get update + apt-get install -y --no-install-recommends git python3 python3-venv python3-pip + + - name: Create Python virtual environment + run: | + python3 -m venv /opt/venv + /opt/venv/bin/pip install --no-cache-dir --upgrade pip + - name: Check out source code - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 32 # This is necessary to get the commits - - name: Get changed python files between base and head - run: > - echo "CHANGED_FILES=$(echo $(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- | grep \.py$))" >> $GITHUB_ENV + - name: Install linting dependencies + run: | + /opt/venv/bin/pip install --no-cache-dir ruff - - if: env.CHANGED_FILES - name: Set up Python - uses: actions/setup-python@master - with: - python-version: "3.10" + - name: Run ruff format check (line-length policy) + run: /opt/venv/bin/ruff format --check --diff --line-length 110 api tests - - if: env.CHANGED_FILES - name: Install Python packages + - name: Run ruff lint check (E/W/F/I + complexity policy) run: | - pip install -r docker/api/requirements-tests.txt - - - if: env.CHANGED_FILES - uses: marian-code/python-lint-annotate@v4 - with: - python-root-list: ${{ env.CHANGED_FILES }} - use-black: false - use-flake8: false - use-isort: false - use-mypy: false - use-pycodestyle: true - use-pydocstyle: false - use-vulture: false - python-version: "3.10" + /opt/venv/bin/ruff check --line-length 110 --select E,W,F,I,C901 --ignore E203 api tests diff --git a/api/admin.py b/api/admin.py index c0123d12..d673cba2 100644 --- a/api/admin.py +++ b/api/admin.py @@ -9,11 +9,12 @@ """Command line utility for creating an admin user""" -import asyncio import argparse -import sys +import asyncio import getpass import os +import sys + import pymongo from .auth import Authentication @@ -28,10 +29,7 @@ async def setup_admin_user(db, username, email, password=None): if not password: password = os.getenv("KCI_INITIAL_PASSWORD") if not password: - print( - "Password is empty and KCI_INITIAL_PASSWORD is not set, " - "aborting." - ) + print("Password is empty and KCI_INITIAL_PASSWORD is not set, aborting.") return None else: retyped = getpass.getpass(f"Retype password for user '{username}': ") @@ -42,13 +40,15 @@ async def setup_admin_user(db, username, email, password=None): hashed_password = Authentication.get_password_hash(password) print(f"Creating {username} user...") try: - return await db.create(User( - username=username, - hashed_password=hashed_password, - email=email, - is_superuser=1, - is_verified=1, - )) + return await db.create( + User( + username=username, + hashed_password=hashed_password, + email=email, + is_superuser=1, + is_verified=1, + ) + ) except pymongo.errors.DuplicateKeyError as exc: err = str(exc) if "username" in err: @@ -67,25 +67,23 @@ async def main(args): db = Database(args.mongo, args.database) await db.initialize_beanie() await db.create_indexes() - created = await setup_admin_user( - db, args.username, args.email, password=args.password - ) + created = await setup_admin_user(db, args.username, args.email, password=args.password) return created is not None -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser("Create KernelCI API admin user") - parser.add_argument('--mongo', default='mongodb://db:27017', - help="Mongo server connection string") - parser.add_argument('--username', default='admin', - help="Admin username") - parser.add_argument('--database', default='kernelci', - help="KernelCI database name") - parser.add_argument('--email', required=True, - help="Admin user email address") parser.add_argument( - '--password', - default='', + "--mongo", + default="mongodb://db:27017", + help="Mongo server connection string", + ) + parser.add_argument("--username", default="admin", help="Admin username") + parser.add_argument("--database", default="kernelci", help="KernelCI database name") + parser.add_argument("--email", required=True, help="Admin user email address") + parser.add_argument( + "--password", + default="", help="Admin password (if empty, falls back to KCI_INITIAL_PASSWORD)", ) arguments = parser.parse_args() diff --git a/api/auth.py b/api/auth.py index 2b9fa02d..7a324ed0 100644 --- a/api/auth.py +++ b/api/auth.py @@ -6,12 +6,13 @@ """User authentication utilities""" -from passlib.context import CryptContext from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, JWTStrategy, ) +from passlib.context import CryptContext + from .config import AuthSettings @@ -34,7 +35,7 @@ def get_jwt_strategy(self) -> JWTStrategy: return JWTStrategy( secret=self._settings.secret_key, algorithm=self._settings.algorithm, - lifetime_seconds=self._settings.access_token_expire_seconds + lifetime_seconds=self._settings.access_token_expire_seconds, ) def get_user_authentication_backend(self): diff --git a/api/config.py b/api/config.py index 815b859c..111357a0 100644 --- a/api/config.py +++ b/api/config.py @@ -12,6 +12,7 @@ # pylint: disable=too-few-public-methods class AuthSettings(BaseSettings): """Authentication settings""" + secret_key: str algorithm: str = "HS256" # Set to None so tokens don't expire @@ -23,6 +24,7 @@ class AuthSettings(BaseSettings): # pylint: disable=too-few-public-methods class PubSubSettings(BaseSettings): """Pub/Sub settings loaded from the environment""" + cloud_events_source: str = "https://api.kernelci.org/" redis_host: str = "redis" redis_db_number: int = 1 @@ -36,6 +38,7 @@ class PubSubSettings(BaseSettings): # pylint: disable=too-few-public-methods class EmailSettings(BaseSettings): """Email settings""" + smtp_host: str smtp_port: int email_sender: EmailStr diff --git a/api/db.py b/api/db.py index b6254b16..ed2301b3 100644 --- a/api/db.py +++ b/api/db.py @@ -6,14 +6,19 @@ """Database abstraction""" -from bson import ObjectId from beanie import init_beanie +from bson import ObjectId from fastapi_pagination.ext.motor import paginate -from motor import motor_asyncio -from redis import asyncio as aioredis from kernelci.api.models import ( - EventHistory, Hierarchy, Node, TelemetryEvent, parse_node_obj + EventHistory, + Hierarchy, + Node, + TelemetryEvent, + parse_node_obj, ) +from motor import motor_asyncio +from redis import asyncio as aioredis + from .models import User, UserGroup @@ -26,33 +31,30 @@ class Database: """ COLLECTIONS = { - User: 'user', - Node: 'node', - UserGroup: 'usergroup', - EventHistory: 'eventhistory', - TelemetryEvent: 'telemetry', + User: "user", + Node: "node", + UserGroup: "usergroup", + EventHistory: "eventhistory", + TelemetryEvent: "telemetry", } OPERATOR_MAP = { - 'lt': '$lt', - 'lte': '$lte', - 'gt': '$gt', - 'gte': '$gte', - 'ne': '$ne', - 're': '$regex', - 'in': '$in', - 'nin': '$nin', + "lt": "$lt", + "lte": "$lte", + "gt": "$gt", + "gte": "$gte", + "ne": "$ne", + "re": "$regex", + "in": "$in", + "nin": "$nin", } - BOOL_VALUE_MAP = { - 'true': True, - 'false': False - } + BOOL_VALUE_MAP = {"true": True, "false": False} - def __init__(self, service='mongodb://db:27017', db_name='kernelci'): + def __init__(self, service="mongodb://db:27017", db_name="kernelci"): self._motor = motor_asyncio.AsyncIOMotorClient(service) # TBD: Make redis host configurable - self._redis = aioredis.from_url('redis://redis:6379') + self._redis = aioredis.from_url("redis://redis:6379") self._db = self._motor[db_name] async def initialize_beanie(self): @@ -143,14 +145,13 @@ def _translate_operators(self, attributes): for op_name, op_value in value.items(): op_key = self.OPERATOR_MAP.get(op_name) if op_key: - if op_key in ('$in', '$nin'): + if op_key in ("$in", "$nin"): # Create a list of values from ',' separated string op_value = op_value.split(",") if isinstance(op_value, str) and op_value.isdecimal(): op_value = int(op_value) if translated_attributes.get(key): - translated_attributes[key].update({ - op_key: op_value}) + translated_attributes[key].update({op_key: op_value}) else: translated_attributes[key] = {op_key: op_value} return translated_attributes @@ -160,7 +161,7 @@ def _convert_int_values(cls, attributes): for key, val in attributes.items(): if isinstance(val, dict): for sub_key, sub_val in val.items(): - if sub_key == 'int': + if sub_key == "int": attributes[key] = int(sub_val) return attributes @@ -205,14 +206,13 @@ async def find_by_attributes_nonpaginated(self, model, attributes): query = self._prepare_query(attributes) # find "limit" and "offset" keys in the query, retrieve them and # remove them from the query - limit = query.pop('limit', None) - offset = query.pop('offset', None) + limit = query.pop("limit", None) + offset = query.pop("offset", None) # convert to int if limit and offset are strings limit = int(limit) if limit is not None else None offset = int(offset) if offset is not None else None if limit is not None and offset is not None: - return await (col.find(query) - .skip(offset).limit(limit).to_list(None)) + return await col.find(query).skip(offset).limit(limit).to_list(None) if limit is not None: return await col.find(query).limit(limit).to_list(None) if offset is not None: @@ -239,7 +239,7 @@ async def create(self, obj): """ if obj.id is not None: raise ValueError(f"Object cannot be created with id: {obj.id}") - delattr(obj, 'id') + delattr(obj, "id") col = self._get_collection(obj.__class__) res = await col.insert_one(obj.model_dump(by_alias=True)) obj.id = res.inserted_id @@ -251,8 +251,7 @@ async def insert_many(self, model, documents): result = await col.insert_many(documents) return result.inserted_ids - async def _create_recursively(self, hierarchy: Hierarchy, parent: Node, - cls, col): + async def _create_recursively(self, hierarchy: Hierarchy, parent: Node, cls, col): obj = parse_node_obj(hierarchy.node) if parent: obj.parent = parent.id @@ -260,13 +259,11 @@ async def _create_recursively(self, hierarchy: Hierarchy, parent: Node, obj.update() if obj.parent == obj.id: raise ValueError("Parent cannot be the same as the object") - res = await col.replace_one( - {'_id': ObjectId(obj.id)}, obj.dict(by_alias=True) - ) + res = await col.replace_one({"_id": ObjectId(obj.id)}, obj.dict(by_alias=True)) if res.matched_count == 0: raise ValueError(f"No object found with id: {obj.id}") else: - delattr(obj, 'id') + delattr(obj, "id") res = await col.insert_one(obj.dict(by_alias=True)) obj.id = res.inserted_id obj = cls(**await col.find_one(ObjectId(obj.id))) @@ -296,9 +293,7 @@ async def update(self, obj): obj.update() if obj.parent == obj.id: raise ValueError("Parent cannot be the same as the object") - res = await col.replace_one( - {'_id': ObjectId(obj.id)}, obj.dict(by_alias=True) - ) + res = await col.replace_one({"_id": ObjectId(obj.id)}, obj.dict(by_alias=True)) if res.matched_count == 0: raise ValueError(f"No object found with id: {obj.id}") return obj.__class__(**await col.find_one(ObjectId(obj.id))) diff --git a/api/email_sender.py b/api/email_sender.py index 4e3f3a1e..489f11ba 100644 --- a/api/email_sender.py +++ b/api/email_sender.py @@ -7,30 +7,30 @@ """SMTP Email Sender module""" -from email.mime.multipart import MIMEMultipart import email import email.mime.text import smtplib +from email.mime.multipart import MIMEMultipart + from fastapi import HTTPException, status + from .config import EmailSettings class EmailSender: # pylint: disable=too-few-public-methods """Class to send email report using SMTP""" + def __init__(self): self._settings = EmailSettings() def _smtp_connect(self): """Method to create a connection with SMTP server""" if self._settings.smtp_port == 465: - smtp = smtplib.SMTP_SSL(self._settings.smtp_host, - self._settings.smtp_port) + smtp = smtplib.SMTP_SSL(self._settings.smtp_host, self._settings.smtp_port) else: - smtp = smtplib.SMTP(self._settings.smtp_host, - self._settings.smtp_port) + smtp = smtplib.SMTP(self._settings.smtp_host, self._settings.smtp_port) smtp.starttls() - smtp.login(self._settings.email_sender, - self._settings.email_password) + smtp.login(self._settings.email_sender, self._settings.email_password) return smtp def _create_email(self, email_subject, email_content, email_recipient): @@ -38,12 +38,12 @@ def _create_email(self, email_subject, email_content, email_recipient): sender, and receiver""" email_msg = MIMEMultipart() email_text = email.mime.text.MIMEText(email_content, "plain", "utf-8") - email_text.replace_header('Content-Transfer-Encoding', 'quopri') - email_text.set_payload(email_content, 'utf-8') + email_text.replace_header("Content-Transfer-Encoding", "quopri") + email_text.set_payload(email_content, "utf-8") email_msg.attach(email_text) - email_msg['To'] = email_recipient - email_msg['From'] = self._settings.email_sender - email_msg['Subject'] = email_subject + email_msg["To"] = email_recipient + email_msg["From"] = self._settings.email_sender + email_msg["Subject"] = email_subject return email_msg def _send_email(self, email_msg): @@ -57,13 +57,10 @@ def _send_email(self, email_msg): print(f"Error in sending email: {str(exc)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to send email" - )from exc + detail="Failed to send email", + ) from exc - def create_and_send_email(self, email_subject, email_content, - email_recipient): + def create_and_send_email(self, email_subject, email_content, email_recipient): """Method to create and send email""" - email_msg = self._create_email( - email_subject, email_content, email_recipient - ) + email_msg = self._create_email(email_subject, email_content, email_recipient) self._send_email(email_msg) diff --git a/api/main.py b/api/main.py index ac67e6c4..17b10ac1 100644 --- a/api/main.py +++ b/api/main.py @@ -9,89 +9,89 @@ """KernelCI API main module""" +import asyncio +import ipaddress import os import re -import asyncio -import traceback import secrets -import ipaddress -import pymongo -from typing import List, Union, Optional -from datetime import datetime, timedelta, timezone +import traceback from contextlib import asynccontextmanager +from datetime import datetime, timedelta, timezone +from typing import List, Optional, Union + +import pymongo +from beanie import PydanticObjectId +from bson import ObjectId, errors from fastapi import ( + Body, Depends, FastAPI, - HTTPException, - status, - Request, Form, Header, + HTTPException, Query, - Body, + Request, Response, + status, ) from fastapi.encoders import jsonable_encoder from fastapi.responses import ( + FileResponse, + HTMLResponse, JSONResponse, PlainTextResponse, - FileResponse, - HTMLResponse ) from fastapi.security import OAuth2PasswordRequestForm from fastapi_pagination import add_pagination -from fastapi_versioning import VersionedFastAPI -from bson import ObjectId, errors from fastapi_users import FastAPIUsers -from beanie import PydanticObjectId -from pydantic import BaseModel +from fastapi_versioning import VersionedFastAPI from jose import jwt from jose.exceptions import JWTError from kernelci.api.models import ( - Node, + EventHistory, Hierarchy, - PublishEvent, - parse_node_obj, KernelVersion, - EventHistory, + Node, + PublishEvent, TelemetryEvent, + parse_node_obj, ) +from pydantic import BaseModel + from .auth import Authentication +from .config import AuthSettings from .db import Database -from .pubsub_mongo import PubSub -from .user_manager import get_user_manager, create_user_manager +from .maintenance import purge_old_nodes +from .metrics import Metrics from .models import ( + InviteAcceptRequest, + InviteUrlResponse, PageModel, Subscription, SubscriptionStats, User, - UserRead, UserCreate, UserCreateRequest, + UserGroup, + UserGroupCreateRequest, UserInviteRequest, UserInviteResponse, + UserRead, UserUpdate, UserUpdateRequest, - UserGroup, - UserGroupCreateRequest, - InviteAcceptRequest, - InviteUrlResponse, ) -from .metrics import Metrics -from .maintenance import purge_old_nodes -from .config import AuthSettings +from .pubsub_mongo import PubSub +from .user_manager import create_user_manager, get_user_manager SUBSCRIPTION_CLEANUP_INTERVAL_MINUTES = 15 # How often to run cleanup task -SUBSCRIPTION_MAX_AGE_MINUTES = 15 # Max age before stale -SUBSCRIPTION_CLEANUP_RETRY_MINUTES = 1 # Retry interval if cleanup fails +SUBSCRIPTION_MAX_AGE_MINUTES = 15 # Max age before stale +SUBSCRIPTION_CLEANUP_RETRY_MINUTES = 1 # Retry interval if cleanup fails DEFAULT_MONGO_SERVICE = "mongodb://db:27017" def _validate_startup_environment(): """Validate required environment variables before app initialization.""" - required_env_vars = ( - "SECRET_KEY", - ) + required_env_vars = ("SECRET_KEY",) missing = [] empty = [] for name in required_env_vars: @@ -104,17 +104,12 @@ def _validate_startup_environment(): if missing or empty: details = [] if missing: - details.append( - "missing: " + ", ".join(sorted(missing)) - ) + details.append("missing: " + ", ".join(sorted(missing))) if empty: - details.append( - "empty: " + ", ".join(sorted(empty)) - ) + details.append("empty: " + ", ".join(sorted(empty))) raise RuntimeError( "Startup environment validation failed. " - "Set required environment variables before starting the API. " - + "; ".join(details) + "Set required environment variables before starting the API. " + "; ".join(details) ) @@ -131,14 +126,14 @@ async def lifespan(app: FastAPI): # pylint: disable=redefined-outer-name await ensure_legacy_node_editors() yield + # List of all the supported API versions. This is a placeholder until the API # actually supports multiple versions with different sets of endpoints and # models etc. -API_VERSIONS = ['v0'] +API_VERSIONS = ["v0"] metrics = Metrics() app = FastAPI(lifespan=lifespan, debug=True, docs_url=None, redoc_url=None) - db = Database(service=os.getenv("MONGO_SERVICE", DEFAULT_MONGO_SERVICE)) auth = Authentication(token_url="user/login") pubsub = None # pylint: disable=invalid-name @@ -162,10 +157,9 @@ async def subscription_cleanup_task(): while True: try: await asyncio.sleep(SUBSCRIPTION_CLEANUP_INTERVAL_MINUTES * 60) - cleaned = await pubsub.cleanup_stale_subscriptions( - SUBSCRIPTION_MAX_AGE_MINUTES) + cleaned = await pubsub.cleanup_stale_subscriptions(SUBSCRIPTION_MAX_AGE_MINUTES) if cleaned > 0: - metrics.add('subscriptions_cleaned', 1) + metrics.add("subscriptions_cleaned", 1) print(f"Cleaned up {cleaned} stale subscriptions") except (ConnectionError, OSError, RuntimeError) as e: print(f"Subscription cleanup error: {e}") @@ -189,8 +183,8 @@ async def initialize_beanie(): async def ensure_legacy_node_editors(): """Grant legacy node edit privileges to specific users.""" - legacy_usernames = {'staging.kernelci.org', 'production'} - group_name = 'node:edit:any' + legacy_usernames = {"staging.kernelci.org", "production"} + group_name = "node:edit:any" group = await db.find_one(UserGroup, name=group_name) if not group: group = await db.create(UserGroup(name=group_name)) @@ -213,21 +207,22 @@ async def ensure_initial_admin_user(): initial_password = os.getenv("KCI_INITIAL_PASSWORD") if not initial_password: raise RuntimeError( - "No admin user exists. Set KCI_INITIAL_PASSWORD to bootstrap " - "the initial admin user." + "No admin user exists. Set KCI_INITIAL_PASSWORD to bootstrap the initial admin user." ) username = os.getenv("KCI_INITIAL_ADMIN_USERNAME") or "admin" email = os.getenv("KCI_INITIAL_ADMIN_EMAIL") or f"{username}@kernelci.org" try: - await db.create(User( - username=username, - hashed_password=Authentication.get_password_hash(initial_password), - email=email, - is_superuser=1, - is_verified=1, - )) + await db.create( + User( + username=username, + hashed_password=Authentication.get_password_hash(initial_password), + email=email, + is_superuser=1, + is_verified=1, + ) + ) print(f"Created initial admin user '{username}' ({email}).") except pymongo.errors.DuplicateKeyError as exc: # Handle startup races across multiple API instances. @@ -252,9 +247,7 @@ async def value_error_exception_handler(request: Request, exc: ValueError): @app.exception_handler(errors.InvalidId) -async def invalid_id_exception_handler( - request: Request, - exc: errors.InvalidId): +async def invalid_id_exception_handler(request: Request, exc: errors.InvalidId): """Global exception handler for `errors.InvalidId` The exception is raised from Database when invalid ObjectId is received""" return JSONResponse( @@ -263,27 +256,30 @@ async def invalid_id_exception_handler( ) -@app.get('/') +@app.get("/") async def root(): """Root endpoint handler""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - index_path = os.path.join(root_dir, 'templates', 'index.html') - with open(index_path, 'r', encoding='utf-8') as file: + index_path = os.path.join(root_dir, "templates", "index.html") + with open(index_path, "r", encoding="utf-8") as file: return HTMLResponse(file.read()) + # ----------------------------------------------------------------------------- # Users -def get_current_user(user: User = Depends( - fastapi_users_instance.current_user(active=True))): +def get_current_user( + user: User = Depends(fastapi_users_instance.current_user(active=True)), +): """Get current active user""" return user -def get_current_superuser(user: User = Depends( - fastapi_users_instance.current_user(active=True, superuser=True))): +def get_current_superuser( + user: User = Depends(fastapi_users_instance.current_user(active=True, superuser=True)), +): """Get current active superuser""" return user @@ -340,44 +336,54 @@ def _decode_invite_token(token: str) -> dict: return payload +def _is_proxy_request(request: Request) -> bool: + client = request.client + if not client or not client.host: + return False + + try: + ip_addr = ipaddress.ip_address(client.host) + except ValueError: + return False + + return not ip_addr.is_global + + +def _parse_forwarded_header(header: str) -> tuple[str | None, str | None]: + forwarded_host = None + forwarded_proto = None + + first_hop = header.split(",", 1)[0] + for pair in first_hop.split(";"): + if "=" not in pair: + continue + key, value = pair.split("=", 1) + key = key.strip().lower() + value = value.strip().strip('"') + if key == "host": + forwarded_host = value + elif key == "proto": + forwarded_proto = value + + return forwarded_host, forwarded_proto + + def _resolve_public_base_url(request: Request) -> str: settings = AuthSettings() if settings.public_base_url: return settings.public_base_url.rstrip("/") - def _is_proxy_addr() -> bool: - client = request.client - if not client or not client.host: - return False - try: - ip_addr = ipaddress.ip_address(client.host) - except ValueError: - return False - return not ip_addr.is_global - + is_proxy_request = _is_proxy_request(request) forwarded_host = None forwarded_proto = None + forwarded_header = request.headers.get("forwarded") - if forwarded_header and _is_proxy_addr(): - first_hop = forwarded_header.split(",", 1)[0] - for pair in first_hop.split(";"): - if "=" not in pair: - continue - key, value = pair.split("=", 1) - key = key.strip().lower() - value = value.strip().strip('"') - if key == "host": - forwarded_host = value - elif key == "proto": - forwarded_proto = value - - if _is_proxy_addr(): - forwarded_host = forwarded_host or request.headers.get( - "x-forwarded-host" - ) - forwarded_proto = forwarded_proto or request.headers.get( - "x-forwarded-proto" - ) + if forwarded_header and is_proxy_request: + forwarded_host, forwarded_proto = _parse_forwarded_header(forwarded_header) + + if is_proxy_request: + forwarded_host = forwarded_host or request.headers.get("x-forwarded-host") + forwarded_proto = forwarded_proto or request.headers.get("x-forwarded-proto") if forwarded_host: scheme = forwarded_proto or request.url.scheme @@ -391,7 +397,7 @@ def _accept_invite_url(public_base_url: str) -> str: async def _find_existing_user_for_invite( - invite: UserInviteRequest, + invite: UserInviteRequest, ) -> User | None: existing_by_username = await db.find_one(User, username=invite.username) if existing_by_username: @@ -423,8 +429,8 @@ def _validate_invite_resend(existing_user: User, invite: UserInviteRequest): async def _create_user_for_invite( - request: Request, - invite: UserInviteRequest, + request: Request, + invite: UserInviteRequest, ) -> User: groups: List[UserGroup] = [] if invite.groups: @@ -439,8 +445,7 @@ async def _create_user_for_invite( ) user_create.groups = groups - created_user = await register_router.routes[0].endpoint( - request, user_create, user_manager) + created_user = await register_router.routes[0].endpoint(request, user_create, user_manager) if invite.is_superuser: user_from_id = await db.find_by_id(User, created_user.id) @@ -451,20 +456,25 @@ async def _create_user_for_invite( app.include_router( - fastapi_users_instance.get_auth_router(auth_backend, - requires_verification=True), + fastapi_users_instance.get_auth_router(auth_backend, requires_verification=True), prefix="/user", - tags=["user"] + tags=["user"], ) -register_router = fastapi_users_instance.get_register_router( - UserRead, UserCreate) +register_router = fastapi_users_instance.get_register_router(UserRead, UserCreate) -@app.post("/user/register", response_model=UserRead, tags=["user"], - response_model_by_alias=False) -async def register(request: Request, user: UserCreateRequest, - current_user: User = Depends(get_current_superuser)): +@app.post( + "/user/register", + response_model=UserRead, + tags=["user"], + response_model_by_alias=False, +) +async def register( + request: Request, + user: UserCreateRequest, + current_user: User = Depends(get_current_superuser), +): """User registration route Custom user registration router to ensure unique username. @@ -481,23 +491,30 @@ async def register(request: Request, user: UserCreateRequest, @app.get("/user/invite", response_class=HTMLResponse, include_in_schema=False) async def invite_user_page(): """Web UI for inviting a user (admin token required)""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - page_path = os.path.join(root_dir, 'templates', 'invite.html') - with open(page_path, 'r', encoding='utf-8') as file: + page_path = os.path.join(root_dir, "templates", "invite.html") + with open(page_path, "r", encoding="utf-8") as file: return HTMLResponse(file.read()) -@app.post("/user/invite", response_model=UserInviteResponse, tags=["user"], - response_model_by_alias=False) -async def invite_user(request: Request, invite: UserInviteRequest, - current_user: User = Depends(get_current_superuser)): +@app.post( + "/user/invite", + response_model=UserInviteResponse, + tags=["user"], + response_model_by_alias=False, +) +async def invite_user( + request: Request, + invite: UserInviteRequest, + current_user: User = Depends(get_current_superuser), +): """Invite a user (admin-only) Creates the user with a random password and sends a single invite link to set a password and verify the account. """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) existing_user = await _find_existing_user_for_invite(invite) if existing_user: _validate_invite_resend(existing_user, invite) @@ -539,19 +556,17 @@ async def invite_user(request: Request, invite: UserInviteRequest, ) async def accept_invite_page(): """Web UI for accepting an invite (sets password + verifies)""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - page_path = os.path.join(root_dir, 'templates', 'accept-invite.html') - with open(page_path, 'r', encoding='utf-8') as file: + page_path = os.path.join(root_dir, "templates", "accept-invite.html") + with open(page_path, "r", encoding="utf-8") as file: return HTMLResponse(file.read()) @app.get("/user/invite/url", response_model=InviteUrlResponse, tags=["user"]) -async def invite_url_preview(request: Request, - current_user: User = Depends( - get_current_superuser)): +async def invite_url_preview(request: Request, current_user: User = Depends(get_current_superuser)): """Preview the resolved public URL used in invite links (admin-only)""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) public_base_url = _resolve_public_base_url(request) return InviteUrlResponse( public_base_url=public_base_url, @@ -559,11 +574,15 @@ async def invite_url_preview(request: Request, ) -@app.post("/user/accept-invite", response_model=UserRead, tags=["user"], - response_model_by_alias=False) +@app.post( + "/user/accept-invite", + response_model=UserRead, + tags=["user"], + response_model_by_alias=False, +) async def accept_invite(accept: InviteAcceptRequest): """Accept an invite token, set password, and verify the user""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) payload = _decode_invite_token(accept.token) user_id = payload.get("sub") email = payload.get("email") @@ -590,9 +609,7 @@ async def accept_invite(accept: InviteAcceptRequest): detail="Invite already accepted", ) - user_from_id.hashed_password = user_manager.password_helper.hash( - accept.password - ) + user_from_id.hashed_password = user_manager.password_helper.hash(accept.password) user_from_id.is_verified = True updated_user = await db.update(user_from_id) @@ -609,41 +626,45 @@ async def accept_invite(accept: InviteAcceptRequest): tags=["user"], ) -users_router = fastapi_users_instance.get_users_router( - UserRead, UserUpdate, requires_verification=True) +users_router = fastapi_users_instance.get_users_router(UserRead, UserUpdate, requires_verification=True) app.add_api_route( path="/whoami", tags=["user"], methods=["GET"], description="Get current user information", - endpoint=users_router.routes[0].endpoint) + endpoint=users_router.routes[0].endpoint, +) app.add_api_route( path="/user/{id}", tags=["user"], methods=["GET"], description="Get user information by ID", dependencies=[Depends(get_current_user)], - endpoint=users_router.routes[2].endpoint) + endpoint=users_router.routes[2].endpoint, +) app.add_api_route( path="/user/{id}", tags=["user"], methods=["DELETE"], description="Delete user by ID", dependencies=[Depends(get_current_superuser)], - endpoint=users_router.routes[4].endpoint) + endpoint=users_router.routes[4].endpoint, +) -@app.patch("/user/me", response_model=UserRead, tags=["user"], - response_model_by_alias=False) -async def update_me(request: Request, user: UserUpdateRequest, - current_user: User = Depends(get_current_user)): +@app.patch("/user/me", response_model=UserRead, tags=["user"], response_model_by_alias=False) +async def update_me( + request: Request, + user: UserUpdateRequest, + current_user: User = Depends(get_current_user), +): """User update route Custom user update router handler will only allow users to update its own profile. """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) if user.groups: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -663,26 +684,29 @@ async def update_me(request: Request, user: UserUpdateRequest, if not group: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=( - "User group does not exist with name: " - f"{group_name}" - ), + detail=(f"User group does not exist with name: {group_name}"), ) groups.append(group) - user_update = UserUpdate(**(user.model_dump( - exclude={'groups', 'is_superuser'}, exclude_none=True))) + user_update = UserUpdate(**(user.model_dump(exclude={"groups", "is_superuser"}, exclude_none=True))) if groups: user_update.groups = groups - return await users_router.routes[1].endpoint( - request, user_update, current_user, user_manager) + return await users_router.routes[1].endpoint(request, user_update, current_user, user_manager) -@app.patch("/user/{user_id}", response_model=UserRead, tags=["user"], - response_model_by_alias=False) -async def update_user(user_id: str, request: Request, user: UserUpdateRequest, - current_user: User = Depends(get_current_superuser)): +@app.patch( + "/user/{user_id}", + response_model=UserRead, + tags=["user"], + response_model_by_alias=False, +) +async def update_user( + user_id: str, + request: Request, + user: UserUpdateRequest, + current_user: User = Depends(get_current_superuser), +): """Router to allow admin users to update other user account""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) user_from_id = await db.find_by_id(User, user_id) if not user_from_id: raise HTTPException( @@ -705,21 +729,15 @@ async def update_user(user_id: str, request: Request, user: UserUpdateRequest, if not group: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=( - "User group does not exist with name: " - f"{group_name}" - ), + detail=(f"User group does not exist with name: {group_name}"), ) groups.append(group) - user_update = UserUpdate(**(user.model_dump( - exclude={'groups'}, exclude_none=True))) + user_update = UserUpdate(**(user.model_dump(exclude={"groups"}, exclude_none=True))) if groups: user_update.groups = groups - updated_user = await users_router.routes[3].endpoint( - user_update, request, user_from_id, user_manager - ) + updated_user = await users_router.routes[3].endpoint(user_update, request, user_from_id, user_manager) # Update superuser explicitly since fastapi-users update route ignores it. if user.is_superuser is not None: user_from_id = await db.find_by_id(User, updated_user.id) @@ -729,25 +747,26 @@ async def update_user(user_id: str, request: Request, user: UserUpdateRequest, @app.get("/user-groups", response_model=PageModel, tags=["user"]) -async def get_user_groups(request: Request, - current_user: User = Depends(get_current_superuser)): +async def get_user_groups(request: Request, current_user: User = Depends(get_current_superuser)): """List user groups (admin-only).""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) query_params = dict(request.query_params) - for pg_key in ['limit', 'offset']: + for pg_key in ["limit", "offset"]: query_params.pop(pg_key, None) paginated_resp = await db.find_by_attributes(UserGroup, query_params) - paginated_resp.items = serialize_paginated_data( - UserGroup, paginated_resp.items) + paginated_resp.items = serialize_paginated_data(UserGroup, paginated_resp.items) return paginated_resp -@app.get("/user-groups/{group_id}", response_model=UserGroup, tags=["user"], - response_model_by_alias=False) -async def get_user_group(group_id: str, - current_user: User = Depends(get_current_superuser)): +@app.get( + "/user-groups/{group_id}", + response_model=UserGroup, + tags=["user"], + response_model_by_alias=False, +) +async def get_user_group(group_id: str, current_user: User = Depends(get_current_superuser)): """Get a user group by id (admin-only).""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) group = await db.find_by_id(UserGroup, group_id) if not group: raise HTTPException( @@ -757,13 +776,17 @@ async def get_user_group(group_id: str, return group -@app.post("/user-groups", response_model=UserGroup, tags=["user"], - response_model_by_alias=False) -async def create_user_group(group: UserGroupCreateRequest, - current_user: User = Depends( - get_current_superuser)): +@app.post( + "/user-groups", + response_model=UserGroup, + tags=["user"], + response_model_by_alias=False, +) +async def create_user_group( + group: UserGroupCreateRequest, current_user: User = Depends(get_current_superuser) +): """Create a user group (admin-only).""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) existing = await db.find_one(UserGroup, name=group.name) if existing: raise HTTPException( @@ -773,13 +796,10 @@ async def create_user_group(group: UserGroupCreateRequest, return await db.create(UserGroup(name=group.name)) -@app.delete("/user-groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT, - tags=["user"]) -async def delete_user_group(group_id: str, - current_user: User = Depends( - get_current_superuser)): +@app.delete("/user-groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT, tags=["user"]) +async def delete_user_group(group_id: str, current_user: User = Depends(get_current_superuser)): """Delete a user group (admin-only).""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) group = await db.find_by_id(UserGroup, group_id) if not group: raise HTTPException( @@ -790,10 +810,7 @@ async def delete_user_group(group_id: str, if assigned_count: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail=( - "User group is assigned to users and cannot be deleted. " - "Remove it from users first." - ), + detail=("User group is assigned to users and cannot be deleted. Remove it from users first."), ) await db.delete_by_id(UserGroup, group_id) return Response(status_code=status.HTTP_204_NO_CONTENT) @@ -801,10 +818,10 @@ async def delete_user_group(group_id: str, def _get_node_runtime(node: Node) -> Optional[str]: """Best-effort runtime lookup from node data.""" - data = getattr(node, 'data', None) + data = getattr(node, "data", None) if isinstance(data, dict): - return data.get('runtime') - return getattr(data, 'runtime', None) + return data.get("runtime") + return getattr(data, "runtime", None) def _user_can_edit_node(user: User, node: Node) -> bool: @@ -814,23 +831,20 @@ def _user_can_edit_node(user: User, node: Node) -> bool: if user.username == node.owner: return True user_group_names = {group.name for group in user.groups} - if 'node:edit:any' in user_group_names: + if "node:edit:any" in user_group_names: return True - if any(group_name in user_group_names - for group_name in getattr(node, 'user_groups', [])): + if any(group_name in user_group_names for group_name in getattr(node, "user_groups", [])): return True runtime = _get_node_runtime(node) if runtime: - runtime_editor = ":".join(['runtime', runtime, 'node-editor']) - runtime_admin = ":".join(['runtime', runtime, 'node-admin']) - if (runtime_editor in user_group_names - or runtime_admin in user_group_names): + runtime_editor = ":".join(["runtime", runtime, "node-editor"]) + runtime_admin = ":".join(["runtime", runtime, "node-admin"]) + if runtime_editor in user_group_names or runtime_admin in user_group_names: return True return False -async def authorize_user(node_id: str, - user: User = Depends(get_current_user)): +async def authorize_user(node_id: str, user: User = Depends(get_current_user)): """Return the user if active, authenticated, and authorized""" # Only the user that created the node or any other user from the permitted @@ -839,40 +853,43 @@ async def authorize_user(node_id: str, if not node_from_id: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Node not found with id: {node_id}" + detail=f"Node not found with id: {node_id}", ) if not _user_can_edit_node(user, node_from_id): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Unauthorized to complete the operation" + detail="Unauthorized to complete the operation", ) return user -@app.get('/users', response_model=PageModel, tags=["user"], - response_model_exclude={"items": {"__all__": { - "hashed_password"}}}) -async def get_users(request: Request, - current_user: User = Depends(get_current_user)): +@app.get( + "/users", + response_model=PageModel, + tags=["user"], + response_model_exclude={"items": {"__all__": {"hashed_password"}}}, +) +async def get_users(request: Request, current_user: User = Depends(get_current_user)): """Get all the users if no request parameters have passed. - Get the matching users otherwise.""" - metrics.add('http_requests_total', 1) + Get the matching users otherwise.""" + metrics.add("http_requests_total", 1) query_params = dict(request.query_params) # Drop pagination parameters from query as they're already in arguments - for pg_key in ['limit', 'offset']: + for pg_key in ["limit", "offset"]: query_params.pop(pg_key, None) paginated_resp = await db.find_by_attributes(User, query_params) - paginated_resp.items = serialize_paginated_data( - User, paginated_resp.items) + paginated_resp.items = serialize_paginated_data(User, paginated_resp.items) return paginated_resp @app.post("/user/update-password", tags=["user"]) -async def update_password(request: Request, - credentials: OAuth2PasswordRequestForm = Depends(), - new_password: str = Form(None)): +async def update_password( + request: Request, + credentials: OAuth2PasswordRequestForm = Depends(), + new_password: str = Form(None), +): """Update user password""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) user = await user_manager.authenticate(credentials) if user is None or not user.is_active: raise HTTPException( @@ -881,17 +898,14 @@ async def update_password(request: Request, ) user_update = UserUpdate(password=new_password) user_from_username = await db.find_one(User, username=credentials.username) - await users_router.routes[3].endpoint( - user_update, request, user_from_username, user_manager - ) + await users_router.routes[3].endpoint(user_update, request, user_from_username, user_manager) # EventHistory is now stored by pubsub.publish_cloudevent() # No need for separate _get_eventhistory function -def _parse_event_id_filter(query_params: dict, event_id: str, - event_ids: str) -> None: +def _parse_event_id_filter(query_params: dict, event_id: str, event_ids: str) -> None: """Parse and validate event id/ids filter parameters. Modifies query_params in place to add _id filter. @@ -900,31 +914,66 @@ def _parse_event_id_filter(query_params: dict, event_id: str, if event_id and event_ids: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Provide either id or ids, not both" + detail="Provide either id or ids, not both", ) if event_id: try: - query_params['_id'] = ObjectId(event_id) + query_params["_id"] = ObjectId(event_id) except (errors.InvalidId, TypeError) as exc: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid id format" - ) from exc + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid id format") from exc elif event_ids: try: - ids_list = [ObjectId(x.strip()) - for x in event_ids.split(',') if x.strip()] + ids_list = [ObjectId(x.strip()) for x in event_ids.split(",") if x.strip()] except (errors.InvalidId, TypeError) as exc: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid ids format" - ) from exc + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid ids format") from exc if not ids_list: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="ids must contain at least one id" + detail="ids must contain at least one id", ) - query_params['_id'] = {'$in': ids_list} + query_params["_id"] = {"$in": ids_list} + + +def _apply_simple_filters(query_params: dict, simple_filters: dict): + for param, field in simple_filters.items(): + value = query_params.pop(param, None) + if value: + query_params[field] = value + + +def _apply_node_filter(query_params: dict, node_id: str): + if not node_id: + return + + if "data.id" in query_params: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Provide either node_id or data.id, not both", + ) + query_params["data.id"] = node_id + + +def _apply_from_filter(query_params: dict, from_ts: str): + if not from_ts: + return + + if isinstance(from_ts, str): + try: + from_ts = datetime.fromisoformat(from_ts) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid 'from' parameter, must be an ISO 8601 datetime", + ) from exc + query_params["timestamp"] = {"$gt": from_ts} + + +def _apply_recursive_validation(recursive, limit): + if recursive and (not limit or int(limit) > 1000): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Recursive limit is too large, max is 1000", + ) def _build_events_query(query_params: dict) -> tuple: @@ -933,68 +982,43 @@ def _build_events_query(query_params: dict) -> tuple: Returns (recursive, processed_query_params). Raises HTTPException on validation errors. """ - # Simple filters: param_name -> query_field simple_filters = { - 'kind': 'data.kind', - 'state': 'data.state', - 'result': 'data.result', - 'op': 'data.op', - 'name': 'data.name', - 'group': 'data.group', - 'owner': 'data.owner', - 'channel': 'channel', + "kind": "data.kind", + "state": "data.state", + "result": "data.result", + "op": "data.op", + "name": "data.name", + "group": "data.group", + "owner": "data.owner", + "channel": "channel", } - recursive = query_params.pop('recursive', None) - limit = query_params.pop('limit', None) - from_ts = query_params.pop('from', None) - node_id = query_params.pop('node_id', None) - path = query_params.pop('path', None) + recursive = query_params.pop("recursive", None) + limit = query_params.pop("limit", None) + from_ts = query_params.pop("from", None) + node_id = query_params.pop("node_id", None) + path = query_params.pop("path", None) - # Apply simple filters - for param, field in simple_filters.items(): - value = query_params.pop(param, None) - if value: - query_params[field] = value + _apply_simple_filters(query_params, simple_filters) + _apply_node_filter(query_params, node_id) - if node_id: - if 'data.id' in query_params: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Provide either node_id or data.id, not both" - ) - query_params['data.id'] = node_id - - event_id = query_params.pop('id', None) - event_ids = query_params.pop('ids', None) + event_id = query_params.pop("id", None) + event_ids = query_params.pop("ids", None) _parse_event_id_filter(query_params, event_id, event_ids) - if from_ts: - if isinstance(from_ts, str): - try: - from_ts = datetime.fromisoformat(from_ts) - except ValueError as exc: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid 'from' parameter, must be an ISO 8601 datetime" - ) from exc - query_params['timestamp'] = {'$gt': from_ts} + _apply_from_filter(query_params, from_ts) if path: - query_params['data.path'] = {'$regex': path} + query_params["data.path"] = {"$regex": path} if limit: - query_params['limit'] = int(limit) + query_params["limit"] = int(limit) - if recursive and (not limit or int(limit) > 1000): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Recursive limit is too large, max is 1000" - ) + _apply_recursive_validation(recursive, limit) return recursive, query_params # TBD: Restrict response by Pydantic model -@app.get('/events') +@app.get("/events") async def get_events(request: Request): """Get all the events if no request parameters have passed. Format: [{event1}, {event2}, ...] or if recursive is set to true, @@ -1017,19 +1041,19 @@ async def get_events(request: Request): - recursive: Retrieve node together with event This API endpoint is under development and may change in future. """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) query_params = dict(request.query_params) recursive, query_params = _build_events_query(query_params) resp = await db.find_by_attributes_nonpaginated(EventHistory, query_params) resp_list = [] for item in resp: - item['id'] = str(item['_id']) - item.pop('_id') + item["id"] = str(item["_id"]) + item.pop("_id") if recursive: - node = await db.find_by_id(Node, item['data']['id']) + node = await db.find_by_id(Node, item["data"]["id"]) if node: - item['node'] = node + item["node"] = node resp_list.append(item) json_comp = jsonable_encoder(resp_list) return JSONResponse(content=json_comp) @@ -1042,7 +1066,8 @@ async def get_events(request: Request): # query patterns and allows us to optimize indexes and storage # separately. -@app.post('/telemetry', response_model=dict, tags=["telemetry"]) + +@app.post("/telemetry", response_model=dict, tags=["telemetry"]) async def post_telemetry( events: List[dict], current_user: User = Depends(get_current_user), @@ -1053,7 +1078,7 @@ async def post_telemetry( least 'kind' and 'runtime' fields. Events are validated against the TelemetryEvent model before insertion. """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) if not events: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -1069,13 +1094,13 @@ async def post_telemetry( detail=f"Invalid telemetry event: {exc}", ) from exc doc = obj.model_dump(by_alias=True) - doc.pop('_id', None) + doc.pop("_id", None) docs.append(doc) inserted_ids = await db.insert_many(TelemetryEvent, docs) return {"inserted": len(inserted_ids)} -@app.get('/telemetry', response_model=PageModel, tags=["telemetry"]) +@app.get("/telemetry", response_model=PageModel, tags=["telemetry"]) async def get_telemetry(request: Request): """Query telemetry events with optional filters. @@ -1083,51 +1108,53 @@ async def get_telemetry(request: Request): via 'since' and 'until' parameters (ISO 8601 format). Results are paginated (default limit=50). """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) query_params = dict(request.query_params) - for pg_key in ['limit', 'offset']: + for pg_key in ["limit", "offset"]: query_params.pop(pg_key, None) - since = query_params.pop('since', None) - until = query_params.pop('until', None) + since = query_params.pop("since", None) + until = query_params.pop("until", None) if since or until: ts_filter = {} if since: - ts_filter['$gte'] = datetime.fromisoformat(since) + ts_filter["$gte"] = datetime.fromisoformat(since) if until: - ts_filter['$lte'] = datetime.fromisoformat(until) - query_params['ts'] = ts_filter + ts_filter["$lte"] = datetime.fromisoformat(until) + query_params["ts"] = ts_filter # Convert string 'true'/'false' for boolean fields - if 'is_infra_error' in query_params: - val = query_params['is_infra_error'].lower() - if val not in ['true', 'false']: + if "is_infra_error" in query_params: + val = query_params["is_infra_error"].lower() + if val not in ["true", "false"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Bad is_infra_error value, use 'true' or 'false'", ) - if val == 'true': - query_params['is_infra_error'] = True + if val == "true": + query_params["is_infra_error"] = True else: - query_params['is_infra_error'] = False + query_params["is_infra_error"] = False - paginated_resp = await db.find_by_attributes( - TelemetryEvent, query_params - ) - paginated_resp.items = serialize_paginated_data( - TelemetryEvent, paginated_resp.items - ) + paginated_resp = await db.find_by_attributes(TelemetryEvent, query_params) + paginated_resp.items = serialize_paginated_data(TelemetryEvent, paginated_resp.items) return paginated_resp TELEMETRY_STATS_GROUP_FIELDS = { - 'runtime', 'device_type', 'job_name', 'tree', 'branch', - 'arch', 'kind', 'error_type', + "runtime", + "device_type", + "job_name", + "tree", + "branch", + "arch", + "kind", + "error_type", } -@app.get('/telemetry/stats', tags=["telemetry"]) +@app.get("/telemetry/stats", tags=["telemetry"]) async def get_telemetry_stats(request: Request): """Get aggregated telemetry statistics. @@ -1149,16 +1176,16 @@ async def get_telemetry_stats(request: Request): Returns grouped counts with pass/fail/incomplete/infra_error breakdowns for result-bearing events. """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) query_params = dict(request.query_params) - group_by_str = query_params.pop('group_by', None) + group_by_str = query_params.pop("group_by", None) if not group_by_str: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="'group_by' parameter is required", ) - group_by = [f.strip() for f in group_by_str.split(',')] + group_by = [f.strip() for f in group_by_str.split(",")] invalid = set(group_by) - TELEMETRY_STATS_GROUP_FIELDS if invalid: raise HTTPException( @@ -1168,76 +1195,82 @@ async def get_telemetry_stats(request: Request): match_stage = { key: query_params.pop(key) - for key in ('kind', 'runtime', 'device_type', 'job_name', - 'tree', 'branch', 'arch') + for key in ( + "kind", + "runtime", + "device_type", + "job_name", + "tree", + "branch", + "arch", + ) if query_params.get(key) } - since = query_params.pop('since', None) - until = query_params.pop('until', None) + since = query_params.pop("since", None) + until = query_params.pop("until", None) if since or until: - match_stage['ts'] = { - **({'$gte': datetime.fromisoformat(since)} if since else {}), - **({'$lte': datetime.fromisoformat(until)} if until else {}), + match_stage["ts"] = { + **({"$gte": datetime.fromisoformat(since)} if since else {}), + **({"$lte": datetime.fromisoformat(until)} if until else {}), } - pipeline = [{'$match': match_stage}] if match_stage else [] - pipeline.append({ - '$group': { - '_id': {f: f'${f}' for f in group_by}, - 'total': {'$sum': 1}, - 'pass': {'$sum': { - '$cond': [{'$eq': ['$result', 'pass']}, 1, 0] - }}, - 'fail': {'$sum': { - '$cond': [{'$eq': ['$result', 'fail']}, 1, 0] - }}, - 'incomplete': {'$sum': { - '$cond': [{'$eq': ['$result', 'incomplete']}, 1, 0] - }}, - 'skip': {'$sum': { - '$cond': [{'$eq': ['$result', 'skip']}, 1, 0] - }}, - 'infra_error': {'$sum': { - '$cond': ['$is_infra_error', 1, 0] - }}, + pipeline = [{"$match": match_stage}] if match_stage else [] + pipeline.append( + { + "$group": { + "_id": {f: f"${f}" for f in group_by}, + "total": {"$sum": 1}, + "pass": {"$sum": {"$cond": [{"$eq": ["$result", "pass"]}, 1, 0]}}, + "fail": {"$sum": {"$cond": [{"$eq": ["$result", "fail"]}, 1, 0]}}, + "incomplete": {"$sum": {"$cond": [{"$eq": ["$result", "incomplete"]}, 1, 0]}}, + "skip": {"$sum": {"$cond": [{"$eq": ["$result", "skip"]}, 1, 0]}}, + "infra_error": {"$sum": {"$cond": ["$is_infra_error", 1, 0]}}, + } } - }) - pipeline.append({'$sort': {'total': -1}}) + ) + pipeline.append({"$sort": {"total": -1}}) results = await db.aggregate(TelemetryEvent, pipeline) - return JSONResponse(content=jsonable_encoder([ - { - **doc['_id'].copy(), - 'total': doc['total'], - 'pass': doc['pass'], - 'fail': doc['fail'], - 'incomplete': doc['incomplete'], - 'skip': doc['skip'], - 'infra_error': doc['infra_error'], - } - for doc in results - ])) + return JSONResponse( + content=jsonable_encoder( + [ + { + **doc["_id"].copy(), + "total": doc["total"], + "pass": doc["pass"], + "fail": doc["fail"], + "incomplete": doc["incomplete"], + "skip": doc["skip"], + "infra_error": doc["infra_error"], + } + for doc in results + ] + ) + ) + # This is test value, can adjust based on expected query patterns and volumes. ANOMALY_WINDOW_MAP = { - '1h': 1, '3h': 3, '6h': 6, '12h': 12, '24h': 24, '48h': 48, + "1h": 1, + "3h": 3, + "6h": 6, + "12h": 12, + "24h": 24, + "48h": 48, } -@app.get('/telemetry/anomalies', tags=["telemetry"]) +@app.get("/telemetry/anomalies", tags=["telemetry"]) async def get_telemetry_anomalies( - window: str = Query( - '6h', description='Time window: 1h, 3h, 6h, 12h, 24h, 48h' - ), + window: str = Query("6h", description="Time window: 1h, 3h, 6h, 12h, 24h, 48h"), threshold: float = Query( - 0.5, ge=0.0, le=1.0, - description='Min failure/infra error rate to flag (0.0-1.0)' - ), - min_total: int = Query( - 3, ge=1, - description='Min events in window to consider (avoids noise)' + 0.5, + ge=0.0, + le=1.0, + description="Min failure/infra error rate to flag (0.0-1.0)", ), + min_total: int = Query(3, ge=1, description="Min events in window to consider (avoids noise)"), ): """Detect anomalies in telemetry data. @@ -1247,106 +1280,101 @@ async def get_telemetry_anomalies( Returns a list sorted by severity (highest error rate first). """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) hours = ANOMALY_WINDOW_MAP.get(window) if not hours: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid window '{window}'. " - f"Use: {', '.join(ANOMALY_WINDOW_MAP.keys())}", + detail=f"Invalid window '{window}'. Use: {', '.join(ANOMALY_WINDOW_MAP.keys())}", ) since = datetime.utcnow() - timedelta(hours=hours) # Anomaly 1: High infra error / failure rate per runtime+device_type result_pipeline = [ - {'$match': { - 'kind': {'$in': ['job_result', 'test_result']}, - 'ts': {'$gte': since}, - }}, - {'$group': { - '_id': { - 'runtime': '$runtime', - 'device_type': '$device_type', - }, - 'total': {'$sum': 1}, - 'fail': {'$sum': { - '$cond': [{'$eq': ['$result', 'fail']}, 1, 0] - }}, - 'incomplete': {'$sum': { - '$cond': [{'$eq': ['$result', 'incomplete']}, 1, 0] - }}, - 'infra_error': {'$sum': { - '$cond': ['$is_infra_error', 1, 0] - }}, - }}, - {'$match': {'total': {'$gte': min_total}}}, - {'$addFields': { - 'infra_rate': { - '$divide': ['$infra_error', '$total'] - }, - 'fail_rate': { - '$divide': [ - {'$add': ['$fail', '$incomplete']}, '$total' + { + "$match": { + "kind": {"$in": ["job_result", "test_result"]}, + "ts": {"$gte": since}, + } + }, + { + "$group": { + "_id": { + "runtime": "$runtime", + "device_type": "$device_type", + }, + "total": {"$sum": 1}, + "fail": {"$sum": {"$cond": [{"$eq": ["$result", "fail"]}, 1, 0]}}, + "incomplete": {"$sum": {"$cond": [{"$eq": ["$result", "incomplete"]}, 1, 0]}}, + "infra_error": {"$sum": {"$cond": ["$is_infra_error", 1, 0]}}, + } + }, + {"$match": {"total": {"$gte": min_total}}}, + { + "$addFields": { + "infra_rate": {"$divide": ["$infra_error", "$total"]}, + "fail_rate": {"$divide": [{"$add": ["$fail", "$incomplete"]}, "$total"]}, + } + }, + { + "$match": { + "$or": [ + {"infra_rate": {"$gte": threshold}}, + {"fail_rate": {"$gte": threshold}}, ] - }, - }}, - {'$match': { - '$or': [ - {'infra_rate': {'$gte': threshold}}, - {'fail_rate': {'$gte': threshold}}, - ] - }}, - {'$sort': {'infra_rate': -1, 'fail_rate': -1}}, + } + }, + {"$sort": {"infra_rate": -1, "fail_rate": -1}}, ] # Anomaly 2: Submission/connectivity errors per runtime error_pipeline = [ - {'$match': { - 'kind': {'$in': ['runtime_error', 'job_skip']}, - 'ts': {'$gte': since}, - }}, - {'$group': { - '_id': { - 'runtime': '$runtime', - 'error_type': '$error_type', - }, - 'count': {'$sum': 1}, - }}, - {'$match': {'count': {'$gte': min_total}}}, - {'$sort': {'count': -1}}, + { + "$match": { + "kind": {"$in": ["runtime_error", "job_skip"]}, + "ts": {"$gte": since}, + } + }, + { + "$group": { + "_id": { + "runtime": "$runtime", + "error_type": "$error_type", + }, + "count": {"$sum": 1}, + } + }, + {"$match": {"count": {"$gte": min_total}}}, + {"$sort": {"count": -1}}, ] - result_anomalies = await db.aggregate( - TelemetryEvent, result_pipeline - ) - error_anomalies = await db.aggregate( - TelemetryEvent, error_pipeline - ) + result_anomalies = await db.aggregate(TelemetryEvent, result_pipeline) + error_anomalies = await db.aggregate(TelemetryEvent, error_pipeline) output = { - 'window': window, - 'threshold': threshold, - 'min_total': min_total, - 'since': since.isoformat(), - 'result_anomalies': [], - 'error_anomalies': [], + "window": window, + "threshold": threshold, + "min_total": min_total, + "since": since.isoformat(), + "result_anomalies": [], + "error_anomalies": [], } for doc in result_anomalies: - row = doc['_id'].copy() - row['total'] = doc['total'] - row['fail'] = doc['fail'] - row['incomplete'] = doc['incomplete'] - row['infra_error'] = doc['infra_error'] - row['infra_rate'] = round(doc['infra_rate'], 3) - row['fail_rate'] = round(doc['fail_rate'], 3) - output['result_anomalies'].append(row) + row = doc["_id"].copy() + row["total"] = doc["total"] + row["fail"] = doc["fail"] + row["incomplete"] = doc["incomplete"] + row["infra_error"] = doc["infra_error"] + row["infra_rate"] = round(doc["infra_rate"], 3) + row["fail_rate"] = round(doc["fail_rate"], 3) + output["result_anomalies"].append(row) for doc in error_anomalies: - row = doc['_id'].copy() - row['count'] = doc['count'] - output['error_anomalies'].append(row) + row = doc["_id"].copy() + row["count"] = doc["count"] + output["error_anomalies"].append(row) return JSONResponse(content=jsonable_encoder(output)) @@ -1355,17 +1383,17 @@ async def get_telemetry_anomalies( # Nodes def _get_node_event_data(operation, node, is_hierarchy=False): return { - 'op': operation, - 'id': str(node.id), - 'kind': node.kind, - 'name': node.name, - 'path': node.path, - 'group': node.group, - 'state': node.state, - 'result': node.result, - 'owner': node.owner, - 'data': node.data, - 'is_hierarchy': is_hierarchy, + "op": operation, + "id": str(node.id), + "kind": node.kind, + "name": node.name, + "path": node.path, + "group": node.group, + "state": node.state, + "result": node.result, + "owner": node.owner, + "data": node.data, + "is_hierarchy": is_hierarchy, } @@ -1373,22 +1401,21 @@ async def translate_null_query_params(query_params: dict): """Translate null query parameters to None""" translated = query_params.copy() for key, value in query_params.items(): - if value == 'null': + if value == "null": translated[key] = None return translated -@app.get('/node/{node_id}', response_model=Union[Node, None], - response_model_by_alias=False) +@app.get("/node/{node_id}", response_model=Union[Node, None], response_model_by_alias=False) async def get_node(node_id: str): """Get node information from the provided node id""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) try: return await db.find_by_id(Node, node_id) except KeyError as error: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Node not found with the kind: {str(error)}" + detail=f"Node not found with the kind: {str(error)}", ) from error @@ -1403,19 +1430,19 @@ def serialize_paginated_data(model, data: list): """ serialized_data = [] for obj in data: - serialized_data.append(model(**obj).model_dump(mode='json')) + serialized_data.append(model(**obj).model_dump(mode="json")) return serialized_data -@app.get('/nodes', response_model=PageModel) +@app.get("/nodes", response_model=PageModel) async def get_nodes(request: Request): """Get all the nodes if no request parameters have passed. - Get all the matching nodes otherwise, within the pagination limit.""" - metrics.add('http_requests_total', 1) + Get all the matching nodes otherwise, within the pagination limit.""" + metrics.add("http_requests_total", 1) query_params = dict(request.query_params) # Drop pagination parameters from query as they're already in arguments - for pg_key in ['limit', 'offset']: + for pg_key in ["limit", "offset"]: query_params.pop(pg_key, None) query_params = await translate_null_query_params(query_params) @@ -1426,15 +1453,15 @@ async def get_nodes(request: Request): model = Node translated_params = model.translate_fields(query_params) paginated_resp = await db.find_by_attributes(model, translated_params) - paginated_resp.items = serialize_paginated_data( - model, paginated_resp.items) + paginated_resp.items = serialize_paginated_data(model, paginated_resp.items) return paginated_resp except KeyError as error: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Node not found with the kind: {str(error)}" + detail=f"Node not found with the kind: {str(error)}", ) from error + add_pagination(app) @@ -1445,7 +1472,7 @@ async def db_find_node_nonpaginated(query_params): return await db.find_by_attributes_nonpaginated(model, translated_params) -@app.get('/nodes/fast', response_model=List[Node]) +@app.get("/nodes/fast", response_model=List[Node]) async def get_nodes_fast(request: Request): """Get all the nodes if no request parameters have passed. This is non-paginated version of get_nodes. @@ -1459,28 +1486,25 @@ async def get_nodes_fast(request: Request): try: # Query using the base Node model, regardless of the specific # node type, use asyncio.wait_for with timeout 30 seconds - resp = await asyncio.wait_for( - db_find_node_nonpaginated(query_params), - timeout=15 - ) + resp = await asyncio.wait_for(db_find_node_nonpaginated(query_params), timeout=15) return resp except asyncio.TimeoutError as error: raise HTTPException( status_code=status.HTTP_504_GATEWAY_TIMEOUT, - detail=f"Timeout while fetching nodes: {str(error)}" + detail=f"Timeout while fetching nodes: {str(error)}", ) from error except KeyError as error: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Node not found with the kind: {str(error)}" + detail=f"Node not found with the kind: {str(error)}", ) from error -@app.get('/count', response_model=int) +@app.get("/count", response_model=int) async def get_nodes_count(request: Request): """Get the count of all the nodes if no request parameters have passed. - Get the count of all the matching nodes otherwise.""" - metrics.add('http_requests_total', 1) + Get the count of all the matching nodes otherwise.""" + metrics.add("http_requests_total", 1) query_params = dict(request.query_params) query_params = await translate_null_query_params(query_params) @@ -1494,7 +1518,7 @@ async def get_nodes_count(request: Request): except KeyError as error: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Node not found with the kind: {str(error)}" + detail=f"Node not found with the kind: {str(error)}", ) from error @@ -1504,26 +1528,29 @@ async def _verify_user_group_existence(user_groups: List[str]): if not await db.find_one(UserGroup, name=group_name): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"User group does not exist with name: {group_name}") + detail=f"User group does not exist with name: {group_name}", + ) def _translate_version_fields(node: Node): """Translate Node version fields""" data = node.data if data: - version = data.get('kernel_revision', {}).get('version') + version = data.get("kernel_revision", {}).get("version") if version: version = KernelVersion.translate_version_fields(version) - node.data['kernel_revision']['version'] = version + node.data["kernel_revision"]["version"] = version return node -@app.post('/node', response_model=Node, response_model_by_alias=False) -async def post_node(node: Node, - authorization: str | None = Header(default=None), - current_user: User = Depends(get_current_user)): +@app.post("/node", response_model=Node, response_model_by_alias=False) +async def post_node( + node: Node, + authorization: str | None = Header(default=None), + current_user: User = Depends(get_current_user), +): """Create a new node""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) # [TODO] Remove translation below once we can use it in the pipeline node = _translate_version_fields(node) @@ -1536,7 +1563,7 @@ async def post_node(node: Node, if not parent: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Parent not found with id: {node.parent}" + detail=f"Parent not found with id: {node.parent}", ) await _verify_user_group_existence(node.user_groups) @@ -1546,17 +1573,17 @@ async def post_node(node: Node, # specific kind. The concrete Node submodel (Kbuild, Checkout, etc.) # is only used for data format validation obj = await db.create(node) - data = _get_node_event_data('created', obj) + data = _get_node_event_data("created", obj) attributes = {} - if data.get('owner', None): - attributes['owner'] = data['owner'] + if data.get("owner", None): + attributes["owner"] = data["owner"] # publish_cloudevent now stores to eventhistory collection - await pubsub.publish_cloudevent('node', data, attributes) + await pubsub.publish_cloudevent("node", data, attributes) return obj def is_same_flags(old_node, new_node): - """ Compare processed_by_kcidb_bridge flags + """Compare processed_by_kcidb_bridge flags Returns True if flags are same, False otherwise """ old_flag = old_node.processed_by_kcidb_bridge @@ -1566,18 +1593,21 @@ def is_same_flags(old_node, new_node): return False -@app.put('/node/{node_id}', response_model=Node, response_model_by_alias=False) -async def put_node(node_id: str, node: Node, - user: str = Depends(authorize_user), - noevent: Optional[bool] = Query(None)): +@app.put("/node/{node_id}", response_model=Node, response_model_by_alias=False) +async def put_node( + node_id: str, + node: Node, + user: str = Depends(authorize_user), + noevent: Optional[bool] = Query(None), +): """Update an already added node""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) node.id = ObjectId(node_id) node_from_id = await db.find_by_id(Node, node_id) if not node_from_id: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Node not found with id: {node.id}" + detail=f"Node not found with id: {node.id}", ) # [TODO] Remove translation below once we can use it in the pipeline @@ -1586,20 +1616,15 @@ async def put_node(node_id: str, node: Node, # Sanity checks # Note: do not update node ownership fields, don't update 'state' # until we've checked the state transition is valid. - update_data = node.model_dump( - exclude={'owner', 'submitter', 'user_groups', 'state'}) + update_data = node.model_dump(exclude={"owner", "submitter", "user_groups", "state"}) new_node_def = node_from_id.model_copy(update=update_data) # 1- Parse and validate node to specific subtype specialized_node = parse_node_obj(new_node_def) # 2 - State transition checks - is_valid, message = specialized_node.validate_node_state_transition( - node.state) + is_valid, message = specialized_node.validate_node_state_transition(node.state) if not is_valid: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=message - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=message) # if state changes, reset processed_by_kcidb_bridge flag if node.state != new_node_def.state: new_node_def.processed_by_kcidb_bridge = False @@ -1614,31 +1639,31 @@ async def put_node(node_id: str, node: Node, # Update node in the DB obj = await db.update(new_node_def) - data = _get_node_event_data('updated', obj) + data = _get_node_event_data("updated", obj) attributes = {} - if data.get('owner', None): - attributes['owner'] = data['owner'] + if data.get("owner", None): + attributes["owner"] = data["owner"] if not noevent: # publish_cloudevent now stores to eventhistory collection - await pubsub.publish_cloudevent('node', data, attributes) + await pubsub.publish_cloudevent("node", data, attributes) return obj class NodeUpdateRequest(BaseModel): """Request model for updating multiple nodes""" + nodes: List[str] field: str value: str -@app.put('/batch/nodeset', response_model=int) -async def put_batch_nodeset(data: NodeUpdateRequest, - user: str = Depends(get_current_user)): +@app.put("/batch/nodeset", response_model=int) +async def put_batch_nodeset(data: NodeUpdateRequest, user: str = Depends(get_current_user)): """ Set a field to a value for multiple nodes TBD: Make db.bulkupdate to update multiple nodes in one go """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) updated = 0 nodes = data.nodes field = data.field @@ -1648,31 +1673,27 @@ async def put_batch_nodeset(data: NodeUpdateRequest, if not node_from_id: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Node not found with id: {node_id}" + detail=f"Node not found with id: {node_id}", ) # Verify authorization, and ignore if not permitted. if not _user_can_edit_node(user, node_from_id): continue # right now we support only field: # processed_by_kcidb_bridge, also value should be boolean - if field == 'processed_by_kcidb_bridge': - if value in ['true', 'True']: + if field == "processed_by_kcidb_bridge": + if value in ["true", "True"]: value = True - elif value in ['false', 'False']: + elif value in ["false", "False"]: value = False setattr(node_from_id, field, value) await db.update(node_from_id) updated += 1 else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Field not supported" - ) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Field not supported") return updated -async def _set_node_ownership_recursively(user: User, hierarchy: Hierarchy, - submitter: str, treeid: str): +async def _set_node_ownership_recursively(user: User, hierarchy: Hierarchy, submitter: str, treeid: str): """Set node ownership information for a hierarchy of nodes""" if not hierarchy.node.owner: hierarchy.node.owner = user.username @@ -1682,57 +1703,59 @@ async def _set_node_ownership_recursively(user: User, hierarchy: Hierarchy, await _set_node_ownership_recursively(user, node, submitter, treeid) -@app.put('/nodes/{node_id}', response_model=List[Node], - response_model_by_alias=False) +@app.put("/nodes/{node_id}", response_model=List[Node], response_model_by_alias=False) async def put_nodes( - node_id: str, nodes: Hierarchy, - authorization: str | None = Header(default=None), - user: str = Depends(authorize_user)): + node_id: str, + nodes: Hierarchy, + authorization: str | None = Header(default=None), + user: str = Depends(authorize_user), +): """Add a hierarchy of nodes to an existing root node""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) nodes.node.id = ObjectId(node_id) # Retrieve the root node from the DB and submitter node_from_id = await db.find_by_id(Node, node_id) if not node_from_id: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Node not found with id: {node_id}" + detail=f"Node not found with id: {node_id}", ) submitter = node_from_id.submitter treeid = node_from_id.treeid await _set_node_ownership_recursively(user, nodes, submitter, treeid) obj_list = await db.create_hierarchy(nodes, Node) - data = _get_node_event_data('updated', obj_list[0], True) + data = _get_node_event_data("updated", obj_list[0], True) attributes = {} - if data.get('owner', None): - attributes['owner'] = data['owner'] + if data.get("owner", None): + attributes["owner"] = data["owner"] # publish_cloudevent now stores to eventhistory collection - await pubsub.publish_cloudevent('node', data, attributes) + await pubsub.publish_cloudevent("node", data, attributes) return obj_list # ----------------------------------------------------------------------------- # Key/Value namespace enabled store -@app.get('/kv/{namespace}/{key}', response_model=Union[str, None]) -async def get_kv(namespace: str, key: str, - user: User = Depends(get_current_user)): - +@app.get("/kv/{namespace}/{key}", response_model=Union[str, None]) +async def get_kv(namespace: str, key: str, user: User = Depends(get_current_user)): """Get a key value pair from the store""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) return await db.get_kv(namespace, key) -@app.post('/kv/{namespace}/{key}', response_model=Optional[str]) -async def post_kv(namespace: str, key: str, - value: Optional[str] = Body(default=None), - user: User = Depends(get_current_user)): +@app.post("/kv/{namespace}/{key}", response_model=Optional[str]) +async def post_kv( + namespace: str, + key: str, + value: Optional[str] = Body(default=None), + user: User = Depends(get_current_user), +): """Set a key-value pair in the store namespace and key are part of the URL value is part of the request body. If value is not provided, we need to call delete_kv to remove the key. """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) if not value: await db.del_kv(namespace, key) return "OK" @@ -1743,11 +1766,10 @@ async def post_kv(namespace: str, key: str, # Delete a key-value pair from the store -@app.delete('/kv/{namespace}/{key}', response_model=Optional[str]) -async def delete_kv(namespace: str, key: str, - user: User = Depends(get_current_user)): +@app.delete("/kv/{namespace}/{key}", response_model=Optional[str]) +async def delete_kv(namespace: str, key: str, user: User = Depends(get_current_user)): """Delete a key-value pair from the store""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) await db.del_kv(namespace, key) response = "Key-value pair deleted successfully" return response @@ -1756,17 +1778,21 @@ async def delete_kv(namespace: str, key: str, # ----------------------------------------------------------------------------- # Pub/Sub -@app.post('/subscribe/{channel}', response_model=Subscription) -async def subscribe(channel: str, user: User = Depends(get_current_user), - promisc: Optional[bool] = Query(None), - subscriber_id: Optional[str] = Query( - None, - description="Unique subscriber ID for durable " - "delivery. If provided, missed events " - "will be delivered on reconnection. " - "Without this, events are " - "fire-and-forget." - )): + +@app.post("/subscribe/{channel}", response_model=Subscription) +async def subscribe( + channel: str, + user: User = Depends(get_current_user), + promisc: Optional[bool] = Query(None), + subscriber_id: Optional[str] = Query( + None, + description="Unique subscriber ID for durable " + "delivery. If provided, missed events " + "will be delivered on reconnection. " + "Without this, events are " + "fire-and-forget.", + ), +): """Subscribe handler for Pub/Sub channel Args: @@ -1778,65 +1804,61 @@ async def subscribe(channel: str, user: User = Depends(get_current_user), identifier like "scheduler-prod-1" or "dashboard-main". Without subscriber_id, standard fire-and-forget pub/sub. """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) options = {} if promisc: - options['promiscuous'] = promisc + options["promiscuous"] = promisc if subscriber_id: - options['subscriber_id'] = subscriber_id + options["subscriber_id"] = subscriber_id return await pubsub.subscribe(channel, user.username, options) -@app.post('/unsubscribe/{sub_id}') +@app.post("/unsubscribe/{sub_id}") async def unsubscribe(sub_id: int, user: User = Depends(get_current_user)): """Unsubscribe handler for Pub/Sub channel""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) try: await pubsub.unsubscribe(sub_id, user.username) except KeyError as error: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Subscription id not found: {str(error)}" + detail=f"Subscription id not found: {str(error)}", ) from error except RuntimeError as error: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(error) - ) from error + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)) from error -@app.get('/listen/{sub_id}') +@app.get("/listen/{sub_id}") async def listen(sub_id: int, user: User = Depends(get_current_user)): """Listen messages from a subscribed Pub/Sub channel""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) try: return await pubsub.listen(sub_id, user.username) except KeyError as error: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Subscription id not found: {str(error)}" + detail=f"Subscription id not found: {str(error)}", ) from error except RuntimeError as error: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error while listening to sub id {sub_id}: {str(error)}" + detail=f"Error while listening to sub id {sub_id}: {str(error)}", ) from error -@app.post('/publish/{channel}') -async def publish(event: PublishEvent, channel: str, - user: User = Depends(get_current_user)): +@app.post("/publish/{channel}") +async def publish(event: PublishEvent, channel: str, user: User = Depends(get_current_user)): """Publish an event on the provided Pub/Sub channel""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) event_dict = PublishEvent.dict(event) # 1 - Extract data and attributes from the event # 2 - Add the owner as an extra attribute # 3 - Collect all the other extra attributes, if available, without # overwriting any of the standard ones in the dict - data = event_dict.pop('data') + data = event_dict.pop("data") extra_attributes = event_dict.pop("attributes") attributes = event_dict - attributes['owner'] = user.username + attributes["owner"] = user.username if extra_attributes: for k in extra_attributes: if k not in attributes: @@ -1844,212 +1866,198 @@ async def publish(event: PublishEvent, channel: str, await pubsub.publish_cloudevent(channel, data, attributes) -@app.post('/push/{list_name}') -async def push(raw: dict, list_name: str, - user: User = Depends(get_current_user)): +@app.post("/push/{list_name}") +async def push(raw: dict, list_name: str, user: User = Depends(get_current_user)): """Push a message on the provided list""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) attributes = dict(raw) - data = attributes.pop('data') + data = attributes.pop("data") await pubsub.push_cloudevent(list_name, data, attributes) -@app.get('/pop/{list_name}') +@app.get("/pop/{list_name}") async def pop(list_name: str, user: User = Depends(get_current_user)): """Pop a message from a given list""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) return await pubsub.pop(list_name) -@app.get('/stats/subscriptions', response_model=List[SubscriptionStats]) +@app.get("/stats/subscriptions", response_model=List[SubscriptionStats]) async def stats(user: User = Depends(get_current_superuser)): """Get details of all existing subscriptions""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) return await pubsub.subscription_stats() -@app.get('/viewer') +@app.get("/viewer") async def viewer(): """Serve simple HTML page to view the API /static/viewer.html Set various no-cache tag we might update it often""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - viewer_path = os.path.join(root_dir, 'templates', 'viewer.html') - with open(viewer_path, 'r', encoding='utf-8') as file: + viewer_path = os.path.join(root_dir, "templates", "viewer.html") + with open(viewer_path, "r", encoding="utf-8") as file: # set header to text/html and no-cache stuff hdr = { - 'Content-Type': 'text/html', - 'Cache-Control': 'no-cache, no-store, must-revalidate', - 'Pragma': 'no-cache', - 'Expires': '0' + "Content-Type": "text/html", + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", } return PlainTextResponse(file.read(), headers=hdr) -@app.get('/dashboard') +@app.get("/dashboard") async def dashboard(): """Serve simple HTML page to view the API dashboard.html Set various no-cache tag we might update it often""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - dashboard_path = os.path.join(root_dir, 'templates', 'dashboard.html') - with open(dashboard_path, 'r', encoding='utf-8') as file: + dashboard_path = os.path.join(root_dir, "templates", "dashboard.html") + with open(dashboard_path, "r", encoding="utf-8") as file: # set header to text/html and no-cache stuff hdr = { - 'Content-Type': 'text/html', - 'Cache-Control': 'no-cache, no-store, must-revalidate', - 'Pragma': 'no-cache', - 'Expires': '0' + "Content-Type": "text/html", + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", } return PlainTextResponse(file.read(), headers=hdr) -@app.get('/manage') +@app.get("/manage") async def manage(): """Serve simple HTML page to submit custom nodes""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - manage_path = os.path.join(root_dir, 'templates', 'manage.html') - with open(manage_path, 'r', encoding='utf-8') as file: + manage_path = os.path.join(root_dir, "templates", "manage.html") + with open(manage_path, "r", encoding="utf-8") as file: # set header to text/html and no-cache stuff hdr = { - 'Content-Type': 'text/html', - 'Cache-Control': 'no-cache, no-store, must-revalidate', - 'Pragma': 'no-cache', - 'Expires': '0' + "Content-Type": "text/html", + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", } return PlainTextResponse(file.read(), headers=hdr) -@app.get('/analytics') +@app.get("/analytics") async def analytics_page(): """Serve pipeline analytics dashboard with telemetry data""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - analytics_path = os.path.join(root_dir, 'templates', 'analytics.html') - with open(analytics_path, 'r', encoding='utf-8') as file: + analytics_path = os.path.join(root_dir, "templates", "analytics.html") + with open(analytics_path, "r", encoding="utf-8") as file: hdr = { - 'Content-Type': 'text/html', - 'Cache-Control': 'no-cache, no-store, must-revalidate', - 'Pragma': 'no-cache', - 'Expires': '0' + "Content-Type": "text/html", + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", } return PlainTextResponse(file.read(), headers=hdr) -@app.get('/stats') +@app.get("/stats") async def stats_page(): """Serve simple HTML page to view infrastructure statistics""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - stats_path = os.path.join(root_dir, 'templates', 'stats.html') - with open(stats_path, 'r', encoding='utf-8') as file: + stats_path = os.path.join(root_dir, "templates", "stats.html") + with open(stats_path, "r", encoding="utf-8") as file: # set header to text/html and no-cache stuff hdr = { - 'Content-Type': 'text/html', - 'Cache-Control': 'no-cache, no-store, must-revalidate', - 'Pragma': 'no-cache', - 'Expires': '0' + "Content-Type": "text/html", + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0", } return PlainTextResponse(file.read(), headers=hdr) -@app.get('/icons/{icon_name}') +@app.get("/icons/{icon_name}") async def icons(icon_name: str): """Serve icons from /static/icons""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) - if not re.match(r'^[A-Za-z0-9_.-]+\.png$', icon_name): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid icon name" - ) - icon_path = os.path.join(root_dir, 'templates', icon_name) + if not re.match(r"^[A-Za-z0-9_.-]+\.png$", icon_name): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid icon name") + icon_path = os.path.join(root_dir, "templates", icon_name) return FileResponse(icon_path) -@app.get('/static/css/{filename}') +@app.get("/static/css/{filename}") async def serve_css(filename: str): """Serve CSS files from api/static/css/""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) print(f"[CSS] Request for: {filename}") print(f"[CSS] root_dir: {root_dir}") # Security: only allow safe filenames - if not re.match(r'^[A-Za-z0-9_.-]+\.css$', filename): + if not re.match(r"^[A-Za-z0-9_.-]+\.css$", filename): print(f"[CSS] Invalid filename pattern: {filename}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid filename" - ) - file_path = os.path.join(root_dir, 'static', 'css', filename) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid filename") + file_path = os.path.join(root_dir, "static", "css", filename) print(f"[CSS] Looking for file at: {file_path}") print(f"[CSS] File exists: {os.path.isfile(file_path)}") if not os.path.isfile(file_path): print(f"[CSS] File not found: {file_path}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="File not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") print(f"[CSS] Serving file: {file_path}") return FileResponse( file_path, media_type="text/css", headers={ - 'Cache-Control': 'public, max-age=3600', # Cache for 1 hour - } + "Cache-Control": "public, max-age=3600", # Cache for 1 hour + }, ) -@app.get('/static/js/{filename}') +@app.get("/static/js/{filename}") async def serve_js(filename: str): """Serve JavaScript files from api/static/js/""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) root_dir = os.path.dirname(os.path.abspath(__file__)) print(f"[JS] Request for: {filename}") print(f"[JS] root_dir: {root_dir}") # Security: only allow safe filenames - if not re.match(r'^[A-Za-z0-9_.-]+\.js$', filename): + if not re.match(r"^[A-Za-z0-9_.-]+\.js$", filename): print(f"[JS] Invalid filename pattern: {filename}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid filename" - ) - file_path = os.path.join(root_dir, 'static', 'js', filename) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid filename") + file_path = os.path.join(root_dir, "static", "js", filename) print(f"[JS] Looking for file at: {file_path}") print(f"[JS] File exists: {os.path.isfile(file_path)}") if not os.path.isfile(file_path): print(f"[JS] File not found: {file_path}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="File not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") print(f"[JS] Serving file: {file_path}") return FileResponse( file_path, media_type="application/javascript", headers={ - 'Cache-Control': 'public, max-age=3600', # Cache for 1 hour - } + "Cache-Control": "public, max-age=3600", # Cache for 1 hour + }, ) -@app.get('/metrics') +@app.get("/metrics") async def get_metrics(): """Get metrics""" - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) # return metrics as plaintext in prometheus format all_metrics = metrics.all() - response = '' + response = "" for key, value in all_metrics.items(): response += f'{key}{{instance="api"}} {value}\n' return PlainTextResponse(response) -@app.get('/maintenance/purge-old-nodes') -async def purge_handler(current_user: User = Depends(get_current_superuser), - days: int = 180, - batch_size: int = 1000): +@app.get("/maintenance/purge-old-nodes") +async def purge_handler( + current_user: User = Depends(get_current_superuser), + days: int = 180, + batch_size: int = 1000, +): """Purge old nodes from the database This is a maintenance operation and should be performed only by superusers. @@ -2057,22 +2065,24 @@ async def purge_handler(current_user: User = Depends(get_current_superuser), - days: Number of days to keep nodes, default is 180. - batch_size: Number of nodes to delete in one batch, default is 1000. """ - metrics.add('http_requests_total', 1) + metrics.add("http_requests_total", 1) return await purge_old_nodes(age_days=days, batch_size=batch_size) -versioned_app = VersionedFastAPI(app, - version_format='{major}', - prefix_format='/v{major}', - enable_latest=True, - default_version=(0, 0), - on_startup=[ - pubsub_startup, - create_indexes, - initialize_beanie, - ensure_legacy_node_editors, - start_background_tasks, - ]) +versioned_app = VersionedFastAPI( + app, + version_format="{major}", + prefix_format="/v{major}", + enable_latest=True, + default_version=(0, 0), + on_startup=[ + pubsub_startup, + create_indexes, + initialize_beanie, + ensure_legacy_node_editors, + start_background_tasks, + ], +) # traceback_exception_handler is a global exception handler that will be @@ -2083,7 +2093,7 @@ def traceback_exception_handler(request: Request, exc: Exception): traceback.print_exception(type(exc), exc, exc.__traceback__) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"message": "Internal server error, check container logs"} + content={"message": "Internal server error, check container logs"}, ) @@ -2092,31 +2102,25 @@ def traceback_exception_handler(request: Request, exc: Exception): # https://github.com/DeanWay/fastapi-versioning/issues/30 for sub_app in versioned_app.routes: if hasattr(sub_app.app, "add_exception_handler"): - sub_app.app.add_exception_handler( - ValueError, value_error_exception_handler - ) - sub_app.app.add_exception_handler( - errors.InvalidId, invalid_id_exception_handler - ) + sub_app.app.add_exception_handler(ValueError, value_error_exception_handler) + sub_app.app.add_exception_handler(errors.InvalidId, invalid_id_exception_handler) # print traceback for all other exceptions - sub_app.app.add_exception_handler( - Exception, traceback_exception_handler - ) + sub_app.app.add_exception_handler(Exception, traceback_exception_handler) @versioned_app.middleware("http") async def redirect_http_requests(request: Request, call_next): """Redirect request with version prefix when no version is provided""" response = None - path = request.scope['path'] - match = re.match(r'^/(v[\d.]+)', path) + path = request.scope["path"] + match = re.match(r"^/(v[\d.]+)", path) if match: prefix = match.group(1) if prefix not in API_VERSIONS: response = PlainTextResponse( f"Unsupported API version: {prefix}", - status_code=status.HTTP_400_BAD_REQUEST + status_code=status.HTTP_400_BAD_REQUEST, ) - elif not path.startswith('/latest'): - request.scope['path'] = '/latest' + path + elif not path.startswith("/latest"): + request.scope["path"] = "/latest" + path return response or await call_next(request) diff --git a/api/maintenance.py b/api/maintenance.py index 2d8a436f..925b83c5 100644 --- a/api/maintenance.py +++ b/api/maintenance.py @@ -9,8 +9,10 @@ This module provides maintenance utilities for the KernelCI API, including functions to purge old nodes from the database and manage MongoDB connections. """ + import datetime import os + from pymongo import MongoClient DEFAULT_MONGO_SERVICE = "mongodb://db:27017" @@ -27,9 +29,7 @@ def purge_ids(db, collection, ids): ids (list): List of document IDs to delete. """ print("Purging", len(ids), "from", collection) - db[collection].delete_many({ - "_id": {"$in": ids} - }) + db[collection].delete_many({"_id": {"$in": ids}}) def connect_to_db(): @@ -58,9 +58,7 @@ async def purge_old_nodes(age_days=180, batch_size=1000): """ date_end = datetime.datetime.today() - datetime.timedelta(days=age_days) db = connect_to_db() - nodes = db["node"].find({ - "created": {"$lt": date_end} - }) + nodes = db["node"].find({"created": {"$lt": date_end}}) # We need to delete node in chunks of {batch_size} # to not block the main thread for too long deleted = 0 @@ -74,9 +72,5 @@ async def purge_old_nodes(age_days=180, batch_size=1000): if del_batch: deleted += len(del_batch) purge_ids(db, "node", del_batch) - db = { - 'response': 'ok', - 'deleted': deleted, - 'age_days': age_days - } + db = {"response": "ok", "deleted": deleted, "age_days": age_days} return db diff --git a/api/metrics.py b/api/metrics.py index 150db54a..a722e969 100644 --- a/api/metrics.py +++ b/api/metrics.py @@ -8,45 +8,46 @@ import threading -class Metrics(): - ''' +class Metrics: + """ Class to store and update various metrics - ''' + """ + def __init__(self): - ''' + """ Initialize metrics dictionary and lock - ''' + """ self.metrics = {} - self.metrics['http_requests_total'] = 0 + self.metrics["http_requests_total"] = 0 self.lock = threading.Lock() # Various internal metrics def update(self): - ''' + """ Update metrics (reserved for future use) - ''' + """ def add(self, key, value): - ''' + """ Add a value to a metric - ''' + """ with self.lock: if key not in self.metrics: self.metrics[key] = 0 self.metrics[key] += value def get(self, key): - ''' + """ Get the value of a metric - ''' + """ self.update() with self.lock: return self.metrics.get(key, 0) def all(self): - ''' + """ Get all the metrics - ''' + """ self.update() with self.lock: return self.metrics diff --git a/api/models.py b/api/models.py index 9e8ad8c8..179a1b9b 100644 --- a/api/models.py +++ b/api/models.py @@ -12,7 +12,18 @@ """Server-side model definitions""" from datetime import datetime -from typing import Optional, TypeVar, List +from typing import List, Optional, TypeVar + +from beanie import ( + Document, + Indexed, + PydanticObjectId, +) +from fastapi import Query +from fastapi_pagination import LimitOffsetPage, LimitOffsetParams +from fastapi_users import schemas +from fastapi_users.db import BeanieBaseUser +from kernelci.api.models_base import DatabaseModel, ModelId from pydantic import ( BaseModel, EmailStr, @@ -20,45 +31,25 @@ field_validator, ) from typing_extensions import Annotated -from fastapi import Query -from fastapi_pagination import LimitOffsetPage, LimitOffsetParams -from fastapi_users.db import BeanieBaseUser -from fastapi_users import schemas -from beanie import ( - Indexed, - Document, - PydanticObjectId, -) -from kernelci.api.models_base import DatabaseModel, ModelId - # PubSub model definitions + class Subscription(BaseModel): """Pub/Sub subscription object model""" - id: int = Field( - description='Subscription ID' - ) - channel: str = Field( - description='Subscription channel name' - ) - user: str = Field( - description=("Username of the user that created the " - "subscription (owner)") - ) - promiscuous: bool = Field( - description='Listen to all users messages', - default=False) + + id: int = Field(description="Subscription ID") + channel: str = Field(description="Subscription channel name") + user: str = Field(description=("Username of the user that created the subscription (owner)")) + promiscuous: bool = Field(description="Listen to all users messages", default=False) class SubscriptionStats(Subscription): """Pub/Sub subscription statistics object model""" - created: datetime = Field( - description='Timestamp of connection creation' - ) + + created: datetime = Field(description="Timestamp of connection creation") last_poll: Optional[datetime] = Field( - default=None, - description='Timestamp when connection last polled for data' + default=None, description="Timestamp when connection last polled for data" ) @@ -66,71 +57,63 @@ class SubscriptionStats(Subscription): # Note: Event storage uses EventHistory model from kernelci-core # (stored in 'eventhistory' collection with sequence_id, channel, owner fields) + class SubscriberState(BaseModel): """Tracks subscriber position for durable event delivery Only created when subscriber_id is provided during subscription. Enables catch-up on missed events after reconnection. """ - subscriber_id: str = Field( - description='Unique subscriber identifier (client-provided)' - ) - channel: str = Field( - description='Subscribed channel name' - ) - user: str = Field( - description='Username of subscriber (for ownership validation)' - ) - promiscuous: bool = Field( - default=False, - description='If true, receive all messages regardless of owner' - ) + + subscriber_id: str = Field(description="Unique subscriber identifier (client-provided)") + channel: str = Field(description="Subscribed channel name") + user: str = Field(description="Username of subscriber (for ownership validation)") + promiscuous: bool = Field(default=False, description="If true, receive all messages regardless of owner") last_event_id: int = Field( - default=0, - description='Last acknowledged event ID (implicit ACK on next poll)' + default=0, description="Last acknowledged event ID (implicit ACK on next poll)" ) created_at: datetime = Field( - default_factory=datetime.utcnow, - description='Subscription creation timestamp' + default_factory=datetime.utcnow, description="Subscription creation timestamp" ) last_poll: Optional[datetime] = Field( - default=None, - description='Last poll timestamp (used for stale cleanup)' + default=None, description="Last poll timestamp (used for stale cleanup)" ) # User model definitions + class UserGroup(DatabaseModel): """API model to group associated user accounts""" - name: str = Field( - description="User group name" - ) + + name: str = Field(description="User group name") @classmethod def get_indexes(cls): """Get an index to bind unique constraint to group name""" return [ - cls.Index('name', {'unique': True}), + cls.Index("name", {"unique": True}), ] class UserGroupCreateRequest(BaseModel): """Create user group request schema for API router""" + name: str = Field(description="User group name") -class User(BeanieBaseUser, Document, # pylint: disable=too-many-ancestors - DatabaseModel): +class User( + BeanieBaseUser, + Document, # pylint: disable=too-many-ancestors + DatabaseModel, +): """API User model""" + username: Annotated[str, Indexed(unique=True)] - groups: List[UserGroup] = Field( - default=[], - description="A list of groups that the user belongs to" - ) + groups: List[UserGroup] = Field(default=[], description="A list of groups that the user belongs to") - @field_validator('groups') - def validate_groups(cls, groups): # pylint: disable=no-self-argument + @field_validator("groups") + def validate_groups(cls, groups): # pylint: disable=no-self-argument """Unique group constraint""" unique_names = {group.name for group in groups} if len(unique_names) != len(groups): @@ -139,6 +122,7 @@ def validate_groups(cls, groups): # pylint: disable=no-self-argument class Settings(BeanieBaseUser.Settings): """Configurations""" + # MongoDB collection name for model name = "user" @@ -146,17 +130,18 @@ class Settings(BeanieBaseUser.Settings): def get_indexes(cls): """Get indices""" return [ - cls.Index('email', {'unique': True}), + cls.Index("email", {"unique": True}), ] class UserRead(schemas.BaseUser[PydanticObjectId], ModelId): """Schema for reading a user""" + username: Annotated[str, Indexed(unique=True)] groups: List[UserGroup] = Field(default=[]) - @field_validator('groups') - def validate_groups(cls, groups): # pylint: disable=no-self-argument + @field_validator("groups") + def validate_groups(cls, groups): # pylint: disable=no-self-argument """Unique group constraint""" unique_names = {group.name for group in groups} if len(unique_names) != len(groups): @@ -166,11 +151,12 @@ def validate_groups(cls, groups): # pylint: disable=no-self-argument class UserCreateRequest(schemas.BaseUserCreate): """Create user request schema for API router""" + username: Annotated[str, Indexed(unique=True)] groups: List[str] = Field(default=[]) - @field_validator('groups') - def validate_groups(cls, groups): # pylint: disable=no-self-argument + @field_validator("groups") + def validate_groups(cls, groups): # pylint: disable=no-self-argument """Unique group constraint""" unique_names = set(groups) if len(unique_names) != len(groups): @@ -180,11 +166,12 @@ def validate_groups(cls, groups): # pylint: disable=no-self-argument class UserCreate(schemas.BaseUserCreate): """Schema used for sending create user request to 'fastapi-users' router""" + username: Annotated[str, Indexed(unique=True)] groups: List[UserGroup] = Field(default=[]) - @field_validator('groups') - def validate_groups(cls, groups): # pylint: disable=no-self-argument + @field_validator("groups") + def validate_groups(cls, groups): # pylint: disable=no-self-argument """Unique group constraint""" unique_names = {group.name for group in groups} if len(unique_names) != len(groups): @@ -194,12 +181,12 @@ def validate_groups(cls, groups): # pylint: disable=no-self-argument class UserUpdateRequest(schemas.BaseUserUpdate): """Update user request schema for API router""" - username: Annotated[Optional[str], Indexed(unique=True), - Field(default=None)] + + username: Annotated[Optional[str], Indexed(unique=True), Field(default=None)] groups: List[str] = Field(default=[]) - @field_validator('groups') - def validate_groups(cls, groups): # pylint: disable=no-self-argument + @field_validator("groups") + def validate_groups(cls, groups): # pylint: disable=no-self-argument """Unique group constraint""" unique_names = set(groups) if len(unique_names) != len(groups): @@ -209,12 +196,12 @@ def validate_groups(cls, groups): # pylint: disable=no-self-argument class UserUpdate(schemas.BaseUserUpdate): """Schema used for sending update user request to 'fastapi-users' router""" - username: Annotated[Optional[str], Indexed(unique=True), - Field(default=None)] + + username: Annotated[Optional[str], Indexed(unique=True), Field(default=None)] groups: List[UserGroup] = Field(default=[]) - @field_validator('groups') - def validate_groups(cls, groups): # pylint: disable=no-self-argument + @field_validator("groups") + def validate_groups(cls, groups): # pylint: disable=no-self-argument """Unique group constraint""" unique_names = {group.name for group in groups} if len(unique_names) != len(groups): @@ -224,6 +211,7 @@ def validate_groups(cls, groups): # pylint: disable=no-self-argument # Invite-only user onboarding models + class UserInviteRequest(BaseModel): """Admin invite request schema for API router""" @@ -235,8 +223,8 @@ class UserInviteRequest(BaseModel): return_token: bool = False resend_if_exists: bool = False - @field_validator('groups') - def validate_groups(cls, groups): # pylint: disable=no-self-argument + @field_validator("groups") + def validate_groups(cls, groups): # pylint: disable=no-self-argument """Unique group constraint""" unique_names = set(groups) if len(unique_names) != len(groups): @@ -271,6 +259,7 @@ class InviteUrlResponse(BaseModel): # Pagination models + class CustomLimitOffsetParams(LimitOffsetParams): """Model to set custom constraint on limit diff --git a/api/pubsub.py b/api/pubsub.py index 83cdfeb7..23cca2a4 100644 --- a/api/pubsub.py +++ b/api/pubsub.py @@ -5,15 +5,16 @@ """Pub/Sub implementation""" -import logging import asyncio - import json +import logging from datetime import datetime, timedelta -from redis import asyncio as aioredis + from cloudevents.http import CloudEvent, to_json -from .models import Subscription, SubscriptionStats +from redis import asyncio as aioredis + from .config import PubSubSettings +from .models import Subscription, SubscriptionStats logger = logging.getLogger(__name__) @@ -27,7 +28,7 @@ class PubSub: available from the `docker-compose` setup. """ - ID_KEY = 'kernelci-api-pubsub-id' + ID_KEY = "kernelci-api-pubsub-id" @classmethod async def create(cls, *args, **kwargs): @@ -42,10 +43,7 @@ def __init__(self, host=None, db_number=None): host = self._settings.redis_host if db_number is None: db_number = self._settings.redis_db_number - self._redis = aioredis.from_url( - 'redis://' + host + '/' + str(db_number), - health_check_interval=30 - ) + self._redis = aioredis.from_url("redis://" + host + "/" + str(db_number), health_check_interval=30) # self._subscriptions is a dict that matches a subscription id # (key) with a Subscription object ('sub') and a redis # PubSub object ('redis_sub'). For instance: @@ -65,8 +63,7 @@ def _start_keep_alive_timer(self): return if not self._keep_alive_timer or self._keep_alive_timer.done(): loop = asyncio.get_running_loop() - self._keep_alive_timer = asyncio.run_coroutine_threadsafe( - self._keep_alive(), loop) + self._keep_alive_timer = asyncio.run_coroutine_threadsafe(self._keep_alive(), loop) async def _keep_alive(self): while True: @@ -81,7 +78,7 @@ async def _keep_alive(self): def _update_channels(self): self._channels = set() for sub in self._subscriptions.values(): - for channel in sub['redis_sub'].channels.keys(): + for channel in sub["redis_sub"].channels.keys(): self._channels.add(channel.decode()) async def subscribe(self, channel, user, options=None): @@ -93,13 +90,15 @@ async def subscribe(self, channel, user, options=None): async with self._lock: redis_sub = self._redis.pubsub() sub = Subscription(id=sub_id, channel=channel, user=user) - if options and options.get('promiscuous'): + if options and options.get("promiscuous"): sub.promiscuous = True await redis_sub.subscribe(channel) - self._subscriptions[sub_id] = {'redis_sub': redis_sub, - 'sub': sub, - 'created': datetime.utcnow(), - 'last_poll': None} + self._subscriptions[sub_id] = { + "redis_sub": redis_sub, + "sub": sub, + "created": datetime.utcnow(), + "last_poll": None, + } self._update_channels() self._start_keep_alive_timer() return sub @@ -115,14 +114,13 @@ async def unsubscribe(self, sub_id, user=None): # Only allow a user to unsubscribe its own # subscriptions. One exception: let an anonymous (internal) # call to this function to unsubscribe any subscription - if user and user != sub['sub'].user: - raise RuntimeError(f"Subscription {sub_id} " - f"not owned by {user}") + if user and user != sub["sub"].user: + raise RuntimeError(f"Subscription {sub_id} not owned by {user}") self._subscriptions.pop(sub_id) self._update_channels() - await sub['redis_sub'].unsubscribe() + await sub["redis_sub"].unsubscribe() # shut down pubsub connection - await sub['redis_sub'].close() + await sub["redis_sub"].close() async def listen(self, sub_id, user=None): """Listen for Pub/Sub messages @@ -136,23 +134,20 @@ async def listen(self, sub_id, user=None): # Only allow a user to listen to its own subscriptions. One # exception: let an anonymous (internal) call to this function # to listen to any subscription - if user and user != sub['sub'].user: - raise RuntimeError(f"Subscription {sub_id} " - f"not owned by {user}") + if user and user != sub["sub"].user: + raise RuntimeError(f"Subscription {sub_id} not owned by {user}") while True: - self._subscriptions[sub_id]['last_poll'] = datetime.utcnow() + self._subscriptions[sub_id]["last_poll"] = datetime.utcnow() msg = None try: - msg = await sub['redis_sub'].get_message( - ignore_subscribe_messages=True, timeout=1.0 - ) + msg = await sub["redis_sub"].get_message(ignore_subscribe_messages=True, timeout=1.0) except aioredis.ConnectionError: async with self._lock: - channel = self._subscriptions[sub_id]['sub'].channel + channel = self._subscriptions[sub_id]["sub"].channel new_redis_sub = self._redis.pubsub() await new_redis_sub.subscribe(channel) - self._subscriptions[sub_id]['redis_sub'] = new_redis_sub - sub['redis_sub'] = new_redis_sub + self._subscriptions[sub_id]["redis_sub"] = new_redis_sub + sub["redis_sub"] = new_redis_sub continue except aioredis.RedisError as exc: # log the error and continue @@ -161,14 +156,14 @@ async def listen(self, sub_id, user=None): if msg is None: continue - msg_data = json.loads(msg['data']) + msg_data = json.loads(msg["data"]) # If the subscription is promiscuous, return the message # without checking the owner - if sub['sub'].promiscuous: + if sub["sub"].promiscuous: return msg # If the subscription is not promiscuous, check the owner of the # message - if 'owner' in msg_data and msg_data['owner'] != sub['sub'].user: + if "owner" in msg_data and msg_data["owner"] != sub["sub"].user: continue return msg @@ -194,7 +189,7 @@ async def pop(self, list_name): """ while True: msg = await self._redis.blpop(list_name, timeout=1.0) - data = json.loads(msg[1].decode('utf-8')) if msg else None + data = json.loads(msg[1].decode("utf-8")) if msg else None if data is not None: return data @@ -209,10 +204,10 @@ async def publish_cloudevent(self, channel, data, attributes=None): """ if not attributes: attributes = {} - if not attributes.get('type'): - attributes['type'] = "api.kernelci.org" - if not attributes.get('source'): - attributes['source'] = self._settings.cloud_events_source + if not attributes.get("type"): + attributes["type"] = "api.kernelci.org" + if not attributes.get("source"): + attributes["source"] = self._settings.cloud_events_source event = CloudEvent(attributes=attributes, data=data) await self.publish(channel, to_json(event)) @@ -236,13 +231,13 @@ async def subscription_stats(self): """Get existing subscription details""" subscriptions = [] for _, subscription in self._subscriptions.items(): - sub = subscription['sub'] + sub = subscription["sub"] stats = SubscriptionStats( id=sub.id, channel=sub.channel, user=sub.user, - created=subscription['created'], - last_poll=subscription['last_poll'] + created=subscription["created"], + last_poll=subscription["last_poll"], ) subscriptions.append(stats) return subscriptions @@ -254,7 +249,7 @@ async def cleanup_stale_subscriptions(self, max_age_minutes=30): async with self._lock: for sub_id, sub_data in self._subscriptions.items(): - last_poll = sub_data.get('last_poll') + last_poll = sub_data.get("last_poll") if last_poll and last_poll < cutoff: stale_ids.append(sub_id) diff --git a/api/pubsub_mongo.py b/api/pubsub_mongo.py index 0188f72b..d842203b 100644 --- a/api/pubsub_mongo.py +++ b/api/pubsub_mongo.py @@ -17,20 +17,20 @@ Events are stored in MongoDB with TTL for automatic cleanup. """ -import logging import asyncio import json +import logging import os from datetime import datetime, timedelta -from typing import Optional, Dict, Any, List +from typing import Any, Dict, List, Optional -from redis import asyncio as aioredis from cloudevents.http import CloudEvent, to_json -from pymongo import ASCENDING, WriteConcern from motor import motor_asyncio +from pymongo import ASCENDING, WriteConcern +from redis import asyncio as aioredis -from .models import Subscription, SubscriptionStats, SubscriberState from .config import PubSubSettings +from .models import SubscriberState, Subscription, SubscriptionStats logger = logging.getLogger(__name__) @@ -50,13 +50,13 @@ class PubSub: # pylint: disable=too-many-instance-attributes independently. """ - ID_KEY = 'kernelci-api-pubsub-id' - EVENT_SEQ_KEY = 'kernelci-api-event-seq' + ID_KEY = "kernelci-api-pubsub-id" + EVENT_SEQ_KEY = "kernelci-api-event-seq" # Collection names # Use existing eventhistory collection for unified event storage - EVENT_HISTORY_COLLECTION = 'eventhistory' - SUBSCRIBER_STATE_COLLECTION = 'subscriber_state' + EVENT_HISTORY_COLLECTION = "eventhistory" + SUBSCRIBER_STATE_COLLECTION = "subscriber_state" # Default settings DEFAULT_MAX_CATCHUP_EVENTS = 1000 @@ -68,24 +68,19 @@ async def create(cls, *args, mongo_client=None, **kwargs): await pubsub._init() return pubsub - def __init__(self, mongo_client=None, host=None, db_number=None, - mongo_db_name='kernelci'): + def __init__(self, mongo_client=None, host=None, db_number=None, mongo_db_name="kernelci"): self._settings = PubSubSettings() if host is None: host = self._settings.redis_host if db_number is None: db_number = self._settings.redis_db_number - self._redis = aioredis.from_url( - 'redis://' + host + '/' + str(db_number), - health_check_interval=30 - ) + self._redis = aioredis.from_url("redis://" + host + "/" + str(db_number), health_check_interval=30) # MongoDB setup if mongo_client is None: - mongo_service = os.getenv('MONGO_SERVICE') or 'mongodb://db:27017' - self._mongo_client = motor_asyncio.AsyncIOMotorClient( - mongo_service) + mongo_service = os.getenv("MONGO_SERVICE") or "mongodb://db:27017" + self._mongo_client = motor_asyncio.AsyncIOMotorClient(mongo_service) else: self._mongo_client = mongo_client self._mongo_db = self._mongo_client[mongo_db_name] @@ -93,8 +88,7 @@ def __init__(self, mongo_client=None, host=None, db_number=None, # In-memory subscription tracking (for fire-and-forget mode) # {sub_id: {'sub': Subscription, 'redis_sub': PubSub, # 'subscriber_id': str|None, ...}} - self._subscriptions: Dict[int, Dict[str, Any]] = \ - {} + self._subscriptions: Dict[int, Dict[str, Any]] = {} self._channels = set() self._lock = asyncio.Lock() self._keep_alive_timer = None @@ -121,8 +115,7 @@ async def _migrate_eventhistory_if_needed(self): # Check if collection exists collections = await self._mongo_db.list_collection_names() if self.EVENT_HISTORY_COLLECTION not in collections: - logger.info( - "eventhistory collection does not exist, will be created") + logger.info("eventhistory collection does not exist, will be created") return # Check existing indexes @@ -134,19 +127,16 @@ async def _migrate_eventhistory_if_needed(self): for _, index_info in indexes.items(): # Check for old 24h TTL - if 'expireAfterSeconds' in index_info: - ttl = index_info['expireAfterSeconds'] + if "expireAfterSeconds" in index_info: + ttl = index_info["expireAfterSeconds"] if ttl == 86400: old_format_detected = True - logger.warning( - "Detected old eventhistory format (24h TTL). " - "Migration required." - ) + logger.warning("Detected old eventhistory format (24h TTL). Migration required.") # Check for new sequence_id index - if 'key' in index_info: - keys = [k[0] for k in index_info['key']] - if 'sequence_id' in keys: + if "key" in index_info: + keys = [k[0] for k in index_info["key"]] + if "sequence_id" in keys: has_sequence_index = True if old_format_detected and not has_sequence_index: @@ -164,16 +154,13 @@ async def _migrate_eventhistory(self, col): # Drop all indexes except _id indexes = await col.index_information() for index_name in indexes: - if index_name != '_id_': + if index_name != "_id_": logger.info("Dropping index: %s", index_name) await col.drop_index(index_name) # Drop all documents (they lack required fields) result = await col.delete_many({}) - logger.info( - "Deleted %d old eventhistory documents", - result.deleted_count - ) + logger.info("Deleted %d old eventhistory documents", result.deleted_count) logger.info("eventhistory migration complete") @@ -196,22 +183,14 @@ async def _ensure_indexes(self): # Compound index for filtered event queries (kind + timestamp) await event_col.create_index( - [('data.kind', ASCENDING), ('timestamp', ASCENDING)], - name='kind_timestamp' + [("data.kind", ASCENDING), ("timestamp", ASCENDING)], name="kind_timestamp" ) # Subscriber state indexes # Unique index on subscriber_id - await sub_col.create_index( - 'subscriber_id', - unique=True, - name='unique_subscriber_id' - ) + await sub_col.create_index("subscriber_id", unique=True, name="unique_subscriber_id") # Index for stale cleanup - await sub_col.create_index( - 'last_poll', - name='last_poll' - ) + await sub_col.create_index("last_poll", name="last_poll") def _start_keep_alive_timer(self): """Start keep-alive timer for Redis pub/sub connections""" @@ -219,8 +198,7 @@ def _start_keep_alive_timer(self): return if not self._keep_alive_timer or self._keep_alive_timer.done(): loop = asyncio.get_running_loop() - self._keep_alive_timer = asyncio.run_coroutine_threadsafe( - self._keep_alive(), loop) + self._keep_alive_timer = asyncio.run_coroutine_threadsafe(self._keep_alive(), loop) async def _keep_alive(self): """Send periodic BEEP to keep connections alive""" @@ -237,8 +215,8 @@ async def _keep_alive(self): async def _publish_keepalive(self, channel: str, data: str): """Publish keep-alive message (Redis only, no MongoDB storage)""" attributes = { - 'type': "api.kernelci.org", - 'source': self._settings.cloud_events_source, + "type": "api.kernelci.org", + "source": self._settings.cloud_events_source, } event = CloudEvent(attributes=attributes, data=data) await self._redis.publish(channel, to_json(event)) @@ -247,16 +225,15 @@ def _update_channels(self): """Update tracked channels from active subscriptions""" self._channels = set() for sub in self._subscriptions.values(): - if sub.get('redis_sub'): - for channel in sub['redis_sub'].channels.keys(): + if sub.get("redis_sub"): + for channel in sub["redis_sub"].channels.keys(): self._channels.add(channel.decode()) async def _get_next_event_id(self) -> int: """Get next sequential event ID from Redis""" return await self._redis.incr(self.EVENT_SEQ_KEY) - async def _store_event(self, channel: str, data: Dict[str, Any], - owner: Optional[str] = None) -> int: + async def _store_event(self, channel: str, data: Dict[str, Any], owner: Optional[str] = None) -> int: """Store event in eventhistory collection and return sequence ID Uses the same collection as /events API endpoint (EventHistory model). @@ -264,51 +241,42 @@ async def _store_event(self, channel: str, data: Dict[str, Any], """ sequence_id = await self._get_next_event_id() event_doc = { - 'timestamp': datetime.utcnow(), - 'sequence_id': sequence_id, - 'channel': channel, - 'owner': owner, - 'data': data, + "timestamp": datetime.utcnow(), + "sequence_id": sequence_id, + "channel": channel, + "owner": owner, + "data": data, } col = self._mongo_db[self.EVENT_HISTORY_COLLECTION] # Use w=1 for acknowledged writes (durability) - await col.with_options( - write_concern=WriteConcern(w=1) - ).insert_one(event_doc) + await col.with_options(write_concern=WriteConcern(w=1)).insert_one(event_doc) return sequence_id - async def _get_subscriber_state( - self, subscriber_id: str) -> Optional[Dict]: + async def _get_subscriber_state(self, subscriber_id: str) -> Optional[Dict]: """Get subscriber state from MongoDB""" col = self._mongo_db[self.SUBSCRIBER_STATE_COLLECTION] - return await col.find_one({'subscriber_id': subscriber_id}) + return await col.find_one({"subscriber_id": subscriber_id}) - async def _update_subscriber_state(self, subscriber_id: str, - last_event_id: int, - last_poll: datetime = None): + async def _update_subscriber_state( + self, subscriber_id: str, last_event_id: int, last_poll: datetime = None + ): """Update subscriber's last_event_id and last_poll""" col = self._mongo_db[self.SUBSCRIBER_STATE_COLLECTION] - update = {'last_event_id': last_event_id} + update = {"last_event_id": last_event_id} if last_poll: - update['last_poll'] = last_poll - await col.update_one( - {'subscriber_id': subscriber_id}, - {'$set': update} - ) + update["last_poll"] = last_poll + await col.update_one({"subscriber_id": subscriber_id}, {"$set": update}) @staticmethod def _decode_redis_message(msg: Dict) -> Dict: """Decode Redis message bytes to strings for JSON serialization""" return { - 'type': msg.get('type'), - 'pattern': (msg.get('pattern').decode('utf-8') - if msg.get('pattern') else None), - 'channel': (msg['channel'].decode('utf-8') - if isinstance(msg['channel'], bytes) - else msg['channel']), - 'data': (msg['data'].decode('utf-8') - if isinstance(msg['data'], bytes) - else msg['data']), + "type": msg.get("type"), + "pattern": (msg.get("pattern").decode("utf-8") if msg.get("pattern") else None), + "channel": ( + msg["channel"].decode("utf-8") if isinstance(msg["channel"], bytes) else msg["channel"] + ), + "data": (msg["data"].decode("utf-8") if isinstance(msg["data"], bytes) else msg["data"]), } def _eventhistory_to_cloudevent(self, event: Dict) -> str: @@ -318,20 +286,25 @@ def _eventhistory_to_cloudevent(self, event: Dict) -> str: for consistent delivery format between catch-up and real-time events. """ attributes = { - 'type': 'api.kernelci.org', - 'source': self._settings.cloud_events_source, + "type": "api.kernelci.org", + "source": self._settings.cloud_events_source, } - if event.get('owner'): - attributes['owner'] = event['owner'] + if event.get("owner"): + attributes["owner"] = event["owner"] - ce = CloudEvent(attributes=attributes, data=event.get('data', {})) - return to_json(ce).decode('utf-8') + ce = CloudEvent(attributes=attributes, data=event.get("data", {})) + return to_json(ce).decode("utf-8") # pylint: disable=too-many-arguments - async def _get_missed_events(self, channel: str, after_seq_id: int, *, - owner_filter: Optional[str] = None, - promiscuous: bool = False, - limit: int = None) -> List[Dict]: + async def _get_missed_events( + self, + channel: str, + after_seq_id: int, + *, + owner_filter: Optional[str] = None, + promiscuous: bool = False, + limit: int = None, + ) -> List[Dict]: """Get events after a given sequence ID for catch-up Queries the eventhistory collection used by /events API. @@ -341,24 +314,20 @@ async def _get_missed_events(self, channel: str, after_seq_id: int, *, limit = self.DEFAULT_MAX_CATCHUP_EVENTS col = self._mongo_db[self.EVENT_HISTORY_COLLECTION] - query = { - 'channel': channel, - 'sequence_id': {'$gt': after_seq_id} - } + query = {"channel": channel, "sequence_id": {"$gt": after_seq_id}} # If not promiscuous, filter by owner if not promiscuous and owner_filter: - query['$or'] = [ - {'owner': owner_filter}, - {'owner': None}, - {'owner': {'$exists': False}} + query["$or"] = [ + {"owner": owner_filter}, + {"owner": None}, + {"owner": {"$exists": False}}, ] - cursor = col.find(query).sort('sequence_id', ASCENDING).limit(limit) + cursor = col.find(query).sort("sequence_id", ASCENDING).limit(limit) return await cursor.to_list(length=limit) - async def subscribe(self, channel: str, user: str, - options: Optional[Dict] = None) -> Subscription: + async def subscribe(self, channel: str, user: str, options: Optional[Dict] = None) -> Subscription: """Subscribe to a Pub/Sub channel Args: @@ -372,27 +341,22 @@ async def subscribe(self, channel: str, user: str, Subscription object with id, channel, user, promiscuous fields """ sub_id = await self._redis.incr(self.ID_KEY) - subscriber_id = options.get('subscriber_id') if options else None - promiscuous = options.get('promiscuous', False) if options else False + subscriber_id = options.get("subscriber_id") if options else None + promiscuous = options.get("promiscuous", False) if options else False async with self._lock: redis_sub = self._redis.pubsub() - sub = Subscription( - id=sub_id, - channel=channel, - user=user, - promiscuous=promiscuous - ) + sub = Subscription(id=sub_id, channel=channel, user=user, promiscuous=promiscuous) await redis_sub.subscribe(channel) self._subscriptions[sub_id] = { - 'redis_sub': redis_sub, - 'sub': sub, - 'subscriber_id': subscriber_id, - 'created': datetime.utcnow(), - 'last_poll': None, - 'pending_catchup': [], # Events to deliver before real-time - 'catchup_done': not subscriber_id, + "redis_sub": redis_sub, + "sub": sub, + "subscriber_id": subscriber_id, + "created": datetime.utcnow(), + "last_poll": None, + "pending_catchup": [], # Events to deliver before real-time + "catchup_done": not subscriber_id, } self._update_channels() self._start_keep_alive_timer() @@ -411,32 +375,37 @@ async def subscribe(self, channel: str, user: str, # pylint: disable=too-many-arguments async def _setup_durable_subscription( - self, sub_id: int, subscriber_id: str, - *, channel: str, user: str, promiscuous: bool): + self, + sub_id: int, + subscriber_id: str, + *, + channel: str, + user: str, + promiscuous: bool, + ): """Set up or restore durable subscription state""" col = self._mongo_db[self.SUBSCRIBER_STATE_COLLECTION] - existing = await col.find_one({'subscriber_id': subscriber_id}) + existing = await col.find_one({"subscriber_id": subscriber_id}) if existing: # Existing subscriber - verify ownership - if existing['user'] != user: - raise RuntimeError( - f"Subscriber {subscriber_id} owned by different user" - ) + if existing["user"] != user: + raise RuntimeError(f"Subscriber {subscriber_id} owned by different user") # Load pending catch-up events missed = await self._get_missed_events( - channel=existing['channel'], - after_seq_id=existing['last_event_id'], + channel=existing["channel"], + after_seq_id=existing["last_event_id"], owner_filter=user, - promiscuous=promiscuous + promiscuous=promiscuous, ) async with self._lock: - self._subscriptions[sub_id]['pending_catchup'] = missed + self._subscriptions[sub_id]["pending_catchup"] = missed sub = self._subscriptions[sub_id] - sub['last_acked_id'] = existing['last_event_id'] + sub["last_acked_id"] = existing["last_event_id"] logger.info( "Subscriber %s reconnected, %d missed events", - subscriber_id, len(missed) + subscriber_id, + len(missed), ) else: # New subscriber - get current event ID as starting point @@ -447,14 +416,15 @@ async def _setup_durable_subscription( user=user, promiscuous=promiscuous, last_event_id=current_id, - created_at=datetime.utcnow() + created_at=datetime.utcnow(), ) await col.insert_one(state.model_dump()) async with self._lock: - self._subscriptions[sub_id]['last_acked_id'] = current_id + self._subscriptions[sub_id]["last_acked_id"] = current_id logger.info( "New durable subscriber %s starting at event %d", - subscriber_id, current_id + subscriber_id, + current_id, ) async def unsubscribe(self, sub_id: int, user: Optional[str] = None): @@ -469,82 +439,88 @@ async def unsubscribe(self, sub_id: int, user: Optional[str] = None): raise KeyError(f"Subscription {sub_id} not found") # Only allow user to unsubscribe their own subscriptions - if user and user != sub['sub'].user: - raise RuntimeError( - f"Subscription {sub_id} not owned by {user}" - ) + if user and user != sub["sub"].user: + raise RuntimeError(f"Subscription {sub_id} not owned by {user}") self._subscriptions.pop(sub_id) self._update_channels() - await sub['redis_sub'].unsubscribe() - await sub['redis_sub'].close() - - async def listen(self, sub_id: int, - user: Optional[str] = None) -> Optional[Dict]: - # pylint: disable=too-many-branches - """Listen for Pub/Sub messages + await sub["redis_sub"].unsubscribe() + await sub["redis_sub"].close() - For durable subscriptions (with subscriber_id): - 1. First delivers any missed events from catch-up queue - 2. Then waits for real-time events - 3. Implicitly ACKs previous event when called again - - Returns message dict or None on error. - """ + async def _get_listen_subscription(self, sub_id: int, user: Optional[str] = None): async with self._lock: sub_data = self._subscriptions.get(sub_id) if not sub_data: raise KeyError(f"Subscription {sub_id} not found") - sub = sub_data['sub'] - subscriber_id = sub_data.get('subscriber_id') + sub = sub_data["sub"] - # Ownership check if user and user != sub.user: raise RuntimeError(f"Subscription {sub_id} not owned by {user}") - # For durable subscriptions, handle implicit ACK - if subscriber_id and sub_data.get('last_delivered_id'): + return sub, sub_data + + async def _update_listen_subscription_state(self, sub: Subscription, sub_data: dict): + subscriber_id = sub_data.get("subscriber_id") + if subscriber_id and sub_data.get("last_delivered_id"): await self._update_subscriber_state( - subscriber_id, - sub_data['last_delivered_id'], - datetime.utcnow() + subscriber_id, sub_data["last_delivered_id"], datetime.utcnow() ) - sub_data['last_acked_id'] = sub_data['last_delivered_id'] - - # Check for pending catch-up events first - if sub_data.get('pending_catchup'): - event = sub_data['pending_catchup'].pop(0) - sub_data['last_delivered_id'] = event['sequence_id'] - self._subscriptions[sub_id]['last_poll'] = datetime.utcnow() - # Reconstruct CloudEvent format from eventhistory data - cloudevent_data = self._eventhistory_to_cloudevent(event) - return { - 'channel': sub.channel, - 'data': cloudevent_data, - 'pattern': None, - 'type': 'message' - } + sub_data["last_acked_id"] = sub_data["last_delivered_id"] + + return subscriber_id - # Mark catch-up as complete - if not sub_data.get('catchup_done'): - sub_data['catchup_done'] = True + def _consume_pending_catchup(self, sub_id: int, sub: Subscription, sub_data: dict) -> Optional[Dict]: + if not sub_data.get("pending_catchup"): + return None + + event = sub_data["pending_catchup"].pop(0) + sub_data["last_delivered_id"] = event["sequence_id"] + self._subscriptions[sub_id]["last_poll"] = datetime.utcnow() + + cloudevent_data = self._eventhistory_to_cloudevent(event) + return { + "channel": sub.channel, + "data": cloudevent_data, + "pattern": None, + "type": "message", + } - # Real-time listening via Redis + async def _rebuild_redis_subscription(self, sub_id: int, sub: Subscription, sub_data: dict): + async with self._lock: + channel = sub.channel + new_redis_sub = self._redis.pubsub() + await new_redis_sub.subscribe(channel) + self._subscriptions[sub_id]["redis_sub"] = new_redis_sub + sub_data["redis_sub"] = new_redis_sub + + def _maybe_update_delivery_offset(self, subscriber_id: Optional[str], sub_data: dict, msg_data: Any): + if not subscriber_id or not isinstance(msg_data, dict): + return + + sequence_id = msg_data.get("_sequence_id") + if sequence_id: + sub_data["last_delivered_id"] = sequence_id + + def _should_deliver_to_user(self, sub: Subscription, msg_data: dict) -> bool: + if sub.promiscuous: + return True + + return not ("owner" in msg_data and msg_data["owner"] != sub.user) + + async def _listen_for_message( + self, + sub_id: int, + sub: Subscription, + subscriber_id: Optional[str], + sub_data: dict, + ): while True: - self._subscriptions[sub_id]['last_poll'] = datetime.utcnow() - msg = None + self._subscriptions[sub_id]["last_poll"] = datetime.utcnow() try: - msg = await sub_data['redis_sub'].get_message( - ignore_subscribe_messages=True, timeout=1.0 - ) + msg = await sub_data["redis_sub"].get_message(ignore_subscribe_messages=True, timeout=1.0) except aioredis.ConnectionError: - async with self._lock: - channel = sub.channel - new_redis_sub = self._redis.pubsub() - await new_redis_sub.subscribe(channel) - self._subscriptions[sub_id]['redis_sub'] = new_redis_sub - sub_data['redis_sub'] = new_redis_sub + await self._rebuild_redis_subscription(sub_id, sub, sub_data) continue except aioredis.RedisError as exc: logger.error("Redis error: %s", exc) @@ -553,27 +529,39 @@ async def listen(self, sub_id: int, if msg is None: continue - msg_data = json.loads(msg['data']) - - # For durable subscriptions, track the sequence ID - if subscriber_id and isinstance(msg_data, dict): - sequence_id = msg_data.get('_sequence_id') - if sequence_id: - sub_data['last_delivered_id'] = sequence_id - - # Filter by owner if not promiscuous - if sub.promiscuous: - return self._decode_redis_message(msg) - if 'owner' in msg_data and msg_data['owner'] != sub.user: + msg_data = json.loads(msg["data"]) + self._maybe_update_delivery_offset(subscriber_id, sub_data, msg_data) + if not self._should_deliver_to_user(sub, msg_data): continue + return self._decode_redis_message(msg) + async def listen(self, sub_id: int, user: Optional[str] = None) -> Optional[Dict]: + """Listen for Pub/Sub messages + + For durable subscriptions (with subscriber_id): + 1. First delivers any missed events from catch-up queue + 2. Then waits for real-time events + 3. Implicitly ACKs previous event when called again + + Returns message dict or None on error. + """ + sub, sub_data = await self._get_listen_subscription(sub_id, user) + subscriber_id = await self._update_listen_subscription_state(sub, sub_data) + pending_msg = self._consume_pending_catchup(sub_id, sub, sub_data) + if pending_msg: + return pending_msg + + if not sub_data.get("catchup_done"): + sub_data["catchup_done"] = True + + return await self._listen_for_message(sub_id, sub, subscriber_id, sub_data) + async def publish(self, channel: str, message: str): """Publish a message on a channel (Redis only, no durability)""" await self._redis.publish(channel, message) - async def publish_cloudevent(self, channel: str, data: Any, - attributes: Optional[Dict] = None): + async def publish_cloudevent(self, channel: str, data: Any, attributes: Optional[Dict] = None): """Publish a CloudEvent on a Pub/Sub channel Events are: @@ -586,12 +574,12 @@ async def publish_cloudevent(self, channel: str, data: Any, """ if not attributes: attributes = {} - if not attributes.get('type'): - attributes['type'] = "api.kernelci.org" - if not attributes.get('source'): - attributes['source'] = self._settings.cloud_events_source + if not attributes.get("type"): + attributes["type"] = "api.kernelci.org" + if not attributes.get("source"): + attributes["source"] = self._settings.cloud_events_source - owner = attributes.get('owner') + owner = attributes.get("owner") # Store in MongoDB eventhistory (for durable delivery and /events API) # Store the raw data dict, not CloudEvent JSON @@ -599,11 +587,11 @@ async def publish_cloudevent(self, channel: str, data: Any, # Create CloudEvent for Redis real-time delivery event = CloudEvent(attributes=attributes, data=data) - event_json = to_json(event).decode('utf-8') + event_json = to_json(event).decode("utf-8") # Add sequence_id to message for tracking durable subscriptions msg_with_id = json.loads(event_json) - msg_with_id['_sequence_id'] = sequence_id + msg_with_id["_sequence_id"] = sequence_id await self._redis.publish(channel, json.dumps(msg_with_id)) async def push(self, list_name: str, message: str): @@ -614,12 +602,11 @@ async def pop(self, list_name: str) -> Optional[Dict]: """Pop a message from a list""" while True: msg = await self._redis.blpop(list_name, timeout=1.0) - data = json.loads(msg[1].decode('utf-8')) if msg else None + data = json.loads(msg[1].decode("utf-8")) if msg else None if data is not None: return data - async def push_cloudevent(self, list_name: str, data: Any, - attributes: Optional[Dict] = None): + async def push_cloudevent(self, list_name: str, data: Any, attributes: Optional[Dict] = None): """Push a CloudEvent on a list""" if not attributes: attributes = { @@ -633,19 +620,18 @@ async def subscription_stats(self) -> List[SubscriptionStats]: """Get existing subscription details""" subscriptions = [] for _, subscription in self._subscriptions.items(): - sub = subscription['sub'] + sub = subscription["sub"] stats = SubscriptionStats( id=sub.id, channel=sub.channel, user=sub.user, - created=subscription['created'], - last_poll=subscription['last_poll'] + created=subscription["created"], + last_poll=subscription["last_poll"], ) subscriptions.append(stats) return subscriptions - async def cleanup_stale_subscriptions(self, - max_age_minutes: int = 30) -> int: + async def cleanup_stale_subscriptions(self, max_age_minutes: int = 30) -> int: """Remove subscriptions not polled recently For durable subscriptions, only the in-memory state is cleaned up. @@ -656,7 +642,7 @@ async def cleanup_stale_subscriptions(self, async with self._lock: for sub_id, sub_data in self._subscriptions.items(): - last_poll = sub_data.get('last_poll') + last_poll = sub_data.get("last_poll") if last_poll and last_poll < cutoff: stale_ids.append(sub_id) @@ -668,8 +654,7 @@ async def cleanup_stale_subscriptions(self, return len(stale_ids) - async def cleanup_stale_subscriber_states(self, - max_age_days: int = 30) -> int: + async def cleanup_stale_subscriber_states(self, max_age_days: int = 30) -> int: """Remove subscriber states not used for a long time This is separate from subscription cleanup - it removes the @@ -677,5 +662,5 @@ async def cleanup_stale_subscriber_states(self, """ cutoff = datetime.utcnow() - timedelta(days=max_age_days) col = self._mongo_db[self.SUBSCRIBER_STATE_COLLECTION] - result = await col.delete_many({'last_poll': {'$lt': cutoff}}) + result = await col.delete_many({"last_poll": {"$lt": cutoff}}) return result.deleted_count diff --git a/api/user_manager.py b/api/user_manager.py index 67ded7bc..06237009 100644 --- a/api/user_manager.py +++ b/api/user_manager.py @@ -5,7 +5,10 @@ """User Manager""" -from typing import Optional, Any, Dict +from typing import Any, Dict, Optional + +import jinja2 +from beanie import PydanticObjectId from fastapi import Depends, Request, Response from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager @@ -13,15 +16,15 @@ BeanieUserDatabase, ObjectIDIDMixin, ) -from beanie import PydanticObjectId -import jinja2 -from .models import User + from .config import AuthSettings from .email_sender import EmailSender +from .models import User class UserManager(ObjectIDIDMixin, BaseUserManager[User, PydanticObjectId]): """User management logic""" + settings = AuthSettings() reset_password_token_secret = settings.secret_key verification_token_secret = settings.secret_key @@ -29,9 +32,7 @@ class UserManager(ObjectIDIDMixin, BaseUserManager[User, PydanticObjectId]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._email_sender = None - self._template_env = jinja2.Environment( - loader=jinja2.PackageLoader("api", "templates") - ) + self._template_env = jinja2.Environment(loader=jinja2.PackageLoader("api", "templates")) @property def email_sender(self): @@ -40,8 +41,7 @@ def email_sender(self): self._email_sender = EmailSender() return self._email_sender - async def on_after_register(self, user: User, - request: Optional[Request] = None): + async def on_after_register(self, user: User, request: Optional[Request] = None): """Handler to execute after successful user registration""" print(f"User {user.id} {user.username} has registered.") @@ -56,28 +56,26 @@ async def send_invite_email(self, user: User, token: str, invite_url: str): ) self.email_sender.create_and_send_email(subject, content, user.email) - async def on_after_login(self, user: User, - request: Optional[Request] = None, - response: Optional[Response] = None): + async def on_after_login( + self, + user: User, + request: Optional[Request] = None, + response: Optional[Response] = None, + ): """Handler to execute after successful user login""" print(f"User {user.id} {user.username} logged in.") - async def on_after_forgot_password(self, user: User, token: str, - request: Optional[Request] = None): + async def on_after_forgot_password(self, user: User, token: str, request: Optional[Request] = None): """Handler to execute after successful forgot password request""" template = self._template_env.get_template("reset-password.jinja2") subject = "Reset Password Token for KernelCI API account" - content = template.render( - username=user.username, token=token - ) + content = template.render(username=user.username, token=token) self.email_sender.create_and_send_email(subject, content, user.email) - async def on_after_reset_password(self, user: User, - request: Optional[Request] = None): + async def on_after_reset_password(self, user: User, request: Optional[Request] = None): """Handler to execute after successful password reset""" print(f"User {user.id} {user.username} has reset their password.") - template = self._template_env.get_template( - "reset-password-successful.jinja2") + template = self._template_env.get_template("reset-password-successful.jinja2") subject = "Password reset successful for KernelCI API account" content = template.render( username=user.username, @@ -91,24 +89,21 @@ async def send_invite_accepted_email(self, user: User): content = template.render(username=user.username) self.email_sender.create_and_send_email(subject, content, user.email) - async def on_after_update(self, user: User, update_dict: Dict[str, Any], - request: Optional[Request] = None): + async def on_after_update( + self, user: User, update_dict: Dict[str, Any], request: Optional[Request] = None + ): """Handler to execute after successful user update""" print(f"User {user.id} {user.username} has been updated.") - async def on_before_delete(self, user: User, - request: Optional[Request] = None): + async def on_before_delete(self, user: User, request: Optional[Request] = None): """Handler to execute before user delete.""" print(f"User {user.id} {user.username} is going to be deleted.") - async def on_after_delete(self, user: User, - request: Optional[Request] = None): + async def on_after_delete(self, user: User, request: Optional[Request] = None): """Handler to execute after user delete.""" print(f"User {user.id} {user.username} was successfully deleted.") - async def authenticate( - self, credentials: OAuth2PasswordRequestForm - ) -> User | None: + async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> User | None: """ Overload user authentication method `BaseUserManager.authenticate`. This is to fix login endpoint to receive `username` instead of `email`. @@ -118,17 +113,14 @@ async def authenticate( self.password_helper.hash(credentials.password) return None - verified, updated_password_hash = \ - self.password_helper.verify_and_update( - credentials.password, user.hashed_password - ) + verified, updated_password_hash = self.password_helper.verify_and_update( + credentials.password, user.hashed_password + ) if not verified: return None # Update password hash to a more robust one if needed if updated_password_hash is not None: - await self.user_db.update( - user, {"hashed_password": updated_password_hash} - ) + await self.user_db.update(user, {"hashed_password": updated_password_hash}) return user diff --git a/tests/e2e_tests/conftest.py b/tests/e2e_tests/conftest.py index d290af51..a8c06e76 100644 --- a/tests/e2e_tests/conftest.py +++ b/tests/e2e_tests/conftest.py @@ -8,29 +8,28 @@ import pytest from httpx import AsyncClient -from motor.motor_asyncio import AsyncIOMotorClient - from kernelci.api.models import Node, Regression +from motor.motor_asyncio import AsyncIOMotorClient from api.main import versioned_app -BASE_URL = 'http://api:8000/latest/' -DB_URL = 'mongodb://db:27017' -DB_NAME = 'kernelci' +BASE_URL = "http://api:8000/latest/" +DB_URL = "mongodb://db:27017" +DB_NAME = "kernelci" db_client = AsyncIOMotorClient(DB_URL) db = db_client[DB_NAME] node_model_fields = set(Node.model_fields.keys()) regression_model_fields = set(Regression.model_fields.keys()) paginated_response_keys = { - 'items', - 'total', - 'limit', - 'offset', + "items", + "total", + "limit", + "offset", } -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") async def test_async_client(): """Fixture to get Test client for asynchronous tests""" async with AsyncClient(app=versioned_app, base_url=BASE_URL) as client: @@ -41,7 +40,7 @@ async def test_async_client(): async def db_create(collection, obj): """Database create method""" - delattr(obj, 'id') + delattr(obj, "id") col = db[collection] # res = await col.insert_one(obj.dict(by_alias=True)) res = await col.insert_one(obj.model_dump(by_alias=True)) @@ -49,7 +48,7 @@ async def db_create(collection, obj): return obj -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def event_loop(): """Get an instance of the default event loop using database client. The event loop will be used for all async tests. diff --git a/tests/e2e_tests/listen_handler.py b/tests/e2e_tests/listen_handler.py index 9b5a7b88..f32e410d 100644 --- a/tests/e2e_tests/listen_handler.py +++ b/tests/e2e_tests/listen_handler.py @@ -7,6 +7,7 @@ """Helper function for KernelCI API listen handler""" import asyncio + import pytest @@ -16,13 +17,13 @@ def create_listen_task(test_async_client, subscription_id): API endpoint `/listen`. Returns the task instance. """ - listen_path = '/'.join(['listen', str(subscription_id)]) + listen_path = "/".join(["listen", str(subscription_id)]) task_listen = asyncio.create_task( test_async_client.get( listen_path, headers={ "Accept": "application/json", - "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member }, ) ) diff --git a/tests/e2e_tests/test_count_handler.py b/tests/e2e_tests/test_count_handler.py index 8508a47f..f86ec914 100644 --- a/tests/e2e_tests/test_count_handler.py +++ b/tests/e2e_tests/test_count_handler.py @@ -6,15 +6,11 @@ """End-to-end test functions for KernelCI API count handler""" - import pytest @pytest.mark.asyncio -@pytest.mark.dependency( - depends=[ - 'tests/e2e_tests/test_pipeline.py::test_node_pipeline'], - scope='session') +@pytest.mark.dependency(depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], scope="session") async def test_count_nodes(test_async_client): """ Test Case : Test KernelCI API GET /count endpoint @@ -28,10 +24,7 @@ async def test_count_nodes(test_async_client): @pytest.mark.asyncio -@pytest.mark.dependency( - depends=[ - 'tests/e2e_tests/test_pipeline.py::test_node_pipeline'], - scope='session') +@pytest.mark.dependency(depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], scope="session") async def test_count_nodes_matching_attributes(test_async_client): """ Test Case : Test KernelCI API GET /count endpoint with attributes diff --git a/tests/e2e_tests/test_node_handler.py b/tests/e2e_tests/test_node_handler.py index 2de3ae85..252dcac1 100644 --- a/tests/e2e_tests/test_node_handler.py +++ b/tests/e2e_tests/test_node_handler.py @@ -7,6 +7,7 @@ """Test functions for KernelCI API node handler""" import json + import pytest from .conftest import node_model_fields, paginated_response_keys @@ -23,10 +24,10 @@ async def create_node(test_async_client, node): "node", headers={ "Accept": "application/json", - "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member }, - data=json.dumps(node) - ) + data=json.dumps(node), + ) assert response.status_code == 200 assert response.json().keys() == node_model_fields return response @@ -43,7 +44,7 @@ async def get_node_by_id(test_async_client, node_id): f"node/{node_id}", headers={ "Accept": "application/json", - "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member }, ) assert response.status_code == 200 @@ -64,12 +65,12 @@ async def get_node_by_attribute(test_async_client, params): params=params, headers={ "Accept": "application/json", - "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member }, ) assert response.status_code == 200 assert response.json().keys() == paginated_response_keys - assert response.json()['total'] >= 0 + assert response.json()["total"] >= 0 return response @@ -84,9 +85,9 @@ async def update_node(test_async_client, node): f"node/{node['id']}", headers={ "Accept": "application/json", - "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member }, - data=json.dumps(node) + data=json.dumps(node), ) assert response.status_code == 200 assert response.json().keys() == node_model_fields diff --git a/tests/e2e_tests/test_password_handler.py b/tests/e2e_tests/test_password_handler.py index 70deb3b9..67bb8bda 100644 --- a/tests/e2e_tests/test_password_handler.py +++ b/tests/e2e_tests/test_password_handler.py @@ -7,6 +7,7 @@ """End-to-end test functions for KernelCI API password reset handler""" import json + import pytest @@ -27,12 +28,8 @@ async def test_password_endpoint(test_async_client): "user/me", headers={ "Accept": "application/json", - "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member }, - data=json.dumps( - { - "password": "foo" - } - ), + data=json.dumps({"password": "foo"}), ) assert response.status_code == 200 diff --git a/tests/e2e_tests/test_pipeline.py b/tests/e2e_tests/test_pipeline.py index 243490ce..7b0c12e4 100644 --- a/tests/e2e_tests/test_pipeline.py +++ b/tests/e2e_tests/test_pipeline.py @@ -14,9 +14,9 @@ @pytest.mark.dependency( - depends=[ - 'tests/e2e_tests/test_subscribe_handler.py::test_subscribe_node_channel'], - scope='session') + depends=["tests/e2e_tests/test_subscribe_handler.py::test_subscribe_node_channel"], + scope="session", +) @pytest.mark.order(4) @pytest.mark.asyncio async def test_node_pipeline(test_async_client): @@ -37,9 +37,7 @@ async def test_node_pipeline(test_async_client): """ # Create Task to listen pubsub event on 'node' channel - task_listen = create_listen_task( - test_async_client, - pytest.node_channel_subscription_id) # pylint: disable=no-member + task_listen = create_listen_task(test_async_client, pytest.node_channel_subscription_id) # pylint: disable=no-member # Create a node node = { @@ -49,36 +47,42 @@ async def test_node_pipeline(test_async_client): "data": { "kernel_revision": { "tree": "mainline", - "url": ( - "https://git.kernel.org/pub/scm/" - "linux/kernel/git/torvalds/linux.git" - ), + "url": ("https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git"), "branch": "master", "commit": "2a987e65025e2b79c6d453b78cb5985ac6e5eb28", - "describe": "v5.16-rc4-31-g2a987e65025e" + "describe": "v5.16-rc4-31-g2a987e65025e", } - } + }, } response = await create_node(test_async_client, node) # Get result of pubsub event listen task await task_listen - event_data = from_json(task_listen.result().json().get('data')).data - assert event_data != 'BEEP' - keys = {'op', 'id', 'kind', 'name', 'path', - 'group', 'state', 'result', 'owner', 'data', 'is_hierarchy'} + event_data = from_json(task_listen.result().json().get("data")).data + assert event_data != "BEEP" + keys = { + "op", + "id", + "kind", + "name", + "path", + "group", + "state", + "result", + "owner", + "data", + "is_hierarchy", + } assert keys == event_data.keys() - assert event_data.get('op') == 'created' - assert event_data.get('id') == response.json()['id'] + assert event_data.get("op") == "created" + assert event_data.get("id") == response.json()["id"] # Get node id from event data and get created node by id - response = await get_node_by_id(test_async_client, event_data.get('id')) + response = await get_node_by_id(test_async_client, event_data.get("id")) node = response.json() # Create Task to listen 'updated' event on 'node' channel - task_listen = create_listen_task( - test_async_client, - pytest.node_channel_subscription_id) # pylint: disable=no-member + task_listen = create_listen_task(test_async_client, pytest.node_channel_subscription_id) # pylint: disable=no-member # Update node.state node.update({"state": "done"}) @@ -87,9 +91,20 @@ async def test_node_pipeline(test_async_client): # Get result of pubsub event listen task await task_listen - event_data = from_json(task_listen.result().json().get('data')).data - assert event_data != 'BEEP' - keys = {'op', 'id', 'kind', 'name', 'path', - 'group', 'state', 'result', 'owner', 'data', 'is_hierarchy'} + event_data = from_json(task_listen.result().json().get("data")).data + assert event_data != "BEEP" + keys = { + "op", + "id", + "kind", + "name", + "path", + "group", + "state", + "result", + "owner", + "data", + "is_hierarchy", + } assert keys == event_data.keys() - assert event_data.get('op') == 'updated' + assert event_data.get("op") == "updated" diff --git a/tests/e2e_tests/test_pubsub_handler.py b/tests/e2e_tests/test_pubsub_handler.py index 42c4e327..abc93d8e 100644 --- a/tests/e2e_tests/test_pubsub_handler.py +++ b/tests/e2e_tests/test_pubsub_handler.py @@ -7,14 +7,15 @@ """End-to-end test function for KernelCI API pubsub handler""" import pytest -from cloudevents.http import CloudEvent, to_structured, from_json +from cloudevents.http import CloudEvent, from_json, to_structured from .listen_handler import create_listen_task @pytest.mark.dependency( - depends=['tests/e2e_tests/test_subscribe_handler.py::test_subscribe_test_channel'], - scope='session') + depends=["tests/e2e_tests/test_subscribe_handler.py::test_subscribe_test_channel"], + scope="session", +) @pytest.mark.asyncio async def test_pubsub_handler(test_async_client): """ @@ -23,9 +24,7 @@ async def test_pubsub_handler(test_async_client): Use pubsub listener task to verify published event message. """ # Create Task to listen pubsub event on 'test_channel' channel - task_listen = create_listen_task( - test_async_client, - pytest.test_channel_subscription_id) # pylint: disable=no-member + task_listen = create_listen_task(test_async_client, pytest.test_channel_subscription_id) # pylint: disable=no-member # Created and publish CloudEvent attributes = { @@ -35,23 +34,19 @@ async def test_pubsub_handler(test_async_client): data = {"message": "Test message"} event = CloudEvent(attributes, data) headers, body = to_structured(event) - headers['Authorization'] = f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member - response = await test_async_client.post( - "publish/test_channel", - headers=headers, - data=body - ) + headers["Authorization"] = f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member + response = await test_async_client.post("publish/test_channel", headers=headers, data=body) assert response.status_code == 200 # Get result of pubsub event listener await task_listen assert task_listen.result().json().keys() == { - 'channel', - 'data', - 'pattern', - 'type', + "channel", + "data", + "pattern", + "type", } - event_data = from_json(task_listen.result().json().get('data')).data - assert event_data != 'BEEP' - assert ('message',) == tuple(event_data.keys()) - assert event_data.get('message') == 'Test message' + event_data = from_json(task_listen.result().json().get("data")).data + assert event_data != "BEEP" + assert ("message",) == tuple(event_data.keys()) + assert event_data.get("message") == "Test message" diff --git a/tests/e2e_tests/test_regression_handler.py b/tests/e2e_tests/test_regression_handler.py index 82d61c47..13d91594 100644 --- a/tests/e2e_tests/test_regression_handler.py +++ b/tests/e2e_tests/test_regression_handler.py @@ -11,10 +11,7 @@ from .test_node_handler import create_node, get_node_by_attribute -@pytest.mark.dependency( - depends=[ - "tests/e2e_tests/test_pipeline.py::test_node_pipeline"], - scope="session") +@pytest.mark.dependency(depends=["tests/e2e_tests/test_pipeline.py::test_node_pipeline"], scope="session") @pytest.mark.asyncio async def test_regression_handler(test_async_client): """ @@ -27,9 +24,7 @@ async def test_regression_handler(test_async_client): method. """ # Get "checkout" node - response = await get_node_by_attribute( - test_async_client, {"name": "checkout"} - ) + response = await get_node_by_attribute(test_async_client, {"name": "checkout"}) checkout_node = response.json()["items"][0] # Create a 'kver' passed node @@ -52,28 +47,21 @@ async def test_regression_handler(test_async_client): "result": "pass", } - passed_node_obj = ( - await create_node(test_async_client, passed_node) - ).json() + passed_node_obj = (await create_node(test_async_client, passed_node)).json() # Create a 'kver' failed node failed_node = passed_node.copy() failed_node["result"] = "fail" - failed_node_obj = ( - await create_node(test_async_client, failed_node) - ).json() + failed_node_obj = (await create_node(test_async_client, failed_node)).json() # Create a "kver" regression node - regression_fields = ['group', 'name', 'path', 'state'] - regression_node = { - field: failed_node_obj[field] - for field in regression_fields - } + regression_fields = ["group", "name", "path", "state"] + regression_node = {field: failed_node_obj[field] for field in regression_fields} regression_node["kind"] = "regression" regression_node["data"] = { "fail_node": failed_node_obj["id"], - "pass_node": passed_node_obj["id"] + "pass_node": passed_node_obj["id"], } await create_node(test_async_client, regression_node) diff --git a/tests/e2e_tests/test_subscribe_handler.py b/tests/e2e_tests/test_subscribe_handler.py index 84fdf0d7..50012068 100644 --- a/tests/e2e_tests/test_subscribe_handler.py +++ b/tests/e2e_tests/test_subscribe_handler.py @@ -11,8 +11,9 @@ @pytest.mark.asyncio @pytest.mark.dependency( - depends=['tests/e2e_tests/test_user_creation.py::test_create_regular_user'], - scope='session') + depends=["tests/e2e_tests/test_user_creation.py::test_create_regular_user"], + scope="session", +) @pytest.mark.order(3) async def test_subscribe_node_channel(test_async_client): """ @@ -27,19 +28,20 @@ async def test_subscribe_node_channel(test_async_client): "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member }, ) - pytest.node_channel_subscription_id = response.json()['id'] + pytest.node_channel_subscription_id = response.json()["id"] assert response.status_code == 200 # only id, channel, user is mandatory in the response - assert 'id' in response.json() - assert 'channel' in response.json() - assert 'user' in response.json() - assert response.json().get('channel') == 'node' + assert "id" in response.json() + assert "channel" in response.json() + assert "user" in response.json() + assert response.json().get("channel") == "node" @pytest.mark.asyncio @pytest.mark.dependency( - depends=['tests/e2e_tests/test_user_creation.py::test_create_regular_user'], - scope='session') + depends=["tests/e2e_tests/test_user_creation.py::test_create_regular_user"], + scope="session", +) @pytest.mark.order(3) async def test_subscribe_test_channel(test_async_client): """ @@ -54,19 +56,20 @@ async def test_subscribe_test_channel(test_async_client): "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member }, ) - pytest.test_channel_subscription_id = response.json()['id'] + pytest.test_channel_subscription_id = response.json()["id"] assert response.status_code == 200 # only id, channel, user is mandatory in the response - assert 'id' in response.json() - assert 'channel' in response.json() - assert 'user' in response.json() - assert response.json().get('channel') == 'test_channel' + assert "id" in response.json() + assert "channel" in response.json() + assert "user" in response.json() + assert response.json().get("channel") == "test_channel" @pytest.mark.asyncio @pytest.mark.dependency( - depends=['tests/e2e_tests/test_user_creation.py::test_create_regular_user'], - scope='session') + depends=["tests/e2e_tests/test_user_creation.py::test_create_regular_user"], + scope="session", +) @pytest.mark.order(3) async def test_subscribe_user_group_channel(test_async_client): """ @@ -82,10 +85,10 @@ async def test_subscribe_user_group_channel(test_async_client): "Authorization": f"Bearer {pytest.BEARER_TOKEN}" # pylint: disable=no-member }, ) - pytest.user_group_channel_subscription_id = response.json()['id'] + pytest.user_group_channel_subscription_id = response.json()["id"] assert response.status_code == 200 # only id, channel, user is mandatory in the response - assert 'id' in response.json() - assert 'channel' in response.json() - assert 'user' in response.json() - assert response.json().get('channel') == 'user_group' + assert "id" in response.json() + assert "channel" in response.json() + assert "user" in response.json() + assert response.json().get("channel") == "user_group" diff --git a/tests/e2e_tests/test_unsubscribe_handler.py b/tests/e2e_tests/test_unsubscribe_handler.py index 62932b7f..9fc8b2ad 100644 --- a/tests/e2e_tests/test_unsubscribe_handler.py +++ b/tests/e2e_tests/test_unsubscribe_handler.py @@ -11,9 +11,9 @@ @pytest.mark.asyncio @pytest.mark.dependency( - depends=[ - 'tests/e2e_tests/test_subscribe_handler.py::test_subscribe_node_channel'], - scope='session') + depends=["tests/e2e_tests/test_subscribe_handler.py::test_subscribe_node_channel"], + scope="session", +) @pytest.mark.order("last") async def test_unsubscribe_node_channel(test_async_client): """ @@ -32,9 +32,9 @@ async def test_unsubscribe_node_channel(test_async_client): @pytest.mark.asyncio @pytest.mark.dependency( - depends=[ - 'tests/e2e_tests/test_subscribe_handler.py::test_subscribe_test_channel'], - scope='session') + depends=["tests/e2e_tests/test_subscribe_handler.py::test_subscribe_test_channel"], + scope="session", +) @pytest.mark.order("last") async def test_unsubscribe_test_channel(test_async_client): """ diff --git a/tests/e2e_tests/test_user_creation.py b/tests/e2e_tests/test_user_creation.py index e8fd222e..864b0ea3 100644 --- a/tests/e2e_tests/test_user_creation.py +++ b/tests/e2e_tests/test_user_creation.py @@ -6,17 +6,20 @@ """End-to-end test functions for KernelCI API user creation""" import json + import pytest -from api.models import User -from api.db import Database from api.auth import Authentication +from api.db import Database +from api.models import User + from .conftest import db_create @pytest.mark.dependency( depends=["tests/e2e_tests/test_user_group_handler.py::test_create_user_groups"], - scope="session") + scope="session", +) @pytest.mark.dependency() @pytest.mark.order(1) @pytest.mark.asyncio @@ -27,8 +30,8 @@ async def test_create_admin_user(test_async_client): Request authentication token using '/user/login' endpoint for the user and store it in pytest global variable 'ADMIN_BEARER_TOKEN'. """ - username = 'admin' - password = 'test' + username = "admin" + password = "test" hashed_password = Authentication.get_password_hash(password) obj = await db_create( @@ -36,28 +39,29 @@ async def test_create_admin_user(test_async_client): User( username=username, hashed_password=hashed_password, - email='test-admin@kernelci.org', + email="test-admin@kernelci.org", groups=[], is_superuser=1, - is_verified=1 - )) + is_verified=1, + ), + ) assert obj is not None response = await test_async_client.post( "user/login", headers={ "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded" + "Content-Type": "application/x-www-form-urlencoded", }, - data=f"username={username}&password={password}" + data=f"username={username}&password={password}", ) print("response.json()", response.json()) assert response.status_code == 200 assert response.json().keys() == { - 'access_token', - 'token_type', + "access_token", + "token_type", } - pytest.ADMIN_BEARER_TOKEN = response.json()['access_token'] + pytest.ADMIN_BEARER_TOKEN = response.json()["access_token"] @pytest.mark.dependency(depends=["test_create_admin_user"]) @@ -69,36 +73,39 @@ async def test_create_regular_user(test_async_client): user when requested with admin user's bearer token. Request '/user/login' endpoint for the user and store it in pytest global variable 'BEARER_TOKEN'. """ - username = 'test_user' - password = 'test' - email = 'test@kernelci.org' + username = "test_user" + password = "test" + email = "test@kernelci.org" response = await test_async_client.post( "user/register", headers={ - "Accept": "application/json", - "Authorization": f"Bearer {pytest.ADMIN_BEARER_TOKEN}" - }, - data=json.dumps({ - 'username': username, - 'password': password, - 'email': email - }) + "Accept": "application/json", + "Authorization": f"Bearer {pytest.ADMIN_BEARER_TOKEN}", + }, + data=json.dumps({"username": username, "password": password, "email": email}), ) assert response.status_code == 200 - assert ('id', 'email', 'is_active', 'is_superuser', 'is_verified', - 'username', 'groups') == tuple(response.json().keys()) + assert ( + "id", + "email", + "is_active", + "is_superuser", + "is_verified", + "username", + "groups", + ) == tuple(response.json().keys()) # User needs to verified before getting access token # Directly updating user to by pass user verification via email - user_id = response.json()['id'] + user_id = response.json()["id"] response = await test_async_client.patch( f"user/{user_id}", headers={ "Accept": "application/json", "Content-Type": "application/json", - "Authorization": f"Bearer {pytest.ADMIN_BEARER_TOKEN}" + "Authorization": f"Bearer {pytest.ADMIN_BEARER_TOKEN}", }, - data=json.dumps({"is_verified": True}) + data=json.dumps({"is_verified": True}), ) assert response.status_code == 200 @@ -106,16 +113,16 @@ async def test_create_regular_user(test_async_client): "user/login", headers={ "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded" + "Content-Type": "application/x-www-form-urlencoded", }, - data=f"username={username}&password={password}" + data=f"username={username}&password={password}", ) assert response.status_code == 200 assert response.json().keys() == { - 'access_token', - 'token_type', + "access_token", + "token_type", } - pytest.BEARER_TOKEN = response.json()['access_token'] + pytest.BEARER_TOKEN = response.json()["access_token"] @pytest.mark.asyncio @@ -132,13 +139,20 @@ async def test_whoami(test_async_client): "whoami", headers={ "Accept": "application/json", - "Authorization": f"Bearer {pytest.BEARER_TOKEN}" + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", }, ) assert response.status_code == 200 - assert ('id', 'email', 'is_active', 'is_superuser', 'is_verified', - 'username', 'groups') == tuple(response.json().keys()) - assert response.json()['username'] == 'test_user' + assert ( + "id", + "email", + "is_active", + "is_superuser", + "is_verified", + "username", + "groups", + ) == tuple(response.json().keys()) + assert response.json()["username"] == "test_user" @pytest.mark.asyncio @@ -154,14 +168,10 @@ async def test_create_user_negative(test_async_client): response = await test_async_client.post( "user/register", headers={ - "Accept": "application/json", - "Authorization": f"Bearer {pytest.BEARER_TOKEN}" - }, - data=json.dumps({ - 'username': 'test', - 'password': 'test', - 'email': 'test@kernelci.org' - }) + "Accept": "application/json", + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", + }, + data=json.dumps({"username": "test", "password": "test", "email": "test@kernelci.org"}), ) assert response.status_code == 403 - assert response.json() == {'detail': 'Forbidden'} + assert response.json() == {"detail": "Forbidden"} diff --git a/tests/e2e_tests/test_user_invite.py b/tests/e2e_tests/test_user_invite.py index 0bbee960..4932f611 100644 --- a/tests/e2e_tests/test_user_invite.py +++ b/tests/e2e_tests/test_user_invite.py @@ -7,6 +7,7 @@ """End-to-end test functions for KernelCI API invite flow""" import json + import pytest diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index e602e0e6..a48eb3aa 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -11,55 +11,55 @@ """pytest fixtures for KernelCI API""" from unittest.mock import AsyncMock + import fakeredis.aioredis -from fastapi.testclient import TestClient -from fastapi import Request, HTTPException, status import pytest -from mongomock_motor import AsyncMongoMockClient from beanie import init_beanie +from fastapi import HTTPException, Request, status +from fastapi.testclient import TestClient from httpx import AsyncClient +from mongomock_motor import AsyncMongoMockClient from api.main import ( app, - versioned_app, - get_current_user, get_current_superuser, + get_current_user, + versioned_app, ) -from api.models import User, Subscription +from api.models import Subscription, User from api.pubsub import PubSub BEARER_TOKEN = "Bearer \ eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJib2IifQ.\ ci1smeJeuX779PptTkuaG1SEdkp5M1S1AgYvX8VdB20" -ADMIN_BEARER_TOKEN = 'Bearer \ +ADMIN_BEARER_TOKEN = "Bearer \ eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.\ eyJzdWIiOiJib2IiLCJzY29wZXMiOlsiYWRtaW4iXX0.\ -t3bAE-pHSzZaSHp7FMlImqgYvL6f_0xDUD-nQwxEm3k' +t3bAE-pHSzZaSHp7FMlImqgYvL6f_0xDUD-nQwxEm3k" -API_VERSION = 'latest' -BASE_URL = 'http://testserver/' + API_VERSION + '/' +API_VERSION = "latest" +BASE_URL = "http://testserver/" + API_VERSION + "/" def mock_get_current_user(request: Request): """ Get current active user """ - token = request.headers.get('authorization') + token = request.headers.get("authorization") if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing token", ) return User( - id='65265305c74695807499037f', - username='bob', - hashed_password='$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.' - 'xCZGmM8jWXUXJZ4L', - email='bob@kernelci.org', + id="65265305c74695807499037f", + username="bob", + hashed_password="$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.xCZGmM8jWXUXJZ4L", + email="bob@kernelci.org", is_active=True, is_superuser=False, - is_verified=True + is_verified=True, ) @@ -67,7 +67,7 @@ def mock_get_current_admin_user(request: Request): """ Get current active admin user """ - token = request.headers.get('authorization') + token = request.headers.get("authorization") if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -79,15 +79,14 @@ def mock_get_current_admin_user(request: Request): detail="Forbidden", ) return User( - id='653a5e1a7e9312c86f8f86e1', - username='admin', - hashed_password='$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.' - 'xCZGmM8jWXUXJZ4K', - email='admin@kernelci.org', + id="653a5e1a7e9312c86f8f86e1", + username="admin", + hashed_password="$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.xCZGmM8jWXUXJZ4K", + email="admin@kernelci.org", groups=[], is_active=True, is_superuser=True, - is_verified=True + is_verified=True, ) @@ -96,7 +95,7 @@ def mock_get_current_admin_user(request: Request): app.dependency_overrides[get_current_superuser] = mock_get_current_admin_user -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def test_client(): """Fixture to get FastAPI Test client instance""" with TestClient(app=versioned_app, base_url=BASE_URL) as client: @@ -116,8 +115,7 @@ async def test_async_client(): def mock_db_create(mocker): """Mocks async call to Database class method used to create object""" async_mock = AsyncMock() - mocker.patch('api.db.Database.create', - side_effect=async_mock) + mocker.patch("api.db.Database.create", side_effect=async_mock) return async_mock @@ -125,8 +123,7 @@ def mock_db_create(mocker): def mock_db_count(mocker): """Mocks async call to Database class method used to count objects""" async_mock = AsyncMock() - mocker.patch('api.db.Database.count', - side_effect=async_mock) + mocker.patch("api.db.Database.count", side_effect=async_mock) return async_mock @@ -137,8 +134,7 @@ def mock_db_find_by_attributes(mocker): used to find a list of objects by attributes """ async_mock = AsyncMock() - mocker.patch('api.db.Database.find_by_attributes', - side_effect=async_mock) + mocker.patch("api.db.Database.find_by_attributes", side_effect=async_mock) return async_mock @@ -149,8 +145,7 @@ def mock_db_find_by_id(mocker): used to find an object by id """ async_mock = AsyncMock() - mocker.patch('api.db.Database.find_by_id', - side_effect=async_mock) + mocker.patch("api.db.Database.find_by_id", side_effect=async_mock) return async_mock @@ -158,8 +153,7 @@ def mock_db_find_by_id(mocker): def mock_db_delete_by_id(mocker): """Mocks async call to Database class method used to delete an object""" async_mock = AsyncMock() - mocker.patch('api.db.Database.delete_by_id', - side_effect=async_mock) + mocker.patch("api.db.Database.delete_by_id", side_effect=async_mock) return async_mock @@ -167,8 +161,7 @@ def mock_db_delete_by_id(mocker): def mock_db_find_one(mocker): """Mocks async call to database method used to find one object""" async_mock = AsyncMock() - mocker.patch('api.db.Database.find_one', - side_effect=async_mock) + mocker.patch("api.db.Database.find_one", side_effect=async_mock) return async_mock @@ -176,8 +169,7 @@ def mock_db_find_one(mocker): def mock_init_sub_id(mocker): """Mocks async call to PubSub method to initialize subscription id""" async_mock = AsyncMock() - mocker.patch('api.pubsub.PubSub._init_sub_id', - side_effect=async_mock) + mocker.patch("api.pubsub.PubSub._init_sub_id", side_effect=async_mock) return async_mock @@ -185,8 +177,7 @@ def mock_init_sub_id(mocker): def mock_listen(mocker): """Mocks async call to listen method of PubSub""" async_mock = AsyncMock() - mocker.patch('api.pubsub.PubSub.listen', - side_effect=async_mock) + mocker.patch("api.pubsub.PubSub.listen", side_effect=async_mock) return async_mock @@ -197,8 +188,7 @@ def mock_publish_cloudevent(mocker): used to publish cloud event """ async_mock = AsyncMock() - mocker.patch('api.pubsub.PubSub.publish_cloudevent', - side_effect=async_mock) + mocker.patch("api.pubsub.PubSub.publish_cloudevent", side_effect=async_mock) return async_mock @@ -207,7 +197,7 @@ def mock_pubsub(mocker): """Mocks `_redis` member of PubSub class instance""" pubsub = PubSub() redis_mock = fakeredis.aioredis.FakeRedis() - mocker.patch.object(pubsub, '_redis', redis_mock) + mocker.patch.object(pubsub, "_redis", redis_mock) return pubsub @@ -216,11 +206,10 @@ def mock_pubsub_subscriptions(mocker): """Mocks `_redis` and `_subscriptions` member of PubSub class instance""" pubsub = PubSub() redis_mock = fakeredis.aioredis.FakeRedis() - sub = Subscription(id=1, channel='test', user='test') - mocker.patch.object(pubsub, '_redis', redis_mock) - subscriptions_mock = dict( - {1: {'sub': sub, 'redis_sub': pubsub._redis.pubsub()}}) - mocker.patch.object(pubsub, '_subscriptions', subscriptions_mock) + sub = Subscription(id=1, channel="test", user="test") + mocker.patch.object(pubsub, "_redis", redis_mock) + subscriptions_mock = dict({1: {"sub": sub, "redis_sub": pubsub._redis.pubsub()}}) + mocker.patch.object(pubsub, "_subscriptions", subscriptions_mock) return pubsub @@ -232,8 +221,8 @@ def mock_pubsub_publish(mocker): """ pubsub = PubSub() redis_mock = fakeredis.aioredis.FakeRedis() - mocker.patch.object(pubsub, '_redis', redis_mock) - mocker.patch.object(pubsub._redis, 'execute_command') + mocker.patch.object(pubsub, "_redis", redis_mock) + mocker.patch.object(pubsub._redis, "execute_command") return pubsub @@ -241,8 +230,7 @@ def mock_pubsub_publish(mocker): def mock_subscribe(mocker): """Mocks async call to subscribe method of PubSub""" async_mock = AsyncMock() - mocker.patch('api.pubsub.PubSub.subscribe', - side_effect=async_mock) + mocker.patch("api.pubsub.PubSub.subscribe", side_effect=async_mock) return async_mock @@ -250,8 +238,7 @@ def mock_subscribe(mocker): def mock_unsubscribe(mocker): """Mocks async call to unsubscribe method of PubSub""" async_mock = AsyncMock() - mocker.patch('api.pubsub.PubSub.unsubscribe', - side_effect=async_mock) + mocker.patch("api.pubsub.PubSub.unsubscribe", side_effect=async_mock) return async_mock @@ -260,10 +247,8 @@ async def mock_init_beanie(mocker): """Mocks async call to Database method to initialize Beanie""" async_mock = AsyncMock() client = AsyncMongoMockClient() - init = await init_beanie( - document_models=[User], database=client.get_database(name="db")) - mocker.patch('api.db.Database.initialize_beanie', - side_effect=async_mock, return_value=init) + init = await init_beanie(document_models=[User], database=client.get_database(name="db")) + mocker.patch("api.db.Database.initialize_beanie", side_effect=async_mock, return_value=init) return async_mock @@ -273,8 +258,7 @@ def mock_db_update(mocker): Mocks async call to Database class method used to update object """ async_mock = AsyncMock() - mocker.patch('api.db.Database.update', - side_effect=async_mock) + mocker.patch("api.db.Database.update", side_effect=async_mock) return async_mock @@ -282,8 +266,7 @@ def mock_db_update(mocker): async def mock_beanie_get_user_by_id(mocker): """Mocks async call to external method to get model by id""" async_mock = AsyncMock() - mocker.patch('fastapi_users_db_beanie.BeanieUserDatabase.get', - side_effect=async_mock) + mocker.patch("fastapi_users_db_beanie.BeanieUserDatabase.get", side_effect=async_mock) return async_mock @@ -291,8 +274,7 @@ async def mock_beanie_get_user_by_id(mocker): async def mock_beanie_user_update(mocker): """Mocks async call to external method to update user""" async_mock = AsyncMock() - mocker.patch('fastapi_users_db_beanie.BeanieUserDatabase.update', - side_effect=async_mock) + mocker.patch("fastapi_users_db_beanie.BeanieUserDatabase.update", side_effect=async_mock) return async_mock @@ -302,8 +284,10 @@ def mock_auth_current_user(mocker): Mocks async call to external method to get authenticated user """ async_mock = AsyncMock() - mocker.patch('fastapi_users.authentication.Authenticator._authenticate', - side_effect=async_mock) + mocker.patch( + "fastapi_users.authentication.Authenticator._authenticate", + side_effect=async_mock, + ) return async_mock @@ -313,6 +297,5 @@ def mock_user_find(mocker): Mocks async call to external method to find user model using Beanie """ async_mock = AsyncMock() - mocker.patch('api.models.User.find_one', - side_effect=async_mock) + mocker.patch("api.models.User.find_one", side_effect=async_mock) return async_mock diff --git a/tests/unit_tests/test_authz_handler.py b/tests/unit_tests/test_authz_handler.py index 28632901..e6df2d73 100644 --- a/tests/unit_tests/test_authz_handler.py +++ b/tests/unit_tests/test_authz_handler.py @@ -9,6 +9,7 @@ import json from kernelci.api.models import Node, Revision + from api.main import _user_can_edit_node from api.models import User, UserGroup from tests.unit_tests.conftest import BEARER_TOKEN diff --git a/tests/unit_tests/test_events_handler.py b/tests/unit_tests/test_events_handler.py index 63cd8d3e..c48d43ce 100644 --- a/tests/unit_tests/test_events_handler.py +++ b/tests/unit_tests/test_events_handler.py @@ -69,7 +69,5 @@ def test_get_events_filter_by_node_id_alias(mock_db_find_by_attributes, test_cli def test_get_events_rejects_node_id_and_data_id(test_client): """GET /events rejects requests with both node_id and data.id parameters.""" - resp = test_client.get( - "events?node_id=693af4f5fee8383e92b6b0eb&data.id=693af4f5fee8383e92b6b0eb" - ) + resp = test_client.get("events?node_id=693af4f5fee8383e92b6b0eb&data.id=693af4f5fee8383e92b6b0eb") assert resp.status_code == 400 diff --git a/tests/unit_tests/test_listen_handler.py b/tests/unit_tests/test_listen_handler.py index 2a493564..43f727ce 100644 --- a/tests/unit_tests/test_listen_handler.py +++ b/tests/unit_tests/test_listen_handler.py @@ -21,13 +21,11 @@ def test_listen_endpoint(mock_listen, test_client): HTTP Response Code 200 OK Listen for events on a channel. """ - mock_listen.return_value = 'Listening for events on channel 1' + mock_listen.return_value = "Listening for events on channel 1" response = test_client.get( "listen/1", - headers={ - "Authorization": BEARER_TOKEN - }, + headers={"Authorization": BEARER_TOKEN}, ) assert response.status_code == 200 @@ -43,12 +41,10 @@ def test_listen_endpoint_not_found(test_client): """ response = test_client.get( "listen/1", - headers={ - "Authorization": BEARER_TOKEN - }, + headers={"Authorization": BEARER_TOKEN}, ) assert response.status_code == 404 - assert 'detail' in response.json() + assert "detail" in response.json() def test_listen_endpoint_without_token(test_client): @@ -61,8 +57,6 @@ def test_listen_endpoint_without_token(test_client): """ response = test_client.get( "listen/1", - headers={ - "Accept": "application/json" - }, + headers={"Accept": "application/json"}, ) assert response.status_code == 401 diff --git a/tests/unit_tests/test_node_handler.py b/tests/unit_tests/test_node_handler.py index 00034deb..1cc685f9 100644 --- a/tests/unit_tests/test_node_handler.py +++ b/tests/unit_tests/test_node_handler.py @@ -18,8 +18,7 @@ from tests.unit_tests.conftest import BEARER_TOKEN -def test_create_node_endpoint(mock_db_create, mock_publish_cloudevent, - test_client): +def test_create_node_endpoint(mock_db_create, mock_publish_cloudevent, test_client): """ Test Case : Test KernelCI API /node endpoint Expected Result : @@ -41,7 +40,7 @@ def test_create_node_endpoint(mock_db_create, mock_publish_cloudevent, name="checkout", path=["checkout"], group="debug", - data={'kernel_revision': revision_obj}, + data={"kernel_revision": revision_obj}, parent=None, state="closing", result=None, @@ -58,43 +57,39 @@ def test_create_node_endpoint(mock_db_create, mock_publish_cloudevent, } response = test_client.post( "node", - headers={ - "Accept": "application/json", - "Authorization": BEARER_TOKEN - }, - data=json.dumps(request_dict) - ) + headers={"Accept": "application/json", "Authorization": BEARER_TOKEN}, + data=json.dumps(request_dict), + ) print("response.json()", response.json()) assert response.status_code == 200 assert response.json().keys() == { - 'id', - 'artifacts', - 'created', - 'data', - 'debug', - 'group', - 'jobfilter', - 'platform_filter', - 'holdoff', - 'kind', - 'name', - 'owner', - 'path', - 'parent', - 'result', - 'submitter', - 'state', - 'timeout', - 'treeid', - 'updated', - 'user_groups', - 'processed_by_kcidb_bridge', - 'retry_counter', + "id", + "artifacts", + "created", + "data", + "debug", + "group", + "jobfilter", + "platform_filter", + "holdoff", + "kind", + "name", + "owner", + "path", + "parent", + "result", + "submitter", + "state", + "timeout", + "treeid", + "updated", + "user_groups", + "processed_by_kcidb_bridge", + "retry_counter", } -def test_get_nodes_by_attributes_endpoint(mock_db_find_by_attributes, - test_client): +def test_get_nodes_by_attributes_endpoint(mock_db_find_by_attributes, test_client): """ Test Case : Test KernelCI API GET /nodes?attribute_name=attribute_value endpoint for the positive path @@ -141,10 +136,7 @@ def test_get_nodes_by_attributes_endpoint(mock_db_find_by_attributes, "treeid": "61bda8f2eb1a63d2b7152414", } mock_db_find_by_attributes.return_value = PageModel( - items=[node_obj_1, node_obj_2], - total=2, - limit=50, - offset=0 + items=[node_obj_1, node_obj_2], total=2, limit=50, offset=0 ) params = { @@ -158,15 +150,13 @@ def test_get_nodes_by_attributes_endpoint(mock_db_find_by_attributes, response = test_client.get( "nodes", params=params, - ) + ) print("response.json()", response.json()) assert response.status_code == 200 - assert len(response.json()['items']) > 0 + assert len(response.json()["items"]) > 0 -def test_get_nodes_by_attributes_endpoint_node_not_found( - mock_db_find_by_attributes, - test_client): +def test_get_nodes_by_attributes_endpoint_node_not_found(mock_db_find_by_attributes, test_client): """ Test Case : Test KernelCI API GET /nodes?attribute_name=attribute_value endpoint for the node not found @@ -175,28 +165,16 @@ def test_get_nodes_by_attributes_endpoint_node_not_found( Empty list """ - mock_db_find_by_attributes.return_value = PageModel( - items=[], - total=0, - limit=50, - offset=0 - ) + mock_db_find_by_attributes.return_value = PageModel(items=[], total=0, limit=50, offset=0) - params = { - "name": "checkout", - "revision.tree": "baseline" - } - response = test_client.get( - "nodes", - params=params - ) + params = {"name": "checkout", "revision.tree": "baseline"} + response = test_client.get("nodes", params=params) print("response.json()", response.json()) assert response.status_code == 200 - assert response.json().get('total') == 0 + assert response.json().get("total") == 0 -def test_get_node_by_id_endpoint(mock_db_find_by_id, - test_client): +def test_get_node_by_id_endpoint(mock_db_find_by_id, test_client): """ Test Case : Test KernelCI API GET /node/{node_id} endpoint for the positive path @@ -209,7 +187,7 @@ def test_get_node_by_id_endpoint(mock_db_find_by_id, url="https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git", branch="master", commit="2a987e65025e2b79c6d453b78cb5985ac6e5eb26", - describe="v5.16-rc4-31-g2a987e65025e" + describe="v5.16-rc4-31-g2a987e65025e", ) node_obj = Node( id="61bda8f2eb1a63d2b7152418", @@ -217,7 +195,7 @@ def test_get_node_by_id_endpoint(mock_db_find_by_id, name="checkout", path=["checkout"], group="blah", - data={'kernel_revision': revision_obj}, + data={"kernel_revision": revision_obj}, parent=None, state="closing", result=None, @@ -229,34 +207,33 @@ def test_get_node_by_id_endpoint(mock_db_find_by_id, print("response.json()", response.json()) assert response.status_code == 200 assert response.json().keys() == { - 'id', - 'artifacts', - 'created', - 'data', - 'debug', - 'group', - 'jobfilter', - 'platform_filter', - 'holdoff', - 'kind', - 'name', - 'owner', - 'path', - 'parent', - 'result', - 'submitter', - 'state', - 'timeout', - 'treeid', - 'updated', - 'user_groups', - 'processed_by_kcidb_bridge', - 'retry_counter', + "id", + "artifacts", + "created", + "data", + "debug", + "group", + "jobfilter", + "platform_filter", + "holdoff", + "kind", + "name", + "owner", + "path", + "parent", + "result", + "submitter", + "state", + "timeout", + "treeid", + "updated", + "user_groups", + "processed_by_kcidb_bridge", + "retry_counter", } -def test_get_node_by_id_endpoint_empty_response(mock_db_find_by_id, - test_client): +def test_get_node_by_id_endpoint_empty_response(mock_db_find_by_id, test_client): """ Test Case : Test KernelCI API GET /node/{node_id} endpoint for negative path @@ -272,8 +249,7 @@ def test_get_node_by_id_endpoint_empty_response(mock_db_find_by_id, assert response.json() is None -def test_get_all_nodes(mock_db_find_by_attributes, - test_client): +def test_get_all_nodes(mock_db_find_by_attributes, test_client): """ Test Case : Test KernelCI API GET /nodes endpoint for the positive path @@ -347,10 +323,7 @@ def test_get_all_nodes(mock_db_find_by_attributes, } mock_db_find_by_attributes.return_value = PageModel( - items=[node_obj_1, node_obj_2, node_obj_3], - total=3, - limit=50, - offset=0 + items=[node_obj_1, node_obj_2, node_obj_3], total=3, limit=50, offset=0 ) response = test_client.get("nodes") @@ -359,8 +332,7 @@ def test_get_all_nodes(mock_db_find_by_attributes, assert len(response.json()) > 0 -def test_get_all_nodes_empty_response(mock_db_find_by_attributes, - test_client): +def test_get_all_nodes_empty_response(mock_db_find_by_attributes, test_client): """ Test Case : Test KernelCI API GET /nodes endpoint for the negative path @@ -368,14 +340,9 @@ def test_get_all_nodes_empty_response(mock_db_find_by_attributes, HTTP Response Code 200 OK Empty list as no Node object is added. """ - mock_db_find_by_attributes.return_value = PageModel( - items=[], - total=0, - limit=50, - offset=0 - ) + mock_db_find_by_attributes.return_value = PageModel(items=[], total=0, limit=50, offset=0) response = test_client.get("nodes") print("response.json()", response.json()) assert response.status_code == 200 - assert response.json().get('total') == 0 + assert response.json().get("total") == 0 diff --git a/tests/unit_tests/test_pubsub.py b/tests/unit_tests/test_pubsub.py index 2c3f0ab2..385ae5a1 100644 --- a/tests/unit_tests/test_pubsub.py +++ b/tests/unit_tests/test_pubsub.py @@ -9,6 +9,7 @@ """Unit test functions for KernelCI API Pub/Sub""" import json + import pytest @@ -23,8 +24,8 @@ async def test_subscribe_single_channel(mock_pubsub): PubSub._subscriptions dict should have one entry. This entry's key should be equal 1. """ - result = await mock_pubsub.subscribe('CHANNEL', 'test') - assert result.channel == 'CHANNEL' + result = await mock_pubsub.subscribe("CHANNEL", "test") + assert result.channel == "CHANNEL" assert result.id == 1 assert len(mock_pubsub._subscriptions) == 1 assert 1 in mock_pubsub._subscriptions @@ -46,9 +47,9 @@ async def test_subscribe_multiple_channels(mock_pubsub): """ # Reset `ID_KEY` value to get subscription ID starting from 1 await mock_pubsub._redis.set(mock_pubsub.ID_KEY, 0) - channels = ((1, 'CHANNEL1'), (2, 'CHANNEL2'), (3, 'CHANNEL3')) + channels = ((1, "CHANNEL1"), (2, "CHANNEL2"), (3, "CHANNEL3")) for expected_id, expected_channel in channels: - result = await mock_pubsub.subscribe(expected_channel, 'test') + result = await mock_pubsub.subscribe(expected_channel, "test") assert result.channel == expected_channel assert result.id == expected_id assert len(mock_pubsub._subscriptions) == 3 @@ -98,7 +99,7 @@ async def test_pubsub_publish_couldevent(mock_pubsub_publish): return value, but a json to be published in a channel. """ - data = 'validate json' + data = "validate json" attributes = { "specversion": "1.0", "id": "6878b661-96dc-4e93-8c92-26eb9ff8db64", @@ -107,7 +108,7 @@ async def test_pubsub_publish_couldevent(mock_pubsub_publish): "time": "2022-01-31T21:29:29.675593+00:00", } - await mock_pubsub_publish.publish_cloudevent('CHANNEL1', data, attributes) + await mock_pubsub_publish.publish_cloudevent("CHANNEL1", data, attributes) expected_json = str.encode( '{"specversion": "1.0", ' diff --git a/tests/unit_tests/test_subscribe_handler.py b/tests/unit_tests/test_subscribe_handler.py index 36906848..808d396a 100644 --- a/tests/unit_tests/test_subscribe_handler.py +++ b/tests/unit_tests/test_subscribe_handler.py @@ -7,8 +7,8 @@ """Unit test function for KernelCI API subscribe handler""" -from tests.unit_tests.conftest import BEARER_TOKEN from api.pubsub import Subscription +from tests.unit_tests.conftest import BEARER_TOKEN def test_subscribe_endpoint(mock_subscribe, test_client): @@ -18,14 +18,12 @@ def test_subscribe_endpoint(mock_subscribe, test_client): HTTP Response Code 200 OK JSON with 'id' and 'channel' keys """ - subscribe = Subscription(id=1, channel='abc', user='test') + subscribe = Subscription(id=1, channel="abc", user="test") mock_subscribe.return_value = subscribe response = test_client.post( "subscribe/abc", - headers={ - "Authorization": BEARER_TOKEN - }, + headers={"Authorization": BEARER_TOKEN}, ) print("response.json()", response.json()) assert response.status_code == 200 diff --git a/tests/unit_tests/test_token_handler.py b/tests/unit_tests/test_token_handler.py index d5753faa..db14b8b6 100644 --- a/tests/unit_tests/test_token_handler.py +++ b/tests/unit_tests/test_token_handler.py @@ -11,12 +11,12 @@ """Unit test function for KernelCI API token handler""" import pytest + from api.models import User @pytest.mark.asyncio -async def test_token_endpoint(test_async_client, mock_user_find, - mock_beanie_user_update): +async def test_token_endpoint(test_async_client, mock_user_find, mock_beanie_user_update): """ Test Case : Test KernelCI API /user/login endpoint Expected Result : @@ -24,15 +24,14 @@ async def test_token_endpoint(test_async_client, mock_user_find, JSON with 'access_token' and 'token_type' key """ user = User( - id='65265305c74695807499037f', - username='bob', - hashed_password='$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.' - 'xCZGmM8jWXUXJZ4K', - email='bob@kernelci.org', + id="65265305c74695807499037f", + username="bob", + hashed_password="$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.xCZGmM8jWXUXJZ4K", + email="bob@kernelci.org", groups=[], is_active=True, is_superuser=False, - is_verified=True + is_verified=True, ) mock_user_find.return_value = user mock_beanie_user_update.return_value = user @@ -41,17 +40,16 @@ async def test_token_endpoint(test_async_client, mock_user_find, "user/login", headers={ "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded" + "Content-Type": "application/x-www-form-urlencoded", }, - data="username=bob&password=hello" + data="username=bob&password=hello", ) assert response.status_code == 200 - assert ('access_token', 'token_type') == tuple(response.json().keys()) + assert ("access_token", "token_type") == tuple(response.json().keys()) @pytest.mark.asyncio -async def test_token_endpoint_incorrect_password(test_async_client, - mock_user_find): +async def test_token_endpoint_incorrect_password(test_async_client, mock_user_find): """ Test Case : Test KernelCI API /user/login endpoint for negative path Incorrect password should be passed to the endpoint @@ -67,10 +65,10 @@ async def test_token_endpoint_incorrect_password(test_async_client, "user/login", headers={ "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded" + "Content-Type": "application/x-www-form-urlencoded", }, - data="username=bob&password=hello1" + data="username=bob&password=hello1", ) print("response json", response.json()) assert response.status_code == 400 - assert response.json() == {'detail': 'LOGIN_BAD_CREDENTIALS'} + assert response.json() == {"detail": "LOGIN_BAD_CREDENTIALS"} diff --git a/tests/unit_tests/test_unsubscribe_handler.py b/tests/unit_tests/test_unsubscribe_handler.py index 947dce6e..afa1a982 100644 --- a/tests/unit_tests/test_unsubscribe_handler.py +++ b/tests/unit_tests/test_unsubscribe_handler.py @@ -18,9 +18,7 @@ def test_unsubscribe_endpoint(mock_unsubscribe, test_client): """ response = test_client.post( "unsubscribe/1", - headers={ - "Authorization": BEARER_TOKEN - }, + headers={"Authorization": BEARER_TOKEN}, ) assert response.status_code == 200 @@ -34,10 +32,8 @@ def test_unsubscribe_endpoint_empty_response(test_client): """ response = test_client.post( "unsubscribe/1", - headers={ - "Authorization": BEARER_TOKEN - }, + headers={"Authorization": BEARER_TOKEN}, ) print("response.json()", response.json()) assert response.status_code == 404 - assert 'detail' in response.json() + assert "detail" in response.json() diff --git a/tests/unit_tests/test_user_group_handler.py b/tests/unit_tests/test_user_group_handler.py index 5aa8db5e..d701eeaa 100644 --- a/tests/unit_tests/test_user_group_handler.py +++ b/tests/unit_tests/test_user_group_handler.py @@ -49,8 +49,7 @@ def test_create_user_group(mock_db_find_one, mock_db_create, test_client): assert response.json()["name"] == "runtime:pull-labs-demo:node-editor" -def test_delete_user_group(mock_db_find_by_id, mock_db_count, - mock_db_delete_by_id, test_client): +def test_delete_user_group(mock_db_find_by_id, mock_db_count, mock_db_delete_by_id, test_client): """DELETE /user-groups/{id} removes an unused user group.""" mock_db_find_by_id.return_value = UserGroup(name="team-a") mock_db_count.return_value = 0 @@ -69,8 +68,7 @@ def test_delete_user_group(mock_db_find_by_id, mock_db_count, ) -def test_delete_user_group_when_assigned(mock_db_find_by_id, mock_db_count, - test_client): +def test_delete_user_group_when_assigned(mock_db_find_by_id, mock_db_count, test_client): """DELETE /user-groups/{id} rejects when group is assigned to users.""" mock_db_find_by_id.return_value = UserGroup(name="team-a") mock_db_count.return_value = 2 diff --git a/tests/unit_tests/test_user_handler.py b/tests/unit_tests/test_user_handler.py index 43e8ce9c..823e1444 100644 --- a/tests/unit_tests/test_user_handler.py +++ b/tests/unit_tests/test_user_handler.py @@ -8,18 +8,18 @@ """Unit test function for KernelCI API user handler""" import json + import pytest +from api.models import UserGroup, UserRead from tests.unit_tests.conftest import ( ADMIN_BEARER_TOKEN, BEARER_TOKEN, ) -from api.models import UserGroup, UserRead @pytest.mark.asyncio -async def test_create_regular_user(mock_db_find_one, mock_db_create, - test_async_client): +async def test_create_regular_user(mock_db_find_one, mock_db_create, test_async_client): """ Test Case : Test KernelCI API /user/register endpoint to create regular user when requested with admin user's bearer token @@ -30,37 +30,36 @@ async def test_create_regular_user(mock_db_find_one, mock_db_create, """ mock_db_find_one.return_value = None user = UserRead( - id='65265305c74695807499037f', - username='test', - email='test@kernelci.org', + id="65265305c74695807499037f", + username="test", + email="test@kernelci.org", groups=[], is_active=True, is_verified=False, - is_superuser=False + is_superuser=False, ) mock_db_create.return_value = user response = await test_async_client.post( "user/register", - headers={ - "Accept": "application/json", - "Authorization": ADMIN_BEARER_TOKEN - }, - data=json.dumps({ - 'username': 'test', - 'password': 'test', - 'email': 'test@kernelci.org' - }) + headers={"Accept": "application/json", "Authorization": ADMIN_BEARER_TOKEN}, + data=json.dumps({"username": "test", "password": "test", "email": "test@kernelci.org"}), ) print(response.json()) assert response.status_code == 200 - assert ('id', 'email', 'is_active', 'is_superuser', 'is_verified', - 'username', 'groups') == tuple(response.json().keys()) + assert ( + "id", + "email", + "is_active", + "is_superuser", + "is_verified", + "username", + "groups", + ) == tuple(response.json().keys()) @pytest.mark.asyncio -async def test_create_admin_user(test_async_client, mock_db_find_one, - mock_db_find_by_id, mock_db_update): +async def test_create_admin_user(test_async_client, mock_db_find_one, mock_db_find_by_id, mock_db_update): """ Test Case : Test KernelCI API /user/register endpoint to create admin user when requested with admin user's bearer token @@ -70,13 +69,13 @@ async def test_create_admin_user(test_async_client, mock_db_find_one, 'is_verified' and 'is_superuser' keys """ user = UserRead( - id='61bda8f2eb1a63d2b7152419', - username='test_admin', - email='test-admin@kernelci.org', + id="61bda8f2eb1a63d2b7152419", + username="test_admin", + email="test-admin@kernelci.org", groups=[], is_active=True, is_verified=False, - is_superuser=True + is_superuser=True, ) mock_db_find_one.return_value = None mock_db_find_by_id.return_value = user @@ -84,21 +83,27 @@ async def test_create_admin_user(test_async_client, mock_db_find_one, response = await test_async_client.post( "user/register", - headers={ - "Accept": "application/json", - "Authorization": ADMIN_BEARER_TOKEN - }, - data=json.dumps({ - 'username': 'test_admin', - 'password': 'test', - 'email': 'test-admin@kernelci.org', - 'is_superuser': True - }) + headers={"Accept": "application/json", "Authorization": ADMIN_BEARER_TOKEN}, + data=json.dumps( + { + "username": "test_admin", + "password": "test", + "email": "test-admin@kernelci.org", + "is_superuser": True, + } + ), ) print(response.json()) assert response.status_code == 200 - assert ('id', 'email', 'is_active', 'is_superuser', 'is_verified', - 'username', 'groups') == tuple(response.json().keys()) + assert ( + "id", + "email", + "is_active", + "is_superuser", + "is_verified", + "username", + "groups", + ) == tuple(response.json().keys()) @pytest.mark.asyncio @@ -112,24 +117,18 @@ async def test_create_user_endpoint_negative(test_async_client): """ response = await test_async_client.post( "user/register", - headers={ - "Accept": "application/json", - "Authorization": BEARER_TOKEN - }, - data=json.dumps({ - 'username': 'test', - 'password': 'test', - 'email': 'test@kernelci.org' - }) + headers={"Accept": "application/json", "Authorization": BEARER_TOKEN}, + data=json.dumps({"username": "test", "password": "test", "email": "test@kernelci.org"}), ) print(response.json()) assert response.status_code == 403 - assert response.json() == {'detail': 'Forbidden'} + assert response.json() == {"detail": "Forbidden"} @pytest.mark.asyncio -async def test_create_user_with_group(test_async_client, mock_db_find_one, - mock_db_update, mock_db_find_by_id): +async def test_create_user_with_group( + test_async_client, mock_db_find_one, mock_db_update, mock_db_find_by_id +): """ Test Case : Test KernelCI API /user/register endpoint to create a user with a user group @@ -139,40 +138,45 @@ async def test_create_user_with_group(test_async_client, mock_db_find_one, 'is_verified' and 'is_superuser' keys """ user = UserRead( - id='61bda8f2eb1a63d2b7152419', - username='test_admin', - email='test-admin@kernelci.org', - groups=[UserGroup(name='kernelci')], + id="61bda8f2eb1a63d2b7152419", + username="test_admin", + email="test-admin@kernelci.org", + groups=[UserGroup(name="kernelci")], is_active=True, is_verified=False, - is_superuser=False + is_superuser=False, ) mock_db_find_by_id.return_value = user mock_db_update.return_value = user - mock_db_find_one.side_effect = [None, UserGroup(name='kernelci')] + mock_db_find_one.side_effect = [None, UserGroup(name="kernelci")] response = await test_async_client.post( "user/register", - headers={ - "Accept": "application/json", - "Authorization": ADMIN_BEARER_TOKEN - }, - data=json.dumps({ - 'username': 'test', - 'password': 'test', - 'email': 'test-admin@kernelci.org', - 'groups': ['kernelci'] - }) + headers={"Accept": "application/json", "Authorization": ADMIN_BEARER_TOKEN}, + data=json.dumps( + { + "username": "test", + "password": "test", + "email": "test-admin@kernelci.org", + "groups": ["kernelci"], + } + ), ) print(response.json()) assert response.status_code == 200 - assert ('id', 'email', 'is_active', 'is_superuser', 'is_verified', - 'username', 'groups') == tuple(response.json().keys()) + assert ( + "id", + "email", + "is_active", + "is_superuser", + "is_verified", + "username", + "groups", + ) == tuple(response.json().keys()) @pytest.mark.asyncio -async def test_get_user_by_id_endpoint(test_async_client, - mock_beanie_get_user_by_id): +async def test_get_user_by_id_endpoint(test_async_client, mock_beanie_get_user_by_id): """ Test Case : Test KernelCI API GET /user/{user_id} endpoint with admin token @@ -181,23 +185,28 @@ async def test_get_user_by_id_endpoint(test_async_client, JSON with User object attributes """ user_obj = UserRead( - id='61bda8f2eb1a63d2b7152418', - username='test', - email='test@kernelci.org', - groups=[], - is_active=True, - is_verified=False, - is_superuser=False - ) + id="61bda8f2eb1a63d2b7152418", + username="test", + email="test@kernelci.org", + groups=[], + is_active=True, + is_verified=False, + is_superuser=False, + ) mock_beanie_get_user_by_id.return_value = user_obj response = await test_async_client.get( "user/61bda8f2eb1a63d2b7152418", - headers={ - "Accept": "application/json", - "Authorization": ADMIN_BEARER_TOKEN - }) + headers={"Accept": "application/json", "Authorization": ADMIN_BEARER_TOKEN}, + ) print("response.json()", response.json()) assert response.status_code == 200 - assert ('id', 'email', 'is_active', 'is_superuser', 'is_verified', - 'username', 'groups') == tuple(response.json().keys()) + assert ( + "id", + "email", + "is_active", + "is_superuser", + "is_verified", + "username", + "groups", + ) == tuple(response.json().keys()) diff --git a/tests/unit_tests/test_whoami_handler.py b/tests/unit_tests/test_whoami_handler.py index 2b4a7b4a..0b579008 100644 --- a/tests/unit_tests/test_whoami_handler.py +++ b/tests/unit_tests/test_whoami_handler.py @@ -11,8 +11,9 @@ """Unit test function for KernelCI API whoami handler""" import pytest -from tests.unit_tests.conftest import BEARER_TOKEN + from api.models import UserRead +from tests.unit_tests.conftest import BEARER_TOKEN @pytest.mark.asyncio @@ -25,24 +26,27 @@ async def test_whoami_endpoint(test_async_client, mock_auth_current_user): and 'active' keys """ user = UserRead( - id='61bda8f2eb1a63d2b7152420', - username='test-user', - email='test-user@kernelci.org', - groups=[], - is_active=True, - is_verified=False, - is_superuser=False - ) + id="61bda8f2eb1a63d2b7152420", + username="test-user", + email="test-user@kernelci.org", + groups=[], + is_active=True, + is_verified=False, + is_superuser=False, + ) mock_auth_current_user.return_value = user, BEARER_TOKEN response = await test_async_client.get( "whoami", - headers={ - "Accept": "application/json", - "Authorization": BEARER_TOKEN - }, + headers={"Accept": "application/json", "Authorization": BEARER_TOKEN}, ) print(response.json(), response.status_code) assert response.status_code == 200 - assert ('id', 'email', 'is_active', 'is_superuser', - 'is_verified', 'username', - 'groups') == tuple(response.json().keys()) + assert ( + "id", + "email", + "is_active", + "is_superuser", + "is_verified", + "username", + "groups", + ) == tuple(response.json().keys())