diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index e9156f7ba..387c4d3f6 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -91,7 +91,7 @@ async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) Returns: The generated event ID for the stored event """ - pass # pragma: no cover + pass # pragma: lax no cover @abstractmethod async def replay_events_after( @@ -108,7 +108,7 @@ async def replay_events_after( Returns: The stream ID of the replayed events """ - pass # pragma: no cover + pass # pragma: lax no cover class StreamableHTTPServerTransport: @@ -175,7 +175,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: # pragma: lax no cover """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -203,7 +203,7 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover send_stream.close() receive_stream.close() - def close_standalone_sse_stream(self) -> None: # pragma: no cover + def close_standalone_sse_stream(self) -> None: # pragma: lax no cover """Close the standalone GET SSE stream, triggering client reconnection. This method closes the HTTP connection for the standalone GET stream used @@ -238,10 +238,10 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: # pragma: lax no cover self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: # pragma: lax no cover self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -289,7 +289,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: # pragma: lax no cover response_headers.update(headers) if self.mcp_session_id: @@ -328,11 +328,11 @@ def _create_json_response( headers=response_headers, ) - def _get_session_id(self, request: Request) -> str | None: # pragma: no cover + def _get_session_id(self, request: Request) -> str | None: # pragma: lax no cover """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) - def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: lax no cover """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", @@ -352,7 +352,7 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None: # Close the request stream await self._request_streams[request_id][0].aclose() await self._request_streams[request_id][1].aclose() - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover # During cleanup, we catch all exceptions since streams might be in various states logger.debug("Error closing memory streams - may already be closed") finally: @@ -370,7 +370,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # pragma: lax no cover # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -381,20 +381,32 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No if request.method == "POST": await self._handle_post_request(scope, request, receive, send) - elif request.method == "GET": # pragma: no cover + elif request.method == "GET": # pragma: lax no cover await self._handle_get_request(request, send) - elif request.method == "DELETE": # pragma: no cover + elif request.method == "DELETE": # pragma: lax no cover await self._handle_delete_request(request, send) - else: # pragma: no cover + else: # pragma: lax no cover await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """Check if the request accepts the required media types.""" + """Check if the request accepts the required media types. + + Supports wildcard media types per RFC 9110 Section 12.5.1: + - */* matches any media type + - application/* matches any application subtype (e.g., application/json) + - text/* matches any text subtype (e.g., text/event-stream) + """ accept_header = request.headers.get("accept", "") - accept_types = [media_type.strip() for media_type in accept_header.split(",")] + # Strip quality parameters (e.g., ";q=0.9") before matching + accept_types = [media_type.strip().split(";")[0].strip() for media_type in accept_header.split(",")] - has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) - has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) + has_json = any( + media_type.startswith(CONTENT_TYPE_JSON) or media_type in {"*/*", "application/*"} + for media_type in accept_types + ) + has_sse = any( + media_type.startswith(CONTENT_TYPE_SSE) or media_type in {"*/*", "text/*"} for media_type in accept_types + ) return has_json, has_sse @@ -430,7 +442,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer - if writer is None: # pragma: no cover + if writer is None: # pragma: lax no cover raise ValueError("No read stream writer available. Ensure connect() is called first.") try: # Validate Accept header @@ -438,7 +450,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return # Validate Content-Type - if not self._check_content_type(request): # pragma: no cover + if not self._check_content_type(request): # pragma: lax no cover response = self._create_error_response( "Unsupported Media Type: Content-Type must be application/json", HTTPStatus.UNSUPPORTED_MEDIA_TYPE, @@ -458,7 +470,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: # pragma: lax no cover response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -470,7 +482,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Check if this is an initialization request is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" - if is_initialization_request: # pragma: no cover + if is_initialization_request: # pragma: lax no cover # Check if the server already has an established session if self.mcp_session_id: # Check if request has a session ID @@ -484,11 +496,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): # pragma: lax no cover return # For notifications and responses only, return 202 Accepted - if not isinstance(message, JSONRPCRequest): # pragma: no cover + if not isinstance(message, JSONRPCRequest): # pragma: lax no cover # Create response object and send it response = self._create_json_response( None, @@ -535,7 +547,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re response_message = event_message.message break # For notifications and request, keep waiting - else: # pragma: no cover + else: # pragma: lax no cover logger.debug(f"received: {event_message.message.method}") # At this point we should have a response @@ -543,7 +555,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Create JSON response response = self._create_json_response(response_message) await response(scope, receive, send) - else: # pragma: no cover + else: # pragma: lax no cover # This shouldn't happen in normal operation logger.error("No response message received before stream closed") response = self._create_error_response( @@ -551,7 +563,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re HTTPStatus.INTERNAL_SERVER_ERROR, ) await response(scope, receive, send) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error processing JSON response") response = self._create_error_response( "Error processing request", @@ -561,7 +573,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) finally: await self._clean_up_memory_streams(request_id) - else: # pragma: no cover + else: # pragma: lax no cover # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) @@ -623,7 +635,7 @@ async def sse_writer(): await sse_stream_reader.aclose() await self._clean_up_memory_streams(request_id) - except Exception as err: # pragma: no cover + except Exception as err: # pragma: lax no cover logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -635,7 +647,7 @@ async def sse_writer(): await writer.send(Exception(err)) return - async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: lax no cover """Handle GET request to establish SSE. This allows the server to communicate to the client without the client @@ -726,7 +738,7 @@ async def standalone_sse_writer(): await sse_stream_reader.aclose() await self._clean_up_memory_streams(GET_STREAM_KEY) - async def _handle_delete_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_delete_request(self, request: Request, send: Send) -> None: # pragma: lax no cover """Handle DELETE requests for explicit session termination.""" # Validate session ID if not self.mcp_session_id: @@ -776,11 +788,11 @@ async def terminate(self) -> None: await self._write_stream_reader.aclose() if self._write_stream is not None: # pragma: no branch await self._write_stream.aclose() - except Exception as e: # pragma: no cover + except Exception as e: # pragma: lax no cover # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: lax no cover """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, @@ -796,14 +808,14 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) - async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: no cover + async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover if not await self._validate_session(request, send): return False if not await self._validate_protocol_version(request, send): return False return True - async def _validate_session(self, request: Request, send: Send) -> bool: # pragma: no cover + async def _validate_session(self, request: Request, send: Send) -> bool: # pragma: lax no cover """Validate the session ID in the request.""" if not self.mcp_session_id: # If we're not using session IDs, return True @@ -832,7 +844,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag return True - async def _validate_protocol_version(self, request: Request, send: Send) -> bool: # pragma: no cover + async def _validate_protocol_version(self, request: Request, send: Send) -> bool: # pragma: lax no cover """Validate the protocol version header in the request.""" # Get the protocol version from the request headers protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) @@ -854,7 +866,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: lax no cover """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. """ @@ -980,7 +992,7 @@ async def message_router(): # send it there target_request_id = response_id # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( # pragma: lax no cover session_message.metadata is not None and isinstance( session_message.metadata, @@ -1004,13 +1016,13 @@ async def message_router(): try: # Send both the message and the event ID await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except ( # pragma: no cover + except ( # pragma: lax no cover anyio.BrokenResourceError, anyio.ClosedResourceError, ): # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: # pragma: lax no cover logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client @@ -1041,6 +1053,6 @@ async def message_router(): await read_stream.aclose() await write_stream_reader.aclose() await write_stream.aclose() - except Exception as e: # pragma: no cover + except Exception as e: # pragma: lax no cover # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 964c52b6f..bec7d4bd5 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -182,7 +182,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") # Assert task group is not None for type checking @@ -213,7 +213,9 @@ async def _handle_stateful_request( request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) # Existing session case - if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover + if ( + request_mcp_session_id is not None and request_mcp_session_id in self._server_instances + ): # pragma: lax no cover transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") await transport.handle_request(scope, receive, send) @@ -297,5 +299,5 @@ class StreamableHTTPASGIApp: def __init__(self, session_manager: StreamableHTTPSessionManager): self.session_manager = session_manager - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: lax no cover await self.session_manager.handle_request(scope, receive, send) diff --git a/tests/issues/test_1641_accept_header_wildcard.py b/tests/issues/test_1641_accept_header_wildcard.py new file mode 100644 index 000000000..fb2e2a1bf --- /dev/null +++ b/tests/issues/test_1641_accept_header_wildcard.py @@ -0,0 +1,155 @@ +"""Test for issue #1641 - Accept header wildcard support. + +The MCP server was rejecting requests with wildcard Accept headers like `*/*` +or `application/*`, returning 406 Not Acceptable. Per RFC 9110 Section 12.5.1, +wildcard media types are valid and should match the required content types. + +These tests verify the `_check_accept_headers` method directly, ensuring +wildcard media types are properly matched against the required content types +(application/json and text/event-stream). +""" + +import pytest +from starlette.requests import Request + +from mcp.server.streamable_http import StreamableHTTPServerTransport + + +def _make_request(accept: str) -> Request: + """Create a minimal Request with the given Accept header.""" + scope = { + "type": "http", + "method": "POST", + "headers": [(b"accept", accept.encode())], + } + return Request(scope) + + +@pytest.mark.anyio +async def test_accept_wildcard_star_star_json_mode(): + """Accept: */* should satisfy application/json requirement.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=True, + ) + request = _make_request("*/*") + has_json, has_sse = transport._check_accept_headers(request) + assert has_json, "*/* should match application/json" + assert has_sse, "*/* should match text/event-stream" + + +@pytest.mark.anyio +async def test_accept_wildcard_star_star_sse_mode(): + """Accept: */* should satisfy both JSON and SSE requirements.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=False, + ) + request = _make_request("*/*") + has_json, has_sse = transport._check_accept_headers(request) + assert has_json, "*/* should match application/json" + assert has_sse, "*/* should match text/event-stream" + + +@pytest.mark.anyio +async def test_accept_application_wildcard(): + """Accept: application/* should satisfy application/json but not text/event-stream.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=True, + ) + request = _make_request("application/*") + has_json, has_sse = transport._check_accept_headers(request) + assert has_json, "application/* should match application/json" + assert not has_sse, "application/* should NOT match text/event-stream" + + +@pytest.mark.anyio +async def test_accept_text_wildcard_with_json(): + """Accept: application/json, text/* should satisfy both requirements in SSE mode.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=False, + ) + request = _make_request("application/json, text/*") + has_json, has_sse = transport._check_accept_headers(request) + assert has_json, "application/json should match JSON content type" + assert has_sse, "text/* should match text/event-stream" + + +@pytest.mark.anyio +async def test_accept_wildcard_with_quality_parameter(): + """Accept: */*;q=0.8 should be accepted (quality parameters stripped before matching).""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=True, + ) + request = _make_request("*/*;q=0.8") + has_json, has_sse = transport._check_accept_headers(request) + assert has_json, "*/*;q=0.8 should match application/json after stripping quality" + assert has_sse, "*/*;q=0.8 should match text/event-stream after stripping quality" + + +@pytest.mark.anyio +async def test_accept_invalid_still_rejected(): + """Accept: text/plain should not match JSON or SSE content types.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=True, + ) + request = _make_request("text/plain") + has_json, has_sse = transport._check_accept_headers(request) + assert not has_json, "text/plain should NOT match application/json" + assert not has_sse, "text/plain should NOT match text/event-stream" + + +@pytest.mark.anyio +async def test_accept_partial_wildcard_sse_mode(): + """Accept: application/* alone should not satisfy SSE requirement.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=False, + ) + request = _make_request("application/*") + has_json, has_sse = transport._check_accept_headers(request) + assert has_json, "application/* should match application/json" + assert not has_sse, "application/* should NOT match text/event-stream" + + +@pytest.mark.anyio +async def test_accept_explicit_types(): + """Accept: application/json, text/event-stream should match both explicitly.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=False, + ) + request = _make_request("application/json, text/event-stream") + has_json, has_sse = transport._check_accept_headers(request) + assert has_json, "application/json should match" + assert has_sse, "text/event-stream should match" + + +@pytest.mark.anyio +async def test_accept_text_wildcard_alone(): + """Accept: text/* alone should match SSE but not JSON.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=False, + ) + request = _make_request("text/*") + has_json, has_sse = transport._check_accept_headers(request) + assert not has_json, "text/* should NOT match application/json" + assert has_sse, "text/* should match text/event-stream" + + +@pytest.mark.anyio +async def test_accept_multiple_quality_parameters(): + """Multiple types with quality parameters should all be parsed correctly.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=False, + ) + request = _make_request("application/json;q=1.0, text/event-stream;q=0.9") + has_json, has_sse = transport._check_accept_headers(request) + assert has_json, "application/json;q=1.0 should match after stripping quality" + assert has_sse, "text/event-stream;q=0.9 should match after stripping quality" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b1332772a..a35a75be9 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -566,10 +566,10 @@ def json_server_url(json_server_port: int) -> str: # Basic request validation tests def test_accept_header_validation(basic_server: None, basic_server_url: str): """Test that Accept header is properly validated.""" - # Test without Accept header + # Test with non-matching Accept header (text/html doesn't match json or sse) response = requests.post( f"{basic_server_url}/mcp", - headers={"Content-Type": "application/json"}, + headers={"Content-Type": "application/json", "Accept": "text/html"}, json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, ) assert response.status_code == 406 @@ -818,12 +818,13 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): - """Test that json_response servers reject requests without Accept header.""" + """Test that json_response servers reject requests with non-matching Accept header.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( mcp_url, headers={ "Content-Type": "application/json", + "Accept": "text/html", }, json=INIT_REQUEST, ) @@ -935,12 +936,13 @@ def test_get_validation(basic_server: None, basic_server_url: str): assert init_data is not None negotiated_version = init_data["result"]["protocolVersion"] - # Test without Accept header + # Test with non-matching Accept header response = requests.get( mcp_url, headers={ MCP_SESSION_ID_HEADER: session_id, MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + "Accept": "text/html", }, stream=True, )