diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 759d2131a..f95b1a74a 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -222,7 +222,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: lax no cover """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 9dcee67f7..cab122443 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -116,7 +116,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: lax no cover if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") @@ -195,7 +195,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: lax no cover logger.debug("Handling POST message") request = Request(scope, receive) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c88..62140cd36 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -177,7 +177,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: # pragma: lax no cover """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -205,7 +205,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover send_stream.close() receive_stream.close() - def close_standalone_sse_stream(self) -> None: # pragma: no cover + def close_standalone_sse_stream(self) -> None: # pragma: lax no cover """Close the standalone GET SSE stream, triggering client reconnection. This method closes the HTTP connection for the standalone GET stream used @@ -240,10 +240,10 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: # pragma: lax no cover self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: # pragma: lax no cover self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -291,7 +291,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: # pragma: lax no cover response_headers.update(headers) if self.mcp_session_id: @@ -342,7 +342,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: # pragma: lax no cover event_data["id"] = event_message.event_id return event_data @@ -372,7 +372,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # pragma: lax no cover # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -387,7 +387,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: # pragma: lax no cover await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -467,7 +467,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: # pragma: lax no cover response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -493,7 +493,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): # pragma: lax no cover return # For notifications and responses only, return 202 Accepted @@ -659,7 +659,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: # pragma: lax no cover response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -667,11 +667,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): # pragma: no cover + if not await self._validate_request_headers(request, send): # pragma: lax no cover return # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: lax no cover await self._replay_events(last_event_id, request, send) return @@ -685,7 +685,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: # pragma: lax no cover response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -714,7 +714,7 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error in standalone SSE writer") finally: logger.debug("Closing standalone SSE writer") @@ -791,7 +791,7 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: lax no cover """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, @@ -824,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: # pragma: lax no cover response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -849,11 +849,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: # pragma: lax no cover protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: lax no cover supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -865,7 +865,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: lax no cover """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. @@ -991,7 +991,7 @@ async def message_router(): if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( # pragma: lax no cover session_message.metadata is not None and isinstance( session_message.metadata, @@ -1015,10 +1015,10 @@ async def message_router(): try: # Send both the message and the event ID await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover + except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: lax no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: # pragma: lax no cover logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0..e3009ae62 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,7 +40,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: # pragma: lax no cover """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") @@ -62,7 +62,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: # pragma: lax no cover """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: @@ -104,13 +104,13 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res return None # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + host = request.headers.get("host") # pragma: lax no cover + if not self._validate_host(host): # pragma: lax no cover + return Response("Invalid Host header", status_code=421) # pragma: lax no cover # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + origin = request.headers.get("origin") # pragma: lax no cover + if not self._validate_origin(origin): # pragma: lax no cover + return Response("Invalid Origin header", status_code=403) # pragma: lax no cover - return None # pragma: no cover + return None # pragma: lax no cover diff --git a/tests/client/test_http_unicode.py b/tests/client/test_http_unicode.py index cc2e14e46..ee105505f 100644 --- a/tests/client/test_http_unicode.py +++ b/tests/client/test_http_unicode.py @@ -4,11 +4,10 @@ (server→client and client→server) using the streamable HTTP transport. """ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +import httpx import pytest from starlette.applications import Starlette from starlette.routing import Mount @@ -19,7 +18,6 @@ from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.types import TextContent, Tool -from tests.test_helpers import wait_for_server # Test constants with various Unicode characters UNICODE_TEST_STRINGS = { @@ -41,197 +39,132 @@ } -def run_unicode_server(port: int) -> None: # pragma: no cover - """Run the Unicode test server in a separate process.""" - import uvicorn - - # Need to recreate the server setup in this process - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - Tool( - name="echo_unicode", - description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", - input_schema={ - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to echo back"}, - }, - "required": ["text"], - }, - ), - ] - ) +async def _handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + Tool( + name="echo_unicode", + description="🔤 Echo Unicode text - Hello 👋 World 🌍 - Testing 🧪 Unicode ✨", + input_schema={ + "type": "object", + "properties": {"text": {"type": "string", "description": "Text to echo back"}}, + "required": ["text"], + }, + ), + ] + ) + + +async def _handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + if params.name == "echo_unicode": + text = params.arguments.get("text", "") if params.arguments else "" + return types.CallToolResult(content=[TextContent(type="text", text=f"Echo: {text}")]) + raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover + - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - if params.name == "echo_unicode": - text = params.arguments.get("text", "") if params.arguments else "" - return types.CallToolResult( - content=[ - TextContent( - type="text", - text=f"Echo: {text}", - ) - ] +async def _handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None +) -> types.ListPromptsResult: + return types.ListPromptsResult( + prompts=[ + types.Prompt( + name="unicode_prompt", + description="Unicode prompt - Слой хранилища, где располагаются", + arguments=[], ) - else: - raise ValueError(f"Unknown tool: {params.name}") - - async def handle_list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: - return types.ListPromptsResult( - prompts=[ - types.Prompt( - name="unicode_prompt", - description="Unicode prompt - Слой хранилища, где располагаются", - arguments=[], + ] + ) + + +async def _handle_get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> types.GetPromptResult: + if params.name == "unicode_prompt": + return types.GetPromptResult( + messages=[ + types.PromptMessage( + role="user", + content=types.TextContent(type="text", text="Hello世界🌍Привет안녕مرحباשלום"), ) ] ) + raise ValueError(f"Unknown prompt: {params.name}") # pragma: no cover - async def handle_get_prompt( - ctx: ServerRequestContext, params: types.GetPromptRequestParams - ) -> types.GetPromptResult: - if params.name == "unicode_prompt": - return types.GetPromptResult( - messages=[ - types.PromptMessage( - role="user", - content=types.TextContent( - type="text", - text="Hello世界🌍Привет안녕مرحباשלום", - ), - ) - ] - ) - raise ValueError(f"Unknown prompt: {params.name}") +def _make_unicode_app() -> Starlette: server = Server( name="unicode_test_server", - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - on_list_prompts=handle_list_prompts, - on_get_prompt=handle_get_prompt, - ) - - # Create the session manager - session_manager = StreamableHTTPSessionManager( - app=server, - json_response=False, # Use SSE for testing + on_list_tools=_handle_list_tools, + on_call_tool=_handle_call_tool, + on_list_prompts=_handle_list_prompts, + on_get_prompt=_handle_get_prompt, ) + session_manager = StreamableHTTPSessionManager(app=server, json_response=False) @asynccontextmanager async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with session_manager.run(): yield - # Create an ASGI application - app = Starlette( + return Starlette( debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], + routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=lifespan, ) - # Run the server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - uvicorn_server = uvicorn.Server(config) - uvicorn_server.run() - - -@pytest.fixture -def unicode_server_port() -> int: - """Find an available port for the Unicode test server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - @pytest.fixture -def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]: - """Start a Unicode test server in a separate process.""" - proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True) - proc.start() - - # Wait for server to be ready - wait_for_server(unicode_server_port) - - try: - yield f"http://127.0.0.1:{unicode_server_port}" - finally: - # Clean up - try graceful termination first - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - proc.kill() - proc.join(timeout=1) +async def unicode_session() -> AsyncGenerator[ClientSession, None]: + """Create an initialized client session connected to the in-process unicode server.""" + app = _make_unicode_app() + async with app.router.lifespan_context(app): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, follow_redirects=True) as http_client: + async with streamable_http_client("http://testserver/mcp", http_client=http_client) as (rs, ws): + async with ClientSession(rs, ws) as session: + await session.initialize() + yield session @pytest.mark.anyio -async def test_streamable_http_client_unicode_tool_call(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_tool_call(unicode_session: ClientSession) -> None: """Test that Unicode text is correctly handled in tool calls via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List tools (server→client Unicode in descriptions) - tools = await session.list_tools() - assert len(tools.tools) == 1 - - # Check Unicode in tool descriptions - echo_tool = tools.tools[0] - assert echo_tool.name == "echo_unicode" - assert echo_tool.description is not None - assert "🔤" in echo_tool.description - assert "👋" in echo_tool.description + # Test 1: List tools (server→client Unicode in descriptions) + tools = await unicode_session.list_tools() + assert len(tools.tools) == 1 - # Test 2: Send Unicode text in tool call (client→server→client) - for test_name, test_string in UNICODE_TEST_STRINGS.items(): - result = await session.call_tool("echo_unicode", arguments={"text": test_string}) + echo_tool = tools.tools[0] + assert echo_tool.name == "echo_unicode" + assert echo_tool.description is not None + assert "🔤" in echo_tool.description + assert "👋" in echo_tool.description - # Verify server correctly received and echoed back Unicode - assert len(result.content) == 1 - content = result.content[0] - assert content.type == "text" - assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" + # Test 2: Send Unicode text in tool call (client→server→client) + for test_name, test_string in UNICODE_TEST_STRINGS.items(): + result = await unicode_session.call_tool("echo_unicode", arguments={"text": test_string}) + assert len(result.content) == 1 + content = result.content[0] + assert content.type == "text" + assert f"Echo: {test_string}" == content.text, f"Failed for {test_name}" @pytest.mark.anyio -async def test_streamable_http_client_unicode_prompts(running_unicode_server: str) -> None: +async def test_streamable_http_client_unicode_prompts(unicode_session: ClientSession) -> None: """Test that Unicode text is correctly handled in prompts via streamable HTTP.""" - base_url = running_unicode_server - endpoint_url = f"{base_url}/mcp" - - async with streamable_http_client(endpoint_url) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - - # Test 1: List prompts (server→client Unicode in descriptions) - prompts = await session.list_prompts() - assert len(prompts.prompts) == 1 - - prompt = prompts.prompts[0] - assert prompt.name == "unicode_prompt" - assert prompt.description is not None - assert "Слой хранилища, где располагаются" in prompt.description - - # Test 2: Get prompt with Unicode content (server→client) - result = await session.get_prompt("unicode_prompt", arguments={}) - assert len(result.messages) == 1 - - message = result.messages[0] - assert message.role == "user" - assert message.content.type == "text" - assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" + # Test 1: List prompts (server→client Unicode in descriptions) + prompts = await unicode_session.list_prompts() + assert len(prompts.prompts) == 1 + + prompt = prompts.prompts[0] + assert prompt.name == "unicode_prompt" + assert prompt.description is not None + assert "Слой хранилища, где располагаются" in prompt.description + + # Test 2: Get prompt with Unicode content (server→client) + result = await unicode_session.get_prompt("unicode_prompt", arguments={}) + assert len(result.messages) == 1 + + message = result.messages[0] + assert message.role == "user" + assert message.content.type == "text" + assert message.content.text == "Hello世界🌍Привет안녕مرحباשלום" diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a2..bd9a174cd 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,155 +1,127 @@ """Tests for SSE server DNS rebinding protection.""" +import contextlib import logging -import multiprocessing -import socket +from collections.abc import Generator import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Receive, Scope, Send from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_server_in_thread + +# Several tests open an SSE stream, check the status code, then exit without +# consuming the stream. When uvicorn shuts down, it cancels the still-running +# SSE handler mid-operation, and SseServerTransport's internal memory streams +# may be GC'd without their cleanup finalizers running. These ResourceWarnings +# are artifacts of the abrupt-disconnect test pattern, not production bugs. +pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") logger = logging.getLogger(__name__) SERVER_NAME = "test_sse_security_server" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover +class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) async def on_list_tools(self) -> list[Tool]: - return [] + return [] # pragma: no cover -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the SSE server with specified security settings.""" +def make_app(security_settings: TransportSecuritySettings | None = None) -> Starlette: + """Build a Starlette app with SSE transport and the given security settings.""" app = SecurityTestServer() sse_transport = SseServerTransport("/messages/", security_settings) - async def handle_sse(request: Request): - try: + async def handle_sse(request: Request) -> Response: + # connect_sse sends responses directly via ASGI `send` (both the SSE stream + # and any validation error responses), so by the time we return here the + # response has already been sent. Starlette will still try to send our + # return value, which fails with "Unexpected ASGI message". We suppress + # ValueError from connect_sse and wrap the final Response() send in a + # no-op so Starlette's machinery doesn't conflict. + with contextlib.suppress(ValueError): async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: - if streams: + if streams: # pragma: no branch await app.run(streams[0], streams[1], app.create_initialization_options()) - except ValueError as e: - # Validation error was already handled inside connect_sse - logger.debug(f"SSE connection failed validation: {e}") - return Response() + return _AlreadySentResponse() - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse_transport.handle_post_message), - ] + return Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + ) - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") +class _AlreadySentResponse(Response): + """No-op Response for handlers that already sent via raw ASGI `send`.""" -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + pass -@pytest.mark.anyio -async def test_sse_security_default_settings(server_port: int): - """Test SSE with default security settings (protection disabled).""" - process = start_server_process(server_port) +@pytest.fixture +def server_url() -> Generator[str, None, None]: + """Default-settings server for tests that don't need custom security config.""" + with run_server_in_thread(make_app(), lifespan="off") as url: + yield url - try: - headers = {"Host": "evil.com", "Origin": "http://evil.com"} - async with httpx.AsyncClient(timeout=5.0) as client: - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - assert response.status_code == 200 - finally: - process.terminate() - process.join() +@pytest.mark.anyio +async def test_sse_security_default_settings(server_url: str): + """Test SSE with default security settings (protection disabled).""" + headers = {"Host": "evil.com", "Origin": "http://evil.com"} + async with httpx.AsyncClient(timeout=5.0) as client: + async with client.stream("GET", f"{server_url}/sse", headers=headers) as response: + assert response.status_code == 200 @pytest.mark.anyio -async def test_sse_security_invalid_host_header(server_port: int): +async def test_sse_security_invalid_host_header(): """Test SSE with invalid Host header.""" - # Enable security by providing settings with an empty allowed_hosts list security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = {"Host": "evil.com"} - + with run_server_in_thread(make_app(security_settings), lifespan="off") as url: async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{url}/sse", headers={"Host": "evil.com"}) assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_invalid_origin_header(server_port: int): +async def test_sse_security_invalid_origin_header(): """Test SSE with invalid Origin header.""" - # Configure security to allow the host but restrict origins security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] ) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = {"Origin": "http://evil.com"} - + with run_server_in_thread(make_app(security_settings), lifespan="off") as url: async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{url}/sse", headers={"Origin": "http://evil.com"}) assert response.status_code == 403 assert response.text == "Invalid Origin header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_post_invalid_content_type(server_port: int): +async def test_sse_security_post_invalid_content_type(): """Test POST endpoint with invalid Content-Type header.""" - # Configure security to allow the host security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with run_server_in_thread(make_app(security_settings), lifespan="off") as url: async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type fake_session_id = "12345678123456781234567812345678" + # Test POST with invalid content type response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + f"{url}/messages/?session_id={fake_session_id}", headers={"Content-Type": "text/plain"}, content="test", ) @@ -157,137 +129,85 @@ async def test_sse_security_post_invalid_content_type(server_port: int): assert response.text == "Invalid Content-Type header" # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" - ) + response = await client.post(f"{url}/messages/?session_id={fake_session_id}", content="test") assert response.status_code == 400 assert response.text == "Invalid Content-Type header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_disabled(server_port: int): +async def test_sse_security_disabled(): """Test SSE with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = {"Host": "evil.com"} - + with run_server_in_thread(make_app(settings), lifespan="off") as url: async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + async with client.stream("GET", f"{url}/sse", headers={"Host": "evil.com"}) as response: # Should connect successfully even with invalid host assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_custom_allowed_hosts(server_port: int): +async def test_sse_security_custom_allowed_hosts(): """Test SSE with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost", "127.0.0.1", "custom.host"], allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: + with run_server_in_thread(make_app(settings), lifespan="off") as url: # Test with custom allowed host - headers = {"Host": "custom.host"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with custom host + async with client.stream("GET", f"{url}/sse", headers={"Host": "custom.host"}) as response: assert response.status_code == 200 # Test with non-allowed host - headers = {"Host": "evil.com"} - async with httpx.AsyncClient() as client: - response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + response = await client.get(f"{url}/sse", headers={"Host": "evil.com"}) assert response.status_code == 421 assert response.text == "Invalid Host header" - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_wildcard_ports(server_port: int): +async def test_sse_security_wildcard_ports(): """Test SSE with wildcard port patterns.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["localhost:*", "127.0.0.1:*"], allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], ) - process = start_server_process(server_port, settings) - - try: + with run_server_in_thread(make_app(settings), lifespan="off") as url: # Test with various port numbers for test_port in [8080, 3000, 9999]: - headers = {"Host": f"localhost:{test_port}"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port + async with client.stream("GET", f"{url}/sse", headers={"Host": f"localhost:{test_port}"}) as response: assert response.status_code == 200 - headers = {"Origin": f"http://localhost:{test_port}"} - async with httpx.AsyncClient(timeout=5.0) as client: - # For SSE endpoints, we need to use stream to avoid timeout - async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: - # Should connect successfully with any port + headers = {"Origin": f"http://localhost:{test_port}"} + async with client.stream("GET", f"{url}/sse", headers=headers) as response: assert response.status_code == 200 - finally: - process.terminate() - process.join() - @pytest.mark.anyio -async def test_sse_security_post_valid_content_type(server_port: int): +async def test_sse_security_post_valid_content_type(): """Test POST endpoint with valid Content-Type headers.""" - # Configure security to allow the host security_settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] ) - process = start_server_process(server_port, security_settings) - - try: + with run_server_in_thread(make_app(security_settings), lifespan="off") as url: async with httpx.AsyncClient() as client: - # Test with various valid content types valid_content_types = [ "application/json", "application/json; charset=utf-8", "application/json;charset=utf-8", "APPLICATION/JSON", # Case insensitive ] - for content_type in valid_content_types: - # Use a valid UUID format (even though session won't exist) fake_session_id = "12345678123456781234567812345678" response = await client.post( - f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + f"{url}/messages/?session_id={fake_session_id}", headers={"Content-Type": content_type}, json={"test": "data"}, ) - # Will get 404 because session doesn't exist, but that's OK - # We're testing that it passes the content-type check + # Will get 404 because session doesn't exist — that means we passed content-type validation assert response.status_code == 404 assert response.text == "Could not find session" - - finally: - process.terminate() - process.join() diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353..11f75e9a3 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -1,13 +1,10 @@ """Tests for StreamableHTTP server DNS rebinding protection.""" -import multiprocessing -import socket from collections.abc import AsyncGenerator from contextlib import asynccontextmanager import httpx import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import Mount from starlette.types import Receive, Scope, Send @@ -16,36 +13,21 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool -from tests.test_helpers import wait_for_server SERVER_NAME = "test_streamable_http_security_server" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: # pragma: no cover - return f"http://127.0.0.1:{server_port}" - - -class SecurityTestServer(Server): # pragma: no cover +class SecurityTestServer(Server): def __init__(self): super().__init__(SERVER_NAME) async def on_list_tools(self) -> list[Tool]: - return [] + return [] # pragma: no cover -def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover - """Run the StreamableHTTP server with specified security settings.""" +def make_app(security_settings: TransportSecuritySettings | None = None) -> Starlette: + """Build a Starlette app with the given security settings.""" app = SecurityTestServer() - - # Create session manager with security settings session_manager = StreamableHTTPSessionManager( app=app, json_response=False, @@ -53,239 +35,164 @@ def run_server_with_settings(port: int, security_settings: TransportSecuritySett security_settings=security_settings, ) - # Create the ASGI handler async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: await session_manager.handle_request(scope, receive, send) - # Create Starlette app with lifespan @asynccontextmanager async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: async with session_manager.run(): yield - routes = [ - Mount("/", app=handle_streamable_http), - ] + return Starlette(routes=[Mount("/", app=handle_streamable_http)], lifespan=lifespan) - starlette_app = Starlette(routes=routes, lifespan=lifespan) - uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") +@asynccontextmanager +async def make_client( + security_settings: TransportSecuritySettings | None = None, +) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create an httpx client wired to an in-process ASGI app via ASGITransport. -def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): - """Start server in a separate process.""" - process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) - process.start() - # Wait for server to be ready to accept connections - wait_for_server(port) - return process + StreamableHTTP POST requests return promptly (SSE body then close), so the + ASGITransport buffering behavior is not an issue here. + """ + app = make_app(security_settings) + async with app.router.lifespan_context(app): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver", timeout=5.0) as client: + yield client @pytest.mark.anyio -async def test_streamable_http_security_default_settings(server_port: int): +async def test_streamable_http_security_default_settings(): """Test StreamableHTTP with default security settings (protection enabled).""" - process = start_server_process(server_port) - - try: - # Test with valid localhost headers - async with httpx.AsyncClient(timeout=5.0) as client: - # POST request to initialize session - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - }, - ) - assert response.status_code == 200 - assert "mcp-session-id" in response.headers - - finally: - process.terminate() - process.join() + async with make_client() as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers @pytest.mark.anyio -async def test_streamable_http_security_invalid_host_header(server_port: int): +async def test_streamable_http_security_invalid_host_header(): """Test StreamableHTTP with invalid Host header.""" security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid host header - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - finally: - process.terminate() - process.join() + async with make_client(security_settings) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 421 + assert response.text == "Invalid Host header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_origin_header(server_port: int): +async def test_streamable_http_security_invalid_origin_header(): """Test StreamableHTTP with invalid Origin header.""" - security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) - process = start_server_process(server_port, security_settings) - - try: - # Test with invalid origin header - headers = { - "Origin": "http://evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - assert response.status_code == 403 - assert response.text == "Invalid Origin header" - - finally: - process.terminate() - process.join() + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["testserver"]) + async with make_client(security_settings) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Origin": "http://evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 403 + assert response.text == "Invalid Origin header" @pytest.mark.anyio -async def test_streamable_http_security_invalid_content_type(server_port: int): +async def test_streamable_http_security_invalid_content_type(): """Test StreamableHTTP POST with invalid Content-Type header.""" - process = start_server_process(server_port) - - try: - async with httpx.AsyncClient(timeout=5.0) as client: - # Test POST with invalid content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={ - "Content-Type": "text/plain", - "Accept": "application/json, text/event-stream", - }, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - # Test POST with missing content type - response = await client.post( - f"http://127.0.0.1:{server_port}/", - headers={"Accept": "application/json, text/event-stream"}, - content="test", - ) - assert response.status_code == 400 - assert response.text == "Invalid Content-Type header" - - finally: - process.terminate() - process.join() + async with make_client() as client: + # Test POST with invalid content type + response = await client.post( + "/", + headers={ + "Content-Type": "text/plain", + "Accept": "application/json, text/event-stream", + }, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + # Test POST with missing content type + response = await client.post( + "/", + headers={"Accept": "application/json, text/event-stream"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" @pytest.mark.anyio -async def test_streamable_http_security_disabled(server_port: int): +async def test_streamable_http_security_disabled(): """Test StreamableHTTP with security disabled.""" settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) - process = start_server_process(server_port, settings) - - try: - # Test with invalid host header - should still work - headers = { - "Host": "evil.com", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully even with invalid host - assert response.status_code == 200 - - finally: - process.terminate() - process.join() + async with make_client(settings) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + # Should connect successfully even with invalid host + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_custom_allowed_hosts(server_port: int): +async def test_streamable_http_security_custom_allowed_hosts(): """Test StreamableHTTP with custom allowed hosts.""" settings = TransportSecuritySettings( enable_dns_rebinding_protection=True, - allowed_hosts=["localhost", "127.0.0.1", "custom.host"], - allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], + allowed_hosts=["localhost", "testserver", "custom.host"], + allowed_origins=["http://localhost", "http://testserver", "http://custom.host"], ) - process = start_server_process(server_port, settings) - - try: - # Test with custom allowed host - headers = { - "Host": "custom.host", - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.post( - f"http://127.0.0.1:{server_port}/", - json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, - headers=headers, - ) - # Should connect successfully with custom host - assert response.status_code == 200 - finally: - process.terminate() - process.join() + async with make_client(settings) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Host": "custom.host", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + # Should connect successfully with custom host + assert response.status_code == 200 @pytest.mark.anyio -async def test_streamable_http_security_get_request(server_port: int): +async def test_streamable_http_security_get_request(): """Test StreamableHTTP GET request with security.""" - security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) - process = start_server_process(server_port, security_settings) - - try: + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["testserver"]) + async with make_client(security_settings) as client: # Test GET request with invalid host header - headers = { - "Host": "evil.com", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - assert response.status_code == 421 - assert response.text == "Invalid Host header" - - # Test GET request with valid host header - headers = { - "Host": "127.0.0.1", - "Accept": "text/event-stream", - } - - async with httpx.AsyncClient(timeout=5.0) as client: - # GET requests need a session ID in StreamableHTTP - # So it will fail with "Missing session ID" not security error - response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) - # This should pass security but fail on session validation - assert response.status_code == 400 - body = response.json() - assert "Missing session ID" in body["error"]["message"] - - finally: - process.terminate() - process.join() + response = await client.get("/", headers={"Host": "evil.com", "Accept": "text/event-stream"}) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + # Test GET request with valid host header but no session ID + # Should pass security but fail on session validation + response = await client.get("/", headers={"Host": "testserver", "Accept": "text/event-stream"}) + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"] diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 890e99733..6f9cecd86 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,6 +1,4 @@ import json -import multiprocessing -import socket from collections.abc import AsyncGenerator, Generator from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -9,7 +7,6 @@ import anyio import httpx import pytest -import uvicorn from httpx_sse import ServerSentEvent from inline_snapshot import snapshot from starlette.applications import Starlette @@ -41,31 +38,24 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server - -SERVER_NAME = "test_server_for_SSE" - - -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +from tests.test_helpers import run_server_in_thread +# When SSE clients disconnect abruptly (exiting sse_client context while the +# server's long-lived SSE stream is open), uvicorn cancels the server handler +# mid-operation and SseServerTransport's internal memory streams may be GC'd +# without their finalizers running. This is a test-lifecycle artifact of +# abrupt disconnect, not a production bug — real clients consume the stream. +pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" +SERVER_NAME = "test_server_for_SSE" -async def _handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: +async def _handle_read_resource(ctx: ServerRequestContext, params: ReadResourceRequestParams) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": + elif parsed.scheme == "slow": # pragma: no cover await anyio.sleep(2.0) text = f"Slow response from {parsed.netloc}" else: @@ -73,39 +63,15 @@ async def _handle_read_resource( # pragma: no cover return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) -async def _handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - - -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover +def _create_server() -> Server: return Server( SERVER_NAME, on_read_resource=_handle_read_resource, - on_list_tools=_handle_list_tools, - on_call_tool=_handle_call_tool, ) -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover +def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" - # Configure security with allowed hosts/origins for testing security_settings = TransportSecuritySettings( allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) @@ -117,47 +83,25 @@ async def handle_sse(request: Request) -> Response: await server.run(streams[0], streams[1], server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - @pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") +def server() -> Generator[str, None, None]: + """Run the basic SSE server in a background thread, yielding its URL.""" + with run_server_in_thread(make_server_app(), lifespan="off") as url: + yield url @pytest.fixture() -async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server: str) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" - async with httpx.AsyncClient(base_url=server_url) as client: + async with httpx.AsyncClient(base_url=server) as client: yield client @@ -188,8 +132,8 @@ async def connection_test() -> None: @pytest.mark.anyio -async def test_sse_client_basic_connection(server: None, server_url: str) -> None: - async with sse_client(server_url + "/sse") as streams: +async def test_sse_client_basic_connection(server: str) -> None: + async with sse_client(server + "/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -202,10 +146,10 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.mark.anyio -async def test_sse_client_on_session_created(server: None, server_url: str) -> None: +async def test_sse_client_on_session_created(server: str) -> None: captured: list[str] = [] - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client(server + "/sse", on_session_created=captured.append) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -231,7 +175,7 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non @pytest.mark.anyio async def test_sse_client_on_session_created_not_called_when_no_session_id( - server: None, server_url: str, monkeypatch: pytest.MonkeyPatch + server: str, monkeypatch: pytest.MonkeyPatch ) -> None: callback_mock = Mock() @@ -240,7 +184,7 @@ def mock_extract(url: str) -> None: monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) - async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: + async with sse_client(server + "/sse", on_session_created=callback_mock) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -250,8 +194,8 @@ def mock_extract(url: str) -> None: @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: +async def initialized_sse_client_session(server: str) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @@ -297,37 +241,18 @@ async def test_sse_client_timeout( # pragma: no cover pytest.fail("the client should have timed out and returned an error already") -def run_mounted_server(server_port: int) -> None: # pragma: no cover +@pytest.fixture() +def mounted_server() -> Generator[str, None, None]: + """Run the SSE server mounted under a sub-path, yielding its base URL.""" app = make_server_app() main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() - - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) - - yield - - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") + with run_server_in_thread(main_app, lifespan="off") as url: + yield url @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: - async with sse_client(server_url + "/mounted_app/sse") as streams: +async def test_sse_client_basic_connection_mounted_app(mounted_server: str) -> None: + async with sse_client(mounted_server + "/mounted_app/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -339,26 +264,22 @@ async def test_sse_client_basic_connection_mounted_app(mounted_server: None, ser assert isinstance(ping_result, EmptyResult) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - headers_info: dict[str, Any] = {} - if ctx.request: - headers_info = dict(ctx.request.headers) +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert ctx.request is not None + headers_info = dict(ctx.request.headers) if params.name == "echo_headers": return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) - elif params.name == "echo_context": - context_data = { - "request_id": (params.arguments or {}).get("request_id"), - "headers": headers_info, - } - return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) + assert params.name == "echo_context" + context_data = { + "request_id": (params.arguments or {}).get("request_id"), + "headers": headers_info, + } + return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -381,9 +302,8 @@ async def _handle_context_list_tools( # pragma: no cover ) -def run_context_server(server_port: int) -> None: # pragma: no cover - """Run a server that captures request context""" - # Configure security with allowed hosts/origins for testing +def make_context_server_app() -> Starlette: + """Create a Starlette app with an SSE server that echoes request context.""" security_settings = TransportSecuritySettings( allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] ) @@ -399,73 +319,47 @@ async def handle_sse(request: Request) -> Response: await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) return Response() - app = Starlette( + return Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] ) - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting context server on {server_port}") - server.run() - @pytest.fixture() -def context_server(server_port: int) -> Generator[None, None, None]: - """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) - print("starting context server process") - proc.start() - - # Wait for server to be running - print("waiting for context server to start") - wait_for_server(server_port) - - yield - - print("killing context server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("context server process failed to terminate") +def context_server() -> Generator[str, None, None]: + """Run the context-echoing SSE server in a background thread, yielding its URL.""" + with run_server_in_thread(make_context_server_app(), lifespan="off") as url: + yield url @pytest.mark.anyio -async def test_request_context_propagation(context_server: None, server_url: str) -> None: +async def test_request_context_propagation(context_server: str) -> None: """Test that request context is properly propagated through SSE transport.""" - # Test with custom headers custom_headers = { "Authorization": "Bearer test-token", "X-Custom-Header": "test-value", "X-Trace-Id": "trace-123", } - async with sse_client(server_url + "/sse", headers=custom_headers) as ( - read_stream, - write_stream, - ): + async with sse_client(context_server + "/sse", headers=custom_headers) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: - # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) - # Call the tool that echoes headers back tool_result = await session.call_tool("echo_headers", {}) - # Parse the JSON response - assert len(tool_result.content) == 1 headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") - # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" assert headers_data.get("x-custom-header") == "test-value" assert headers_data.get("x-trace-id") == "trace-123" @pytest.mark.anyio -async def test_request_context_isolation(context_server: None, server_url: str) -> None: +async def test_request_context_isolation(context_server: str) -> None: """Test that request contexts are isolated between different SSE clients.""" contexts: list[dict[str, Any]] = [] @@ -473,14 +367,10 @@ async def test_request_context_isolation(context_server: None, server_url: str) for i in range(3): headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} - async with sse_client(server_url + "/sse", headers=headers) as ( - read_stream, - write_stream, - ): + async with sse_client(context_server + "/sse", headers=headers) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() - # Call the tool that echoes context tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 @@ -489,7 +379,6 @@ async def test_request_context_isolation(context_server: None, server_url: str) ) contexts.append(context_data) - # Verify each request had its own context assert len(contexts) == 3 for i, ctx in enumerate(contexts): assert ctx["request_id"] == f"request-{i}" @@ -611,7 +500,7 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: @pytest.mark.anyio -async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) -> None: +async def test_sse_session_cleanup_on_disconnect(server: str) -> None: """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1227 When a client disconnects, the server should remove the session from @@ -622,7 +511,7 @@ async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) captured: list[str] = [] # Connect a client session, then disconnect - async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: + async with sse_client(server + "/sse", on_session_created=captured.append) as streams: async with ClientSession(*streams) as session: await session.initialize() @@ -630,7 +519,7 @@ async def test_sse_session_cleanup_on_disconnect(server: None, server_url: str) # (not 202 as it did before the fix) async with httpx.AsyncClient() as client: response = await client.post( - f"{server_url}/messages/?session_id={captured[0]}", + f"{server}/messages/?session_id={captured[0]}", json={"jsonrpc": "2.0", "method": "ping", "id": 99}, headers={"Content-Type": "application/json"}, ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441..374fc5fbb 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -6,10 +6,7 @@ from __future__ import annotations as _annotations import json -import multiprocessing -import socket import time -import traceback from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -21,7 +18,6 @@ import httpx import pytest import requests -import uvicorn from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request @@ -65,7 +61,7 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_server_in_thread # Test constants SERVER_NAME = "test_streamable_http_server" @@ -108,7 +104,7 @@ async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( # pragma: lax no cover self, last_event_id: EventId, send_callback: EventCallback, @@ -144,11 +140,11 @@ class ServerState: @asynccontextmanager -async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: lax no cover yield ServerState() -async def _handle_read_resource( # pragma: no cover +async def _handle_read_resource( # pragma: lax no cover ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams ) -> ReadResourceResult: uri = str(params.uri) @@ -163,7 +159,7 @@ async def _handle_read_resource( # pragma: no cover return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) -async def _handle_list_tools( # pragma: no cover +async def _handle_list_tools( # pragma: lax no cover ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -228,7 +224,7 @@ async def _handle_list_tools( # pragma: no cover ) -async def _handle_call_tool( # pragma: no cover +async def _handle_call_tool( # pragma: lax no cover ctx: ServerRequestContext[ServerState], params: CallToolRequestParams ) -> CallToolResult: name = params.name @@ -382,7 +378,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) -def _create_server() -> Server[ServerState]: # pragma: no cover +def _create_server() -> Server[ServerState]: # pragma: lax no cover return Server( SERVER_NAME, lifespan=_server_lifespan, @@ -396,7 +392,7 @@ def create_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover +) -> Starlette: # pragma: lax no cover """Create a Starlette application for testing using the session manager. Args: @@ -431,74 +427,18 @@ def create_app( return app -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. - """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, - ) - - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - traceback.print_exc() - - -# Test fixtures - using same approach as SSE tests +# Test fixtures — uvicorn in a background thread with port=0 (no port races) @pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] +def basic_server() -> Generator[str, None, None]: + """Start a basic server. Yields the server URL.""" + with run_server_in_thread(create_app()) as url: + yield url @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +def basic_server_url(basic_server: str) -> str: + """Alias for basic_server (kept for test signature compatibility).""" + return basic_server @pytest.fixture @@ -508,69 +448,32 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield +def event_server(event_store: SimpleEventStore) -> Generator[tuple[SimpleEventStore, str], None, None]: + """Start a server with event store and retry_interval enabled. - # Clean up - proc.kill() - proc.join(timeout=2) + Yields (event_store, server_url). Unlike the old multiprocessing fixture, the + event_store is now the SAME object used by the server (same process), so tests + can inspect server-side state directly if needed. + """ + with run_server_in_thread(create_app(event_store=event_store, retry_interval=500)) as url: + yield event_store, url @pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" +def json_response_server() -> Generator[str, None, None]: + """Start a server with JSON response enabled. Yields the server URL.""" + with run_server_in_thread(create_app(is_json_response_enabled=True)) as url: + yield url @pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +def json_server_url(json_response_server: str) -> str: + """Alias for json_response_server (kept for test signature compatibility).""" + return json_response_server # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): +def test_accept_header_validation(basic_server: str, basic_server_url: str): """Test that Accept header is properly validated.""" # Test without Accept header (suppress requests library default Accept: */*) session = requests.Session() @@ -595,7 +498,7 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_wildcard(basic_server: str, basic_server_url: str, accept_header: str): """Test that wildcard Accept headers are accepted per RFC 7231.""" response = requests.post( f"{basic_server_url}/mcp", @@ -616,7 +519,7 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_incompatible(basic_server: str, basic_server_url: str, accept_header: str): """Test that incompatible Accept headers are rejected for SSE mode.""" response = requests.post( f"{basic_server_url}/mcp", @@ -630,7 +533,7 @@ def test_accept_header_incompatible(basic_server: None, basic_server_url: str, a assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server: None, basic_server_url: str): +def test_content_type_validation(basic_server: str, basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( @@ -646,7 +549,7 @@ def test_content_type_validation(basic_server: None, basic_server_url: str): assert "Invalid Content-Type" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): +def test_json_validation(basic_server: str, basic_server_url: str): """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( @@ -661,7 +564,7 @@ def test_json_validation(basic_server: None, basic_server_url: str): assert "Parse error" in response.text -def test_json_parsing(basic_server: None, basic_server_url: str): +def test_json_parsing(basic_server: str, basic_server_url: str): """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( @@ -676,7 +579,7 @@ def test_json_parsing(basic_server: None, basic_server_url: str): assert "Validation error" in response.text -def test_method_not_allowed(basic_server: None, basic_server_url: str): +def test_method_not_allowed(basic_server: str, basic_server_url: str): """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( @@ -691,7 +594,7 @@ def test_method_not_allowed(basic_server: None, basic_server_url: str): assert "Method Not Allowed" in response.text -def test_session_validation(basic_server: None, basic_server_url: str): +def test_session_validation(basic_server: str, basic_server_url: str): """Test session ID validation.""" # session_id not used directly in this test @@ -766,7 +669,7 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): +def test_session_termination(basic_server: str, basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{basic_server_url}/mcp", @@ -806,7 +709,7 @@ def test_session_termination(basic_server: None, basic_server_url: str): assert "Session has been terminated" in response.text -def test_response(basic_server: None, basic_server_url: str): +def test_response(basic_server: str, basic_server_url: str): """Test response handling for a valid request.""" mcp_url = f"{basic_server_url}/mcp" response = requests.post( @@ -841,7 +744,7 @@ def test_response(basic_server: None, basic_server_url: str): assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server: None, json_server_url: str): +def test_json_response(json_response_server: str, json_server_url: str): """Test response handling when is_json_response_enabled is True.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -856,7 +759,7 @@ def test_json_response(json_response_server: None, json_server_url: str): assert response.headers.get("Content-Type") == "application/json" -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): +def test_json_response_accept_json_only(json_response_server: str, json_server_url: str): """Test that json_response servers only require application/json in Accept header.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -871,7 +774,7 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ assert response.headers.get("Content-Type") == "application/json" -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_missing_accept_header(json_response_server: str, json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" # Suppress requests library default Accept: */* header @@ -888,7 +791,7 @@ def test_json_response_missing_accept_header(json_response_server: None, json_se assert "Not Acceptable" in response.text -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_incorrect_accept_header(json_response_server: str, json_server_url: str): """Test that json_response servers reject requests with incorrect Accept header.""" mcp_url = f"{json_server_url}/mcp" # Test with only text/event-stream (wrong for JSON server) @@ -912,7 +815,7 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): +def test_json_response_wildcard_accept_header(json_response_server: str, json_server_url: str, accept_header: str): """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -927,7 +830,7 @@ def test_json_response_wildcard_accept_header(json_response_server: None, json_s assert response.headers.get("Content-Type") == "application/json" -def test_get_sse_stream(basic_server: None, basic_server_url: str): +def test_get_sse_stream(basic_server: str, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -987,7 +890,7 @@ def test_get_sse_stream(basic_server: None, basic_server_url: str): assert second_get.status_code == 409 -def test_get_validation(basic_server: None, basic_server_url: str): +def test_get_validation(basic_server: str, basic_server_url: str): """Test validation for GET requests.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -1044,14 +947,14 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover +async def http_client(basic_server: str, basic_server_url: str): # pragma: no cover """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_server: str, basic_server_url: str): """Create initialized StreamableHTTP client session.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1060,7 +963,7 @@ async def initialized_client_session(basic_server: None, basic_server_url: str): @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_basic_connection(basic_server: str, basic_server_url: str): """Test basic client connection with initialization.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1105,7 +1008,7 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_persistence(basic_server: str, basic_server_url: str): """Test that session ID persists across requests.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1126,7 +1029,7 @@ async def test_streamable_http_client_session_persistence(basic_server: None, ba @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): +async def test_streamable_http_client_json_response(json_response_server: str, json_server_url: str): """Test client with JSON response mode.""" async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1147,7 +1050,7 @@ async def test_streamable_http_client_json_response(json_response_server: None, @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_get_stream(basic_server: str, basic_server_url: str): """Test GET stream functionality for server-initiated messages.""" notifications_received: list[types.ServerNotification] = [] @@ -1198,7 +1101,7 @@ async def capture_session_id(response: httpx.Response) -> None: @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_termination(basic_server: str, basic_server_url: str): """Test client session termination functionality.""" # Use httpx client with event hooks to capture session ID httpx_client, captured_ids = create_session_id_capturing_client() @@ -1234,7 +1137,7 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch + basic_server: str, basic_server_url: str, monkeypatch: pytest.MonkeyPatch ): """Test client session termination functionality with a 204 response. @@ -1412,7 +1315,7 @@ async def run_tool(): @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): +async def test_streamablehttp_server_sampling(basic_server: str, basic_server_url: str): """Test server-initiated sampling request through streamable HTTP transport.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False @@ -1462,7 +1365,7 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( # pragma: lax no cover ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -1487,7 +1390,7 @@ async def _handle_context_list_tools( # pragma: no cover ) -async def _handle_context_call_tool( # pragma: no cover +async def _handle_context_call_tool( # pragma: lax no cover ctx: ServerRequestContext, params: CallToolRequestParams ) -> CallToolResult: name = params.name @@ -1516,59 +1419,30 @@ async def _handle_context_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +def create_context_aware_app() -> Starlette: + """Create the context-aware test server app.""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, on_call_tool=_handle_context_call_tool, ) - - session_manager = StreamableHTTPSessionManager( - app=server, - event_store=None, - json_response=False, - ) - - app = Starlette( + session_manager = StreamableHTTPSessionManager(app=server, event_store=None, json_response=False) + return Starlette( debug=True, - routes=[ - Mount("/mcp", app=session_manager.handle_request), - ], + routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=lambda app: session_manager.run(), ) - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() - @pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") +def context_aware_server() -> Generator[str, None, None]: + """Start the context-aware server. Yields the server URL.""" + with run_server_in_thread(create_context_aware_app()) as url: + yield url @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_propagation(context_aware_server: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", @@ -1577,7 +1451,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: } async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1601,7 +1475,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_isolation(context_aware_server: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" contexts: list[dict[str, Any]] = [] @@ -1614,7 +1488,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No } async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1639,9 +1513,9 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): +async def test_client_includes_protocol_version_header_after_init(context_aware_server: str): """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # Initialize and get the negotiated version init_result = await session.initialize() @@ -1659,7 +1533,7 @@ async def test_client_includes_protocol_version_header_after_init(context_aware_ assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): +def test_server_validates_protocol_version_header(basic_server: str, basic_server_url: str): """Test that server returns 400 Bad Request version if header unsupported or invalid.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1717,7 +1591,7 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv assert response.status_code == 200 -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): +def test_server_backwards_compatibility_no_protocol_version(basic_server: str, basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1747,7 +1621,7 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server: None, @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): +async def test_client_crash_handled(basic_server: str, basic_server_url: str): """Test that cases where the client crashes are handled gracefully.""" # Simulate bad client that crashes after init @@ -2219,9 +2093,7 @@ async def message_handler( @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_does_not_mutate_provided_client(basic_server: str, basic_server_url: str) -> None: """Test that streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { @@ -2252,9 +2124,7 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_mcp_headers_override_defaults(context_aware_server: str) -> None: """Test that MCP protocol headers override httpx.AsyncClient default headers.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests @@ -2263,7 +2133,10 @@ async def test_streamable_http_client_mcp_headers_override_defaults( # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2283,9 +2156,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_aware_server: str) -> None: """Test that both custom headers and MCP protocol headers are sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", @@ -2294,7 +2165,10 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( } async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_aware_server}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5c04c269f..98a901207 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,16 @@ """Common test utilities for MCP server tests.""" +import contextlib +import gc import socket +import threading import time +import warnings +from collections.abc import Generator +from typing import Literal + +import uvicorn +from starlette.types import ASGIApp def wait_for_server(port: int, timeout: float = 20.0) -> None: @@ -29,3 +38,59 @@ def wait_for_server(port: int, timeout: float = 20.0) -> None: # Server not ready yet, retry quickly time.sleep(0.01) raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover + + +@contextlib.contextmanager +def run_server_in_thread(app: ASGIApp, lifespan: Literal["auto", "on", "off"] = "on") -> Generator[str, None, None]: + """Run a Starlette/ASGI app in a uvicorn server on a background thread. + + Uses `port=0` so the kernel atomically assigns an available port, eliminating + the TOCTOU port-allocation race that affects subprocess-based fixtures. The + actual bound port is read back from the server's socket after binding. + + Unlike multiprocessing, this runs in-process so: + - No port race (port=0 is assigned atomically at bind time) + - No pickling of app/state (the app runs in the same process) + - Faster startup (no fork/exec overhead) + - Works with both asyncio and trio test backends (uvicorn runs its own + asyncio loop in the thread; uvicorn skips signal handlers automatically + when not on the main thread) + + Args: + app: The ASGI application to serve. + lifespan: uvicorn lifespan mode — "on" to run app lifespan events, + "off" to skip them (default "on"). + + Yields: + Base URL of the running server (e.g., "http://127.0.0.1:54321"). + """ + config = uvicorn.Config(app=app, host="127.0.0.1", port=0, log_level="error", lifespan=lifespan) + server = uvicorn.Server(config=config) + + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for uvicorn to bind and start accepting connections + start_time = time.time() + while not server.started: + if time.time() - start_time > 20.0: # pragma: no cover + raise TimeoutError("uvicorn server did not start within 20 seconds") + time.sleep(0.01) + + # Read back the kernel-assigned port from the bound socket + port = server.servers[0].sockets[0].getsockname()[1] + try: + yield f"http://127.0.0.1:{port}" + finally: + server.should_exit = True + thread.join(timeout=5) + # When uvicorn shuts down with in-flight SSE connections, the server + # cancels request handlers mid-operation. SseServerTransport's internal + # memory streams may not get their `finally` cleanup run before GC, + # causing ResourceWarnings. These are artifacts of test abrupt-disconnect + # patterns (open SSE stream → check status → exit without consuming), + # not bugs. Force GC here and suppress the warnings so they don't leak + # into the next test's PytestUnraisableExceptionWarning collector. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ResourceWarning) + gc.collect()