diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index 0857e38bd..68dfe0803 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -3,7 +3,8 @@ from mcp.types import Icon from .context import Context +from .exceptions import PromptError, ResourceError, ToolError from .server import MCPServer from .utilities.types import Audio, Image -__all__ = ["MCPServer", "Context", "Image", "Audio", "Icon"] +__all__ = ["MCPServer", "Context", "Image", "Audio", "Icon", "ToolError", "ResourceError", "PromptError"] diff --git a/src/mcp/server/mcpserver/exceptions.py b/src/mcp/server/mcpserver/exceptions.py index dd1b75e82..a6860b96b 100644 --- a/src/mcp/server/mcpserver/exceptions.py +++ b/src/mcp/server/mcpserver/exceptions.py @@ -17,5 +17,9 @@ class ToolError(MCPServerError): """Error in tool operations.""" +class PromptError(MCPServerError): + """Error in prompt operations.""" + + class InvalidSignature(Exception): """Invalid signature for use with MCPServer.""" diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 0c319d53c..ce0cc604a 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -9,8 +9,10 @@ import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call +from mcp.server.mcpserver.exceptions import PromptError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared.exceptions import MCPError from mcp.types import ContentBlock, Icon, TextContent if TYPE_CHECKING: @@ -141,7 +143,7 @@ async def render( """Render the prompt with arguments. Raises: - ValueError: If required arguments are missing, or if rendering fails. + PromptError: If required arguments are missing, or if rendering fails. """ # Validate required arguments if self.arguments: @@ -149,7 +151,7 @@ async def render( provided = set(arguments or {}) missing = required - provided if missing: - raise ValueError(f"Missing required arguments: {missing}") + raise PromptError(f"Missing required arguments: {missing}") try: # Add context to arguments if needed @@ -182,5 +184,7 @@ async def render( raise ValueError(f"Could not convert prompt result to message: {msg}") return messages + except (PromptError, MCPError): # pragma: no cover + raise except Exception as e: # pragma: no cover raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/server/mcpserver/prompts/manager.py b/src/mcp/server/mcpserver/prompts/manager.py index 28a7a6e98..a07b2dda1 100644 --- a/src/mcp/server/mcpserver/prompts/manager.py +++ b/src/mcp/server/mcpserver/prompts/manager.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any +from mcp.server.mcpserver.exceptions import PromptError from mcp.server.mcpserver.prompts.base import Message, Prompt from mcp.server.mcpserver.utilities.logging import get_logger @@ -54,6 +55,6 @@ async def render_prompt( """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: - raise ValueError(f"Unknown prompt: {name}") + raise PromptError(f"Unknown prompt: {name}") return await prompt.render(arguments, context) diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 42aecd6e3..e17557fb8 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -13,7 +13,9 @@ import pydantic_core from pydantic import Field, ValidationInfo, validate_call +from mcp.server.mcpserver.exceptions import ResourceError from mcp.server.mcpserver.resources.base import Resource +from mcp.shared.exceptions import MCPError from mcp.types import Annotations, Icon @@ -69,6 +71,8 @@ async def read(self) -> str | bytes: return result else: return pydantic_core.to_json(result, fallback=str, indent=2).decode() + except (ResourceError, MCPError): + raise except Exception as e: raise ValueError(f"Error reading resource {self.uri}: {e}") diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2a7a58117..16311fe37 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -31,7 +31,7 @@ from mcp.server.lowlevel.server import LifespanResultT, Server from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.mcpserver.context import Context -from mcp.server.mcpserver.exceptions import ResourceError +from mcp.server.mcpserver.exceptions import PromptError, ResourceError, ToolError from mcp.server.mcpserver.prompts import Prompt, PromptManager from mcp.server.mcpserver.resources import FunctionResource, Resource, ResourceManager from mcp.server.mcpserver.tools import Tool, ToolManager @@ -44,6 +44,8 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError from mcp.types import ( + INTERNAL_ERROR, + INVALID_PARAMS, Annotations, BlobResourceContents, CallToolRequestParams, @@ -303,8 +305,14 @@ async def _handle_call_tool( result = await self.call_tool(params.name, params.arguments or {}, context) except MCPError: raise - except Exception as e: + except ToolError as e: return CallToolResult(content=[TextContent(type="text", text=str(e))], is_error=True) + except Exception: + logger.exception(f"Unhandled error in tool {params.name}") + return CallToolResult( + content=[TextContent(type="text", text=f"Internal error executing tool {params.name}")], + is_error=True, + ) if isinstance(result, CallToolResult): return result if isinstance(result, tuple) and len(result) == 2: @@ -332,7 +340,16 @@ async def _handle_read_resource( self, ctx: ServerRequestContext[LifespanResultT], params: ReadResourceRequestParams ) -> ReadResourceResult: context = Context(request_context=ctx, mcp_server=self) - results = await self.read_resource(params.uri, context) + try: + results = await self.read_resource(params.uri, context) + except MCPError: + raise + except ResourceError as e: + raise MCPError(code=INVALID_PARAMS, message=str(e)) + except Exception: + logger.exception(f"Unhandled error reading resource {params.uri}") + raise MCPError(code=INTERNAL_ERROR, message=f"Internal error reading resource {params.uri}") + contents: list[TextResourceContents | BlobResourceContents] = [] for item in results: if isinstance(item.content, bytes): @@ -369,7 +386,15 @@ async def _handle_get_prompt( self, ctx: ServerRequestContext[LifespanResultT], params: GetPromptRequestParams ) -> GetPromptResult: context = Context(request_context=ctx, mcp_server=self) - return await self.get_prompt(params.name, params.arguments, context) + try: + return await self.get_prompt(params.name, params.arguments, context) + except MCPError: + raise + except PromptError as e: + raise MCPError(code=INVALID_PARAMS, message=str(e)) + except Exception: + logger.exception(f"Unhandled error in prompt {params.name}") + raise MCPError(code=INTERNAL_ERROR, message=f"Internal error getting prompt {params.name}") async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -444,9 +469,10 @@ async def read_resource( try: content = await resource.read() return [ReadResourceContents(content=content, mime_type=resource.mime_type, meta=resource.meta)] + except (ResourceError, MCPError): + raise except Exception as exc: - logger.exception(f"Error getting resource {uri}") - # If an exception happens when reading the resource, we should not leak the exception to the client. + logger.exception(f"Error reading resource {uri}") raise ResourceError(f"Error reading resource {uri}") from exc def add_tool( @@ -1090,7 +1116,7 @@ async def get_prompt( try: prompt = self._prompt_manager.get_prompt(name) if not prompt: - raise ValueError(f"Unknown prompt: {name}") + raise PromptError(f"Unknown prompt: {name}") messages = await prompt.render(arguments, context) @@ -1098,6 +1124,8 @@ async def get_prompt( description=prompt.description, messages=pydantic_core.to_jsonable_python(messages), ) + except (PromptError, MCPError): + raise except Exception as e: logger.exception(f"Error getting prompt {name}") - raise ValueError(str(e)) + raise PromptError(f"Error getting prompt {name}") from e diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index dc65be988..258c8effe 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -11,7 +11,7 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata -from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations @@ -112,12 +112,14 @@ async def run( result = self.fn_metadata.convert_result(result) return result - except UrlElicitationRequiredError: - # Re-raise UrlElicitationRequiredError so it can be properly handled - # as an MCP error response with code -32042 + except (UrlElicitationRequiredError, MCPError, ToolError): + # Re-raise framework and user-raised exceptions without wrapping. + # - UrlElicitationRequiredError → MCP error response (code -32042) + # - MCPError → JSON-RPC error response + # - ToolError → CallToolResult(is_error=True) raise except Exception as e: - raise ToolError(f"Error executing tool {self.name}: {e}") from e + raise ToolError(f"Error executing tool {self.name}") from e def _is_async_callable(obj: Any) -> bool: diff --git a/tests/server/mcpserver/prompts/test_base.py b/tests/server/mcpserver/prompts/test_base.py index fe18e91bd..636fa14e1 100644 --- a/tests/server/mcpserver/prompts/test_base.py +++ b/tests/server/mcpserver/prompts/test_base.py @@ -3,6 +3,7 @@ import pytest from mcp.server.mcpserver import Context +from mcp.server.mcpserver.exceptions import PromptError from mcp.server.mcpserver.prompts.base import AssistantMessage, Message, Prompt, UserMessage from mcp.types import EmbeddedResource, TextContent, TextResourceContents @@ -44,7 +45,7 @@ async def fn(name: str, age: int = 30) -> str: # pragma: no cover return f"Hello, {name}! You're {age} years old." prompt = Prompt.from_function(fn) - with pytest.raises(ValueError): + with pytest.raises(PromptError): await prompt.render({"age": 40}, Context()) @pytest.mark.anyio diff --git a/tests/server/mcpserver/prompts/test_manager.py b/tests/server/mcpserver/prompts/test_manager.py index 99a03db56..96d3ae09f 100644 --- a/tests/server/mcpserver/prompts/test_manager.py +++ b/tests/server/mcpserver/prompts/test_manager.py @@ -1,6 +1,7 @@ import pytest from mcp.server.mcpserver import Context +from mcp.server.mcpserver.exceptions import PromptError from mcp.server.mcpserver.prompts.base import Prompt, UserMessage from mcp.server.mcpserver.prompts.manager import PromptManager from mcp.types import TextContent @@ -93,7 +94,7 @@ def fn(name: str) -> str: async def test_render_unknown_prompt(self): """Test rendering a non-existent prompt.""" manager = PromptManager() - with pytest.raises(ValueError, match="Unknown prompt: unknown"): + with pytest.raises(PromptError, match="Unknown prompt: unknown"): await manager.render_prompt("unknown", None, Context()) @pytest.mark.anyio @@ -106,5 +107,5 @@ def fn(name: str) -> str: # pragma: no cover manager = PromptManager() prompt = Prompt.from_function(fn) manager.add_prompt(prompt) - with pytest.raises(ValueError, match="Missing required arguments"): + with pytest.raises(PromptError, match="Missing required arguments"): await manager.render_prompt("fn", None, Context()) diff --git a/tests/server/mcpserver/test_error_handling.py b/tests/server/mcpserver/test_error_handling.py new file mode 100644 index 000000000..2a0a3bde3 --- /dev/null +++ b/tests/server/mcpserver/test_error_handling.py @@ -0,0 +1,150 @@ +"""Tests for MCPServer raise-based error handling. + +Validates that MCPServer handlers support a consistent raise-based error pattern: +- ToolError → CallToolResult(is_error=True) with the user's message +- ResourceError / PromptError → MCPError (JSON-RPC error) with the user's message +- MCPError → re-raised as-is (protocol-level error) +- Unexpected exceptions → sanitized message (no internal detail leakage) +""" + +import pytest + +from mcp.client import Client +from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver.exceptions import PromptError, ResourceError, ToolError +from mcp.shared.exceptions import MCPError +from mcp.types import INVALID_PARAMS, TextContent + +pytestmark = pytest.mark.anyio + + +# --------------------------------------------------------------------------- +# Tool error handling +# --------------------------------------------------------------------------- + + +class TestToolErrorHandling: + async def test_tool_error_reaches_client(self) -> None: + """User raises ToolError → client sees CallToolResult(is_error=True) with exact message.""" + mcp = MCPServer() + + @mcp.tool() + def fail_tool() -> str: + raise ToolError("invalid input") + + async with Client(mcp) as client: + result = await client.call_tool("fail_tool", {}) + assert result.is_error is True + content = result.content[0] + assert isinstance(content, TextContent) + assert "invalid input" in content.text + + async def test_unexpected_exception_does_not_leak(self) -> None: + """Plain exception should NOT leak internal details to client.""" + mcp = MCPServer() + + @mcp.tool() + def secret_fail() -> str: + raise RuntimeError("secret database password is hunter2") + + async with Client(mcp) as client: + result = await client.call_tool("secret_fail", {}) + assert result.is_error is True + content = result.content[0] + assert isinstance(content, TextContent) + # Internal details must not reach the client + assert "hunter2" not in content.text + assert "secret_fail" in content.text + + async def test_mcp_error_from_tool_becomes_jsonrpc_error(self) -> None: + """MCPError raised in a tool → JSON-RPC error (not CallToolResult).""" + mcp = MCPServer() + + @mcp.tool() + def protocol_fail() -> str: + raise MCPError(code=INVALID_PARAMS, message="bad params") + + async with Client(mcp) as client: + with pytest.raises(MCPError, match="bad params"): + await client.call_tool("protocol_fail", {}) + + +# --------------------------------------------------------------------------- +# Resource error handling +# --------------------------------------------------------------------------- + + +class TestResourceErrorHandling: + async def test_resource_error_reaches_client(self) -> None: + """User raises ResourceError → client sees MCPError with the user's message.""" + mcp = MCPServer() + + @mcp.resource("resource://guarded") + def guarded_resource() -> str: + raise ResourceError("access denied") + + async with Client(mcp) as client: + with pytest.raises(MCPError, match="access denied"): + await client.read_resource("resource://guarded") + + async def test_unexpected_resource_error_does_not_leak(self) -> None: + """Plain exception from resource should NOT leak internal details.""" + mcp = MCPServer() + + @mcp.resource("resource://broken") + def broken_resource() -> str: + raise RuntimeError("secret internal state") + + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("resource://broken") + # Internal details must not reach the client + assert "secret internal state" not in exc_info.value.message + assert "resource://broken" in exc_info.value.message + + async def test_unknown_resource_error(self) -> None: + """Reading a non-existent resource → MCPError.""" + mcp = MCPServer() + + async with Client(mcp) as client: + with pytest.raises(MCPError, match="Unknown resource"): + await client.read_resource("resource://nonexistent") + + +# --------------------------------------------------------------------------- +# Prompt error handling +# --------------------------------------------------------------------------- + + +class TestPromptErrorHandling: + async def test_prompt_error_reaches_client(self) -> None: + """User raises PromptError → client sees MCPError with the user's message.""" + mcp = MCPServer() + + @mcp.prompt() + def bad_prompt() -> str: + raise PromptError("invalid context") + + async with Client(mcp) as client: + with pytest.raises(MCPError, match="invalid context"): + await client.get_prompt("bad_prompt") + + async def test_unknown_prompt_error(self) -> None: + """Getting a non-existent prompt → MCPError.""" + mcp = MCPServer() + + async with Client(mcp) as client: + with pytest.raises(MCPError, match="Unknown prompt"): + await client.get_prompt("nonexistent") + + async def test_missing_prompt_args_error(self) -> None: + """Missing required prompt arguments → MCPError.""" + mcp = MCPServer() + + @mcp.prompt() + def greeting(name: str) -> str: # pragma: no cover + return f"Hello, {name}!" + + async with Client(mcp) as client: + with pytest.raises(MCPError, match="Missing required arguments"): + await client.get_prompt("greeting") diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 3d130bfc3..436c86878 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -258,7 +258,8 @@ async def test_tool_exception_handling(self): assert len(result.content) == 1 content = result.content[0] assert isinstance(content, TextContent) - assert "Test error" in content.text + # Unexpected exceptions don't leak internal details to client + assert "error_tool_fn" in content.text assert result.is_error is True async def test_tool_error_handling(self): @@ -269,7 +270,7 @@ async def test_tool_error_handling(self): assert len(result.content) == 1 content = result.content[0] assert isinstance(content, TextContent) - assert "Test error" in content.text + assert "error_tool_fn" in content.text assert result.is_error is True async def test_tool_error_details(self): @@ -281,7 +282,7 @@ async def test_tool_error_details(self): content = result.content[0] assert isinstance(content, TextContent) assert isinstance(content.text, str) - assert "Test error" in content.text + assert "error_tool_fn" in content.text assert result.is_error is True async def test_tool_return_value_conversion(self): diff --git a/tests/server/mcpserver/test_url_elicitation_error_throw.py b/tests/server/mcpserver/test_url_elicitation_error_throw.py index 2d2993799..9b4dc0ab2 100644 --- a/tests/server/mcpserver/test_url_elicitation_error_throw.py +++ b/tests/server/mcpserver/test_url_elicitation_error_throw.py @@ -93,7 +93,11 @@ async def multi_auth(ctx: Context[ServerSession, None]) -> str: @pytest.mark.anyio async def test_normal_exceptions_still_return_error_result(): - """Test that normal exceptions still return CallToolResult with is_error=True.""" + """Test that normal exceptions still return CallToolResult with is_error=True. + + Unexpected exceptions are sanitized: the client sees a generic message + containing the tool name, not the raw exception details. + """ mcp = MCPServer(name="NormalErrorServer") @mcp.tool(description="A tool that raises a normal exception") @@ -106,4 +110,5 @@ async def failing_tool(ctx: Context[ServerSession, None]) -> str: assert result.is_error is True assert len(result.content) == 1 assert isinstance(result.content[0], types.TextContent) - assert "Something went wrong" in result.content[0].text + # Unexpected exceptions are sanitized — internal details not leaked + assert "failing_tool" in result.content[0].text