Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 70 additions & 159 deletions tests/shared/test_ws.py
Original file line number Diff line number Diff line change
@@ -1,210 +1,121 @@
import multiprocessing
import socket
from collections.abc import AsyncGenerator, Generator
"""Tests for the WebSocket transport.

The smoke test (``test_ws_client_basic_connection``) runs the full WS stack
end-to-end over a real TCP connection and is what provides coverage of
``src/mcp/client/websocket.py``.

The remaining tests verify transport-agnostic MCP semantics (error
propagation, client-side timeouts) and use the in-memory ``Client`` transport
to avoid the cost and flakiness of real network servers.
"""

from collections.abc import Generator
from urllib.parse import urlparse

import anyio
import pytest
import uvicorn
from starlette.applications import Starlette
from starlette.routing import WebSocketRoute
from starlette.websockets import WebSocket

from mcp import MCPError
from mcp import Client, MCPError
from mcp.client.session import ClientSession
from mcp.client.websocket import websocket_client
from mcp.server import Server, ServerRequestContext
from mcp.server.websocket import websocket_server
from mcp.types import (
CallToolRequestParams,
CallToolResult,
EmptyResult,
InitializeResult,
ListToolsResult,
PaginatedRequestParams,
ReadResourceRequestParams,
ReadResourceResult,
TextContent,
TextResourceContents,
Tool,
)
from tests.test_helpers import wait_for_server
from tests.test_helpers import run_uvicorn_in_thread

SERVER_NAME = "test_server_for_WS"

pytestmark = pytest.mark.anyio

@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]

# --- WebSocket transport smoke test (real TCP) -------------------------------

@pytest.fixture
def server_url(server_port: int) -> str:
return f"ws://127.0.0.1:{server_port}"

def make_server_app() -> Starlette:
srv = Server(SERVER_NAME)

async def handle_read_resource( # pragma: no cover
ctx: ServerRequestContext, params: ReadResourceRequestParams
) -> ReadResourceResult:
parsed = urlparse(str(params.uri))
if parsed.scheme == "foobar":
return ReadResourceResult(
contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")]
)
elif parsed.scheme == "slow":
await anyio.sleep(2.0)
return ReadResourceResult(
contents=[
TextResourceContents(
uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain"
)
]
)
raise MCPError(code=404, message="OOPS! no resource with that URI was found")


async def handle_list_tools( # pragma: no cover
ctx: ServerRequestContext, params: PaginatedRequestParams | None
) -> ListToolsResult:
return ListToolsResult(
tools=[
Tool(
name="test_tool",
description="A test tool",
input_schema={"type": "object", "properties": {}},
)
]
)


async def handle_call_tool( # pragma: no cover
ctx: ServerRequestContext, params: CallToolRequestParams
) -> CallToolResult:
return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")])


def _create_server() -> Server: # pragma: no cover
return Server(
SERVER_NAME,
on_read_resource=handle_read_resource,
on_list_tools=handle_list_tools,
on_call_tool=handle_call_tool,
)


# Test fixtures
def make_server_app() -> Starlette: # pragma: no cover
"""Create test Starlette app with WebSocket transport"""
server = _create_server()

async def handle_ws(websocket: WebSocket):
async def handle_ws(websocket: WebSocket) -> None:
async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams:
await server.run(streams[0], streams[1], server.create_initialization_options())

app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)])
return app

await srv.run(streams[0], streams[1], srv.create_initialization_options())

def run_server(server_port: int) -> None: # pragma: no cover
app = make_server_app()
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error"))
print(f"starting server on {server_port}")
server.run()
return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)])


@pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
print("starting process")
proc.start()

# Wait for server to be running
print("waiting for server to start")
wait_for_server(server_port)

yield

print("killing server")
# Signal the server to stop
proc.kill()
proc.join(timeout=2)
if proc.is_alive(): # pragma: no cover
print("server process failed to terminate")
@pytest.fixture
def ws_server_url() -> Generator[str, None, None]:
with run_uvicorn_in_thread(make_server_app()) as base_url:
yield base_url.replace("http://", "ws://") + "/ws"


@pytest.fixture()
async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
"""Create and initialize a WebSocket client session"""
async with websocket_client(server_url + "/ws") as streams:
async def test_ws_client_basic_connection(ws_server_url: str) -> None:
async with websocket_client(ws_server_url) as streams:
async with ClientSession(*streams) as session:
# Test initialization
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.server_info.name == SERVER_NAME

# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)

yield session

# --- In-memory tests (transport-agnostic MCP semantics) ----------------------

# Tests
@pytest.mark.anyio
async def test_ws_client_basic_connection(server: None, server_url: str) -> None:
"""Test the WebSocket connection establishment"""
async with websocket_client(server_url + "/ws") as streams:
async with ClientSession(*streams) as session:
# Test initialization
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.server_info.name == SERVER_NAME

# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)
async def handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult:
parsed = urlparse(str(params.uri))
if parsed.scheme == "foobar":
return ReadResourceResult(
contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")]
)
elif parsed.scheme == "slow":
# Block indefinitely so the client-side fail_after() fires; the pending
# server task is cancelled when the Client context manager exits.
await anyio.sleep_forever()
raise MCPError(code=404, message="OOPS! no resource with that URI was found")


@pytest.fixture
def server() -> Server:
return Server(SERVER_NAME, on_read_resource=handle_read_resource)


@pytest.mark.anyio
async def test_ws_client_happy_request_and_response(
initialized_ws_client_session: ClientSession,
) -> None:
"""Test a successful request and response via WebSocket"""
result = await initialized_ws_client_session.read_resource("foobar://example")
assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list)
assert len(result.contents) > 0
assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Read example"


@pytest.mark.anyio
async def test_ws_client_exception_handling(
initialized_ws_client_session: ClientSession,
) -> None:
"""Test exception handling in WebSocket communication"""
with pytest.raises(MCPError) as exc_info:
await initialized_ws_client_session.read_resource("unknown://example")
assert exc_info.value.error.code == 404


@pytest.mark.anyio
async def test_ws_client_timeout(
initialized_ws_client_session: ClientSession,
) -> None:
"""Test timeout handling in WebSocket communication"""
# Set a very short timeout to trigger a timeout exception
with pytest.raises(TimeoutError):
with anyio.fail_after(0.1): # 100ms timeout
await initialized_ws_client_session.read_resource("slow://example")

# Now test that we can still use the session after a timeout
with anyio.fail_after(5): # Longer timeout to allow completion
result = await initialized_ws_client_session.read_resource("foobar://example")
async def test_ws_client_happy_request_and_response(server: Server) -> None:
async with Client(server) as client:
result = await client.read_resource("foobar://example")
assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list)
assert len(result.contents) > 0
assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Read example"


async def test_ws_client_exception_handling(server: Server) -> None:
async with Client(server) as client:
with pytest.raises(MCPError) as exc_info:
await client.read_resource("unknown://example")
assert exc_info.value.error.code == 404


async def test_ws_client_timeout(server: Server) -> None:
async with Client(server) as client:
with pytest.raises(TimeoutError):
with anyio.fail_after(0.1):
await client.read_resource("slow://example")

# Session remains usable after a client-side timeout abandons a request.
with anyio.fail_after(5):
result = await client.read_resource("foobar://example")
assert isinstance(result, ReadResourceResult)
assert isinstance(result.contents, list)
assert len(result.contents) > 0
assert isinstance(result.contents[0], TextResourceContents)
assert result.contents[0].text == "Read example"
66 changes: 66 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,73 @@
"""Common test utilities for MCP server tests."""

import socket
import threading
import time
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any

import uvicorn

# How long to wait for the uvicorn server thread to reach `started`.
# Generous to absorb CI scheduling delays — actual startup is typically <100ms.
_SERVER_START_TIMEOUT_S = 20.0
_SERVER_SHUTDOWN_TIMEOUT_S = 5.0


@contextmanager
def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]:
"""Run a uvicorn server in a background thread with an ephemeral port.

This eliminates the TOCTOU race that occurs when a test picks a free port
with ``socket.bind((host, 0))``, releases it, then starts a server hoping
to rebind the same port — between release and rebind, another pytest-xdist
worker may claim it, causing connection errors or cross-test contamination.

With ``port=0``, the OS atomically assigns a free port at bind time; the
server holds it from that moment until shutdown. We read the actual port
back from uvicorn's bound socket after startup completes.

Args:
app: ASGI application to serve.
**config_kwargs: Additional keyword arguments for :class:`uvicorn.Config`
(e.g. ``log_level``, ``limit_concurrency``). ``host`` defaults to
``127.0.0.1`` and ``port`` is forced to 0.

Yields:
The base URL of the running server, e.g. ``http://127.0.0.1:54321``.

Raises:
TimeoutError: If the server does not start within 20 seconds.
RuntimeError: If the server thread dies during startup.
"""
config_kwargs.setdefault("host", "127.0.0.1")
config_kwargs.setdefault("log_level", "error")
config = uvicorn.Config(app=app, port=0, **config_kwargs)
server = uvicorn.Server(config=config)

thread = threading.Thread(target=server.run, daemon=True)
thread.start()

# uvicorn sets `server.started = True` at the end of `Server.startup()`,
# after sockets are bound and the lifespan startup phase has completed.
start = time.monotonic()
while not server.started:
if time.monotonic() - start > _SERVER_START_TIMEOUT_S: # pragma: no cover
raise TimeoutError(f"uvicorn server failed to start within {_SERVER_START_TIMEOUT_S}s")
if not thread.is_alive(): # pragma: no cover
raise RuntimeError("uvicorn server thread exited during startup")
time.sleep(0.001)

# server.servers[0] is the asyncio.Server; its bound socket has the real port
port = server.servers[0].sockets[0].getsockname()[1]
host = config.host

try:
yield f"http://{host}:{port}"
finally:
server.should_exit = True
thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S)


def wait_for_server(port: int, timeout: float = 20.0) -> None:
Expand Down
Loading