diff --git a/.github/workflows/publish-packages.yml b/.github/workflows/publish-packages.yml index d0a14280..0a14ccb5 100644 --- a/.github/workflows/publish-packages.yml +++ b/.github/workflows/publish-packages.yml @@ -10,6 +10,8 @@ permissions: env: # Version is read from pyproject.toml - update with: python scripts/update_version.py DRY_RUN: 'false' + # gopher-orch version to download native binaries from (may differ from SDK version) + GOPHER_ORCH_VERSION: 'v0.1.2' jobs: download-binaries: @@ -57,7 +59,8 @@ jobs: GH_TOKEN: ${{ secrets.GOPHER_ORCH_TOKEN }} run: | # Download all assets from the private gopher-orch repo - gh release download ${{ steps.version.outputs.version_tag }} \ + # Uses GOPHER_ORCH_VERSION which may differ from SDK version + gh release download ${{ env.GOPHER_ORCH_VERSION }} \ -R GopherSecurity/gopher-orch \ -D downloads diff --git a/.gitmodules b/.gitmodules index 18a3014b..fd23aded 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "third_party/gopher-orch"] path = third_party/gopher-orch url = https://github.com/GopherSecurity/gopher-orch.git - branch = br_release + branch = dev_auth diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e213e19..7dc4f370 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] + +## [0.1.2] - 2026-03-12 + ## [0.1.1] - 2026-02-28 ## [0.1.0-20260227-124047] - 2026-02-27 @@ -65,5 +68,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- [Unreleased]: https://github.com/GopherSecurity/gopher-mcp-python/compare/v0.1.0-20260227-124047...HEAD -[0.1.1]: https://github.com/GopherSecurity/gopher-mcp-python/compare/v0.1.0-20260227-124047...v0.1.1[0.1.0-20260227-124047]: https://github.com/GopherSecurity/gopher-mcp-python/compare/v0.1.0...v0.1.0-20260227-124047 +[0.1.2]: https://github.com/GopherSecurity/gopher-mcp-python/compare/v0.1.0-20260227-124047...v0.1.2[0.1.1]: https://github.com/GopherSecurity/gopher-mcp-python/compare/v0.1.0-20260227-124047...v0.1.1[0.1.0-20260227-124047]: https://github.com/GopherSecurity/gopher-mcp-python/compare/v0.1.0...v0.1.0-20260227-124047 [0.1.0]: https://github.com/GopherSecurity/gopher-mcp-python/releases/tag/v0.1.0 diff --git a/build.sh b/build.sh index 97ea4a7e..403e6cbd 100755 --- a/build.sh +++ b/build.sh @@ -236,5 +236,18 @@ echo -e "${GREEN}======================================${NC}" echo "" echo -e "Native libraries: ${YELLOW}${NATIVE_LIB_DIR}${NC}" echo -e "Native headers: ${YELLOW}${NATIVE_INCLUDE_DIR}${NC}" -echo -e "Run tests: ${YELLOW}python3 -m pytest tests/${NC}" -echo -e "Run example: ${YELLOW}python3 examples/client_example.py${NC}" +echo "" +echo -e "${GREEN}Run tests:${NC}" +echo -e " ${YELLOW}python3 -m pytest tests/${NC}" +echo "" +echo -e "${GREEN}Run examples:${NC}" +echo -e " ${YELLOW}python3 examples/client_example.py${NC}" +echo "" +echo -e "${GREEN}Run Auth MCP Server example:${NC}" +echo -e " ${YELLOW}cd examples/auth && ./run_example.sh${NC} # Uses server.config settings" +echo -e " ${YELLOW}cd examples/auth && ./run_example.sh --no-auth${NC} # Override to disable auth" +echo -e " ${YELLOW}cd examples/auth && ./run_example.sh --help${NC} # Show all options" +echo "" +echo -e " Test endpoints:" +echo -e " ${YELLOW}curl http://localhost:3001/health${NC}" +echo -e " ${YELLOW}curl -X POST http://localhost:3001/mcp -H 'Content-Type: application/json' -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/list\",\"params\":{}}'${NC}" diff --git a/examples/auth/README.md b/examples/auth/README.md new file mode 100644 index 00000000..5ca75bdd --- /dev/null +++ b/examples/auth/README.md @@ -0,0 +1,353 @@ +# Python Auth MCP Server Example + +OAuth-protected MCP (Model Context Protocol) server implementation in Python using gopher-auth FFI bindings for JWT token validation. + +## Overview + +This example demonstrates: +- OAuth 2.0 protected MCP server using JSON-RPC 2.0 +- JWT token validation via gopher-auth native library (FFI) +- OAuth discovery endpoints (RFC 9728, RFC 8414, OIDC) +- Scope-based access control for MCP tools +- Integration with Keycloak or compatible OAuth providers + +## Prerequisites + +- Python 3.10+ +- pip +- Compiled `libgopher-orch` from gopher-orch (for production use) +- Keycloak or compatible OAuth 2.0 server (optional, for auth testing) + +## Installation + +```bash +# Install dependencies +pip install -r requirements.txt + +# Or install with development dependencies +pip install -e ".[dev]" + +# Copy libgopher-orch to lib/ (from gopher-orch build) +# macOS: +cp /path/to/gopher-orch/build/lib/libgopher-orch.dylib ./lib/ + +# Linux: +cp /path/to/gopher-orch/build/lib/libgopher-orch.so ./lib/ +``` + +## Building the Native Library + +Build libgopher-orch from gopher-orch: + +```bash +cd /path/to/gopher-orch +mkdir -p build && cd build +cmake -DBUILD_SHARED_LIBS=ON .. +make gopher-orch + +# Copy to this example +cp lib/libgopher-orch.* /path/to/gopher-mcp-python/examples/auth/lib/ +``` + +## Configuration + +Create or modify `server.config`: + +### Auth Disabled Mode (Development/Testing) + +```ini +# Server settings +host=0.0.0.0 +port=3001 +server_url=http://localhost:3001 + +# Disable auth for development +auth_disabled=true +``` + +### Auth Enabled Mode (Production) + +```ini +# Server settings +host=0.0.0.0 +port=3001 +server_url=http://localhost:3001 + +# OAuth/IDP settings (Keycloak example) +auth_server_url=https://keycloak.example.com/realms/mcp +client_id=mcp-client +client_secret=your-client-secret + +# Optional: Override derived endpoints +# jwks_uri=https://keycloak.example.com/realms/mcp/protocol/openid-connect/certs +# issuer=https://keycloak.example.com/realms/mcp +# oauth_authorize_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/auth +# oauth_token_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/token + +# Token validation settings +allowed_scopes=openid profile email mcp:read mcp:admin +jwks_cache_duration=3600 +jwks_auto_refresh=true +request_timeout=30 +``` + +### Configuration Options + +| Option | Description | Default | +|--------|-------------|---------| +| `host` | Server bind address | `0.0.0.0` | +| `port` | Server port | `3001` | +| `server_url` | Public server URL | `http://localhost:{port}` | +| `auth_server_url` | OAuth server base URL | - | +| `jwks_uri` | JWKS endpoint URL | Derived from auth_server_url | +| `issuer` | Expected token issuer | Derived from auth_server_url | +| `client_id` | OAuth client ID | - | +| `client_secret` | OAuth client secret | - | +| `allowed_scopes` | Space-separated allowed scopes | `openid profile email mcp:read mcp:admin` | +| `jwks_cache_duration` | JWKS cache TTL in seconds | `3600` | +| `jwks_auto_refresh` | Auto-refresh JWKS before expiry | `true` | +| `request_timeout` | HTTP request timeout in seconds | `30` | +| `auth_disabled` | Disable authentication entirely | `false` | + +## Running the Server + +### Development Mode + +```bash +# Run directly with Python +python -m py_auth_mcp_server + +# Or with custom config +python -m py_auth_mcp_server /path/to/custom.config + +# Override options +python -m py_auth_mcp_server --no-auth +python -m py_auth_mcp_server --host 127.0.0.1 --port 8080 +``` + +### Using Flask Development Server + +```bash +# Set environment variables +export FLASK_APP=py_auth_mcp_server.app:create_app +export FLASK_ENV=development + +# Run with Flask +flask run --port 3001 +``` + +## Testing + +```bash +# Run all tests +pytest + +# Run with coverage +pytest --cov=py_auth_mcp_server + +# Run specific test file +pytest tests/test_integration.py + +# Run with verbose output +pytest -v +``` + +## API Endpoints + +### Health Check + +```bash +curl http://localhost:3001/health +``` + +Response: +```json +{ + "status": "ok", + "timestamp": "2024-01-15T10:30:00.000000+00:00", + "version": "1.0.0", + "uptime": 123 +} +``` + +### OAuth Discovery + +```bash +# Protected Resource Metadata (RFC 9728) +curl http://localhost:3001/.well-known/oauth-protected-resource + +# Authorization Server Metadata (RFC 8414) +curl http://localhost:3001/.well-known/oauth-authorization-server + +# OpenID Configuration +curl http://localhost:3001/.well-known/openid-configuration +``` + +### MCP Tools + +#### Without Authentication (auth_disabled=true) + +```bash +# Initialize session +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}}' + +# List available tools +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}' + +# Get weather (no auth required) +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc": "2.0", "id": 3, "method": "tools/call", "params": {"name": "get-weather", "arguments": {"city": "Seattle"}}}' +``` + +#### With Authentication + +```bash +# Get an access token from your OAuth provider first +TOKEN="your-jwt-token" + +# Get forecast (requires mcp:read scope) +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $TOKEN" \ + -d '{"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "get-forecast", "arguments": {"city": "Portland"}}}' + +# Get weather alerts (requires mcp:admin scope) +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $TOKEN" \ + -d '{"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "get-weather-alerts", "arguments": {"region": "Pacific Northwest"}}}' +``` + +## Available Tools + +| Tool | Description | Required Scope | +|------|-------------|----------------| +| `get-weather` | Current weather for a city | None | +| `get-forecast` | 5-day forecast for a city | `mcp:read` | +| `get-weather-alerts` | Weather alerts for a region | `mcp:admin` | + +## Authentication Flow + +``` ++----------+ +--------------+ +-------------+ +| Client | | MCP Server | | Keycloak | ++----+-----+ +------+-------+ +------+------+ + | | | + | GET /.well-known/oauth-protected-resource + |----------------->| | + | { authorization_servers: [...] } | + |<-----------------| | + | | | + | GET /.well-known/oauth-authorization-server + |----------------->| | + | { authorization_endpoint, ... } | + |<-----------------| | + | | | + | Redirect to authorization_endpoint | + |-------------------------------------->| + | | User authenticates + | Redirect with auth code | + |<--------------------------------------| + | | | + | POST token_endpoint (exchange code) | + |-------------------------------------->| + | | Access token | + |<--------------------------------------| + | | | + | POST /mcp with Bearer token | + |----------------->| | + | | Validate JWT | + | |------------------->| + | | Token valid | + | |<-------------------| + | Tool response | | + |<-----------------| | +``` + +## Troubleshooting + +### Library Loading Errors + +``` +RuntimeError: Auth functions not available +``` + +**Solution:** Ensure the native library is compiled and accessible: +- Copy `libgopher-orch.dylib` (macOS) or `libgopher-orch.so` (Linux) to `./lib/` +- Or set `GOPHER_ORCH_LIBRARY_PATH` environment variable + +### Token Validation Failures + +``` +401 Unauthorized: Token validation failed +``` + +**Causes:** +- Token expired - obtain a new token +- Invalid issuer - check `issuer` in config matches token +- JWKS fetch failed - verify `jwks_uri` is accessible +- Invalid audience - ensure token has correct audience claim + +### JWKS Fetch Errors + +``` +Error: JWKS fetch failed +``` + +**Solutions:** +- Verify `jwks_uri` is correct and accessible +- Check network connectivity to OAuth server +- Increase `request_timeout` if needed + +### Scope Access Denied + +```json +{"error": "access_denied", "message": "Required scope: mcp:read"} +``` + +**Solution:** Ensure your token includes the required scope. Request additional scopes during token acquisition. + +## Project Structure + +``` +examples/auth/ +├── lib/ # Native library (libgopher-orch) +├── py_auth_mcp_server/ +│ ├── __init__.py # Package metadata +│ ├── __main__.py # Entry point +│ ├── app.py # Flask application factory +│ ├── config.py # Configuration loader +│ ├── middleware/ +│ │ ├── __init__.py +│ │ └── oauth_auth.py # OAuth middleware +│ ├── routes/ +│ │ ├── __init__.py +│ │ ├── health.py # Health endpoint +│ │ ├── oauth_endpoints.py # Discovery endpoints +│ │ └── mcp_handler.py # JSON-RPC handler +│ └── tools/ +│ ├── __init__.py +│ └── weather_tools.py # Example tools +├── tests/ +│ ├── __init__.py +│ ├── conftest.py # Pytest fixtures +│ ├── test_config.py +│ ├── test_oauth_auth.py +│ ├── test_mcp_handler.py +│ ├── test_weather_tools.py +│ └── test_integration.py +├── pyproject.toml # Project configuration +├── requirements.txt # Dependencies +├── server.config # Server configuration +├── server.config.example # Configuration template +└── README.md +``` + +## License + +See the main gopher-mcp-python repository for license information. diff --git a/examples/auth/lib/.gitkeep b/examples/auth/lib/.gitkeep new file mode 100644 index 00000000..f128d47c --- /dev/null +++ b/examples/auth/lib/.gitkeep @@ -0,0 +1,5 @@ +# This directory is for the native gopher-orch library. +# Place the compiled library file here: +# - macOS: libgopher-orch.dylib +# - Linux: libgopher-orch.so +# - Windows: gopher-orch.dll diff --git a/examples/auth/py_auth_mcp_server/__init__.py b/examples/auth/py_auth_mcp_server/__init__.py new file mode 100644 index 00000000..d2556f6e --- /dev/null +++ b/examples/auth/py_auth_mcp_server/__init__.py @@ -0,0 +1,22 @@ +"""Python MCP server with OAuth authentication. + +OAuth-protected MCP server example using gopher-auth FFI bindings. +Demonstrates JWT token validation and scope-based access control +for MCP tools. +""" + +__version__ = "1.0.0" +__author__ = "Gopher Security" + +from .app import cleanup_app, create_app +from .config import AuthServerConfig, create_default_config, load_config_from_file + +__all__ = [ + "__version__", + "__author__", + "create_app", + "cleanup_app", + "AuthServerConfig", + "create_default_config", + "load_config_from_file", +] diff --git a/examples/auth/py_auth_mcp_server/__main__.py b/examples/auth/py_auth_mcp_server/__main__.py new file mode 100644 index 00000000..7b00ebb0 --- /dev/null +++ b/examples/auth/py_auth_mcp_server/__main__.py @@ -0,0 +1,167 @@ +"""Entry point for py_auth_mcp_server. + +OAuth-protected MCP server example using gopher-auth FFI bindings. +Demonstrates JWT token validation with Keycloak and scope-based +access control for MCP tools. +""" + +import argparse +import signal +import sys +from pathlib import Path +from typing import NoReturn + +from .app import cleanup_app, create_app +from .config import AuthServerConfig + + +def print_banner() -> None: + """Print startup banner.""" + print("") + print("========================================") + print(" Python Auth MCP Server") + print(" OAuth-Protected MCP Example") + print("========================================") + print("") + + +def print_endpoints(config: AuthServerConfig) -> None: + """Print endpoint information. + + Args: + config: Server configuration. + """ + base_url = config.server_url + + print("Endpoints:") + print(f" Health: GET {base_url}/health") + print(f" OAuth Meta: GET {base_url}/.well-known/oauth-protected-resource") + print(f" Auth Server: GET {base_url}/.well-known/oauth-authorization-server") + print(f" OIDC Config: GET {base_url}/.well-known/openid-configuration") + print(f" OAuth Auth: GET {base_url}/oauth/authorize") + print(f" MCP: POST {base_url}/mcp") + print(f" RPC: POST {base_url}/rpc") + print("") + + if config.auth_disabled: + print("Authentication: DISABLED") + else: + print("Authentication: ENABLED") + print(f" JWKS URI: {config.jwks_uri}") + print(f" Issuer: {config.issuer}") + print("") + + +def main() -> NoReturn: + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Python Auth MCP Server - OAuth-Protected MCP Example" + ) + parser.add_argument( + "config", + nargs="?", + default=None, + help="Path to configuration file (default: server.config in current directory)", + ) + parser.add_argument( + "--host", + default=None, + help="Override host from config", + ) + parser.add_argument( + "--port", + type=int, + default=None, + help="Override port from config", + ) + parser.add_argument( + "--no-auth", + action="store_true", + help="Disable authentication", + ) + args = parser.parse_args() + + print_banner() + + # Determine config path + if args.config: + config_path = Path(args.config) + else: + # Try current directory first, then package directory + cwd_config = Path.cwd() / "server.config" + pkg_config = Path(__file__).parent.parent / "server.config" + + if cwd_config.exists(): + config_path = cwd_config + elif pkg_config.exists(): + config_path = pkg_config + else: + config_path = None + + # Create app + try: + if config_path and config_path.exists(): + print(f"Loading configuration from: {config_path}") + app = create_app(config_path=config_path) + print("Configuration loaded successfully") + else: + print("No configuration file found, using defaults") + app = create_app() + print("") + except Exception as e: + print(f"Failed to create application: {e}") + sys.exit(1) + + # Get config + config: AuthServerConfig = app.config["AUTH_SERVER_CONFIG"] + + # Apply command-line overrides + if args.host: + config.host = args.host + if args.port: + config.port = args.port + if args.no_auth: + config.auth_disabled = True + + # Print auth info + if not config.auth_disabled: + version = app.config.get("AUTH_LIBRARY_VERSION", "unknown") + print(f"Auth library version: {version}") + print("") + + # Setup signal handlers + def signal_handler(signum: int, frame: object) -> None: + print("") + print("Shutting down...") + cleanup_app(app) + print("Goodbye!") + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Print startup info + print(f"Starting server on {config.host}:{config.port}") + print("") + print_endpoints(config) + print("Press Ctrl+C to shutdown") + print("") + + # Run server + try: + app.run( + host=config.host, + port=config.port, + debug=False, + use_reloader=False, + ) + except Exception as e: + print(f"Server error: {e}") + cleanup_app(app) + sys.exit(1) + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/examples/auth/py_auth_mcp_server/app.py b/examples/auth/py_auth_mcp_server/app.py new file mode 100644 index 00000000..74114e0a --- /dev/null +++ b/examples/auth/py_auth_mcp_server/app.py @@ -0,0 +1,157 @@ +"""Flask application factory. + +Creates and configures the Flask application for the Auth MCP Server. +""" + +from __future__ import annotations + +from pathlib import Path + +from flask import Flask +from flask_cors import CORS + +from gopher_mcp_python.ffi.auth import ( + GopherAuthClient, + gopher_get_auth_library_version, + gopher_init_auth_library, + is_auth_available, + gopher_shutdown_auth_library, +) + +from .config import ( + AuthServerConfig, + create_default_config, + load_config_from_file, +) +from .middleware import OAuthAuthMiddleware +from .routes import register_health_routes, register_mcp_routes, register_oauth_routes +from .routes.mcp_handler import McpHandler +from .tools import register_weather_tools + + +def create_app( + config: AuthServerConfig | None = None, + config_path: str | Path | None = None, +) -> Flask: + """Create and configure the Flask application. + + Args: + config: Optional pre-built configuration. + config_path: Optional path to configuration file. + + Returns: + Configured Flask application. + """ + # Load configuration + if config is None: + if config_path is not None: + config = load_config_from_file(config_path) + else: + config = create_default_config() + + # Create Flask app + app = Flask(__name__) + app.config["JSON_SORT_KEYS"] = False + + # Enable CORS + CORS(app, resources={r"/*": {"origins": "*"}}) + + # Store config in app context + app.config["AUTH_SERVER_CONFIG"] = config + + # Initialize auth library if auth is enabled + auth_client: GopherAuthClient | None = None + + if not config.auth_disabled: + if is_auth_available(): + try: + gopher_init_auth_library() + version = gopher_get_auth_library_version() + app.config["AUTH_LIBRARY_VERSION"] = version + + # Create auth client + auth_client = GopherAuthClient(config.jwks_uri, config.issuer) + + # Set client options + if config.jwks_cache_duration > 0: + auth_client.set_option( + "cache_duration", str(config.jwks_cache_duration) + ) + if config.jwks_auto_refresh: + auth_client.set_option("auto_refresh", "true") + if config.request_timeout > 0: + auth_client.set_option( + "request_timeout", str(config.request_timeout) + ) + + app.config["AUTH_CLIENT"] = auth_client + except Exception as e: + app.logger.error(f"Failed to initialize auth library: {e}") + # Continue without auth + config.auth_disabled = True + else: + app.logger.warning( + "Auth functions not available - running without authentication" + ) + config.auth_disabled = True + + # Create auth middleware + auth_middleware = OAuthAuthMiddleware(auth_client, config) + app.config["AUTH_MIDDLEWARE"] = auth_middleware + + # Register before_request handler for auth + app.before_request(auth_middleware.before_request) + + # Get server version + server_version = ( + app.config.get("AUTH_LIBRARY_VERSION", "1.0.0") + if not config.auth_disabled + else "1.0.0" + ) + + # Register health endpoint + register_health_routes(app, version=server_version) + + # Register OAuth discovery endpoints + register_oauth_routes(app, config) + + # Create and register MCP handler + mcp_handler = McpHandler() + app.config["MCP_HANDLER"] = mcp_handler + register_mcp_routes(app, mcp_handler) + + # Register weather tools + register_weather_tools(mcp_handler, auth_middleware) + + # Register teardown handler for cleanup + @app.teardown_appcontext + def cleanup(exception: BaseException | None = None) -> None: + """Clean up resources on app context teardown.""" + pass # Per-request cleanup if needed + + return app + + +def cleanup_app(app: Flask) -> None: + """Clean up application resources. + + Call this when shutting down the application. + + Args: + app: Flask application to clean up. + """ + # Destroy auth client + auth_client = app.config.get("AUTH_CLIENT") + if auth_client is not None: + try: + auth_client.destroy() + except Exception: + pass + + # Shutdown auth library + config = app.config.get("AUTH_SERVER_CONFIG") + if config and not config.auth_disabled: + try: + gopher_shutdown_auth_library() + except Exception: + pass diff --git a/examples/auth/py_auth_mcp_server/config.py b/examples/auth/py_auth_mcp_server/config.py new file mode 100644 index 00000000..adc949cb --- /dev/null +++ b/examples/auth/py_auth_mcp_server/config.py @@ -0,0 +1,237 @@ +"""Configuration loader for Auth MCP Server. + +Mirrors AuthServerConfig from the C++ example: +/gopher-orch/examples/auth/auth_server_config.h +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +@dataclass +class AuthServerConfig: + """Server configuration for the Auth MCP Server.""" + + # Server settings + host: str = "0.0.0.0" + port: int = 3001 + server_url: str = "" + + # OAuth/IDP settings + auth_server_url: str = "" + jwks_uri: str = "" + issuer: str = "" + client_id: str = "" + client_secret: str = "" + token_endpoint: str = "" + + # Direct OAuth endpoint URLs + oauth_authorize_url: str = "" + oauth_token_url: str = "" + + # Scopes + allowed_scopes: str = "openid profile email mcp:read mcp:admin" + + # Cache settings + jwks_cache_duration: int = 3600 + jwks_auto_refresh: bool = True + request_timeout: int = 30 + + # Auth bypass mode + auth_disabled: bool = False + + def __post_init__(self) -> None: + """Set derived defaults after initialization.""" + if not self.server_url: + self.server_url = f"http://localhost:{self.port}" + + +def parse_config_file(content: str) -> dict[str, str]: + """Parse a configuration file in key=value format. + + Args: + content: Raw file content. + + Returns: + Parsed key-value map. + """ + result: dict[str, str] = {} + + for line in content.split("\n"): + trimmed = line.strip() + + # Skip empty lines and comments + if not trimmed or trimmed.startswith("#"): + continue + + eq_index = trimmed.find("=") + if eq_index == -1: + continue + + key = trimmed[:eq_index].strip() + value = trimmed[eq_index + 1 :].strip() + + if key: + result[key] = value + + return result + + +def load_config_from_file(config_path: str | Path) -> AuthServerConfig: + """Load and parse configuration from a file. + + Args: + config_path: Path to the configuration file. + + Returns: + Parsed AuthServerConfig. + + Raises: + FileNotFoundError: If the config file does not exist. + """ + path = Path(config_path) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + content = path.read_text(encoding="utf-8") + config_map = parse_config_file(content) + + return build_config(config_map) + + +def load_config_from_default_location(base_path: str | Path) -> AuthServerConfig: + """Load configuration from the default location relative to a base path. + + Args: + base_path: Base path (typically __file__ directory or executable path). + + Returns: + Parsed AuthServerConfig. + """ + config_path = Path(base_path) / "server.config" + return load_config_from_file(config_path) + + +def build_config(config_map: dict[str, str]) -> AuthServerConfig: + """Build AuthServerConfig from a parsed key-value map. + + Args: + config_map: Parsed configuration map. + + Returns: + AuthServerConfig object. + + Raises: + ValueError: If required fields are missing when auth is enabled. + """ + port = int(config_map.get("port", "3001")) + + config = AuthServerConfig( + # Server settings + host=config_map.get("host", "0.0.0.0"), + port=port, + server_url=config_map.get("server_url", f"http://localhost:{port}"), + # OAuth/IDP settings + auth_server_url=config_map.get("auth_server_url", ""), + jwks_uri=config_map.get("jwks_uri", ""), + issuer=config_map.get("issuer", ""), + client_id=config_map.get("client_id", ""), + client_secret=config_map.get("client_secret", ""), + token_endpoint=config_map.get("token_endpoint", ""), + # Direct OAuth endpoint URLs + oauth_authorize_url=config_map.get("oauth_authorize_url", ""), + oauth_token_url=config_map.get("oauth_token_url", ""), + # Scopes + allowed_scopes=config_map.get( + "allowed_scopes", "openid profile email mcp:read mcp:admin" + ), + # Cache settings + jwks_cache_duration=int(config_map.get("jwks_cache_duration", "3600")), + jwks_auto_refresh=config_map.get("jwks_auto_refresh", "true") != "false", + request_timeout=int(config_map.get("request_timeout", "30")), + # Auth bypass mode + auth_disabled=config_map.get("auth_disabled", "false") == "true", + ) + + # Derive endpoints from auth_server_url if not explicitly set + if config.auth_server_url: + if not config.jwks_uri: + config.jwks_uri = f"{config.auth_server_url}/protocol/openid-connect/certs" + if not config.token_endpoint: + config.token_endpoint = ( + f"{config.auth_server_url}/protocol/openid-connect/token" + ) + if not config.issuer: + config.issuer = config.auth_server_url + if not config.oauth_authorize_url: + config.oauth_authorize_url = ( + f"{config.auth_server_url}/protocol/openid-connect/auth" + ) + if not config.oauth_token_url: + config.oauth_token_url = ( + f"{config.auth_server_url}/protocol/openid-connect/token" + ) + + # Validate required fields when auth is enabled + if not config.auth_disabled: + validate_required_fields(config) + + return config + + +def validate_required_fields(config: AuthServerConfig) -> None: + """Validate that required configuration fields are present. + + Args: + config: Configuration to validate. + + Raises: + ValueError: If required fields are missing. + """ + errors: list[str] = [] + + if not config.client_id: + errors.append("client_id is required when auth is enabled") + if not config.client_secret: + errors.append("client_secret is required when auth is enabled") + if not config.jwks_uri and not config.auth_server_url: + errors.append("jwks_uri or auth_server_url is required when auth is enabled") + + if errors: + raise ValueError( + "Configuration validation failed:\n - " + "\n - ".join(errors) + ) + + +def create_default_config(**overrides: Any) -> AuthServerConfig: + """Create a default configuration for testing or development. + + Args: + **overrides: Optional overrides for default values. + + Returns: + AuthServerConfig with defaults. + """ + defaults = { + "host": "0.0.0.0", + "port": 3001, + "server_url": "http://localhost:3001", + "auth_server_url": "", + "jwks_uri": "", + "issuer": "", + "client_id": "", + "client_secret": "", + "token_endpoint": "", + "oauth_authorize_url": "", + "oauth_token_url": "", + "allowed_scopes": "openid profile email mcp:read mcp:admin", + "jwks_cache_duration": 3600, + "jwks_auto_refresh": True, + "request_timeout": 30, + "auth_disabled": True, + } + defaults.update(overrides) + return AuthServerConfig(**defaults) diff --git a/examples/auth/py_auth_mcp_server/middleware/__init__.py b/examples/auth/py_auth_mcp_server/middleware/__init__.py new file mode 100644 index 00000000..da0e0206 --- /dev/null +++ b/examples/auth/py_auth_mcp_server/middleware/__init__.py @@ -0,0 +1,8 @@ +"""OAuth middleware for request authentication.""" + +from .oauth_auth import OAuthAuthMiddleware, require_auth + +__all__ = [ + "OAuthAuthMiddleware", + "require_auth", +] diff --git a/examples/auth/py_auth_mcp_server/middleware/oauth_auth.py b/examples/auth/py_auth_mcp_server/middleware/oauth_auth.py new file mode 100644 index 00000000..3d2040ae --- /dev/null +++ b/examples/auth/py_auth_mcp_server/middleware/oauth_auth.py @@ -0,0 +1,295 @@ +"""OAuth Authentication Middleware. + +Flask middleware for JWT token validation using gopher-auth FFI. +Mirrors OAuthAuthFilter from the C++ example. +""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import Any + +from flask import Response, g, jsonify, request + +from gopher_mcp_python.ffi.auth import ( + GopherAuthClient, + GopherAuthContext, + GopherValidationOptions, + gopher_create_empty_auth_context, + gopher_generate_www_authenticate_header_v2, +) + +from ..config import AuthServerConfig + + +class OAuthAuthMiddleware: + """OAuth authentication middleware. + + Validates JWT tokens on protected endpoints and attaches + the auth context to Flask's g object. + """ + + def __init__( + self, auth_client: GopherAuthClient | None, config: AuthServerConfig + ) -> None: + """Create new OAuth middleware. + + Args: + auth_client: GopherAuthClient instance for token validation + (None if auth disabled). + config: Server configuration. + """ + self.auth_client = auth_client + self.config = config + self._current_auth_context: GopherAuthContext = gopher_create_empty_auth_context() + + def before_request(self) -> Response | None: + """Flask before_request handler. + + Returns: + Response if auth fails, None to continue processing. + """ + # Handle CORS preflight + if request.method == "OPTIONS": + return self._send_cors_preflight_response() + + path = request.path + + # Check if path requires authentication + if not self.requires_auth(path): + return None + + # Extract bearer token + token = self.extract_token() + if not token: + return self._send_unauthorized("invalid_request", "Missing bearer token") + + # Validate token + valid, error_message = self._validate_token(token) + if not valid: + return self._send_unauthorized( + "invalid_token", error_message or "Token validation failed" + ) + + # Attach auth context to Flask g object + g.auth_context = self._current_auth_context + return None + + def extract_token(self) -> str | None: + """Extract bearer token from request. + + Checks Authorization header first, then query parameter. + + Returns: + Token string or None if not found. + """ + # Try Authorization header first + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + return auth_header[7:] + + # Try query parameter + query_token = request.args.get("access_token", "") + if query_token: + return query_token + + return None + + def requires_auth(self, path: str) -> bool: + """Check if path requires authentication. + + Args: + path: Request path. + + Returns: + True if authentication is required. + """ + # Auth is globally disabled + if self.config.auth_disabled: + return False + + # No auth client available + if self.auth_client is None: + return False + + # Public paths (no auth required) + if path.startswith("/.well-known/"): + return False + if path.startswith("/oauth/"): + return False + if path == "/authorize": + return False + if path == "/health": + return False + if path == "/favicon.ico": + return False + + # Protected paths (auth required) + if path == "/mcp" or path.startswith("/mcp/"): + return True + if path == "/rpc" or path.startswith("/rpc/"): + return True + if path == "/events" or path.startswith("/events/"): + return True + if path == "/sse" or path.startswith("/sse/"): + return True + + # Default: require auth for unknown paths + return True + + def _validate_token(self, token: str) -> tuple[bool, str | None]: + """Validate JWT token using gopher-auth. + + Args: + token: JWT token string. + + Returns: + Tuple of (valid, error_message). + """ + if self.auth_client is None: + return False, "Auth client not initialized" + + with GopherValidationOptions() as options: + options.set_clock_skew(30) + + result = self.auth_client.validate_token(token, options) + + if not result.valid: + return False, result.error_message + + # Extract payload to populate auth context + try: + payload = self.auth_client.extract_payload(token) + + self._current_auth_context = GopherAuthContext( + user_id=payload.subject, + scopes=payload.scopes, + audience=payload.audience or "", + token_expiry=payload.expiration or 0, + authenticated=True, + ) + except Exception: + # Payload extraction failed, but token is valid + self._current_auth_context = GopherAuthContext( + user_id="", + scopes="", + audience="", + token_expiry=0, + authenticated=True, + ) + + return True, None + + def _send_unauthorized(self, error: str, description: str) -> Response: + """Send 401 Unauthorized response with WWW-Authenticate header. + + Args: + error: OAuth error code. + description: Human-readable error description. + + Returns: + Flask Response object. + """ + try: + www_authenticate = gopher_generate_www_authenticate_header_v2( + realm=self.config.server_url, + resource_metadata_url=( + f"{self.config.server_url}/.well-known/oauth-protected-resource" + ), + scopes=self.config.allowed_scopes.split(), + error=error, + error_description=description, + ) + except Exception: + # Fallback to basic Bearer header + www_authenticate = ( + f'Bearer realm="{self.config.server_url}", ' + f'error="{error}", ' + f'error_description="{description}"' + ) + + response = jsonify({"error": error, "error_description": description}) + response.status_code = 401 + response.headers["WWW-Authenticate"] = www_authenticate + response.headers["Content-Type"] = "application/json" + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Expose-Headers"] = "WWW-Authenticate" + return response + + def _send_cors_preflight_response(self) -> Response: + """Send CORS preflight response. + + Returns: + Flask Response object. + """ + response = Response("", status=204) + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD" + ) + response.headers["Access-Control-Allow-Headers"] = ( + "Accept, Accept-Language, Content-Language, Content-Type, Authorization, " + "X-Requested-With, Origin, Cache-Control, Pragma, Mcp-Session-Id, " + "Mcp-Protocol-Version" + ) + response.headers["Access-Control-Max-Age"] = "86400" + response.headers["Access-Control-Expose-Headers"] = ( + "WWW-Authenticate, Content-Length, Content-Type" + ) + return response + + def get_auth_context(self) -> GopherAuthContext: + """Get the current authentication context.""" + return self._current_auth_context + + def is_auth_disabled(self) -> bool: + """Check if authentication is disabled.""" + return self.config.auth_disabled + + def has_scope(self, scope: str) -> bool: + """Check if a scope is present in the current auth context. + + Args: + scope: Scope to check. + + Returns: + True if scope is present. + """ + if not self._current_auth_context.authenticated: + return False + return scope in self._current_auth_context.scopes.split() + + +def require_auth(f: Callable[..., Any]) -> Callable[..., Any]: + """Flask decorator to require authentication on a route. + + Usage: + @app.route('/protected') + @require_auth + def protected(): + auth_context = g.auth_context + return f"Hello, {auth_context.user_id}" + + Args: + f: The route function to wrap. + + Returns: + Wrapped function that checks for auth context. + """ + + @wraps(f) + def decorated_function(*args: Any, **kwargs: Any) -> Any: + auth_context = getattr(g, "auth_context", None) + if auth_context is None or not auth_context.authenticated: + response = jsonify( + { + "error": "unauthorized", + "error_description": "Authentication required", + } + ) + response.status_code = 401 + return response + return f(*args, **kwargs) + + return decorated_function diff --git a/examples/auth/py_auth_mcp_server/routes/__init__.py b/examples/auth/py_auth_mcp_server/routes/__init__.py new file mode 100644 index 00000000..4e695afd --- /dev/null +++ b/examples/auth/py_auth_mcp_server/routes/__init__.py @@ -0,0 +1,29 @@ +"""Route handlers for the MCP server.""" + +from .health import register_health_routes +from .mcp_handler import ( + JsonRpcError, + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + McpHandler, + ToolContentItem, + ToolResult, + ToolSpec, + register_mcp_routes, +) +from .oauth_endpoints import register_oauth_routes + +__all__ = [ + "register_health_routes", + "register_oauth_routes", + "register_mcp_routes", + "McpHandler", + "JsonRpcError", + "JsonRpcErrorCode", + "JsonRpcRequest", + "JsonRpcResponse", + "ToolSpec", + "ToolContentItem", + "ToolResult", +] diff --git a/examples/auth/py_auth_mcp_server/routes/health.py b/examples/auth/py_auth_mcp_server/routes/health.py new file mode 100644 index 00000000..2180987a --- /dev/null +++ b/examples/auth/py_auth_mcp_server/routes/health.py @@ -0,0 +1,75 @@ +"""Health check endpoint. + +Provides a simple health check endpoint for monitoring +and load balancer health checks. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +from flask import Blueprint, Flask, jsonify + +# Module-level start time +_start_time: float = time.time() + + +@dataclass +class HealthResponse: + """Health check response structure.""" + + status: str # 'ok' or 'error' + timestamp: str + version: str | None = None + uptime: int | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: dict[str, Any] = { + "status": self.status, + "timestamp": self.timestamp, + } + if self.version is not None: + result["version"] = self.version + if self.uptime is not None: + result["uptime"] = self.uptime + return result + + +def create_health_blueprint(version: str | None = None) -> Blueprint: + """Create health check blueprint. + + Args: + version: Optional version string to include in response. + + Returns: + Flask Blueprint with health endpoint. + """ + bp = Blueprint("health", __name__) + + @bp.route("/health", methods=["GET"]) + def health() -> Any: + """Health check endpoint.""" + response = HealthResponse( + status="ok", + timestamp=datetime.now(timezone.utc).isoformat(), + version=version, + uptime=int(time.time() - _start_time), + ) + return jsonify(response.to_dict()), 200 + + return bp + + +def register_health_routes(app: Flask, version: str | None = None) -> None: + """Register the health check endpoint. + + Args: + app: Flask application instance. + version: Optional version string to include in response. + """ + bp = create_health_blueprint(version) + app.register_blueprint(bp) diff --git a/examples/auth/py_auth_mcp_server/routes/mcp_handler.py b/examples/auth/py_auth_mcp_server/routes/mcp_handler.py new file mode 100644 index 00000000..841bd34d --- /dev/null +++ b/examples/auth/py_auth_mcp_server/routes/mcp_handler.py @@ -0,0 +1,416 @@ +"""MCP Handler - JSON-RPC 2.0 Implementation. + +Implements the Model Context Protocol (MCP) JSON-RPC handler +for tool registration and invocation. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, TypedDict + +from flask import Blueprint, Flask, Response, jsonify, request + + +class JsonRpcErrorCode: + """Standard JSON-RPC error codes.""" + + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 + + +class JsonRpcError(Exception): + """JSON-RPC 2.0 Error structure. + + Inherits from Exception to allow raising as an exception. + """ + + def __init__(self, code: int, message: str, data: Any = None) -> None: + """Initialize JSON-RPC error. + + Args: + code: JSON-RPC error code. + message: Error message. + data: Optional additional error data. + """ + super().__init__(message) + self.code = code + self.message = message + self.data = data + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + result: dict[str, Any] = {"code": self.code, "message": self.message} + if self.data is not None: + result["data"] = self.data + return result + + +@dataclass +class JsonRpcRequest: + """JSON-RPC 2.0 Request structure.""" + + jsonrpc: str + method: str + id: str | int | None = None + params: dict[str, Any] | None = None + + +@dataclass +class JsonRpcResponse: + """JSON-RPC 2.0 Response structure.""" + + jsonrpc: str + id: str | int | None + result: Any = None + error: JsonRpcError | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + response: dict[str, Any] = {"jsonrpc": self.jsonrpc, "id": self.id} + if self.error is not None: + response["error"] = self.error.to_dict() + else: + response["result"] = self.result + return response + + +class ToolInputProperty(TypedDict, total=False): + """Tool input property schema.""" + + type: str + description: str + enum: list[str] + + +class ToolInputSchema(TypedDict, total=False): + """Tool input schema.""" + + type: str + properties: dict[str, ToolInputProperty] + required: list[str] + + +@dataclass +class ToolSpec: + """Tool specification for MCP.""" + + name: str + description: str + input_schema: ToolInputSchema + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "name": self.name, + "description": self.description, + "inputSchema": self.input_schema, + } + + +@dataclass +class ToolContentItem: + """Tool result content item.""" + + type: str # 'text', 'image', or 'resource' + text: str | None = None + data: str | None = None + mime_type: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + result: dict[str, Any] = {"type": self.type} + if self.text is not None: + result["text"] = self.text + if self.data is not None: + result["data"] = self.data + if self.mime_type is not None: + result["mimeType"] = self.mime_type + return result + + +@dataclass +class ToolResult: + """Tool execution result.""" + + content: list[ToolContentItem] = field(default_factory=list) + is_error: bool = False + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + result: dict[str, Any] = {"content": [c.to_dict() for c in self.content]} + if self.is_error: + result["isError"] = self.is_error + return result + + +# Type alias for tool handler function +ToolHandler = Callable[[dict[str, Any]], ToolResult] + + +@dataclass +class RegisteredTool: + """Registered tool with spec and handler.""" + + spec: ToolSpec + handler: ToolHandler + + +def _set_cors_headers(response: Response) -> Response: + """Set common CORS headers on response.""" + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD" + ) + response.headers["Access-Control-Allow-Headers"] = ( + "Accept, Accept-Language, Content-Language, Content-Type, Authorization, " + "X-Requested-With, Origin, Cache-Control, Pragma, Mcp-Session-Id, " + "Mcp-Protocol-Version" + ) + response.headers["Access-Control-Expose-Headers"] = ( + "WWW-Authenticate, Content-Length, Content-Type" + ) + response.headers["Access-Control-Max-Age"] = "86400" + return response + + +class McpHandler: + """MCP Handler class. + + Manages tool registration and JSON-RPC request handling. + """ + + def __init__(self) -> None: + """Initialize the MCP handler.""" + self._tools: dict[str, RegisteredTool] = {} + self._server_info = { + "name": "py-auth-mcp-server", + "version": "1.0.0", + } + + def register_tool( + self, + name: str, + description: str, + input_schema: ToolInputSchema, + handler: ToolHandler, + ) -> None: + """Register a tool with the handler. + + Args: + name: Unique tool name. + description: Tool description. + input_schema: JSON schema for tool input. + handler: Function to execute when tool is called. + """ + spec = ToolSpec(name=name, description=description, input_schema=input_schema) + self._tools[name] = RegisteredTool(spec=spec, handler=handler) + + def tool( + self, name: str, description: str, input_schema: ToolInputSchema + ) -> Callable[[ToolHandler], ToolHandler]: + """Decorator to register a tool. + + Args: + name: Unique tool name. + description: Tool description. + input_schema: JSON schema for tool input. + + Returns: + Decorator function. + """ + + def decorator(handler: ToolHandler) -> ToolHandler: + self.register_tool(name, description, input_schema, handler) + return handler + + return decorator + + def get_tools(self) -> list[ToolSpec]: + """Get list of registered tools.""" + return [t.spec for t in self._tools.values()] + + def handle_request(self, body: Any) -> JsonRpcResponse: + """Handle a JSON-RPC request. + + Args: + body: Request body. + + Returns: + JSON-RPC response. + """ + # Parse and validate request + parse_result = self._parse_request(body) + if isinstance(parse_result, JsonRpcError): + return JsonRpcResponse(jsonrpc="2.0", id=None, error=parse_result) + + rpc_request = parse_result + request_id = rpc_request.id + + try: + result = self._dispatch_method(rpc_request.method, rpc_request.params or {}) + return JsonRpcResponse(jsonrpc="2.0", id=request_id, result=result) + except JsonRpcError as e: + return JsonRpcResponse(jsonrpc="2.0", id=request_id, error=e) + except Exception as e: + error = JsonRpcError(code=JsonRpcErrorCode.INTERNAL_ERROR, message=str(e)) + return JsonRpcResponse(jsonrpc="2.0", id=request_id, error=error) + + def _parse_request(self, body: Any) -> JsonRpcRequest | JsonRpcError: + """Parse and validate a JSON-RPC request.""" + if not body or not isinstance(body, dict): + return JsonRpcError( + code=JsonRpcErrorCode.INVALID_REQUEST, + message="Invalid request: expected object", + ) + + if body.get("jsonrpc") != "2.0": + return JsonRpcError( + code=JsonRpcErrorCode.INVALID_REQUEST, + message='Invalid request: jsonrpc must be "2.0"', + ) + + method = body.get("method") + if not isinstance(method, str): + return JsonRpcError( + code=JsonRpcErrorCode.INVALID_REQUEST, + message="Invalid request: method must be a string", + ) + + params = body.get("params") + if params is not None and not isinstance(params, dict): + return JsonRpcError( + code=JsonRpcErrorCode.INVALID_PARAMS, + message="Invalid params: must be an object", + ) + + return JsonRpcRequest( + jsonrpc="2.0", + id=body.get("id"), + method=method, + params=params, + ) + + def _dispatch_method(self, method: str, params: dict[str, Any]) -> Any: + """Dispatch a method call to the appropriate handler.""" + if method == "initialize": + return self._handle_initialize(params) + elif method == "tools/list": + return self._handle_tools_list() + elif method == "tools/call": + return self._handle_tools_call(params) + elif method == "ping": + return {} + else: + raise JsonRpcError( + code=JsonRpcErrorCode.METHOD_NOT_FOUND, + message=f"Method not found: {method}", + ) + + def _handle_initialize(self, params: dict[str, Any]) -> dict[str, Any]: + """Handle initialize method.""" + return { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "serverInfo": self._server_info, + } + + def _handle_tools_list(self) -> dict[str, Any]: + """Handle tools/list method.""" + return {"tools": [t.to_dict() for t in self.get_tools()]} + + def _handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]: + """Handle tools/call method.""" + name = params.get("name") + args = params.get("arguments", {}) + + if not isinstance(name, str): + raise JsonRpcError( + code=JsonRpcErrorCode.INVALID_PARAMS, + message="Invalid params: name must be a string", + ) + + tool = self._tools.get(name) + if not tool: + raise JsonRpcError( + code=JsonRpcErrorCode.METHOD_NOT_FOUND, + message=f"Tool not found: {name}", + ) + + result = tool.handler(args) + return result.to_dict() + + +def create_mcp_blueprint(handler: McpHandler) -> Blueprint: + """Create MCP handler blueprint. + + Args: + handler: McpHandler instance. + + Returns: + Flask Blueprint with MCP endpoints. + """ + bp = Blueprint("mcp", __name__) + + @bp.route("/mcp", methods=["OPTIONS"]) + def cors_mcp() -> Response: + response = Response("", status=204) + response.headers["Content-Length"] = "0" + return _set_cors_headers(response) + + @bp.route("/rpc", methods=["OPTIONS"]) + def cors_rpc() -> Response: + response = Response("", status=204) + response.headers["Content-Length"] = "0" + return _set_cors_headers(response) + + @bp.route("/mcp", methods=["POST"]) + def handle_mcp() -> tuple[Response, int]: + body = request.get_json(silent=True) + if body is None: + response = JsonRpcResponse( + jsonrpc="2.0", + id=None, + error=JsonRpcError( + code=JsonRpcErrorCode.PARSE_ERROR, + message="Parse error: invalid JSON", + ), + ) + else: + response = handler.handle_request(body) + json_response = jsonify(response.to_dict()) + return _set_cors_headers(json_response), 200 + + @bp.route("/rpc", methods=["POST"]) + def handle_rpc() -> tuple[Response, int]: + body = request.get_json(silent=True) + if body is None: + response = JsonRpcResponse( + jsonrpc="2.0", + id=None, + error=JsonRpcError( + code=JsonRpcErrorCode.PARSE_ERROR, + message="Parse error: invalid JSON", + ), + ) + else: + response = handler.handle_request(body) + json_response = jsonify(response.to_dict()) + return _set_cors_headers(json_response), 200 + + return bp + + +def register_mcp_routes(app: Flask, handler: McpHandler) -> None: + """Register MCP handler with Flask app. + + Args: + app: Flask application. + handler: McpHandler instance. + """ + bp = create_mcp_blueprint(handler) + app.register_blueprint(bp) diff --git a/examples/auth/py_auth_mcp_server/routes/oauth_endpoints.py b/examples/auth/py_auth_mcp_server/routes/oauth_endpoints.py new file mode 100644 index 00000000..49cc78a5 --- /dev/null +++ b/examples/auth/py_auth_mcp_server/routes/oauth_endpoints.py @@ -0,0 +1,283 @@ +"""OAuth Discovery Endpoints. + +Implements OAuth 2.0 discovery endpoints: +- /.well-known/oauth-protected-resource (RFC 9728) +- /.well-known/oauth-authorization-server (RFC 8414) +- /.well-known/openid-configuration +- /oauth/authorize (redirect to IdP) +- /oauth/register (RFC 7591 dynamic registration) +""" + +from __future__ import annotations + +import math +import time +from typing import Any +from urllib.parse import urlencode, urlparse, urlunparse + +from flask import Blueprint, Flask, Response, jsonify, redirect, request + +from ..config import AuthServerConfig + + +def _set_cors_headers(response: Response) -> Response: + """Set common CORS headers on response. + + Args: + response: Flask response object. + + Returns: + Response with CORS headers set. + """ + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = ( + "Authorization, Content-Type, Accept, Origin, X-Requested-With" + ) + response.headers["Access-Control-Expose-Headers"] = ( + "WWW-Authenticate, Content-Length" + ) + response.headers["Access-Control-Max-Age"] = "86400" + return response + + +def _cors_preflight() -> Response: + """Create CORS preflight response.""" + response = Response("", status=204) + response.headers["Content-Length"] = "0" + return _set_cors_headers(response) + + +def _build_protected_resource_metadata(config: AuthServerConfig) -> dict[str, Any]: + """Build protected resource metadata (RFC 9728). + + Args: + config: Server configuration. + + Returns: + Protected resource metadata dictionary. + """ + scopes = [s for s in config.allowed_scopes.split() if s] + return { + "resource": f"{config.server_url}/mcp", + "authorization_servers": [config.server_url], + "scopes_supported": scopes, + "bearer_methods_supported": ["header", "query"], + "resource_documentation": f"{config.server_url}/docs", + } + + +def create_oauth_blueprint(config: AuthServerConfig) -> Blueprint: + """Create OAuth discovery blueprint. + + Args: + config: Server configuration. + + Returns: + Flask Blueprint with OAuth endpoints. + """ + bp = Blueprint("oauth", __name__) + + # OPTIONS handlers for CORS preflight + @bp.route("/.well-known/oauth-protected-resource", methods=["OPTIONS"]) + def cors_protected_resource() -> Response: + return _cors_preflight() + + @bp.route("/.well-known/oauth-protected-resource/mcp", methods=["OPTIONS"]) + def cors_protected_resource_mcp() -> Response: + return _cors_preflight() + + @bp.route("/.well-known/oauth-authorization-server", methods=["OPTIONS"]) + def cors_authorization_server() -> Response: + return _cors_preflight() + + @bp.route("/.well-known/openid-configuration", methods=["OPTIONS"]) + def cors_openid_configuration() -> Response: + return _cors_preflight() + + @bp.route("/oauth/authorize", methods=["OPTIONS"]) + def cors_authorize() -> Response: + return _cors_preflight() + + @bp.route("/oauth/register", methods=["OPTIONS"]) + def cors_register() -> Response: + return _cors_preflight() + + # RFC 9728 - Protected Resource Metadata (root) + @bp.route("/.well-known/oauth-protected-resource", methods=["GET"]) + def protected_resource() -> tuple[Response, int]: + metadata = _build_protected_resource_metadata(config) + response = jsonify(metadata) + return _set_cors_headers(response), 200 + + # RFC 9728: Resource-specific discovery URL for /mcp endpoint + @bp.route("/.well-known/oauth-protected-resource/mcp", methods=["GET"]) + def protected_resource_mcp() -> tuple[Response, int]: + metadata = _build_protected_resource_metadata(config) + response = jsonify(metadata) + return _set_cors_headers(response), 200 + + # RFC 8414 - OAuth Authorization Server Metadata + @bp.route("/.well-known/oauth-authorization-server", methods=["GET"]) + def authorization_server() -> tuple[Response, int]: + auth_endpoint = ( + config.oauth_authorize_url + or f"{config.auth_server_url}/protocol/openid-connect/auth" + ) + token_endpoint = ( + config.oauth_token_url + or f"{config.auth_server_url}/protocol/openid-connect/token" + ) + + scopes = [s for s in config.allowed_scopes.split() if s] + + metadata: dict[str, Any] = { + "issuer": config.issuer or config.server_url, + "authorization_endpoint": auth_endpoint, + "token_endpoint": token_endpoint, + "registration_endpoint": f"{config.server_url}/oauth/register", + "scopes_supported": scopes, + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "client_secret_post", + "none", + ], + "code_challenge_methods_supported": ["S256"], + } + + if config.jwks_uri: + metadata["jwks_uri"] = config.jwks_uri + + response = jsonify(metadata) + return _set_cors_headers(response), 200 + + # OpenID Connect Discovery + @bp.route("/.well-known/openid-configuration", methods=["GET"]) + def openid_configuration() -> tuple[Response, int]: + auth_endpoint = ( + config.oauth_authorize_url + or f"{config.auth_server_url}/protocol/openid-connect/auth" + ) + token_endpoint = ( + config.oauth_token_url + or f"{config.auth_server_url}/protocol/openid-connect/token" + ) + + base_scopes = ["openid", "profile", "email"] + custom_scopes = [s for s in config.allowed_scopes.split() if s] + all_scopes = list(dict.fromkeys(base_scopes + custom_scopes)) + + metadata: dict[str, Any] = { + "issuer": config.issuer or config.server_url, + "authorization_endpoint": auth_endpoint, + "token_endpoint": token_endpoint, + "scopes_supported": all_scopes, + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "client_secret_post", + ], + "code_challenge_methods_supported": ["S256"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + } + + if config.jwks_uri: + metadata["jwks_uri"] = config.jwks_uri + + if config.auth_server_url: + metadata["userinfo_endpoint"] = ( + f"{config.auth_server_url}/protocol/openid-connect/userinfo" + ) + + response = jsonify(metadata) + return _set_cors_headers(response), 200 + + # OAuth Authorization redirect + @bp.route("/oauth/authorize", methods=["GET"]) + def authorize() -> Response | tuple[Response, int]: + auth_endpoint = ( + config.oauth_authorize_url + or f"{config.auth_server_url}/protocol/openid-connect/auth" + ) + + try: + # Parse the base authorization endpoint + parsed = urlparse(auth_endpoint) + + # Build query parameters from request + params: dict[str, str] = {} + for key, value in request.args.items(): + if isinstance(value, str): + params[key] = value + + # Construct new URL with parameters + new_query = urlencode(params) + auth_url = urlunparse( + ( + parsed.scheme, + parsed.netloc, + parsed.path, + parsed.params, + new_query, + parsed.fragment, + ) + ) + + response = redirect(auth_url, code=302) + return _set_cors_headers(response) + except Exception: + response = jsonify( + { + "error": "server_error", + "error_description": "Failed to construct authorization URL", + } + ) + return _set_cors_headers(response), 500 + + # POST /oauth/register - Dynamic Client Registration (RFC 7591) + @bp.route("/oauth/register", methods=["POST"]) + def register() -> tuple[Response, int]: + body = request.get_json(silent=True) or {} + + # Extract redirect_uris from request + redirect_uris: list[str] = [] + if isinstance(body.get("redirect_uris"), list): + redirect_uris = [ + uri for uri in body["redirect_uris"] if isinstance(uri, str) + ] + + # Return pre-configured credentials (stateless mode) + registration: dict[str, Any] = { + "client_id": config.client_id, + "client_id_issued_at": math.floor(time.time()), + "client_secret_expires_at": 0, # Never expires + "redirect_uris": redirect_uris, + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": ( + "client_secret_post" if config.client_secret else "none" + ), + } + + if config.client_secret: + registration["client_secret"] = config.client_secret + + response = jsonify(registration) + return _set_cors_headers(response), 201 + + return bp + + +def register_oauth_routes(app: Flask, config: AuthServerConfig) -> None: + """Register OAuth discovery endpoints. + + Args: + app: Flask application instance. + config: Server configuration. + """ + bp = create_oauth_blueprint(config) + app.register_blueprint(bp) diff --git a/examples/auth/py_auth_mcp_server/tools/__init__.py b/examples/auth/py_auth_mcp_server/tools/__init__.py new file mode 100644 index 00000000..9528e03c --- /dev/null +++ b/examples/auth/py_auth_mcp_server/tools/__init__.py @@ -0,0 +1,23 @@ +"""MCP tools for the server.""" + +from .weather_tools import ( + access_denied, + get_condition_for_city, + get_simulated_alerts, + get_simulated_forecast, + get_simulated_weather, + get_temp_for_city, + has_scope, + register_weather_tools, +) + +__all__ = [ + "register_weather_tools", + "has_scope", + "access_denied", + "get_simulated_weather", + "get_simulated_forecast", + "get_simulated_alerts", + "get_condition_for_city", + "get_temp_for_city", +] diff --git a/examples/auth/py_auth_mcp_server/tools/weather_tools.py b/examples/auth/py_auth_mcp_server/tools/weather_tools.py new file mode 100644 index 00000000..1f89663c --- /dev/null +++ b/examples/auth/py_auth_mcp_server/tools/weather_tools.py @@ -0,0 +1,290 @@ +"""Weather Tools. + +Example MCP tools demonstrating OAuth scope-based access control. +Mirrors the weather tools from the C++ auth example. +""" + +from __future__ import annotations + +import json +from typing import Any + +from ..middleware.oauth_auth import OAuthAuthMiddleware +from ..routes.mcp_handler import McpHandler, ToolContentItem, ToolResult + + +def has_scope(scopes: str, required: str) -> bool: + """Check if a scope is present in a space-separated scope string. + + Args: + scopes: Space-separated scope string. + required: Required scope to check for. + + Returns: + True if the required scope is present. + """ + if not scopes or not required: + return False + return required in scopes.split() + + +# Weather conditions for simulation +CONDITIONS = [ + "Sunny", + "Cloudy", + "Rainy", + "Partly Cloudy", + "Windy", + "Stormy", +] + + +def _hash_string(s: str) -> int: + """Get a deterministic hash for a string.""" + return sum(ord(c) for c in s) + + +def get_condition_for_city(city: str, offset: int = 0) -> str: + """Get a deterministic but varying condition based on city name. + + Args: + city: City name. + offset: Day offset for variation. + + Returns: + Weather condition string. + """ + h = _hash_string(city) + return CONDITIONS[(h + offset) % len(CONDITIONS)] + + +def get_temp_for_city(city: str, offset: int = 0) -> int: + """Get a deterministic but varying temperature based on city name. + + Args: + city: City name. + offset: Day offset for variation. + + Returns: + Temperature in Celsius. + """ + h = _hash_string(city) + # Temperature between 10-35 Celsius + return 10 + ((h + offset * 7) % 26) + + +def get_simulated_weather(city: str) -> dict[str, Any]: + """Get simulated current weather for a city. + + Args: + city: City name. + + Returns: + Weather data dictionary. + """ + h = _hash_string(city) + + return { + "city": city, + "temperature": get_temp_for_city(city), + "condition": get_condition_for_city(city), + "humidity": 40 + (h % 40), # 40-80% + "windSpeed": 5 + (h % 25), # 5-30 km/h + } + + +def get_simulated_forecast(city: str) -> list[dict[str, Any]]: + """Get simulated 5-day forecast for a city. + + Args: + city: City name. + + Returns: + List of daily forecasts. + """ + days = ["Today", "Tomorrow", "Day 3", "Day 4", "Day 5"] + + return [ + { + "day": day, + "high": get_temp_for_city(city, index) + 5, + "low": get_temp_for_city(city, index) - 5, + "condition": get_condition_for_city(city, index), + } + for index, day in enumerate(days) + ] + + +def get_simulated_alerts(region: str) -> list[dict[str, Any]]: + """Get simulated weather alerts for a region. + + Args: + region: Region name. + + Returns: + List of weather alerts. + """ + h = _hash_string(region) + + # Return different alerts based on region + if h % 3 == 0: + return [ + { + "type": "Heat Warning", + "severity": "moderate", + "message": f"High temperatures expected in {region}. Stay hydrated.", + }, + ] + elif h % 3 == 1: + return [ + { + "type": "Storm Watch", + "severity": "high", + "message": ( + f"Severe thunderstorms possible in {region}. " + "Seek shelter if needed." + ), + }, + { + "type": "Wind Advisory", + "severity": "low", + "message": ( + f"Strong winds expected in {region}. Secure loose objects." + ), + }, + ] + else: + return [] # No alerts + + +def access_denied(scope: str) -> ToolResult: + """Create an access denied error result. + + Args: + scope: Required scope that was missing. + + Returns: + ToolResult with error. + """ + return ToolResult( + content=[ + ToolContentItem( + type="text", + text=json.dumps( + { + "error": "access_denied", + "message": f"Access denied. Required scope: {scope}", + } + ), + ), + ], + is_error=True, + ) + + +def register_weather_tools( + mcp: McpHandler, + auth_middleware: OAuthAuthMiddleware | None, +) -> None: + """Register weather tools with the MCP handler. + + Args: + mcp: MCP handler instance. + auth_middleware: OAuth auth middleware instance (None if auth disabled). + """ + + # get-weather - No authentication required + @mcp.tool( + name="get-weather", + description="Get current weather for a city. No authentication required.", + input_schema={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name to get weather for", + }, + }, + "required": ["city"], + }, + ) + def get_weather(args: dict[str, Any]) -> ToolResult: + city = str(args.get("city", "Unknown")) + weather = get_simulated_weather(city) + + return ToolResult( + content=[ + ToolContentItem( + type="text", + text=json.dumps(weather, indent=2), + ), + ], + ) + + # get-forecast - Requires mcp:read scope + @mcp.tool( + name="get-forecast", + description="Get 5-day weather forecast for a city. Requires mcp:read scope.", + input_schema={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name to get forecast for", + }, + }, + "required": ["city"], + }, + ) + def get_forecast(args: dict[str, Any]) -> ToolResult: + # Check scope if auth is enabled + if auth_middleware and not auth_middleware.is_auth_disabled(): + context = auth_middleware.get_auth_context() + if not has_scope(context.scopes, "mcp:read"): + return access_denied("mcp:read") + + city = str(args.get("city", "Unknown")) + forecast = get_simulated_forecast(city) + + return ToolResult( + content=[ + ToolContentItem( + type="text", + text=json.dumps({"city": city, "forecast": forecast}, indent=2), + ), + ], + ) + + # get-weather-alerts - Requires mcp:admin scope + @mcp.tool( + name="get-weather-alerts", + description="Get weather alerts for a region. Requires mcp:admin scope.", + input_schema={ + "type": "object", + "properties": { + "region": { + "type": "string", + "description": "Region name to get alerts for", + }, + }, + "required": ["region"], + }, + ) + def get_weather_alerts(args: dict[str, Any]) -> ToolResult: + # Check scope if auth is enabled + if auth_middleware and not auth_middleware.is_auth_disabled(): + context = auth_middleware.get_auth_context() + if not has_scope(context.scopes, "mcp:admin"): + return access_denied("mcp:admin") + + region = str(args.get("region", "Unknown")) + alerts = get_simulated_alerts(region) + + return ToolResult( + content=[ + ToolContentItem( + type="text", + text=json.dumps({"region": region, "alerts": alerts}, indent=2), + ), + ], + ) diff --git a/examples/auth/pyproject.toml b/examples/auth/pyproject.toml new file mode 100644 index 00000000..c8aa8270 --- /dev/null +++ b/examples/auth/pyproject.toml @@ -0,0 +1,107 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "py-auth-mcp-server" +version = "1.0.0" +description = "Python MCP server with OAuth authentication using gopher-orch auth via FFI" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "MIT"} +authors = [ + {name = "Gopher Security"} +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "flask>=3.0.0", + "flask-cors>=4.0.0", + "gopher-mcp-python", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4.0", + "pytest-cov>=4.1.0", + "black>=23.0.0", + "ruff>=0.1.0", + "mypy>=1.7.0", +] +# Platform-specific native library packages +darwin-arm64 = ["gopher-mcp-python-native-darwin-arm64"] +darwin-x64 = ["gopher-mcp-python-native-darwin-x64"] +linux-arm64 = ["gopher-mcp-python-native-linux-arm64"] +linux-x64 = ["gopher-mcp-python-native-linux-x64"] +win32-arm64 = ["gopher-mcp-python-native-win32-arm64"] +win32-x64 = ["gopher-mcp-python-native-win32-x64"] + +[project.scripts] +py-auth-mcp-server = "py_auth_mcp_server.__main__:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["py_auth_mcp_server*"] + +[tool.black] +line-length = 88 +target-version = ["py310", "py311", "py312"] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.mypy_cache + | \.ruff_cache + | \.venv + | dist + | build +)/ +''' + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by black) +] + +[tool.ruff.lint.isort] +known-first-party = ["py_auth_mcp_server", "gopher_mcp_python"] + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +strict_optional = true + +[[tool.mypy.overrides]] +module = ["flask.*", "flask_cors.*"] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" diff --git a/examples/auth/requirements.txt b/examples/auth/requirements.txt new file mode 100644 index 00000000..5652bac8 --- /dev/null +++ b/examples/auth/requirements.txt @@ -0,0 +1,25 @@ +# Production dependencies +flask>=3.0.0 +flask-cors>=4.0.0 +gopher-mcp-python + +# Platform-specific native libraries (install one based on your platform) +# macOS Apple Silicon: +# pip install gopher-mcp-python-native-darwin-arm64 +# macOS Intel: +# pip install gopher-mcp-python-native-darwin-x64 +# Linux ARM64: +# pip install gopher-mcp-python-native-linux-arm64 +# Linux x64: +# pip install gopher-mcp-python-native-linux-x64 +# Windows ARM64: +# pip install gopher-mcp-python-native-win32-arm64 +# Windows x64: +# pip install gopher-mcp-python-native-win32-x64 + +# Development dependencies +pytest>=7.4.0 +pytest-cov>=4.1.0 +black>=23.0.0 +ruff>=0.1.0 +mypy>=1.7.0 diff --git a/examples/auth/run_example.sh b/examples/auth/run_example.sh new file mode 100755 index 00000000..7df09ddc --- /dev/null +++ b/examples/auth/run_example.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +# Run the Auth MCP Server example +# Usage: +# ./run_example.sh # Run with auth enabled (requires OAuth config) +# ./run_example.sh --no-auth # Run without auth (development mode) +# ./run_example.sh --help # Show help + +set -e + +# Colors for output +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +NC='\033[0m' # No Color + +# Get the script directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${SCRIPT_DIR}" + +# Check for help flag +if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then + echo "Auth MCP Server Example" + echo "" + echo "Usage:" + echo " ./run_example.sh Run using server.config settings" + echo " ./run_example.sh --no-auth Override config to disable auth" + echo " ./run_example.sh --help Show this help" + echo "" + echo "Options:" + echo " --no-auth Disable OAuth authentication (overrides server.config)" + echo " --host HOST Bind to specific host (default: 0.0.0.0)" + echo " --port PORT Listen on specific port (default: 3001)" + echo "" + echo "Configuration:" + echo " Edit server.config to configure OAuth settings (auth_disabled=true/false)" + echo " See server.config.example for all available options" + echo "" + echo "Test endpoints:" + echo " curl http://localhost:3001/health" + echo " curl -X POST http://localhost:3001/mcp \\" + echo " -H 'Content-Type: application/json' \\" + echo " -d '{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"tools/list\",\"params\":{}}'" + exit 0 +fi + +# Check Python +if ! command -v python3 &> /dev/null; then + echo -e "${RED}Error: Python 3 not found${NC}" + exit 1 +fi + +# Check if dependencies are installed +if ! python3 -c "import flask" 2>/dev/null; then + echo -e "${YELLOW}Installing dependencies...${NC}" + pip3 install -r requirements.txt --quiet 2>/dev/null || pip3 install -r requirements.txt +fi + +# Check if gopher-mcp-python is installed +if ! python3 -c "import gopher_mcp_python" 2>/dev/null; then + echo -e "${YELLOW}Installing gopher-mcp-python...${NC}" + pip3 install gopher-mcp-python --quiet 2>/dev/null || pip3 install gopher-mcp-python +fi + +# Check if native library package is installed and install platform-specific one if not +if ! python3 -c "from gopher_mcp_python.ffi import get_library_path; get_library_path()" 2>/dev/null; then + echo -e "${YELLOW}Installing platform-specific native library...${NC}" + + # Detect platform and architecture + PLATFORM="$(uname -s | tr '[:upper:]' '[:lower:]')" + ARCH="$(uname -m)" + + case "${PLATFORM}" in + darwin) + if [ "${ARCH}" = "arm64" ]; then + pip3 install gopher-mcp-python-native-darwin-arm64 --quiet 2>/dev/null || \ + pip3 install gopher-mcp-python-native-darwin-arm64 + else + pip3 install gopher-mcp-python-native-darwin-x64 --quiet 2>/dev/null || \ + pip3 install gopher-mcp-python-native-darwin-x64 + fi + ;; + linux) + if [ "${ARCH}" = "aarch64" ] || [ "${ARCH}" = "arm64" ]; then + pip3 install gopher-mcp-python-native-linux-arm64 --quiet 2>/dev/null || \ + pip3 install gopher-mcp-python-native-linux-arm64 + else + pip3 install gopher-mcp-python-native-linux-x64 --quiet 2>/dev/null || \ + pip3 install gopher-mcp-python-native-linux-x64 + fi + ;; + msys*|mingw*|cygwin*) + if [ "${ARCH}" = "aarch64" ] || [ "${ARCH}" = "arm64" ]; then + pip3 install gopher-mcp-python-native-win32-arm64 --quiet 2>/dev/null || \ + pip3 install gopher-mcp-python-native-win32-arm64 + else + pip3 install gopher-mcp-python-native-win32-x64 --quiet 2>/dev/null || \ + pip3 install gopher-mcp-python-native-win32-x64 + fi + ;; + *) + echo -e "${RED}Unsupported platform: ${PLATFORM}${NC}" + echo "Please manually install the appropriate native library package." + exit 1 + ;; + esac +fi + +echo -e "${GREEN}Starting Auth MCP Server...${NC}" +echo -e "Configuration: ${YELLOW}server.config${NC}" +echo "" + +# Run server with config file and any additional arguments +exec python3 -m py_auth_mcp_server "${SCRIPT_DIR}/server.config" "$@" diff --git a/examples/auth/server.config b/examples/auth/server.config new file mode 100644 index 00000000..c2c814b0 --- /dev/null +++ b/examples/auth/server.config @@ -0,0 +1,33 @@ +# Auth MCP Server Configuration +# This file follows the same format as the C++ auth example + +# Server settings +host=0.0.0.0 +port=3001 +server_url=https://marni-nightcapped-nonmeditatively.ngrok-free.dev + +# OAuth/IDP settings +# Uncomment and configure for Keycloak or other OAuth provider +client_id=oauth_0a650b79c5a64c3b920ae8c2b20599d9 +client_secret=6BiU2beUi2wIBxY3MUBLyYqoWKa4t0U_kJVm9mvSOKw +auth_server_url=https://auth-test.gopher.security/realms/gopher-mcp-auth +oauth_authorize_url=https://api-test.gopher.security/oauth/authorize + +# Direct OAuth endpoint URLs (optional, derived from auth_server_url if not set) +# jwks_uri=https://keycloak.example.com/realms/mcp/protocol/openid-connect/certs +# issuer=https://keycloak.example.com/realms/mcp +# oauth_authorize_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/auth +# oauth_token_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/token + +# Scopes +exchange_idps=oauth-idp-714982830194556929-google +allowed_scopes=openid profile email scope-001 + +# Cache settings +jwks_cache_duration=3600 +jwks_auto_refresh=true +request_timeout=30 + +# Auth bypass mode (for development/testing) +# Set to true to disable authentication +auth_disabled=false diff --git a/examples/auth/server.config.example b/examples/auth/server.config.example new file mode 100644 index 00000000..483057cb --- /dev/null +++ b/examples/auth/server.config.example @@ -0,0 +1,75 @@ +# Auth MCP Server Configuration +# Copy this file to server.config and customize as needed + +# ============================================ +# Server Settings +# ============================================ + +# Host to bind to (default: 0.0.0.0) +host=0.0.0.0 + +# Port to listen on (default: 3001) +port=3001 + +# Server URL (used for OAuth metadata and redirects) +# This should be the externally accessible URL of this server +server_url=http://localhost:3001 + +# ============================================ +# OAuth/IDP Settings +# ============================================ + +# OAuth client ID (required when auth is enabled) +# client_id=your-client-id + +# OAuth client secret (required when auth is enabled) +# client_secret=your-client-secret + +# Base URL for the OAuth/OIDC server +# Many endpoints will be derived from this URL if not explicitly set +# auth_server_url=https://keycloak.example.com/realms/mcp + +# ============================================ +# Direct OAuth Endpoint URLs +# Optional - derived from auth_server_url if not set +# ============================================ + +# JWKS URI for fetching public keys +# jwks_uri=https://keycloak.example.com/realms/mcp/protocol/openid-connect/certs + +# Token issuer (usually same as auth_server_url) +# issuer=https://keycloak.example.com/realms/mcp + +# OAuth authorization endpoint +# oauth_authorize_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/auth + +# OAuth token endpoint +# oauth_token_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/token + +# ============================================ +# Scopes +# ============================================ + +# Space-separated list of allowed scopes +allowed_scopes=openid profile email mcp:read mcp:admin + +# ============================================ +# Cache Settings +# ============================================ + +# JWKS cache duration in seconds (default: 3600) +jwks_cache_duration=3600 + +# Enable automatic JWKS refresh (default: true) +jwks_auto_refresh=true + +# HTTP request timeout in seconds (default: 30) +request_timeout=30 + +# ============================================ +# Authentication Mode +# ============================================ + +# Set to true to disable authentication (for development/testing) +# When disabled, all endpoints are accessible without a token +auth_disabled=true diff --git a/examples/auth/tests/__init__.py b/examples/auth/tests/__init__.py new file mode 100644 index 00000000..4eb485ea --- /dev/null +++ b/examples/auth/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the py_auth_mcp_server.""" diff --git a/examples/auth/tests/conftest.py b/examples/auth/tests/conftest.py new file mode 100644 index 00000000..60c12a23 --- /dev/null +++ b/examples/auth/tests/conftest.py @@ -0,0 +1,170 @@ +"""Pytest fixtures for py_auth_mcp_server tests.""" + +from __future__ import annotations + +import pytest +from flask import Flask +from flask.testing import FlaskClient + +from gopher_mcp_python.ffi.auth import GopherAuthContext +from py_auth_mcp_server import AuthServerConfig, create_app, create_default_config +from py_auth_mcp_server.routes.mcp_handler import McpHandler + + +@pytest.fixture +def sample_config() -> AuthServerConfig: + """Create a sample configuration for testing. + + Returns: + AuthServerConfig with auth_disabled=True. + """ + return create_default_config( + host="127.0.0.1", + port=5000, + server_url="http://localhost:5000", + auth_disabled=True, + allowed_scopes="openid profile email mcp:read mcp:admin", + ) + + +@pytest.fixture +def app(sample_config: AuthServerConfig) -> Flask: + """Create a Flask application for testing. + + Args: + sample_config: Test configuration. + + Returns: + Configured Flask application. + """ + application = create_app(config=sample_config) + application.config["TESTING"] = True + return application + + +@pytest.fixture +def client(app: Flask) -> FlaskClient: + """Create a test client for the Flask application. + + Args: + app: Flask application. + + Returns: + Flask test client. + """ + return app.test_client() + + +@pytest.fixture +def mcp_handler() -> McpHandler: + """Create a fresh McpHandler instance for testing. + + Returns: + New McpHandler instance. + """ + return McpHandler() + + +@pytest.fixture +def mock_auth_context() -> GopherAuthContext: + """Create a mock GopherAuthContext with test data. + + Returns: + GopherAuthContext with sample values. + """ + return GopherAuthContext( + user_id="test-user-123", + scopes="openid profile email mcp:read", + audience="test-audience", + token_expiry=9999999999, + authenticated=True, + ) + + +@pytest.fixture +def mock_admin_auth_context() -> GopherAuthContext: + """Create a mock GopherAuthContext with admin scopes. + + Returns: + GopherAuthContext with mcp:admin scope. + """ + return GopherAuthContext( + user_id="admin-user-456", + scopes="openid profile email mcp:read mcp:admin", + audience="test-audience", + token_expiry=9999999999, + authenticated=True, + ) + + +@pytest.fixture +def mock_unauthenticated_context() -> GopherAuthContext: + """Create a mock unauthenticated GopherAuthContext. + + Returns: + GopherAuthContext with authenticated=False. + """ + return GopherAuthContext( + user_id="", + scopes="", + audience="", + token_expiry=0, + authenticated=False, + ) + + +@pytest.fixture +def json_rpc_initialize_request() -> dict: + """Create a JSON-RPC initialize request. + + Returns: + JSON-RPC request for initialize method. + """ + return { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + +@pytest.fixture +def json_rpc_tools_list_request() -> dict: + """Create a JSON-RPC tools/list request. + + Returns: + JSON-RPC request for tools/list method. + """ + return { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {}, + } + + +@pytest.fixture +def json_rpc_tools_call_request() -> dict: + """Create a JSON-RPC tools/call request for get-weather. + + Returns: + JSON-RPC request for tools/call method. + """ + return { + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get-weather", + "arguments": { + "city": "London", + }, + }, + } diff --git a/examples/auth/tests/test_config.py b/examples/auth/tests/test_config.py new file mode 100644 index 00000000..f7061a6e --- /dev/null +++ b/examples/auth/tests/test_config.py @@ -0,0 +1,298 @@ +"""Tests for configuration module.""" + +import tempfile + +import pytest + +from py_auth_mcp_server.config import ( + AuthServerConfig, + build_config, + create_default_config, + load_config_from_file, + parse_config_file, + validate_required_fields, +) + + +class TestParseConfigFile: + """Tests for parse_config_file function.""" + + def test_parses_key_value_pairs(self): + """Test parsing simple key=value pairs.""" + content = "host=localhost\nport=3001" + result = parse_config_file(content) + assert result == {"host": "localhost", "port": "3001"} + + def test_ignores_comments(self): + """Test that comments are ignored.""" + content = "# This is a comment\nhost=localhost\n# Another comment" + result = parse_config_file(content) + assert result == {"host": "localhost"} + + def test_ignores_empty_lines(self): + """Test that empty lines are ignored.""" + content = "host=localhost\n\n\nport=3001\n" + result = parse_config_file(content) + assert result == {"host": "localhost", "port": "3001"} + + def test_ignores_whitespace_only_lines(self): + """Test that whitespace-only lines are ignored.""" + content = "host=localhost\n \n\t\nport=3001" + result = parse_config_file(content) + assert result == {"host": "localhost", "port": "3001"} + + def test_trims_whitespace(self): + """Test that whitespace is trimmed from keys and values.""" + content = " host = localhost \n port = 3001 " + result = parse_config_file(content) + assert result == {"host": "localhost", "port": "3001"} + + def test_ignores_lines_without_equals(self): + """Test that lines without '=' are ignored.""" + content = "host=localhost\ninvalid line\nport=3001" + result = parse_config_file(content) + assert result == {"host": "localhost", "port": "3001"} + + def test_handles_values_with_equals(self): + """Test values containing '=' are preserved.""" + content = "url=http://example.com?foo=bar" + result = parse_config_file(content) + assert result == {"url": "http://example.com?foo=bar"} + + def test_empty_content(self): + """Test parsing empty content.""" + result = parse_config_file("") + assert result == {} + + def test_only_comments(self): + """Test parsing content with only comments.""" + content = "# comment 1\n# comment 2" + result = parse_config_file(content) + assert result == {} + + +class TestBuildConfig: + """Tests for build_config function.""" + + def test_default_values(self): + """Test default values are applied.""" + config = build_config({"auth_disabled": "true"}) + assert config.host == "0.0.0.0" + assert config.port == 3001 + assert config.jwks_cache_duration == 3600 + assert config.jwks_auto_refresh is True + assert config.request_timeout == 30 + + def test_custom_values(self): + """Test custom values override defaults.""" + config = build_config( + { + "host": "127.0.0.1", + "port": "8080", + "auth_disabled": "true", + } + ) + assert config.host == "127.0.0.1" + assert config.port == 8080 + + def test_port_type_conversion(self): + """Test port is converted to integer.""" + config = build_config({"port": "9000", "auth_disabled": "true"}) + assert config.port == 9000 + assert isinstance(config.port, int) + + def test_boolean_type_conversion(self): + """Test boolean values are converted correctly.""" + config = build_config( + { + "auth_disabled": "true", + "jwks_auto_refresh": "false", + } + ) + assert config.auth_disabled is True + assert config.jwks_auto_refresh is False + + def test_server_url_default(self): + """Test server_url defaults to localhost with port.""" + config = build_config({"port": "5000", "auth_disabled": "true"}) + assert config.server_url == "http://localhost:5000" + + def test_server_url_custom(self): + """Test custom server_url.""" + config = build_config( + { + "server_url": "https://api.example.com", + "auth_disabled": "true", + } + ) + assert config.server_url == "https://api.example.com" + + +class TestEndpointDerivation: + """Tests for endpoint derivation from auth_server_url.""" + + def test_derives_jwks_uri(self): + """Test jwks_uri is derived from auth_server_url.""" + config = build_config( + { + "auth_server_url": "https://auth.example.com/realms/test", + "client_id": "test-client", + "client_secret": "secret", + } + ) + assert config.jwks_uri == ( + "https://auth.example.com/realms/test/protocol/openid-connect/certs" + ) + + def test_derives_token_endpoint(self): + """Test token_endpoint is derived from auth_server_url.""" + config = build_config( + { + "auth_server_url": "https://auth.example.com/realms/test", + "client_id": "test-client", + "client_secret": "secret", + } + ) + assert config.token_endpoint == ( + "https://auth.example.com/realms/test/protocol/openid-connect/token" + ) + + def test_derives_issuer(self): + """Test issuer is derived from auth_server_url.""" + config = build_config( + { + "auth_server_url": "https://auth.example.com/realms/test", + "client_id": "test-client", + "client_secret": "secret", + } + ) + assert config.issuer == "https://auth.example.com/realms/test" + + def test_explicit_values_not_overwritten(self): + """Test explicit values are not overwritten by derivation.""" + config = build_config( + { + "auth_server_url": "https://auth.example.com/realms/test", + "jwks_uri": "https://custom.example.com/jwks", + "issuer": "https://custom-issuer.example.com", + "client_id": "test-client", + "client_secret": "secret", + } + ) + assert config.jwks_uri == "https://custom.example.com/jwks" + assert config.issuer == "https://custom-issuer.example.com" + + +class TestValidateRequiredFields: + """Tests for validate_required_fields function.""" + + def test_auth_disabled_skips_validation(self): + """Test validation is skipped when auth is disabled.""" + config = create_default_config(auth_disabled=True) + # Should not raise + validate_required_fields(config) + + def test_missing_client_id_raises(self): + """Test missing client_id raises ValueError.""" + config = create_default_config( + auth_disabled=False, + client_secret="secret", + jwks_uri="https://example.com/jwks", + ) + with pytest.raises(ValueError, match="client_id"): + validate_required_fields(config) + + def test_missing_client_secret_raises(self): + """Test missing client_secret raises ValueError.""" + config = create_default_config( + auth_disabled=False, + client_id="test-client", + jwks_uri="https://example.com/jwks", + ) + with pytest.raises(ValueError, match="client_secret"): + validate_required_fields(config) + + def test_missing_jwks_uri_and_auth_server_url_raises(self): + """Test missing jwks_uri and auth_server_url raises ValueError.""" + config = create_default_config( + auth_disabled=False, + client_id="test-client", + client_secret="secret", + ) + with pytest.raises(ValueError, match="jwks_uri"): + validate_required_fields(config) + + def test_valid_config_passes(self): + """Test valid configuration passes validation.""" + config = create_default_config( + auth_disabled=False, + client_id="test-client", + client_secret="secret", + jwks_uri="https://example.com/jwks", + ) + # Should not raise + validate_required_fields(config) + + +class TestLoadConfigFromFile: + """Tests for load_config_from_file function.""" + + def test_loads_config_from_file(self): + """Test loading configuration from file.""" + content = """ +host=127.0.0.1 +port=8080 +auth_disabled=true +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".config", delete=False) as f: + f.write(content) + f.flush() + config = load_config_from_file(f.name) + + assert config.host == "127.0.0.1" + assert config.port == 8080 + assert config.auth_disabled is True + + def test_file_not_found_raises(self): + """Test FileNotFoundError is raised for missing file.""" + with pytest.raises(FileNotFoundError): + load_config_from_file("/nonexistent/path/config.txt") + + +class TestCreateDefaultConfig: + """Tests for create_default_config function.""" + + def test_creates_default_config(self): + """Test creating default configuration.""" + config = create_default_config() + assert config.host == "0.0.0.0" + assert config.port == 3001 + assert config.auth_disabled is True + + def test_accepts_overrides(self): + """Test overrides are applied.""" + config = create_default_config( + host="localhost", + port=5000, + auth_disabled=False, + ) + assert config.host == "localhost" + assert config.port == 5000 + assert config.auth_disabled is False + + +class TestAuthServerConfigDataclass: + """Tests for AuthServerConfig dataclass.""" + + def test_post_init_sets_server_url(self): + """Test __post_init__ sets server_url default.""" + config = AuthServerConfig(port=9000) + assert config.server_url == "http://localhost:9000" + + def test_explicit_server_url_preserved(self): + """Test explicit server_url is not overwritten.""" + config = AuthServerConfig( + port=9000, + server_url="https://api.example.com", + ) + assert config.server_url == "https://api.example.com" diff --git a/examples/auth/tests/test_integration.py b/examples/auth/tests/test_integration.py new file mode 100644 index 00000000..6bd52a03 --- /dev/null +++ b/examples/auth/tests/test_integration.py @@ -0,0 +1,410 @@ +"""Integration tests for the Auth MCP Server.""" + +import json + +from flask.testing import FlaskClient + + +class TestHealthEndpoint: + """Tests for GET /health endpoint.""" + + def test_returns_200(self, client: FlaskClient): + """Test health endpoint returns 200.""" + response = client.get("/health") + assert response.status_code == 200 + + def test_returns_json(self, client: FlaskClient): + """Test health endpoint returns JSON.""" + response = client.get("/health") + assert response.content_type == "application/json" + + def test_status_is_ok(self, client: FlaskClient): + """Test health response has status 'ok'.""" + response = client.get("/health") + data = response.get_json() + assert data["status"] == "ok" + + def test_has_timestamp(self, client: FlaskClient): + """Test health response has timestamp.""" + response = client.get("/health") + data = response.get_json() + assert "timestamp" in data + assert isinstance(data["timestamp"], str) + + def test_has_uptime(self, client: FlaskClient): + """Test health response has uptime.""" + response = client.get("/health") + data = response.get_json() + assert "uptime" in data + assert isinstance(data["uptime"], int) + + +class TestOAuthProtectedResource: + """Tests for /.well-known/oauth-protected-resource endpoint.""" + + def test_returns_200(self, client: FlaskClient): + """Test endpoint returns 200.""" + response = client.get("/.well-known/oauth-protected-resource") + assert response.status_code == 200 + + def test_returns_json(self, client: FlaskClient): + """Test endpoint returns JSON.""" + response = client.get("/.well-known/oauth-protected-resource") + assert response.content_type == "application/json" + + def test_has_resource(self, client: FlaskClient): + """Test response has resource field.""" + response = client.get("/.well-known/oauth-protected-resource") + data = response.get_json() + assert "resource" in data + assert "/mcp" in data["resource"] + + def test_has_authorization_servers(self, client: FlaskClient): + """Test response has authorization_servers.""" + response = client.get("/.well-known/oauth-protected-resource") + data = response.get_json() + assert "authorization_servers" in data + assert isinstance(data["authorization_servers"], list) + + def test_has_scopes_supported(self, client: FlaskClient): + """Test response has scopes_supported.""" + response = client.get("/.well-known/oauth-protected-resource") + data = response.get_json() + assert "scopes_supported" in data + assert isinstance(data["scopes_supported"], list) + + def test_has_cors_headers(self, client: FlaskClient): + """Test response has CORS headers.""" + response = client.get("/.well-known/oauth-protected-resource") + assert "Access-Control-Allow-Origin" in response.headers + + def test_mcp_specific_endpoint(self, client: FlaskClient): + """Test MCP-specific protected resource endpoint.""" + response = client.get("/.well-known/oauth-protected-resource/mcp") + assert response.status_code == 200 + data = response.get_json() + assert "resource" in data + + +class TestOpenIDConfiguration: + """Tests for /.well-known/openid-configuration endpoint.""" + + def test_returns_200(self, client: FlaskClient): + """Test endpoint returns 200.""" + response = client.get("/.well-known/openid-configuration") + assert response.status_code == 200 + + def test_has_issuer(self, client: FlaskClient): + """Test response has issuer.""" + response = client.get("/.well-known/openid-configuration") + data = response.get_json() + assert "issuer" in data + + def test_has_authorization_endpoint(self, client: FlaskClient): + """Test response has authorization_endpoint.""" + response = client.get("/.well-known/openid-configuration") + data = response.get_json() + assert "authorization_endpoint" in data + + def test_has_token_endpoint(self, client: FlaskClient): + """Test response has token_endpoint.""" + response = client.get("/.well-known/openid-configuration") + data = response.get_json() + assert "token_endpoint" in data + + def test_has_scopes_supported(self, client: FlaskClient): + """Test response has scopes_supported.""" + response = client.get("/.well-known/openid-configuration") + data = response.get_json() + assert "scopes_supported" in data + assert "openid" in data["scopes_supported"] + + +class TestOAuthAuthorizationServer: + """Tests for /.well-known/oauth-authorization-server endpoint.""" + + def test_returns_200(self, client: FlaskClient): + """Test endpoint returns 200.""" + response = client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 + + def test_has_response_types_supported(self, client: FlaskClient): + """Test response has response_types_supported.""" + response = client.get("/.well-known/oauth-authorization-server") + data = response.get_json() + assert "response_types_supported" in data + assert "code" in data["response_types_supported"] + + def test_has_grant_types_supported(self, client: FlaskClient): + """Test response has grant_types_supported.""" + response = client.get("/.well-known/oauth-authorization-server") + data = response.get_json() + assert "grant_types_supported" in data + assert "authorization_code" in data["grant_types_supported"] + + +class TestMcpEndpoint: + """Tests for POST /mcp endpoint.""" + + def test_initialize(self, client: FlaskClient, json_rpc_initialize_request): + """Test initialize method.""" + response = client.post( + "/mcp", + json=json_rpc_initialize_request, + content_type="application/json", + ) + assert response.status_code == 200 + data = response.get_json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == 1 + assert "result" in data + assert data["result"]["protocolVersion"] == "2024-11-05" + + def test_tools_list(self, client: FlaskClient, json_rpc_tools_list_request): + """Test tools/list method.""" + response = client.post( + "/mcp", + json=json_rpc_tools_list_request, + content_type="application/json", + ) + assert response.status_code == 200 + data = response.get_json() + assert "result" in data + assert "tools" in data["result"] + # Should have weather tools registered + tool_names = [t["name"] for t in data["result"]["tools"]] + assert "get-weather" in tool_names + + def test_tools_call_get_weather( + self, client: FlaskClient, json_rpc_tools_call_request + ): + """Test tools/call for get-weather.""" + response = client.post( + "/mcp", + json=json_rpc_tools_call_request, + content_type="application/json", + ) + assert response.status_code == 200 + data = response.get_json() + assert "result" in data + assert "content" in data["result"] + content = json.loads(data["result"]["content"][0]["text"]) + assert content["city"] == "London" + assert "temperature" in content + assert "condition" in content + + def test_invalid_json_returns_parse_error(self, client: FlaskClient): + """Test invalid JSON returns parse error.""" + response = client.post( + "/mcp", + data="not valid json", + content_type="application/json", + ) + assert response.status_code == 200 + data = response.get_json() + assert "error" in data + assert data["error"]["code"] == -32700 # PARSE_ERROR + + def test_method_not_found(self, client: FlaskClient): + """Test unknown method returns method not found error.""" + response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "unknown/method", + "params": {}, + }, + content_type="application/json", + ) + assert response.status_code == 200 + data = response.get_json() + assert "error" in data + assert data["error"]["code"] == -32601 # METHOD_NOT_FOUND + + def test_has_cors_headers(self, client: FlaskClient, json_rpc_initialize_request): + """Test response has CORS headers.""" + response = client.post( + "/mcp", + json=json_rpc_initialize_request, + content_type="application/json", + ) + assert "Access-Control-Allow-Origin" in response.headers + + +class TestRpcEndpoint: + """Tests for POST /rpc endpoint (alternative to /mcp).""" + + def test_rpc_endpoint_works(self, client: FlaskClient, json_rpc_initialize_request): + """Test /rpc endpoint works same as /mcp.""" + response = client.post( + "/rpc", + json=json_rpc_initialize_request, + content_type="application/json", + ) + assert response.status_code == 200 + data = response.get_json() + assert "result" in data + assert data["result"]["protocolVersion"] == "2024-11-05" + + +class TestCorsPreflightHandlers: + """Tests for CORS preflight (OPTIONS) handlers.""" + + def test_mcp_options(self, client: FlaskClient): + """Test OPTIONS /mcp returns correct headers.""" + response = client.options("/mcp") + assert response.status_code == 204 + assert "Access-Control-Allow-Methods" in response.headers + + def test_rpc_options(self, client: FlaskClient): + """Test OPTIONS /rpc returns correct headers.""" + response = client.options("/rpc") + assert response.status_code == 204 + + def test_oauth_protected_resource_options(self, client: FlaskClient): + """Test OPTIONS /.well-known/oauth-protected-resource.""" + response = client.options("/.well-known/oauth-protected-resource") + assert response.status_code == 204 + + def test_oauth_register_options(self, client: FlaskClient): + """Test OPTIONS /oauth/register.""" + response = client.options("/oauth/register") + assert response.status_code == 204 + + +class TestOAuthRegisterEndpoint: + """Tests for POST /oauth/register endpoint.""" + + def test_returns_201(self, client: FlaskClient): + """Test endpoint returns 201.""" + response = client.post( + "/oauth/register", + json={"redirect_uris": ["http://localhost:8080/callback"]}, + content_type="application/json", + ) + assert response.status_code == 201 + + def test_returns_client_id(self, client: FlaskClient): + """Test response has client_id.""" + response = client.post( + "/oauth/register", + json={}, + content_type="application/json", + ) + data = response.get_json() + assert "client_id" in data + + def test_preserves_redirect_uris(self, client: FlaskClient): + """Test redirect_uris are preserved.""" + uris = ["http://localhost:8080/callback", "http://localhost:3000/auth"] + response = client.post( + "/oauth/register", + json={"redirect_uris": uris}, + content_type="application/json", + ) + data = response.get_json() + assert data["redirect_uris"] == uris + + +class TestFullFlow: + """Tests for complete MCP flow.""" + + def test_initialize_then_tools_list_then_call(self, client: FlaskClient): + """Test complete flow: initialize -> tools/list -> tools/call.""" + # Initialize + init_response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {}, + }, + content_type="application/json", + ) + assert init_response.status_code == 200 + init_data = init_response.get_json() + assert "result" in init_data + + # List tools + list_response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {}, + }, + content_type="application/json", + ) + assert list_response.status_code == 200 + list_data = list_response.get_json() + assert "tools" in list_data["result"] + assert len(list_data["result"]["tools"]) >= 3 # At least 3 weather tools + + # Call get-weather tool + call_response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get-weather", + "arguments": {"city": "New York"}, + }, + }, + content_type="application/json", + ) + assert call_response.status_code == 200 + call_data = call_response.get_json() + assert "result" in call_data + weather = json.loads(call_data["result"]["content"][0]["text"]) + assert weather["city"] == "New York" + + def test_get_forecast_without_auth(self, client: FlaskClient): + """Test get-forecast works when auth is disabled.""" + response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get-forecast", + "arguments": {"city": "Tokyo"}, + }, + }, + content_type="application/json", + ) + assert response.status_code == 200 + data = response.get_json() + # Auth is disabled, so should succeed + assert "result" in data + forecast = json.loads(data["result"]["content"][0]["text"]) + assert forecast["city"] == "Tokyo" + assert len(forecast["forecast"]) == 5 + + def test_get_weather_alerts_without_auth(self, client: FlaskClient): + """Test get-weather-alerts works when auth is disabled.""" + response = client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get-weather-alerts", + "arguments": {"region": "California"}, + }, + }, + content_type="application/json", + ) + assert response.status_code == 200 + data = response.get_json() + # Auth is disabled, so should succeed + assert "result" in data + alerts = json.loads(data["result"]["content"][0]["text"]) + assert "region" in alerts + assert "alerts" in alerts diff --git a/examples/auth/tests/test_mcp_handler.py b/examples/auth/tests/test_mcp_handler.py new file mode 100644 index 00000000..33a455fc --- /dev/null +++ b/examples/auth/tests/test_mcp_handler.py @@ -0,0 +1,335 @@ +"""Tests for MCP JSON-RPC handler.""" + + +from py_auth_mcp_server.routes.mcp_handler import ( + JsonRpcError, + JsonRpcErrorCode, + JsonRpcResponse, + ToolContentItem, + ToolResult, + ToolSpec, +) + + +class TestJsonRpcTypes: + """Tests for JSON-RPC type classes.""" + + def test_json_rpc_error_to_dict(self): + """Test JsonRpcError.to_dict().""" + error = JsonRpcError( + code=JsonRpcErrorCode.INVALID_REQUEST, + message="Invalid request", + data={"details": "test"}, + ) + result = error.to_dict() + assert result["code"] == -32600 + assert result["message"] == "Invalid request" + assert result["data"] == {"details": "test"} + + def test_json_rpc_error_without_data(self): + """Test JsonRpcError.to_dict() without data.""" + error = JsonRpcError(code=-32600, message="Error") + result = error.to_dict() + assert "data" not in result + + def test_json_rpc_response_with_result(self): + """Test JsonRpcResponse.to_dict() with result.""" + response = JsonRpcResponse( + jsonrpc="2.0", + id=1, + result={"status": "ok"}, + ) + result = response.to_dict() + assert result["jsonrpc"] == "2.0" + assert result["id"] == 1 + assert result["result"] == {"status": "ok"} + assert "error" not in result + + def test_json_rpc_response_with_error(self): + """Test JsonRpcResponse.to_dict() with error.""" + response = JsonRpcResponse( + jsonrpc="2.0", + id=1, + error=JsonRpcError(code=-32600, message="Error"), + ) + result = response.to_dict() + assert result["jsonrpc"] == "2.0" + assert result["id"] == 1 + assert result["error"]["code"] == -32600 + assert "result" not in result + + +class TestToolTypes: + """Tests for tool type classes.""" + + def test_tool_spec_to_dict(self): + """Test ToolSpec.to_dict().""" + spec = ToolSpec( + name="test-tool", + description="A test tool", + input_schema={ + "type": "object", + "properties": { + "param": {"type": "string"}, + }, + }, + ) + result = spec.to_dict() + assert result["name"] == "test-tool" + assert result["description"] == "A test tool" + assert result["inputSchema"]["type"] == "object" + + def test_tool_content_item_to_dict(self): + """Test ToolContentItem.to_dict().""" + item = ToolContentItem( + type="text", + text="Hello, world!", + ) + result = item.to_dict() + assert result["type"] == "text" + assert result["text"] == "Hello, world!" + assert "data" not in result + assert "mimeType" not in result + + def test_tool_result_to_dict(self): + """Test ToolResult.to_dict().""" + tool_result = ToolResult( + content=[ + ToolContentItem(type="text", text="Result"), + ], + is_error=False, + ) + result = tool_result.to_dict() + assert len(result["content"]) == 1 + assert result["content"][0]["text"] == "Result" + assert "isError" not in result + + def test_tool_result_with_error(self): + """Test ToolResult.to_dict() with error.""" + tool_result = ToolResult( + content=[ToolContentItem(type="text", text="Error")], + is_error=True, + ) + result = tool_result.to_dict() + assert result["isError"] is True + + +class TestMcpHandlerRegisterTool: + """Tests for McpHandler.register_tool().""" + + def test_register_tool(self, mcp_handler): + """Test registering a tool.""" + mcp_handler.register_tool( + name="test-tool", + description="Test tool", + input_schema={"type": "object", "properties": {}}, + handler=lambda args: ToolResult( + content=[ToolContentItem(type="text", text="ok")] + ), + ) + tools = mcp_handler.get_tools() + assert len(tools) == 1 + assert tools[0].name == "test-tool" + + def test_tool_decorator(self, mcp_handler): + """Test @tool decorator.""" + + @mcp_handler.tool( + name="decorated-tool", + description="Decorated tool", + input_schema={"type": "object", "properties": {}}, + ) + def my_tool(args): + return ToolResult(content=[ToolContentItem(type="text", text="ok")]) + + tools = mcp_handler.get_tools() + assert len(tools) == 1 + assert tools[0].name == "decorated-tool" + + +class TestMcpHandlerGetTools: + """Tests for McpHandler.get_tools().""" + + def test_empty_initially(self, mcp_handler): + """Test handler has no tools initially.""" + tools = mcp_handler.get_tools() + assert len(tools) == 0 + + def test_returns_registered_tools(self, mcp_handler): + """Test get_tools returns registered tools.""" + mcp_handler.register_tool( + name="tool1", + description="Tool 1", + input_schema={"type": "object", "properties": {}}, + handler=lambda args: ToolResult(content=[]), + ) + mcp_handler.register_tool( + name="tool2", + description="Tool 2", + input_schema={"type": "object", "properties": {}}, + handler=lambda args: ToolResult(content=[]), + ) + tools = mcp_handler.get_tools() + assert len(tools) == 2 + names = [t.name for t in tools] + assert "tool1" in names + assert "tool2" in names + + +class TestMcpHandlerHandleRequest: + """Tests for McpHandler.handle_request().""" + + def test_initialize_method(self, mcp_handler): + """Test handling initialize method.""" + request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {}, + } + response = mcp_handler.handle_request(request) + assert response.id == 1 + assert response.error is None + assert response.result["protocolVersion"] == "2024-11-05" + assert "capabilities" in response.result + assert "serverInfo" in response.result + + def test_tools_list_method(self, mcp_handler): + """Test handling tools/list method.""" + mcp_handler.register_tool( + name="test-tool", + description="Test", + input_schema={"type": "object", "properties": {}}, + handler=lambda args: ToolResult(content=[]), + ) + request = { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {}, + } + response = mcp_handler.handle_request(request) + assert response.error is None + assert len(response.result["tools"]) == 1 + assert response.result["tools"][0]["name"] == "test-tool" + + def test_tools_call_method(self, mcp_handler): + """Test handling tools/call method.""" + mcp_handler.register_tool( + name="echo", + description="Echo tool", + input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + }, + handler=lambda args: ToolResult( + content=[ToolContentItem(type="text", text=args.get("message", ""))] + ), + ) + request = { + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "echo", + "arguments": {"message": "Hello"}, + }, + } + response = mcp_handler.handle_request(request) + assert response.error is None + assert response.result["content"][0]["text"] == "Hello" + + def test_ping_method(self, mcp_handler): + """Test handling ping method.""" + request = { + "jsonrpc": "2.0", + "id": 4, + "method": "ping", + "params": {}, + } + response = mcp_handler.handle_request(request) + assert response.error is None + assert response.result == {} + + def test_method_not_found(self, mcp_handler): + """Test method not found error.""" + request = { + "jsonrpc": "2.0", + "id": 5, + "method": "unknown/method", + "params": {}, + } + response = mcp_handler.handle_request(request) + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.METHOD_NOT_FOUND + assert "unknown/method" in response.error.message + + def test_invalid_request_not_object(self, mcp_handler): + """Test invalid request (not an object).""" + response = mcp_handler.handle_request("not an object") + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.INVALID_REQUEST + + def test_invalid_request_missing_jsonrpc(self, mcp_handler): + """Test invalid request (missing jsonrpc).""" + request = {"id": 1, "method": "ping"} + response = mcp_handler.handle_request(request) + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.INVALID_REQUEST + + def test_invalid_request_wrong_jsonrpc(self, mcp_handler): + """Test invalid request (wrong jsonrpc version).""" + request = {"jsonrpc": "1.0", "id": 1, "method": "ping"} + response = mcp_handler.handle_request(request) + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.INVALID_REQUEST + + def test_invalid_request_missing_method(self, mcp_handler): + """Test invalid request (missing method).""" + request = {"jsonrpc": "2.0", "id": 1} + response = mcp_handler.handle_request(request) + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.INVALID_REQUEST + + def test_invalid_params(self, mcp_handler): + """Test invalid params (not an object).""" + request = {"jsonrpc": "2.0", "id": 1, "method": "ping", "params": "string"} + response = mcp_handler.handle_request(request) + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.INVALID_PARAMS + + def test_tool_not_found(self, mcp_handler): + """Test tool not found error.""" + request = { + "jsonrpc": "2.0", + "id": 6, + "method": "tools/call", + "params": {"name": "nonexistent-tool", "arguments": {}}, + } + response = mcp_handler.handle_request(request) + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.METHOD_NOT_FOUND + assert "nonexistent-tool" in response.error.message + + def test_tools_call_invalid_name(self, mcp_handler): + """Test tools/call with invalid name parameter.""" + request = { + "jsonrpc": "2.0", + "id": 7, + "method": "tools/call", + "params": {"name": 123, "arguments": {}}, + } + response = mcp_handler.handle_request(request) + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.INVALID_PARAMS + + def test_notification_no_id(self, mcp_handler): + """Test notification (no id) is handled.""" + request = { + "jsonrpc": "2.0", + "method": "ping", + "params": {}, + } + response = mcp_handler.handle_request(request) + assert response.id is None + assert response.error is None diff --git a/examples/auth/tests/test_oauth_auth.py b/examples/auth/tests/test_oauth_auth.py new file mode 100644 index 00000000..ef40e701 --- /dev/null +++ b/examples/auth/tests/test_oauth_auth.py @@ -0,0 +1,208 @@ +"""Tests for OAuth authentication middleware.""" + +from flask import Flask + +from py_auth_mcp_server.config import create_default_config +from py_auth_mcp_server.middleware.oauth_auth import OAuthAuthMiddleware, require_auth + + +class TestExtractToken: + """Tests for extract_token method.""" + + def test_extracts_from_authorization_header(self): + """Test token extraction from Authorization header.""" + app = Flask(__name__) + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + + with app.test_request_context( + "/test", + headers={"Authorization": "Bearer test-token-123"}, + ): + token = middleware.extract_token() + assert token == "test-token-123" + + def test_extracts_from_query_param(self): + """Test token extraction from query parameter.""" + app = Flask(__name__) + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + + with app.test_request_context("/test?access_token=query-token-456"): + token = middleware.extract_token() + assert token == "query-token-456" + + def test_header_takes_precedence(self): + """Test Authorization header takes precedence over query param.""" + app = Flask(__name__) + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + + with app.test_request_context( + "/test?access_token=query-token", + headers={"Authorization": "Bearer header-token"}, + ): + token = middleware.extract_token() + assert token == "header-token" + + def test_returns_none_when_missing(self): + """Test returns None when no token present.""" + app = Flask(__name__) + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + + with app.test_request_context("/test"): + token = middleware.extract_token() + assert token is None + + def test_returns_none_for_non_bearer(self): + """Test returns None for non-Bearer auth.""" + app = Flask(__name__) + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + + with app.test_request_context( + "/test", + headers={"Authorization": "Basic dXNlcjpwYXNz"}, + ): + token = middleware.extract_token() + assert token is None + + def test_returns_none_for_empty_query_param(self): + """Test returns None for empty query parameter.""" + app = Flask(__name__) + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + + with app.test_request_context("/test?access_token="): + token = middleware.extract_token() + assert token is None + + +class TestRequiresAuth: + """Tests for requires_auth method.""" + + def test_returns_false_when_auth_disabled(self): + """Test returns False when auth is disabled.""" + config = create_default_config(auth_disabled=True) + middleware = OAuthAuthMiddleware(None, config) + + assert middleware.requires_auth("/mcp") is False + assert middleware.requires_auth("/rpc") is False + + def test_returns_false_when_no_auth_client(self): + """Test returns False when no auth client.""" + config = create_default_config(auth_disabled=False) + middleware = OAuthAuthMiddleware(None, config) + + assert middleware.requires_auth("/mcp") is False + + def test_public_paths_no_auth(self): + """Test public paths don't require auth.""" + config = create_default_config(auth_disabled=False) + # We can't easily test with a real auth client, so we test the path logic + middleware = OAuthAuthMiddleware(None, config) + + # These would return False due to no auth client, but test the path logic + public_paths = [ + "/.well-known/oauth-protected-resource", + "/.well-known/openid-configuration", + "/oauth/authorize", + "/oauth/register", + "/health", + "/favicon.ico", + ] + # Can't fully test without mock, but verify no exceptions + for path in public_paths: + middleware.requires_auth(path) + + def test_protected_paths_listed(self): + """Test protected paths are correctly identified.""" + # This test documents expected protected paths + # Path checking logic is internal, but we verify the class has the method + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + assert hasattr(middleware, "requires_auth") + # Protected paths: /mcp, /rpc, /events, /sse (documented for reference) + + +class TestSendUnauthorized: + """Tests for _send_unauthorized response format.""" + + def test_unauthorized_response_format(self, client): + """Test 401 response has correct format.""" + # Make request to protected endpoint without token + # Since auth is disabled in test config, we need a different approach + # This test verifies the client fixture works + response = client.get("/health") + assert response.status_code == 200 + + +class TestHasScope: + """Tests for has_scope method.""" + + def test_returns_false_when_not_authenticated(self): + """Test returns False when not authenticated.""" + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + + assert middleware.has_scope("mcp:read") is False + + def test_returns_false_for_empty_scopes(self): + """Test returns False when scopes are empty.""" + config = create_default_config() + middleware = OAuthAuthMiddleware(None, config) + # Auth context is not authenticated by default + assert middleware.has_scope("mcp:read") is False + + +class TestIsAuthDisabled: + """Tests for is_auth_disabled method.""" + + def test_returns_true_when_disabled(self): + """Test returns True when auth is disabled.""" + config = create_default_config(auth_disabled=True) + middleware = OAuthAuthMiddleware(None, config) + + assert middleware.is_auth_disabled() is True + + def test_returns_false_when_enabled(self): + """Test returns False when auth is enabled.""" + config = create_default_config(auth_disabled=False) + middleware = OAuthAuthMiddleware(None, config) + + assert middleware.is_auth_disabled() is False + + +class TestRequireAuthDecorator: + """Tests for require_auth decorator.""" + + def test_decorator_allows_authenticated(self): + """Test decorator allows authenticated requests.""" + app = Flask(__name__) + + @app.route("/protected") + @require_auth + def protected(): + return {"status": "ok"} + + with app.test_client() as client: + # Without auth context, should return 401 + with app.test_request_context(): + response = client.get("/protected") + assert response.status_code == 401 + + def test_decorator_returns_401_without_context(self): + """Test decorator returns 401 without auth context.""" + app = Flask(__name__) + + @app.route("/protected") + @require_auth + def protected(): + return {"status": "ok"} + + with app.test_client() as client: + response = client.get("/protected") + assert response.status_code == 401 + data = response.get_json() + assert data["error"] == "unauthorized" diff --git a/examples/auth/tests/test_weather_tools.py b/examples/auth/tests/test_weather_tools.py new file mode 100644 index 00000000..9f9885e0 --- /dev/null +++ b/examples/auth/tests/test_weather_tools.py @@ -0,0 +1,245 @@ +"""Tests for weather tools.""" + +import json + +from py_auth_mcp_server.tools.weather_tools import ( + CONDITIONS, + access_denied, + get_condition_for_city, + get_simulated_alerts, + get_simulated_forecast, + get_simulated_weather, + get_temp_for_city, + has_scope, +) + + +class TestHasScope: + """Tests for has_scope function.""" + + def test_scope_present(self): + """Test returns True when scope is present.""" + assert has_scope("openid profile email mcp:read", "mcp:read") is True + assert has_scope("openid profile", "openid") is True + assert has_scope("mcp:admin", "mcp:admin") is True + + def test_scope_not_present(self): + """Test returns False when scope is not present.""" + assert has_scope("openid profile email", "mcp:read") is False + assert has_scope("mcp:read", "mcp:admin") is False + + def test_empty_scopes(self): + """Test returns False for empty scopes.""" + assert has_scope("", "mcp:read") is False + assert has_scope(" ", "mcp:read") is False + + def test_empty_required(self): + """Test returns False for empty required scope.""" + assert has_scope("openid profile", "") is False + + def test_none_values(self): + """Test returns False for None values.""" + assert has_scope(None, "mcp:read") is False + assert has_scope("openid", None) is False + + def test_partial_match_not_accepted(self): + """Test partial matches are not accepted.""" + assert has_scope("mcp:readonly", "mcp:read") is False + assert has_scope("admin", "mcp:admin") is False + + +class TestGetConditionForCity: + """Tests for get_condition_for_city function.""" + + def test_returns_valid_condition(self): + """Test returns a valid weather condition.""" + condition = get_condition_for_city("London") + assert condition in CONDITIONS + + def test_deterministic(self): + """Test same city always returns same condition.""" + condition1 = get_condition_for_city("Paris") + condition2 = get_condition_for_city("Paris") + assert condition1 == condition2 + + def test_different_cities_may_differ(self): + """Test different cities may have different conditions.""" + cities = ["London", "Paris", "Tokyo", "Sydney", "Cairo"] + conditions = [get_condition_for_city(city) for city in cities] + # Not all should be the same (statistically unlikely) + assert len(set(conditions)) > 1 + + def test_offset_changes_condition(self): + """Test offset changes the condition.""" + conditions = [get_condition_for_city("London", offset=i) for i in range(6)] + # With 6 conditions and 6 offsets, we should cycle through + assert len(set(conditions)) > 1 + + +class TestGetTempForCity: + """Tests for get_temp_for_city function.""" + + def test_returns_valid_temperature(self): + """Test returns temperature in valid range (10-35).""" + temp = get_temp_for_city("London") + assert 10 <= temp <= 35 + + def test_deterministic(self): + """Test same city always returns same temperature.""" + temp1 = get_temp_for_city("Berlin") + temp2 = get_temp_for_city("Berlin") + assert temp1 == temp2 + + def test_offset_changes_temperature(self): + """Test offset changes the temperature.""" + temps = [get_temp_for_city("London", offset=i) for i in range(5)] + # Different offsets should give different temperatures + assert len(set(temps)) > 1 + + +class TestGetSimulatedWeather: + """Tests for get_simulated_weather function.""" + + def test_returns_all_fields(self): + """Test returns all expected fields.""" + weather = get_simulated_weather("London") + assert "city" in weather + assert "temperature" in weather + assert "condition" in weather + assert "humidity" in weather + assert "windSpeed" in weather + + def test_city_is_correct(self): + """Test city name is correct.""" + weather = get_simulated_weather("Tokyo") + assert weather["city"] == "Tokyo" + + def test_humidity_in_range(self): + """Test humidity is in valid range (40-80).""" + weather = get_simulated_weather("Cairo") + assert 40 <= weather["humidity"] <= 79 + + def test_wind_speed_in_range(self): + """Test wind speed is in valid range (5-30).""" + weather = get_simulated_weather("Sydney") + assert 5 <= weather["windSpeed"] <= 29 + + def test_deterministic(self): + """Test same city returns same weather.""" + weather1 = get_simulated_weather("Madrid") + weather2 = get_simulated_weather("Madrid") + assert weather1 == weather2 + + +class TestGetSimulatedForecast: + """Tests for get_simulated_forecast function.""" + + def test_returns_five_days(self): + """Test returns 5-day forecast.""" + forecast = get_simulated_forecast("London") + assert len(forecast) == 5 + + def test_day_names(self): + """Test forecast has correct day names.""" + forecast = get_simulated_forecast("London") + days = [day["day"] for day in forecast] + assert days == ["Today", "Tomorrow", "Day 3", "Day 4", "Day 5"] + + def test_each_day_has_fields(self): + """Test each day has required fields.""" + forecast = get_simulated_forecast("Paris") + for day in forecast: + assert "day" in day + assert "high" in day + assert "low" in day + assert "condition" in day + + def test_high_greater_than_low(self): + """Test high temperature is greater than low.""" + forecast = get_simulated_forecast("Berlin") + for day in forecast: + assert day["high"] > day["low"] + assert day["high"] - day["low"] == 10 # Offset is +5 and -5 + + def test_deterministic(self): + """Test same city returns same forecast.""" + forecast1 = get_simulated_forecast("Rome") + forecast2 = get_simulated_forecast("Rome") + assert forecast1 == forecast2 + + +class TestGetSimulatedAlerts: + """Tests for get_simulated_alerts function.""" + + def test_returns_list(self): + """Test returns a list.""" + alerts = get_simulated_alerts("California") + assert isinstance(alerts, list) + + def test_alert_structure(self): + """Test alerts have correct structure.""" + # Try multiple regions to get one with alerts + for region in ["California", "Texas", "Florida", "Alaska", "Hawaii"]: + alerts = get_simulated_alerts(region) + if alerts: + for alert in alerts: + assert "type" in alert + assert "severity" in alert + assert "message" in alert + break + + def test_region_in_message(self): + """Test region name appears in alert message.""" + for region in ["California", "Texas", "Florida", "Alaska", "Hawaii"]: + alerts = get_simulated_alerts(region) + if alerts: + for alert in alerts: + assert region in alert["message"] + break + + def test_deterministic(self): + """Test same region returns same alerts.""" + alerts1 = get_simulated_alerts("Oregon") + alerts2 = get_simulated_alerts("Oregon") + assert alerts1 == alerts2 + + def test_some_regions_have_no_alerts(self): + """Test some regions have no alerts.""" + # Test many regions to find one with no alerts + no_alerts_found = False + for i in range(100): + region = f"Region{i}" + alerts = get_simulated_alerts(region) + if not alerts: + no_alerts_found = True + break + assert no_alerts_found + + +class TestAccessDenied: + """Tests for access_denied function.""" + + def test_returns_tool_result(self): + """Test returns a ToolResult.""" + result = access_denied("mcp:read") + assert hasattr(result, "content") + assert hasattr(result, "is_error") + + def test_is_error_true(self): + """Test is_error is True.""" + result = access_denied("mcp:read") + assert result.is_error is True + + def test_content_has_error(self): + """Test content contains error message.""" + result = access_denied("mcp:admin") + assert len(result.content) == 1 + content = json.loads(result.content[0].text) + assert content["error"] == "access_denied" + assert "mcp:admin" in content["message"] + + def test_scope_in_message(self): + """Test scope appears in error message.""" + result = access_denied("custom:scope") + content = json.loads(result.content[0].text) + assert "custom:scope" in content["message"] diff --git a/gopher_mcp_python/__init__.py b/gopher_mcp_python/__init__.py index 90d6c441..688ddf6a 100644 --- a/gopher_mcp_python/__init__.py +++ b/gopher_mcp_python/__init__.py @@ -29,7 +29,23 @@ from gopher_mcp_python.server_config import ServerConfig from gopher_mcp_python.ffi import GopherOrchLibrary -__version__ = "0.1.1" +# Auth module re-exports +from gopher_mcp_python.ffi.auth import ( + # Types + GopherAuthError, + ValidationResult, + TokenPayload, + GopherAuthContext, + # Classes + GopherAuthClient, + GopherValidationOptions, + # Functions + gopher_init_auth_library, + gopher_shutdown_auth_library, + is_auth_available, +) + +__version__ = "0.1.2" __all__ = [ # Main classes @@ -47,6 +63,16 @@ "TimeoutError", # FFI "GopherOrchLibrary", + # Auth + "GopherAuthError", + "ValidationResult", + "TokenPayload", + "GopherAuthContext", + "GopherAuthClient", + "GopherValidationOptions", + "gopher_init_auth_library", + "gopher_shutdown_auth_library", + "is_auth_available", # Version "__version__", ] diff --git a/gopher_mcp_python/ffi/__init__.py b/gopher_mcp_python/ffi/__init__.py index 13b058ec..490f35f5 100644 --- a/gopher_mcp_python/ffi/__init__.py +++ b/gopher_mcp_python/ffi/__init__.py @@ -4,4 +4,11 @@ from gopher_mcp_python.ffi.library import GopherOrchLibrary, GopherOrchHandle -__all__ = ["GopherOrchLibrary", "GopherOrchHandle"] +# Auth module +from gopher_mcp_python.ffi import auth + +__all__ = [ + "GopherOrchLibrary", + "GopherOrchHandle", + "auth", +] diff --git a/gopher_mcp_python/ffi/auth/__init__.py b/gopher_mcp_python/ffi/auth/__init__.py new file mode 100644 index 00000000..76b55724 --- /dev/null +++ b/gopher_mcp_python/ffi/auth/__init__.py @@ -0,0 +1,82 @@ +""" +Auth FFI bindings for gopher-auth native library. + +Provides Python bindings for JWT token validation and OAuth support +via the gopher-orch native library. +""" + +from gopher_mcp_python.ffi.auth.types import ( + GopherAuthError, + ERROR_DESCRIPTIONS, + ValidationResult, + TokenPayload, + GopherAuthContext, + is_gopher_auth_error, + get_error_description, + gopher_create_empty_auth_context, +) + +from gopher_mcp_python.ffi.auth.loader import ( + GopherAuthClientPtr, + GopherAuthPayloadPtr, + GopherAuthOptionsPtr, + GopherAuthValidationResult, + load_library, + is_library_loaded, + get_library, + is_auth_available, + get_auth_functions, +) + +from gopher_mcp_python.ffi.auth.validation_options import ( + GopherValidationOptions, + gopher_create_validation_options, +) + +from gopher_mcp_python.ffi.auth.auth_client import ( + GopherAuthClient, + gopher_init_auth_library, + gopher_shutdown_auth_library, + gopher_get_auth_library_version, + gopher_is_auth_library_initialized, + gopher_generate_www_authenticate_header, + gopher_generate_www_authenticate_header_v2, +) + +__all__ = [ + # Enums + "GopherAuthError", + # Constants + "ERROR_DESCRIPTIONS", + # Dataclasses + "ValidationResult", + "TokenPayload", + "GopherAuthContext", + # Type functions + "is_gopher_auth_error", + "get_error_description", + "gopher_create_empty_auth_context", + # Pointer types + "GopherAuthClientPtr", + "GopherAuthPayloadPtr", + "GopherAuthOptionsPtr", + # Structures + "GopherAuthValidationResult", + # Loader functions + "load_library", + "is_library_loaded", + "get_library", + "is_auth_available", + "get_auth_functions", + # Validation options + "GopherValidationOptions", + "gopher_create_validation_options", + # Auth client + "GopherAuthClient", + "gopher_init_auth_library", + "gopher_shutdown_auth_library", + "gopher_get_auth_library_version", + "gopher_is_auth_library_initialized", + "gopher_generate_www_authenticate_header", + "gopher_generate_www_authenticate_header_v2", +] diff --git a/gopher_mcp_python/ffi/auth/auth_client.py b/gopher_mcp_python/ffi/auth/auth_client.py new file mode 100644 index 00000000..a9cd66a4 --- /dev/null +++ b/gopher_mcp_python/ffi/auth/auth_client.py @@ -0,0 +1,547 @@ +""" +AuthClient - JWT token validation client. + +Provides high-level API for validating JWT tokens and extracting +payload information using the gopher-auth native library. +""" + +from ctypes import byref, c_char_p, c_int64, c_void_p +from typing import List, Optional, Tuple + +from gopher_mcp_python.ffi.auth.loader import ( + load_library, + get_auth_functions, + is_auth_available, + GopherAuthValidationResult, +) +from gopher_mcp_python.ffi.auth.types import ( + GopherAuthError, + ValidationResult, + TokenPayload, + get_error_description, +) +from gopher_mcp_python.ffi.auth.validation_options import GopherValidationOptions + + +# ============================================================================ +# Library Lifecycle Functions +# ============================================================================ + +_auth_initialized: bool = False + + +def gopher_init_auth_library() -> bool: + """ + Initialize the auth library. + + This must be called before using any auth functions. + + Returns: + True if initialization successful, False otherwise. + """ + global _auth_initialized + + if _auth_initialized: + return True + + if not load_library(): + return False + + if not is_auth_available(): + return False + + funcs = get_auth_functions() + auth_init = funcs.get("auth_init") + if auth_init is None: + return False + + result = auth_init() + if result == GopherAuthError.SUCCESS: + _auth_initialized = True + return True + + return False + + +def gopher_shutdown_auth_library() -> bool: + """ + Shutdown the auth library and release resources. + + Returns: + True if shutdown successful, False otherwise. + """ + global _auth_initialized + + if not _auth_initialized: + return True + + funcs = get_auth_functions() + auth_shutdown = funcs.get("auth_shutdown") + if auth_shutdown is None: + return False + + result = auth_shutdown() + if result == GopherAuthError.SUCCESS: + _auth_initialized = False + return True + + return False + + +def gopher_get_auth_library_version() -> Optional[str]: + """ + Get the auth library version string. + + Returns: + Version string, or None if unavailable. + """ + if not load_library(): + return None + + funcs = get_auth_functions() + auth_version = funcs.get("auth_version") + if auth_version is None: + return None + + result = auth_version() + if result: + return result.decode("utf-8") + return None + + +def gopher_is_auth_library_initialized() -> bool: + """ + Check if the auth library has been initialized. + + Returns: + True if initialized, False otherwise. + """ + return _auth_initialized + + +# ============================================================================ +# WWW-Authenticate Header Generation +# ============================================================================ + + +def gopher_generate_www_authenticate_header( + realm: str, + error: str = "", + description: str = "", +) -> Optional[str]: + """ + Generate a WWW-Authenticate header for 401 responses. + + Args: + realm: The authentication realm. + error: Optional error code (e.g., "invalid_token"). + description: Optional error description. + + Returns: + The WWW-Authenticate header value, or None on error. + """ + if not load_library(): + return None + + funcs = get_auth_functions() + generate_func = funcs.get("generate_www_authenticate") + if generate_func is None: + return None + + header_out = c_char_p() + result = generate_func( + realm.encode("utf-8"), + error.encode("utf-8"), + description.encode("utf-8"), + byref(header_out), + ) + + if result != GopherAuthError.SUCCESS: + return None + + if header_out.value: + header = header_out.value.decode("utf-8") + # Free the allocated string + free_string = funcs.get("free_string") + if free_string: + free_string(header_out) + return header + + return None + + +def gopher_generate_www_authenticate_header_v2( + realm: str, + resource_metadata_url: str, + scopes: List[str], + error: str = "", + error_description: str = "", +) -> Optional[str]: + """ + Generate a WWW-Authenticate header with RFC 9728 support. + + Args: + realm: The authentication realm. + resource_metadata_url: URL to the OAuth protected resource metadata. + scopes: List of required scopes. + error: Optional error code. + error_description: Optional error description. + + Returns: + The WWW-Authenticate header value, or None on error. + """ + if not load_library(): + return None + + funcs = get_auth_functions() + generate_func = funcs.get("generate_www_authenticate_v2") + if generate_func is None: + return None + + scopes_str = " ".join(scopes) + header_out = c_char_p() + result = generate_func( + realm.encode("utf-8"), + resource_metadata_url.encode("utf-8"), + scopes_str.encode("utf-8"), + error.encode("utf-8"), + error_description.encode("utf-8"), + byref(header_out), + ) + + if result != GopherAuthError.SUCCESS: + return None + + if header_out.value: + header = header_out.value.decode("utf-8") + # Free the allocated string + free_string = funcs.get("free_string") + if free_string: + free_string(header_out) + return header + + return None + + +# ============================================================================ +# AuthClient Class +# ============================================================================ + + +class GopherAuthClient: + """ + JWT token validation client. + + Provides methods for validating JWT tokens and extracting payload + information using JWKS-based signature verification. + + Usage: + # Using context manager (recommended) + with GopherAuthClient("https://auth.example.com/.well-known/jwks.json", + "https://auth.example.com") as client: + result = client.validate_token(token) + if result.valid: + payload = client.extract_payload(token) + + # Manual lifecycle management + client = GopherAuthClient(jwks_uri, issuer) + try: + result, payload = client.validate_and_extract(token) + finally: + client.destroy() + """ + + def __init__(self, jwks_uri: str, issuer: str) -> None: + """ + Create a new AuthClient instance. + + Args: + jwks_uri: URL to the JWKS endpoint. + issuer: Expected token issuer. + + Raises: + RuntimeError: If the library is not loaded or client creation fails. + """ + self._handle: Optional[c_void_p] = None + self._destroyed: bool = False + + if not load_library(): + raise RuntimeError("Failed to load gopher-orch library") + + if not is_auth_available(): + raise RuntimeError("Auth functions not available in library") + + # Initialize library if needed + if not gopher_is_auth_library_initialized(): + if not gopher_init_auth_library(): + raise RuntimeError("Failed to initialize auth library") + + funcs = get_auth_functions() + client_create = funcs.get("client_create") + if client_create is None: + raise RuntimeError("client_create function not available") + + handle = c_void_p() + result = client_create( + byref(handle), + jwks_uri.encode("utf-8"), + issuer.encode("utf-8"), + ) + + if result != GopherAuthError.SUCCESS: + raise RuntimeError( + f"Failed to create auth client: {get_error_description(result)}" + ) + + self._handle = handle + + def set_option(self, key: str, value: str) -> bool: + """ + Set a client option. + + Args: + key: Option name (e.g., "cache_duration", "auto_refresh", "request_timeout"). + value: Option value as string. + + Returns: + True if option was set successfully, False otherwise. + + Raises: + RuntimeError: If the client has been destroyed. + """ + self._ensure_not_destroyed() + + funcs = get_auth_functions() + client_set_option = funcs.get("client_set_option") + if client_set_option is None: + return False + + result = client_set_option( + self._handle, + key.encode("utf-8"), + value.encode("utf-8"), + ) + + return result == GopherAuthError.SUCCESS + + def validate_token( + self, + token: str, + options: Optional[GopherValidationOptions] = None, + ) -> ValidationResult: + """ + Validate a JWT token. + + Args: + token: The JWT token string to validate. + options: Optional validation options (scopes, audience, etc.). + + Returns: + ValidationResult with valid flag, error code, and message. + + Raises: + RuntimeError: If the client has been destroyed or function unavailable. + """ + self._ensure_not_destroyed() + + funcs = get_auth_functions() + validate_token = funcs.get("validate_token") + if validate_token is None: + return ValidationResult( + valid=False, + error_code=GopherAuthError.NOT_INITIALIZED, + error_message="validate_token function not available", + ) + + result_struct = GopherAuthValidationResult() + options_handle = options.get_handle() if options else None + + err = validate_token( + self._handle, + token.encode("utf-8"), + options_handle, + byref(result_struct), + ) + + if err != GopherAuthError.SUCCESS: + return ValidationResult( + valid=False, + error_code=err, + error_message=get_error_description(err), + ) + + error_message = None + if result_struct.error_message: + error_message = result_struct.error_message.decode("utf-8") + + return ValidationResult( + valid=result_struct.valid, + error_code=result_struct.error_code, + error_message=error_message, + ) + + def extract_payload(self, token: str) -> Optional[TokenPayload]: + """ + Extract payload from a JWT token without validation. + + Args: + token: The JWT token string. + + Returns: + TokenPayload with extracted claims, or None on error. + + Raises: + RuntimeError: If the client has been destroyed. + """ + self._ensure_not_destroyed() + + funcs = get_auth_functions() + extract_payload = funcs.get("extract_payload") + if extract_payload is None: + return None + + payload_handle = c_void_p() + result = extract_payload( + token.encode("utf-8"), + byref(payload_handle), + ) + + if result != GopherAuthError.SUCCESS: + return None + + try: + subject = self._get_payload_string(payload_handle, "payload_get_subject") + scopes = self._get_payload_string(payload_handle, "payload_get_scopes") + audience = self._get_payload_string(payload_handle, "payload_get_audience") + issuer = self._get_payload_string(payload_handle, "payload_get_issuer") + + # Get expiration + expiration = None + get_exp = funcs.get("payload_get_expiration") + if get_exp: + exp_value = c_int64() + if get_exp(payload_handle, byref(exp_value)) == GopherAuthError.SUCCESS: + expiration = exp_value.value + + return TokenPayload( + subject=subject or "", + scopes=scopes or "", + audience=audience, + expiration=expiration, + issuer=issuer, + ) + finally: + # Clean up payload handle + payload_destroy = funcs.get("payload_destroy") + if payload_destroy: + payload_destroy(payload_handle) + + def _get_payload_string( + self, + payload_handle: c_void_p, + func_name: str, + ) -> Optional[str]: + """ + Helper to get a string field from payload. + + Args: + payload_handle: The payload handle. + func_name: Name of the getter function. + + Returns: + The string value, or None if unavailable. + """ + funcs = get_auth_functions() + getter = funcs.get(func_name) + if getter is None: + return None + + value_out = c_char_p() + result = getter(payload_handle, byref(value_out)) + + if result != GopherAuthError.SUCCESS: + return None + + if value_out.value: + value = value_out.value.decode("utf-8") + # Free the allocated string + free_string = funcs.get("free_string") + if free_string: + free_string(value_out) + return value + + return None + + def validate_and_extract( + self, + token: str, + options: Optional[GopherValidationOptions] = None, + ) -> Tuple[ValidationResult, Optional[TokenPayload]]: + """ + Validate a token and extract its payload in one call. + + Args: + token: The JWT token string. + options: Optional validation options. + + Returns: + Tuple of (ValidationResult, TokenPayload or None). + + Raises: + RuntimeError: If the client has been destroyed. + """ + result = self.validate_token(token, options) + + if not result.valid: + return result, None + + payload = self.extract_payload(token) + return result, payload + + def destroy(self) -> None: + """ + Destroy the client and release resources. + + This method is idempotent - calling it multiple times is safe. + """ + if self._destroyed or self._handle is None: + return + + funcs = get_auth_functions() + client_destroy = funcs.get("client_destroy") + if client_destroy is not None: + client_destroy(self._handle) + + self._handle = None + self._destroyed = True + + def is_destroyed(self) -> bool: + """ + Check if this client has been destroyed. + + Returns: + True if destroyed, False otherwise. + """ + return self._destroyed + + def _ensure_not_destroyed(self) -> None: + """ + Ensure the client has not been destroyed. + + Raises: + RuntimeError: If the client has been destroyed. + """ + if self._destroyed: + raise RuntimeError("GopherAuthClient has been destroyed") + + def __enter__(self) -> "GopherAuthClient": + """Enter context manager.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager and destroy resources.""" + self.destroy() + + def __del__(self) -> None: + """Destructor - ensure resources are released.""" + self.destroy() diff --git a/gopher_mcp_python/ffi/auth/loader.py b/gopher_mcp_python/ffi/auth/loader.py new file mode 100644 index 00000000..57254570 --- /dev/null +++ b/gopher_mcp_python/ffi/auth/loader.py @@ -0,0 +1,404 @@ +""" +Auth Library Loader - ctypes bindings for libgopher-orch. + +Provides FFI bindings to the gopher-auth native library for +JWT token validation and OAuth support. + +Note: The gopher_auth_* functions are part of libgopher-orch, +not a separate library. +""" + +import ctypes +import os +import platform +import sys +from ctypes import ( + POINTER, + Structure, + c_bool, + c_char_p, + c_int, + c_int32, + c_int64, + c_void_p, +) +from pathlib import Path +from typing import Dict, List, Optional, Any + +# Track if library is loaded +_lib: Optional[ctypes.CDLL] = None +_lib_available: bool = False +_debug: bool = False + +# Opaque pointer types +GopherAuthClientPtr = c_void_p +GopherAuthPayloadPtr = c_void_p +GopherAuthOptionsPtr = c_void_p + + +class GopherAuthValidationResult(Structure): + """Validation result structure from C API.""" + + _fields_ = [ + ("valid", c_bool), + ("error_code", c_int32), + ("error_message", c_char_p), + ] + + +def _get_library_name() -> str: + """ + Get the library name for the current platform. + + Returns: + Library filename appropriate for the current OS. + """ + system = platform.system().lower() + if system == "darwin": + return "libgopher-orch.dylib" + elif system == "windows": + return "gopher-orch.dll" + else: + return "libgopher-orch.so" + + +def _get_platform_package_path() -> Optional[str]: + """ + Get the path to the platform-specific optional dependency package. + + This supports pip-distributed native packages with architecture-specific + binaries. + + Returns: + Path to the library directory if found, None otherwise. + """ + system = platform.system().lower() + arch = platform.machine().lower() + + # Normalize architecture names + arch_map = { + "x86_64": "x64", + "amd64": "x64", + "arm64": "arm64", + "aarch64": "arm64", + } + normalized_arch = arch_map.get(arch, arch) + + # Platform names + platform_map = { + "darwin": "darwin", + "linux": "linux", + "windows": "win32", + } + platform_name = platform_map.get(system) + if not platform_name: + return None + + package_name = f"gopher_mcp_python_native_{platform_name}_{normalized_arch}" + + # Try to find the package in site-packages + for site_path in sys.path: + package_dir = Path(site_path) / package_name + if package_dir.exists(): + lib_path = package_dir / "lib" + if lib_path.exists(): + return str(lib_path) + + return None + + +def _get_search_paths() -> List[str]: + """ + Get search paths for the native library. + + Returns: + List of directory paths to search for the library. + """ + paths: List[str] = [] + + # Platform-specific optional dependency package + platform_package_path = _get_platform_package_path() + if platform_package_path: + paths.append(platform_package_path) + + # Get the directory containing this module + module_dir = Path(__file__).parent.parent.parent.parent + + # Development paths + paths.extend([ + str(Path.cwd() / "native" / "lib"), + str(Path.cwd() / "lib"), + str(module_dir / "native" / "lib"), + str(module_dir.parent / "native" / "lib"), + ]) + + # System paths + system = platform.system().lower() + if system == "darwin": + paths.extend(["/usr/local/lib", "/opt/homebrew/lib"]) + paths.append("/usr/lib") + + return paths + + +def load_library() -> bool: + """ + Load the gopher-orch native library. + + Searches for the library in the following order: + 1. Path specified by GOPHER_ORCH_LIBRARY_PATH environment variable + 2. Platform-specific pip package + 3. Development paths (native/lib, lib) + 4. System paths (/usr/local/lib, /usr/lib, etc.) + + Returns: + True if library loaded successfully, False otherwise. + """ + global _lib, _lib_available, _debug + + if _lib is not None: + return _lib_available + + _debug = os.environ.get("DEBUG") is not None + library_name = _get_library_name() + search_paths = _get_search_paths() + + # Try environment variable path first + env_path = os.environ.get("GOPHER_ORCH_LIBRARY_PATH") or os.environ.get( + "GOPHER_AUTH_LIBRARY_PATH" + ) + if env_path and os.path.exists(env_path): + try: + _lib = ctypes.CDLL(env_path) + _lib_available = True + return True + except OSError as e: + if _debug: + print(f"Failed to load from environment path: {e}", file=sys.stderr) + + # Try search paths + for search_path in search_paths: + lib_file = os.path.join(search_path, library_name) + if os.path.exists(lib_file): + try: + _lib = ctypes.CDLL(lib_file) + _lib_available = True + return True + except OSError as e: + if _debug: + print(f"Failed to load from {search_path}: {e}", file=sys.stderr) + + # Try system paths (let the OS find it) + try: + _lib = ctypes.CDLL(library_name) + _lib_available = True + return True + except OSError as e: + if _debug: + print(f"Failed to load gopher-orch library: {e}", file=sys.stderr) + print("Searched paths:", file=sys.stderr) + for p in search_paths: + print(f" - {p}", file=sys.stderr) + + _lib_available = False + return False + + +def is_library_loaded() -> bool: + """ + Check if the library is loaded and available. + + Returns: + True if the library is loaded, False otherwise. + """ + return _lib_available + + +def get_library() -> Optional[ctypes.CDLL]: + """ + Get the loaded library instance. + + Returns: + The loaded CDLL instance, or None if not loaded. + """ + return _lib + + +# ============================================================================ +# FFI Function Bindings +# ============================================================================ + +# Track if functions are set up +_functions_setup: bool = False +_auth_functions_available: bool = False + +# Function binding specifications: (name, argtypes, restype) +_FUNCTION_SPECS = [ + # Library lifecycle + ("gopher_auth_init", [], c_int32), + ("gopher_auth_shutdown", [], c_int32), + ("gopher_auth_version", [], c_char_p), + # Client + ("gopher_auth_client_create", [POINTER(c_void_p), c_char_p, c_char_p], c_int32), + ("gopher_auth_client_destroy", [c_void_p], c_int32), + ("gopher_auth_client_set_option", [c_void_p, c_char_p, c_char_p], c_int32), + # Options + ("gopher_auth_validation_options_create", [POINTER(c_void_p)], c_int32), + ("gopher_auth_validation_options_destroy", [c_void_p], c_int32), + ("gopher_auth_validation_options_set_scopes", [c_void_p, c_char_p], c_int32), + ("gopher_auth_validation_options_set_audience", [c_void_p, c_char_p], c_int32), + ("gopher_auth_validation_options_set_clock_skew", [c_void_p, c_int64], c_int32), + # Validation + ("gopher_auth_validate_token", [c_void_p, c_char_p, c_void_p, POINTER(GopherAuthValidationResult)], c_int32), + ("gopher_auth_extract_payload", [c_char_p, POINTER(c_void_p)], c_int32), + # Payload + ("gopher_auth_payload_get_subject", [c_void_p, POINTER(c_char_p)], c_int32), + ("gopher_auth_payload_get_scopes", [c_void_p, POINTER(c_char_p)], c_int32), + ("gopher_auth_payload_get_audience", [c_void_p, POINTER(c_char_p)], c_int32), + ("gopher_auth_payload_get_expiration", [c_void_p, POINTER(c_int64)], c_int32), + ("gopher_auth_payload_get_issuer", [c_void_p, POINTER(c_char_p)], c_int32), + ("gopher_auth_payload_destroy", [c_void_p], c_int32), + # Utility + ("gopher_auth_free_string", [c_char_p], None), + ("gopher_auth_generate_www_authenticate", [c_char_p, c_char_p, c_char_p, POINTER(c_char_p)], c_int32), + ("gopher_auth_generate_www_authenticate_v2", [c_char_p, c_char_p, c_char_p, c_char_p, c_char_p, POINTER(c_char_p)], c_int32), +] + + +def _setup_functions() -> bool: + """ + Setup FFI function bindings for all gopher_auth_* functions. + + This configures argument types and return types for all exported + C functions in the library. Functions that don't exist in the + library are skipped. + + Returns: + True if at least gopher_auth_init is available, False otherwise. + """ + global _functions_setup, _auth_functions_available + + if _functions_setup: + return _auth_functions_available + + if _lib is None: + return False + + bound_count = 0 + for name, argtypes, restype in _FUNCTION_SPECS: + try: + func = getattr(_lib, name) + func.argtypes = argtypes + func.restype = restype + bound_count += 1 + except AttributeError: + if _debug: + print(f"Function not found: {name}", file=sys.stderr) + + _functions_setup = True + # Consider auth available if we bound at least the init function + _auth_functions_available = bound_count > 0 and _has_function("gopher_auth_init") + + if _debug: + print(f"Bound {bound_count}/{len(_FUNCTION_SPECS)} auth functions", file=sys.stderr) + + return _auth_functions_available + + +def _has_function(name: str) -> bool: + """ + Check if a specific function exists in the library. + + Args: + name: The function name to check. + + Returns: + True if function exists, False otherwise. + """ + if _lib is None: + return False + try: + getattr(_lib, name) + return True + except AttributeError: + return False + + +def _get_function(name: str) -> Optional[Any]: + """ + Get a function from the library if it exists. + + Args: + name: The function name to get. + + Returns: + The function object if found, None otherwise. + """ + if _lib is None: + return None + try: + return getattr(_lib, name) + except AttributeError: + return None + + +def is_auth_available() -> bool: + """ + Check if auth functions are available in the loaded library. + + Returns: + True if auth functions are available, False otherwise. + """ + if not load_library(): + return False + _setup_functions() + return _auth_functions_available + + +def get_auth_functions() -> Dict[str, Any]: + """ + Get dictionary of all auth function references. + + This ensures the library is loaded and functions are set up, + then returns references to all the C functions. Functions that + don't exist in the library will have None as their value. + + Returns: + Dictionary mapping function names to ctypes function objects + or None for unavailable functions. + """ + if not load_library(): + return {} + + _setup_functions() + + return { + # Library lifecycle + "auth_init": _get_function("gopher_auth_init"), + "auth_shutdown": _get_function("gopher_auth_shutdown"), + "auth_version": _get_function("gopher_auth_version"), + # Client + "client_create": _get_function("gopher_auth_client_create"), + "client_destroy": _get_function("gopher_auth_client_destroy"), + "client_set_option": _get_function("gopher_auth_client_set_option"), + # Options + "options_create": _get_function("gopher_auth_validation_options_create"), + "options_destroy": _get_function("gopher_auth_validation_options_destroy"), + "options_set_scopes": _get_function("gopher_auth_validation_options_set_scopes"), + "options_set_audience": _get_function("gopher_auth_validation_options_set_audience"), + "options_set_clock_skew": _get_function("gopher_auth_validation_options_set_clock_skew"), + # Validation + "validate_token": _get_function("gopher_auth_validate_token"), + "extract_payload": _get_function("gopher_auth_extract_payload"), + # Payload + "payload_get_subject": _get_function("gopher_auth_payload_get_subject"), + "payload_get_scopes": _get_function("gopher_auth_payload_get_scopes"), + "payload_get_audience": _get_function("gopher_auth_payload_get_audience"), + "payload_get_expiration": _get_function("gopher_auth_payload_get_expiration"), + "payload_get_issuer": _get_function("gopher_auth_payload_get_issuer"), + "payload_destroy": _get_function("gopher_auth_payload_destroy"), + # Utility + "free_string": _get_function("gopher_auth_free_string"), + "generate_www_authenticate": _get_function("gopher_auth_generate_www_authenticate"), + "generate_www_authenticate_v2": _get_function("gopher_auth_generate_www_authenticate_v2"), + } diff --git a/gopher_mcp_python/ffi/auth/types.py b/gopher_mcp_python/ffi/auth/types.py new file mode 100644 index 00000000..43545213 --- /dev/null +++ b/gopher_mcp_python/ffi/auth/types.py @@ -0,0 +1,135 @@ +""" +Auth Types - Type definitions for gopher-auth FFI bindings. + +These types mirror the C API from gopher-orch/include/gopher/orch/auth/auth_c_api.h +""" + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import List, Optional + + +class GopherAuthError(IntEnum): + """Error codes from gopher_auth_error_t enum.""" + + SUCCESS = 0 + INVALID_TOKEN = -1000 + EXPIRED_TOKEN = -1001 + INVALID_SIGNATURE = -1002 + INVALID_ISSUER = -1003 + INVALID_AUDIENCE = -1004 + INSUFFICIENT_SCOPE = -1005 + JWKS_FETCH_FAILED = -1006 + INVALID_KEY = -1007 + NETWORK_ERROR = -1008 + INVALID_CONFIG = -1009 + OUT_OF_MEMORY = -1010 + INVALID_PARAMETER = -1011 + NOT_INITIALIZED = -1012 + INTERNAL_ERROR = -1013 + TOKEN_EXCHANGE_FAILED = -1014 + IDP_NOT_LINKED = -1015 + INVALID_IDP_ALIAS = -1016 + + +# Human-readable descriptions for each error code +ERROR_DESCRIPTIONS = { + GopherAuthError.SUCCESS: "Success", + GopherAuthError.INVALID_TOKEN: "Invalid token format or structure", + GopherAuthError.EXPIRED_TOKEN: "Token has expired", + GopherAuthError.INVALID_SIGNATURE: "Token signature verification failed", + GopherAuthError.INVALID_ISSUER: "Token issuer does not match expected value", + GopherAuthError.INVALID_AUDIENCE: "Token audience does not match expected value", + GopherAuthError.INSUFFICIENT_SCOPE: "Token does not have required scopes", + GopherAuthError.JWKS_FETCH_FAILED: "Failed to fetch JWKS from server", + GopherAuthError.INVALID_KEY: "Invalid or unsupported key in JWKS", + GopherAuthError.NETWORK_ERROR: "Network error during authentication", + GopherAuthError.INVALID_CONFIG: "Invalid configuration", + GopherAuthError.OUT_OF_MEMORY: "Out of memory", + GopherAuthError.INVALID_PARAMETER: "Invalid parameter provided", + GopherAuthError.NOT_INITIALIZED: "Auth library not initialized", + GopherAuthError.INTERNAL_ERROR: "Internal error", + GopherAuthError.TOKEN_EXCHANGE_FAILED: "Token exchange failed", + GopherAuthError.IDP_NOT_LINKED: "Identity provider not linked", + GopherAuthError.INVALID_IDP_ALIAS: "Invalid identity provider alias", +} + + +def is_gopher_auth_error(code: int) -> bool: + """ + Check if a value is a valid GopherAuthError code. + + Args: + code: The error code to check. + + Returns: + True if the code is a valid GopherAuthError, False otherwise. + """ + if code == GopherAuthError.SUCCESS: + return True + # Error codes range from -1000 to -1016 + return GopherAuthError.INVALID_TOKEN >= code >= GopherAuthError.INVALID_IDP_ALIAS + + +def get_error_description(code: int) -> str: + """ + Get human-readable description for an error code. + + Args: + code: The error code to get description for. + + Returns: + Human-readable error description. + """ + try: + error = GopherAuthError(code) + return ERROR_DESCRIPTIONS.get(error, f"Unknown error code: {code}") + except ValueError: + return f"Unknown error code: {code}" + + +@dataclass +class ValidationResult: + """Token validation result.""" + + valid: bool + error_code: int + error_message: Optional[str] = None + + +@dataclass +class TokenPayload: + """Decoded JWT token payload.""" + + subject: str + scopes: str + audience: Optional[str] = None + expiration: Optional[int] = None + issuer: Optional[str] = None + + +@dataclass +class GopherAuthContext: + """Authentication context for the current request.""" + + user_id: str + scopes: str + audience: str + token_expiry: int + authenticated: bool + + +def gopher_create_empty_auth_context() -> GopherAuthContext: + """ + Create an empty auth context (unauthenticated). + + Returns: + A GopherAuthContext with default empty values and authenticated=False. + """ + return GopherAuthContext( + user_id="", + scopes="", + audience="", + token_expiry=0, + authenticated=False, + ) diff --git a/gopher_mcp_python/ffi/auth/validation_options.py b/gopher_mcp_python/ffi/auth/validation_options.py new file mode 100644 index 00000000..eb650b00 --- /dev/null +++ b/gopher_mcp_python/ffi/auth/validation_options.py @@ -0,0 +1,239 @@ +""" +ValidationOptions - Configuration for token validation. + +Provides a fluent API for configuring JWT token validation options +such as required scopes, audience, and clock skew tolerance. +""" + +from ctypes import byref, c_void_p +from typing import List, Optional, TYPE_CHECKING + +from gopher_mcp_python.ffi.auth.loader import ( + load_library, + get_auth_functions, + is_auth_available, + GopherAuthOptionsPtr, +) +from gopher_mcp_python.ffi.auth.types import GopherAuthError + + +class GopherValidationOptions: + """ + Configuration options for JWT token validation. + + Provides a fluent API for setting validation parameters like + required scopes, audience, and clock skew tolerance. + + Usage: + # Using context manager (recommended) + with GopherValidationOptions() as options: + options.set_scopes(["read", "write"]).set_audience("my-api") + result = client.validate_token(token, options) + + # Manual lifecycle management + options = GopherValidationOptions() + try: + options.set_scopes(["read"]) + result = client.validate_token(token, options) + finally: + options.destroy() + + # Using factory function with defaults + options = gopher_create_validation_options(clock_skew=60) + """ + + def __init__(self) -> None: + """ + Create a new ValidationOptions instance. + + Raises: + RuntimeError: If the library is not loaded or auth is not available. + """ + self._handle: Optional[c_void_p] = None + self._destroyed: bool = False + + if not load_library(): + raise RuntimeError("Failed to load gopher-orch library") + + if not is_auth_available(): + raise RuntimeError("Auth functions not available in library") + + funcs = get_auth_functions() + options_create = funcs.get("options_create") + if options_create is None: + raise RuntimeError("options_create function not available") + + handle = c_void_p() + result = options_create(byref(handle)) + + if result != GopherAuthError.SUCCESS: + raise RuntimeError(f"Failed to create validation options: error {result}") + + self._handle = handle + + def set_scopes(self, scopes: List[str]) -> "GopherValidationOptions": + """ + Set the required scopes for token validation. + + Args: + scopes: List of scope strings that the token must have. + + Returns: + Self for method chaining. + + Raises: + RuntimeError: If the options have been destroyed or function unavailable. + """ + self._ensure_not_destroyed() + + funcs = get_auth_functions() + options_set_scopes = funcs.get("options_set_scopes") + if options_set_scopes is None: + raise RuntimeError("options_set_scopes function not available") + + # Join scopes with space separator + scopes_str = " ".join(scopes) + options_set_scopes(self._handle, scopes_str.encode("utf-8")) + + return self + + def set_audience(self, audience: str) -> "GopherValidationOptions": + """ + Set the required audience for token validation. + + Args: + audience: The expected audience value in the token. + + Returns: + Self for method chaining. + + Raises: + RuntimeError: If the options have been destroyed or function unavailable. + """ + self._ensure_not_destroyed() + + funcs = get_auth_functions() + options_set_audience = funcs.get("options_set_audience") + if options_set_audience is None: + raise RuntimeError("options_set_audience function not available") + + options_set_audience(self._handle, audience.encode("utf-8")) + + return self + + def set_clock_skew(self, seconds: int) -> "GopherValidationOptions": + """ + Set the clock skew tolerance for token expiration checks. + + Args: + seconds: Number of seconds of clock skew to allow. + + Returns: + Self for method chaining. + + Raises: + RuntimeError: If the options have been destroyed or function unavailable. + """ + self._ensure_not_destroyed() + + funcs = get_auth_functions() + options_set_clock_skew = funcs.get("options_set_clock_skew") + if options_set_clock_skew is None: + raise RuntimeError("options_set_clock_skew function not available") + + options_set_clock_skew(self._handle, seconds) + + return self + + def get_handle(self) -> c_void_p: + """ + Get the native handle for this options instance. + + Returns: + The ctypes void pointer handle. + + Raises: + RuntimeError: If the options have been destroyed. + """ + self._ensure_not_destroyed() + return self._handle + + def destroy(self) -> None: + """ + Destroy the native options handle and release resources. + + This method is idempotent - calling it multiple times is safe. + """ + if self._destroyed or self._handle is None: + return + + funcs = get_auth_functions() + options_destroy = funcs.get("options_destroy") + if options_destroy is not None: + options_destroy(self._handle) + + self._handle = None + self._destroyed = True + + def is_destroyed(self) -> bool: + """ + Check if this options instance has been destroyed. + + Returns: + True if destroyed, False otherwise. + """ + return self._destroyed + + def _ensure_not_destroyed(self) -> None: + """ + Ensure the options have not been destroyed. + + Raises: + RuntimeError: If the options have been destroyed. + """ + if self._destroyed: + raise RuntimeError("GopherValidationOptions has been destroyed") + + def __enter__(self) -> "GopherValidationOptions": + """Enter context manager.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager and destroy resources.""" + self.destroy() + + def __del__(self) -> None: + """Destructor - ensure resources are released.""" + self.destroy() + + +def gopher_create_validation_options( + scopes: Optional[List[str]] = None, + audience: Optional[str] = None, + clock_skew: int = 30, +) -> GopherValidationOptions: + """ + Factory function to create GopherValidationOptions with common defaults. + + Args: + scopes: Optional list of required scopes. + audience: Optional required audience. + clock_skew: Clock skew tolerance in seconds (default: 30). + + Returns: + A configured GopherValidationOptions instance. + + Raises: + RuntimeError: If the library is not loaded or auth is not available. + """ + options = GopherValidationOptions() + + if scopes is not None: + options.set_scopes(scopes) + + if audience is not None: + options.set_audience(audience) + + options.set_clock_skew(clock_skew) + + return options diff --git a/packages/darwin-arm64/gopher_mcp_python_native_darwin_arm64/__init__.py b/packages/darwin-arm64/gopher_mcp_python_native_darwin_arm64/__init__.py index de3a310c..2e9569c1 100644 --- a/packages/darwin-arm64/gopher_mcp_python_native_darwin_arm64/__init__.py +++ b/packages/darwin-arm64/gopher_mcp_python_native_darwin_arm64/__init__.py @@ -7,7 +7,7 @@ import os from pathlib import Path -__version__ = "0.1.1" +__version__ = "0.1.2" # Platform identifier PLATFORM = "darwin" diff --git a/packages/darwin-arm64/pyproject.toml b/packages/darwin-arm64/pyproject.toml index d3376057..9415f5af 100644 --- a/packages/darwin-arm64/pyproject.toml +++ b/packages/darwin-arm64/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gopher-mcp-python-native-darwin-arm64" -version = "0.1.1" +version = "0.1.2" description = "Native library for gopher-mcp-python (macOS ARM64)" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/packages/darwin-x64/gopher_mcp_python_native_darwin_x64/__init__.py b/packages/darwin-x64/gopher_mcp_python_native_darwin_x64/__init__.py index a20dbbcb..8ecb7914 100644 --- a/packages/darwin-x64/gopher_mcp_python_native_darwin_x64/__init__.py +++ b/packages/darwin-x64/gopher_mcp_python_native_darwin_x64/__init__.py @@ -7,7 +7,7 @@ import os from pathlib import Path -__version__ = "0.1.1" +__version__ = "0.1.2" # Platform identifier PLATFORM = "darwin" diff --git a/packages/darwin-x64/pyproject.toml b/packages/darwin-x64/pyproject.toml index da2afd49..d7ba46f1 100644 --- a/packages/darwin-x64/pyproject.toml +++ b/packages/darwin-x64/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gopher-mcp-python-native-darwin-x64" -version = "0.1.1" +version = "0.1.2" description = "Native library for gopher-mcp-python (macOS Intel)" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/packages/linux-arm64/gopher_mcp_python_native_linux_arm64/__init__.py b/packages/linux-arm64/gopher_mcp_python_native_linux_arm64/__init__.py index 75c9d273..f03a7715 100644 --- a/packages/linux-arm64/gopher_mcp_python_native_linux_arm64/__init__.py +++ b/packages/linux-arm64/gopher_mcp_python_native_linux_arm64/__init__.py @@ -7,7 +7,7 @@ import os from pathlib import Path -__version__ = "0.1.1" +__version__ = "0.1.2" # Platform identifier PLATFORM = "linux" diff --git a/packages/linux-arm64/pyproject.toml b/packages/linux-arm64/pyproject.toml index 815f3f02..6926c3af 100644 --- a/packages/linux-arm64/pyproject.toml +++ b/packages/linux-arm64/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gopher-mcp-python-native-linux-arm64" -version = "0.1.1" +version = "0.1.2" description = "Native library for gopher-mcp-python (Linux ARM64)" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/packages/linux-x64/gopher_mcp_python_native_linux_x64/__init__.py b/packages/linux-x64/gopher_mcp_python_native_linux_x64/__init__.py index 8320627b..679cd027 100644 --- a/packages/linux-x64/gopher_mcp_python_native_linux_x64/__init__.py +++ b/packages/linux-x64/gopher_mcp_python_native_linux_x64/__init__.py @@ -7,7 +7,7 @@ import os from pathlib import Path -__version__ = "0.1.1" +__version__ = "0.1.2" # Platform identifier PLATFORM = "linux" diff --git a/packages/linux-x64/pyproject.toml b/packages/linux-x64/pyproject.toml index 48c67300..db60a7e9 100644 --- a/packages/linux-x64/pyproject.toml +++ b/packages/linux-x64/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gopher-mcp-python-native-linux-x64" -version = "0.1.1" +version = "0.1.2" description = "Native library for gopher-mcp-python (Linux x64)" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/packages/win32-arm64/gopher_mcp_python_native_win32_arm64/__init__.py b/packages/win32-arm64/gopher_mcp_python_native_win32_arm64/__init__.py index e9dc8659..7aa87233 100644 --- a/packages/win32-arm64/gopher_mcp_python_native_win32_arm64/__init__.py +++ b/packages/win32-arm64/gopher_mcp_python_native_win32_arm64/__init__.py @@ -7,7 +7,7 @@ import os from pathlib import Path -__version__ = "0.1.1" +__version__ = "0.1.2" # Platform identifier PLATFORM = "win32" diff --git a/packages/win32-arm64/pyproject.toml b/packages/win32-arm64/pyproject.toml index c4e638a9..d078ba08 100644 --- a/packages/win32-arm64/pyproject.toml +++ b/packages/win32-arm64/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gopher-mcp-python-native-win32-arm64" -version = "0.1.1" +version = "0.1.2" description = "Native library for gopher-mcp-python (Windows ARM64)" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/packages/win32-x64/gopher_mcp_python_native_win32_x64/__init__.py b/packages/win32-x64/gopher_mcp_python_native_win32_x64/__init__.py index 3a2e9719..bf9cb6dc 100644 --- a/packages/win32-x64/gopher_mcp_python_native_win32_x64/__init__.py +++ b/packages/win32-x64/gopher_mcp_python_native_win32_x64/__init__.py @@ -7,7 +7,7 @@ import os from pathlib import Path -__version__ = "0.1.1" +__version__ = "0.1.2" # Platform identifier PLATFORM = "win32" diff --git a/packages/win32-x64/pyproject.toml b/packages/win32-x64/pyproject.toml index 1d413035..944c75bd 100644 --- a/packages/win32-x64/pyproject.toml +++ b/packages/win32-x64/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gopher-mcp-python-native-win32-x64" -version = "0.1.1" +version = "0.1.2" description = "Native library for gopher-mcp-python (Windows x64)" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/pyproject.toml b/pyproject.toml index cfe2d5fa..fb33bf68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gopher-mcp-python" -version = "0.1.1" +version = "0.1.2.1" description = "Python SDK for Gopher MCP - AI Agent orchestration framework with native performance" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/tests/ffi/__init__.py b/tests/ffi/__init__.py new file mode 100644 index 00000000..54590edd --- /dev/null +++ b/tests/ffi/__init__.py @@ -0,0 +1 @@ +"""Tests for FFI modules.""" diff --git a/tests/ffi/auth/__init__.py b/tests/ffi/auth/__init__.py new file mode 100644 index 00000000..531e0803 --- /dev/null +++ b/tests/ffi/auth/__init__.py @@ -0,0 +1 @@ +"""Tests for auth FFI module.""" diff --git a/tests/ffi/auth/test_auth_client.py b/tests/ffi/auth/test_auth_client.py new file mode 100644 index 00000000..109bd836 --- /dev/null +++ b/tests/ffi/auth/test_auth_client.py @@ -0,0 +1,200 @@ +"""Tests for GopherAuthClient class and module functions.""" + +import pytest + +from gopher_mcp_python.ffi.auth.loader import is_auth_available, load_library +from gopher_mcp_python.ffi.auth.auth_client import ( + GopherAuthClient, + gopher_init_auth_library, + gopher_shutdown_auth_library, + gopher_get_auth_library_version, + gopher_is_auth_library_initialized, + gopher_generate_www_authenticate_header, + gopher_generate_www_authenticate_header_v2, +) + + +class TestLibraryLifecycle: + """Tests for library lifecycle functions.""" + + def test_init_returns_bool(self): + """Test gopher_init_auth_library returns boolean.""" + result = gopher_init_auth_library() + assert isinstance(result, bool) + + def test_is_initialized_returns_bool(self): + """Test gopher_is_auth_library_initialized returns boolean.""" + result = gopher_is_auth_library_initialized() + assert isinstance(result, bool) + + def test_shutdown_returns_bool(self): + """Test gopher_shutdown_auth_library returns boolean.""" + result = gopher_shutdown_auth_library() + assert isinstance(result, bool) + + +class TestGetAuthLibraryVersion: + """Tests for gopher_get_auth_library_version function.""" + + def test_returns_string_or_none(self): + """Test function returns string or None.""" + version = gopher_get_auth_library_version() + assert version is None or isinstance(version, str) + + +class TestGenerateWwwAuthenticateHeader: + """Tests for gopher_generate_www_authenticate_header function.""" + + @pytest.mark.skipif( + not is_auth_available(), + reason="Auth functions not available", + ) + def test_returns_string_or_none(self): + """Test function returns string or None.""" + header = gopher_generate_www_authenticate_header("test-realm") + assert header is None or isinstance(header, str) + + def test_without_auth_returns_none(self): + """Test returns None when auth not available.""" + if not is_auth_available(): + header = gopher_generate_www_authenticate_header("test-realm") + assert header is None + + +class TestGenerateWwwAuthenticateHeaderV2: + """Tests for gopher_generate_www_authenticate_header_v2 function.""" + + @pytest.mark.skipif( + not is_auth_available(), + reason="Auth functions not available", + ) + def test_returns_string_or_none(self): + """Test function returns string or None.""" + header = gopher_generate_www_authenticate_header_v2( + realm="test-realm", + resource_metadata_url="https://example.com/.well-known/oauth-protected-resource", + scopes=["read", "write"], + ) + assert header is None or isinstance(header, str) + + +class TestGopherAuthClientWhenNotAvailable: + """Tests for GopherAuthClient when auth not available.""" + + @pytest.mark.skipif( + is_auth_available(), + reason="Auth functions are available", + ) + def test_raises_when_auth_not_available(self): + """Test GopherAuthClient raises when auth functions not available.""" + with pytest.raises(RuntimeError, match="not available"): + GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) + + +# Skip remaining tests if auth functions not available +pytestmark_auth = pytest.mark.skipif( + not is_auth_available(), + reason="Auth functions not available in library", +) + + +@pytestmark_auth +class TestGopherAuthClientInit: + """Tests for GopherAuthClient initialization.""" + + def test_creates_instance(self): + """Test creating GopherAuthClient instance.""" + client = GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) + assert client is not None + client.destroy() + + def test_not_destroyed_initially(self): + """Test client is not destroyed after creation.""" + client = GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) + assert client.is_destroyed() is False + client.destroy() + + +@pytestmark_auth +class TestGopherAuthClientSetOption: + """Tests for set_option method.""" + + def test_set_option_returns_bool(self): + """Test set_option returns boolean.""" + client = GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) + result = client.set_option("cache_duration", "3600") + assert isinstance(result, bool) + client.destroy() + + +@pytestmark_auth +class TestGopherAuthClientLifecycle: + """Tests for lifecycle methods.""" + + def test_destroy_marks_destroyed(self): + """Test destroy() marks client as destroyed.""" + client = GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) + client.destroy() + assert client.is_destroyed() is True + + def test_destroy_is_idempotent(self): + """Test destroy() can be called multiple times.""" + client = GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) + client.destroy() + client.destroy() # Should not raise + assert client.is_destroyed() is True + + def test_validate_token_raises_when_destroyed(self): + """Test validate_token() raises after destroy.""" + client = GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) + client.destroy() + with pytest.raises(RuntimeError, match="destroyed"): + client.validate_token("some-token") + + +@pytestmark_auth +class TestGopherAuthClientContextManager: + """Tests for context manager protocol.""" + + def test_with_statement(self): + """Test using GopherAuthClient with 'with' statement.""" + with GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) as client: + assert client is not None + assert client.is_destroyed() is False + assert client.is_destroyed() is True + + def test_with_statement_exception_cleanup(self): + """Test resources cleaned up even on exception.""" + try: + with GopherAuthClient( + "https://example.com/.well-known/jwks.json", + "https://example.com", + ) as client: + raise ValueError("Test error") + except ValueError: + pass + assert client.is_destroyed() is True diff --git a/tests/ffi/auth/test_loader.py b/tests/ffi/auth/test_loader.py new file mode 100644 index 00000000..2f5453aa --- /dev/null +++ b/tests/ffi/auth/test_loader.py @@ -0,0 +1,210 @@ +"""Tests for auth loader module.""" + +import platform +import pytest +from unittest.mock import patch + +from gopher_mcp_python.ffi.auth.loader import ( + _get_library_name, + _get_search_paths, + load_library, + is_library_loaded, + get_library, + is_auth_available, + get_auth_functions, + GopherAuthClientPtr, + GopherAuthPayloadPtr, + GopherAuthOptionsPtr, + GopherAuthValidationResult, +) + + +class TestGetLibraryName: + """Tests for _get_library_name function.""" + + def test_returns_string(self): + """Test function returns a string.""" + name = _get_library_name() + assert isinstance(name, str) + assert len(name) > 0 + + @patch("platform.system") + def test_darwin_library_name(self, mock_system): + """Test macOS library name.""" + mock_system.return_value = "Darwin" + name = _get_library_name() + assert name == "libgopher-orch.dylib" + + @patch("platform.system") + def test_windows_library_name(self, mock_system): + """Test Windows library name.""" + mock_system.return_value = "Windows" + name = _get_library_name() + assert name == "gopher-orch.dll" + + @patch("platform.system") + def test_linux_library_name(self, mock_system): + """Test Linux library name.""" + mock_system.return_value = "Linux" + name = _get_library_name() + assert name == "libgopher-orch.so" + + +class TestGetSearchPaths: + """Tests for _get_search_paths function.""" + + def test_returns_list(self): + """Test function returns a list.""" + paths = _get_search_paths() + assert isinstance(paths, list) + + def test_paths_are_strings(self): + """Test all paths are strings.""" + paths = _get_search_paths() + for path in paths: + assert isinstance(path, str) + + def test_includes_native_lib(self): + """Test includes native/lib path.""" + paths = _get_search_paths() + native_lib_found = any("native" in p and "lib" in p for p in paths) + assert native_lib_found + + @patch("platform.system") + def test_darwin_includes_homebrew(self, mock_system): + """Test macOS includes homebrew paths.""" + mock_system.return_value = "Darwin" + paths = _get_search_paths() + homebrew_found = any("/opt/homebrew/lib" in p for p in paths) + assert homebrew_found + + def test_includes_usr_lib(self): + """Test includes /usr/lib path.""" + paths = _get_search_paths() + usr_lib_found = any("/usr/lib" in p for p in paths) + assert usr_lib_found + + +class TestPointerTypes: + """Tests for opaque pointer types.""" + + def test_client_ptr_is_void_p(self): + """Test GopherAuthClientPtr is c_void_p.""" + from ctypes import c_void_p + assert GopherAuthClientPtr == c_void_p + + def test_payload_ptr_is_void_p(self): + """Test GopherAuthPayloadPtr is c_void_p.""" + from ctypes import c_void_p + assert GopherAuthPayloadPtr == c_void_p + + def test_options_ptr_is_void_p(self): + """Test GopherAuthOptionsPtr is c_void_p.""" + from ctypes import c_void_p + assert GopherAuthOptionsPtr == c_void_p + + +class TestValidationResultStructure: + """Tests for GopherAuthValidationResult structure.""" + + def test_has_valid_field(self): + """Test structure has valid field.""" + result = GopherAuthValidationResult() + assert hasattr(result, "valid") + + def test_has_error_code_field(self): + """Test structure has error_code field.""" + result = GopherAuthValidationResult() + assert hasattr(result, "error_code") + + def test_has_error_message_field(self): + """Test structure has error_message field.""" + result = GopherAuthValidationResult() + assert hasattr(result, "error_message") + + def test_default_values(self): + """Test default field values.""" + result = GopherAuthValidationResult() + assert result.valid is False + assert result.error_code == 0 + + +class TestLoadLibrary: + """Tests for load_library function.""" + + def test_returns_bool(self): + """Test function returns boolean.""" + result = load_library() + assert isinstance(result, bool) + + def test_idempotent(self): + """Test multiple calls return same result.""" + result1 = load_library() + result2 = load_library() + assert result1 == result2 + + +class TestIsLibraryLoaded: + """Tests for is_library_loaded function.""" + + def test_returns_bool(self): + """Test function returns boolean.""" + result = is_library_loaded() + assert isinstance(result, bool) + + def test_consistent_with_load(self): + """Test consistency with load_library.""" + load_result = load_library() + is_loaded = is_library_loaded() + assert load_result == is_loaded + + +class TestGetLibrary: + """Tests for get_library function.""" + + def test_returns_none_or_cdll(self): + """Test function returns None or CDLL instance.""" + lib = get_library() + if lib is not None: + import ctypes + assert isinstance(lib, ctypes.CDLL) + + +class TestIsAuthAvailable: + """Tests for is_auth_available function.""" + + def test_returns_bool(self): + """Test function returns boolean.""" + result = is_auth_available() + assert isinstance(result, bool) + + +class TestGetAuthFunctions: + """Tests for get_auth_functions function.""" + + def test_returns_dict(self): + """Test function returns dictionary.""" + funcs = get_auth_functions() + assert isinstance(funcs, dict) + + def test_expected_keys(self): + """Test dictionary has expected keys.""" + funcs = get_auth_functions() + expected_keys = [ + "auth_init", + "auth_shutdown", + "auth_version", + "client_create", + "client_destroy", + "validate_token", + "extract_payload", + ] + for key in expected_keys: + assert key in funcs + + def test_values_none_or_callable(self): + """Test values are None or callable.""" + funcs = get_auth_functions() + for name, func in funcs.items(): + if func is not None: + assert callable(func), f"{name} is not callable" diff --git a/tests/ffi/auth/test_types.py b/tests/ffi/auth/test_types.py new file mode 100644 index 00000000..8bce27ec --- /dev/null +++ b/tests/ffi/auth/test_types.py @@ -0,0 +1,227 @@ +"""Tests for auth types module.""" + +import pytest + +from gopher_mcp_python.ffi.auth.types import ( + GopherAuthError, + ERROR_DESCRIPTIONS, + ValidationResult, + TokenPayload, + GopherAuthContext, + is_gopher_auth_error, + get_error_description, + gopher_create_empty_auth_context, +) + + +class TestGopherAuthError: + """Tests for GopherAuthError enum.""" + + def test_success_value(self): + """Test SUCCESS has value 0.""" + assert GopherAuthError.SUCCESS == 0 + + def test_invalid_token_value(self): + """Test INVALID_TOKEN has value -1000.""" + assert GopherAuthError.INVALID_TOKEN == -1000 + + def test_expired_token_value(self): + """Test EXPIRED_TOKEN has value -1001.""" + assert GopherAuthError.EXPIRED_TOKEN == -1001 + + def test_invalid_idp_alias_value(self): + """Test INVALID_IDP_ALIAS has value -1016.""" + assert GopherAuthError.INVALID_IDP_ALIAS == -1016 + + def test_all_error_codes_negative(self): + """Test all error codes except SUCCESS are negative.""" + for error in GopherAuthError: + if error != GopherAuthError.SUCCESS: + assert error < 0 + + +class TestIsGopherAuthError: + """Tests for is_gopher_auth_error function.""" + + def test_success_is_valid(self): + """Test SUCCESS (0) is valid.""" + assert is_gopher_auth_error(0) is True + + def test_invalid_token_is_valid(self): + """Test INVALID_TOKEN (-1000) is valid.""" + assert is_gopher_auth_error(-1000) is True + + def test_invalid_idp_alias_is_valid(self): + """Test INVALID_IDP_ALIAS (-1016) is valid.""" + assert is_gopher_auth_error(-1016) is True + + def test_positive_number_is_invalid(self): + """Test positive numbers are invalid.""" + assert is_gopher_auth_error(1) is False + assert is_gopher_auth_error(100) is False + + def test_out_of_range_negative_is_invalid(self): + """Test out of range negative numbers are invalid.""" + assert is_gopher_auth_error(-999) is False + assert is_gopher_auth_error(-1017) is False + assert is_gopher_auth_error(-2000) is False + + +class TestGetErrorDescription: + """Tests for get_error_description function.""" + + def test_success_description(self): + """Test SUCCESS description.""" + assert get_error_description(0) == "Success" + + def test_invalid_token_description(self): + """Test INVALID_TOKEN description.""" + desc = get_error_description(-1000) + assert "Invalid token" in desc + + def test_expired_token_description(self): + """Test EXPIRED_TOKEN description.""" + desc = get_error_description(-1001) + assert "expired" in desc.lower() + + def test_unknown_error_code(self): + """Test unknown error code returns generic message.""" + desc = get_error_description(-9999) + assert "Unknown error code" in desc + assert "-9999" in desc + + def test_all_errors_have_descriptions(self): + """Test all error codes have descriptions.""" + for error in GopherAuthError: + desc = get_error_description(error) + assert desc is not None + assert len(desc) > 0 + + +class TestErrorDescriptions: + """Tests for ERROR_DESCRIPTIONS dict.""" + + def test_all_errors_in_dict(self): + """Test all error codes are in ERROR_DESCRIPTIONS.""" + for error in GopherAuthError: + assert error in ERROR_DESCRIPTIONS + + def test_descriptions_are_strings(self): + """Test all descriptions are non-empty strings.""" + for error, desc in ERROR_DESCRIPTIONS.items(): + assert isinstance(desc, str) + assert len(desc) > 0 + + +class TestValidationResult: + """Tests for ValidationResult dataclass.""" + + def test_create_valid_result(self): + """Test creating a valid result.""" + result = ValidationResult(valid=True, error_code=0) + assert result.valid is True + assert result.error_code == 0 + assert result.error_message is None + + def test_create_invalid_result(self): + """Test creating an invalid result with error.""" + result = ValidationResult( + valid=False, + error_code=-1000, + error_message="Invalid token", + ) + assert result.valid is False + assert result.error_code == -1000 + assert result.error_message == "Invalid token" + + +class TestTokenPayload: + """Tests for TokenPayload dataclass.""" + + def test_create_minimal_payload(self): + """Test creating payload with required fields only.""" + payload = TokenPayload(subject="user123", scopes="read") + assert payload.subject == "user123" + assert payload.scopes == "read" + assert payload.audience is None + assert payload.expiration is None + assert payload.issuer is None + + def test_create_full_payload(self): + """Test creating payload with all fields.""" + payload = TokenPayload( + subject="user123", + scopes="read write", + audience="my-api", + expiration=1234567890, + issuer="https://auth.example.com", + ) + assert payload.subject == "user123" + assert payload.scopes == "read write" + assert payload.audience == "my-api" + assert payload.expiration == 1234567890 + assert payload.issuer == "https://auth.example.com" + + +class TestGopherAuthContext: + """Tests for GopherAuthContext dataclass.""" + + def test_create_authenticated_context(self): + """Test creating an authenticated context.""" + context = GopherAuthContext( + user_id="user123", + scopes="read write", + audience="my-api", + token_expiry=1234567890, + authenticated=True, + ) + assert context.user_id == "user123" + assert context.scopes == "read write" + assert context.audience == "my-api" + assert context.token_expiry == 1234567890 + assert context.authenticated is True + + def test_create_unauthenticated_context(self): + """Test creating an unauthenticated context.""" + context = GopherAuthContext( + user_id="", + scopes="", + audience="", + token_expiry=0, + authenticated=False, + ) + assert context.authenticated is False + + +class TestGopherCreateEmptyAuthContext: + """Tests for gopher_create_empty_auth_context function.""" + + def test_returns_auth_context(self): + """Test function returns GopherAuthContext instance.""" + context = gopher_create_empty_auth_context() + assert isinstance(context, GopherAuthContext) + + def test_empty_user_id(self): + """Test empty auth context has empty user_id.""" + context = gopher_create_empty_auth_context() + assert context.user_id == "" + + def test_empty_scopes(self): + """Test empty auth context has empty scopes.""" + context = gopher_create_empty_auth_context() + assert context.scopes == "" + + def test_empty_audience(self): + """Test empty auth context has empty audience.""" + context = gopher_create_empty_auth_context() + assert context.audience == "" + + def test_zero_token_expiry(self): + """Test empty auth context has zero token_expiry.""" + context = gopher_create_empty_auth_context() + assert context.token_expiry == 0 + + def test_not_authenticated(self): + """Test empty auth context is not authenticated.""" + context = gopher_create_empty_auth_context() + assert context.authenticated is False diff --git a/tests/ffi/auth/test_validation_options.py b/tests/ffi/auth/test_validation_options.py new file mode 100644 index 00000000..1ca5edc1 --- /dev/null +++ b/tests/ffi/auth/test_validation_options.py @@ -0,0 +1,166 @@ +"""Tests for GopherValidationOptions class.""" + +import pytest + +from gopher_mcp_python.ffi.auth.loader import is_auth_available +from gopher_mcp_python.ffi.auth.validation_options import ( + GopherValidationOptions, + gopher_create_validation_options, +) + + +# Skip all tests if auth functions not available +pytestmark = pytest.mark.skipif( + not is_auth_available(), + reason="Auth functions not available in library", +) + + +class TestGopherValidationOptionsInit: + """Tests for GopherValidationOptions initialization.""" + + def test_creates_instance(self): + """Test creating GopherValidationOptions instance.""" + options = GopherValidationOptions() + assert options is not None + options.destroy() + + def test_not_destroyed_initially(self): + """Test options are not destroyed after creation.""" + options = GopherValidationOptions() + assert options.is_destroyed() is False + options.destroy() + + +class TestGopherValidationOptionsFluentApi: + """Tests for fluent API methods.""" + + def test_set_scopes_returns_self(self): + """Test set_scopes returns self for chaining.""" + options = GopherValidationOptions() + result = options.set_scopes(["read", "write"]) + assert result is options + options.destroy() + + def test_set_audience_returns_self(self): + """Test set_audience returns self for chaining.""" + options = GopherValidationOptions() + result = options.set_audience("my-api") + assert result is options + options.destroy() + + def test_set_clock_skew_returns_self(self): + """Test set_clock_skew returns self for chaining.""" + options = GopherValidationOptions() + result = options.set_clock_skew(60) + assert result is options + options.destroy() + + def test_method_chaining(self): + """Test multiple methods can be chained.""" + options = GopherValidationOptions() + result = ( + options + .set_scopes(["read"]) + .set_audience("my-api") + .set_clock_skew(30) + ) + assert result is options + options.destroy() + + +class TestGopherValidationOptionsLifecycle: + """Tests for lifecycle methods.""" + + def test_destroy_marks_destroyed(self): + """Test destroy() marks options as destroyed.""" + options = GopherValidationOptions() + options.destroy() + assert options.is_destroyed() is True + + def test_destroy_is_idempotent(self): + """Test destroy() can be called multiple times.""" + options = GopherValidationOptions() + options.destroy() + options.destroy() # Should not raise + assert options.is_destroyed() is True + + def test_get_handle_raises_when_destroyed(self): + """Test get_handle() raises after destroy.""" + options = GopherValidationOptions() + options.destroy() + with pytest.raises(RuntimeError, match="destroyed"): + options.get_handle() + + def test_set_scopes_raises_when_destroyed(self): + """Test set_scopes() raises after destroy.""" + options = GopherValidationOptions() + options.destroy() + with pytest.raises(RuntimeError, match="destroyed"): + options.set_scopes(["read"]) + + +class TestGopherValidationOptionsContextManager: + """Tests for context manager protocol.""" + + def test_with_statement(self): + """Test using GopherValidationOptions with 'with' statement.""" + with GopherValidationOptions() as options: + assert options is not None + assert options.is_destroyed() is False + assert options.is_destroyed() is True + + def test_with_statement_with_config(self): + """Test configuring options within 'with' statement.""" + with GopherValidationOptions() as options: + options.set_scopes(["read", "write"]) + options.set_audience("my-api") + assert options.is_destroyed() is False + + def test_with_statement_exception_cleanup(self): + """Test resources cleaned up even on exception.""" + try: + with GopherValidationOptions() as options: + raise ValueError("Test error") + except ValueError: + pass + assert options.is_destroyed() is True + + +class TestGopherCreateValidationOptions: + """Tests for gopher_create_validation_options factory function.""" + + def test_creates_instance(self): + """Test factory creates GopherValidationOptions instance.""" + options = gopher_create_validation_options() + assert isinstance(options, GopherValidationOptions) + options.destroy() + + def test_default_clock_skew(self): + """Test default clock_skew is applied.""" + options = gopher_create_validation_options() + # Can't directly verify, but should not raise + assert options is not None + options.destroy() + + def test_with_scopes(self): + """Test factory with scopes parameter.""" + options = gopher_create_validation_options(scopes=["read", "write"]) + assert options is not None + options.destroy() + + def test_with_audience(self): + """Test factory with audience parameter.""" + options = gopher_create_validation_options(audience="my-api") + assert options is not None + options.destroy() + + def test_with_all_params(self): + """Test factory with all parameters.""" + options = gopher_create_validation_options( + scopes=["read"], + audience="my-api", + clock_skew=60, + ) + assert options is not None + options.destroy() diff --git a/third_party/gopher-orch b/third_party/gopher-orch index 63ac8c66..48ddecc4 160000 --- a/third_party/gopher-orch +++ b/third_party/gopher-orch @@ -1 +1 @@ -Subproject commit 63ac8c66f86a3113d6e13202404f46a6bf82b065 +Subproject commit 48ddecc475f486ed675fd22128168c6229f65f2a