diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 9addb661d..fd24a15ae 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -1,210 +1,121 @@ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +"""Tests for the WebSocket transport. + +The smoke test (``test_ws_client_basic_connection``) runs the full WS stack +end-to-end over a real TCP connection and is what provides coverage of +``src/mcp/client/websocket.py``. + +The remaining tests verify transport-agnostic MCP semantics (error +propagation, client-side timeouts) and use the in-memory ``Client`` transport +to avoid the cost and flakiness of real network servers. +""" + +from collections.abc import Generator from urllib.parse import urlparse import anyio import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp import MCPError +from mcp import Client, MCPError from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client from mcp.server import Server, ServerRequestContext from mcp.server.websocket import websocket_server from mcp.types import ( - CallToolRequestParams, - CallToolResult, EmptyResult, InitializeResult, - ListToolsResult, - PaginatedRequestParams, ReadResourceRequestParams, ReadResourceResult, - TextContent, TextResourceContents, - Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_uvicorn_in_thread SERVER_NAME = "test_server_for_WS" +pytestmark = pytest.mark.anyio -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +# --- WebSocket transport smoke test (real TCP) ------------------------------- -@pytest.fixture -def server_url(server_port: int) -> str: - return f"ws://127.0.0.1:{server_port}" +def make_server_app() -> Starlette: + srv = Server(SERVER_NAME) -async def handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: - parsed = urlparse(str(params.uri)) - if parsed.scheme == "foobar": - return ReadResourceResult( - contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] - ) - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain" - ) - ] - ) - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - -async def handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - - -async def handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover - return Server( - SERVER_NAME, - on_read_resource=handle_read_resource, - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - - -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with WebSocket transport""" - server = _create_server() - - async def handle_ws(websocket: WebSocket): + async def handle_ws(websocket: WebSocket) -> None: async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) - - app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - return app - + await srv.run(streams[0], streams[1], srv.create_initialization_options()) -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() + return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") +@pytest.fixture +def ws_server_url() -> Generator[str, None, None]: + with run_uvicorn_in_thread(make_server_app()) as base_url: + yield base_url.replace("http://", "ws://") + "/ws" -@pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - """Create and initialize a WebSocket client session""" - async with websocket_client(server_url + "/ws") as streams: +async def test_ws_client_basic_connection(ws_server_url: str) -> None: + async with websocket_client(ws_server_url) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) - yield session +# --- In-memory tests (transport-agnostic MCP semantics) ---------------------- -# Tests -@pytest.mark.anyio -async def test_ws_client_basic_connection(server: None, server_url: str) -> None: - """Test the WebSocket connection establishment""" - async with websocket_client(server_url + "/ws") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) +async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: + parsed = urlparse(str(params.uri)) + if parsed.scheme == "foobar": + return ReadResourceResult( + contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] + ) + elif parsed.scheme == "slow": + # Block indefinitely so the client-side fail_after() fires; the pending + # server task is cancelled when the Client context manager exits. + await anyio.sleep_forever() + raise MCPError(code=404, message="OOPS! no resource with that URI was found") + + +@pytest.fixture +def server() -> Server: + return Server(SERVER_NAME, on_read_resource=handle_read_resource) -@pytest.mark.anyio -async def test_ws_client_happy_request_and_response( - initialized_ws_client_session: ClientSession, -) -> None: - """Test a successful request and response via WebSocket""" - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" - - -@pytest.mark.anyio -async def test_ws_client_exception_handling( - initialized_ws_client_session: ClientSession, -) -> None: - """Test exception handling in WebSocket communication""" - with pytest.raises(MCPError) as exc_info: - await initialized_ws_client_session.read_resource("unknown://example") - assert exc_info.value.error.code == 404 - - -@pytest.mark.anyio -async def test_ws_client_timeout( - initialized_ws_client_session: ClientSession, -) -> None: - """Test timeout handling in WebSocket communication""" - # Set a very short timeout to trigger a timeout exception - with pytest.raises(TimeoutError): - with anyio.fail_after(0.1): # 100ms timeout - await initialized_ws_client_session.read_resource("slow://example") - - # Now test that we can still use the session after a timeout - with anyio.fail_after(5): # Longer timeout to allow completion - result = await initialized_ws_client_session.read_resource("foobar://example") +async def test_ws_client_happy_request_and_response(server: Server) -> None: + async with Client(server) as client: + result = await client.read_resource("foobar://example") assert isinstance(result, ReadResourceResult) assert isinstance(result.contents, list) assert len(result.contents) > 0 assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Read example" + + +async def test_ws_client_exception_handling(server: Server) -> None: + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("unknown://example") + assert exc_info.value.error.code == 404 + + +async def test_ws_client_timeout(server: Server) -> None: + async with Client(server) as client: + with pytest.raises(TimeoutError): + with anyio.fail_after(0.1): + await client.read_resource("slow://example") + + # Session remains usable after a client-side timeout abandons a request. + with anyio.fail_after(5): + result = await client.read_resource("foobar://example") + assert isinstance(result, ReadResourceResult) + assert isinstance(result.contents, list) + assert len(result.contents) > 0 + assert isinstance(result.contents[0], TextResourceContents) + assert result.contents[0].text == "Read example" diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5c04c269f..bcc7e3edf 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,73 @@ """Common test utilities for MCP server tests.""" import socket +import threading import time +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import uvicorn + +# How long to wait for the uvicorn server thread to reach `started`. +# Generous to absorb CI scheduling delays — actual startup is typically <100ms. +_SERVER_START_TIMEOUT_S = 20.0 +_SERVER_SHUTDOWN_TIMEOUT_S = 5.0 + + +@contextmanager +def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]: + """Run a uvicorn server in a background thread with an ephemeral port. + + This eliminates the TOCTOU race that occurs when a test picks a free port + with ``socket.bind((host, 0))``, releases it, then starts a server hoping + to rebind the same port — between release and rebind, another pytest-xdist + worker may claim it, causing connection errors or cross-test contamination. + + With ``port=0``, the OS atomically assigns a free port at bind time; the + server holds it from that moment until shutdown. We read the actual port + back from uvicorn's bound socket after startup completes. + + Args: + app: ASGI application to serve. + **config_kwargs: Additional keyword arguments for :class:`uvicorn.Config` + (e.g. ``log_level``, ``limit_concurrency``). ``host`` defaults to + ``127.0.0.1`` and ``port`` is forced to 0. + + Yields: + The base URL of the running server, e.g. ``http://127.0.0.1:54321``. + + Raises: + TimeoutError: If the server does not start within 20 seconds. + RuntimeError: If the server thread dies during startup. + """ + config_kwargs.setdefault("host", "127.0.0.1") + config_kwargs.setdefault("log_level", "error") + config = uvicorn.Config(app=app, port=0, **config_kwargs) + server = uvicorn.Server(config=config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # uvicorn sets `server.started = True` at the end of `Server.startup()`, + # after sockets are bound and the lifespan startup phase has completed. + start = time.monotonic() + while not server.started: + if time.monotonic() - start > _SERVER_START_TIMEOUT_S: # pragma: no cover + raise TimeoutError(f"uvicorn server failed to start within {_SERVER_START_TIMEOUT_S}s") + if not thread.is_alive(): # pragma: no cover + raise RuntimeError("uvicorn server thread exited during startup") + time.sleep(0.001) + + # server.servers[0] is the asyncio.Server; its bound socket has the real port + port = server.servers[0].sockets[0].getsockname()[1] + host = config.host + + try: + yield f"http://{host}:{port}" + finally: + server.should_exit = True + thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) def wait_for_server(port: int, timeout: float = 20.0) -> None: