Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3013,6 +3013,18 @@ async def _call_streamable_http(self, scope, receive, send):
# Skip rewriting for well-known URIs (RFC 9728 OAuth metadata, etc.)
# These paths may end with /mcp but should not be rewritten to the MCP transport
if not app_path.startswith("/.well-known/"):
# Normalise bare /mcp to /mcp/ so the Starlette Mount at /mcp matches
# directly. Without this rewrite Starlette would emit a 307 redirect
# to /mcp/, which httpx cannot follow for streaming POST bodies
# (chunked Streamable HTTP) and surfaces as httpx.ReadError during
# initialize. See #4275.
if app_path == "/mcp":
new_path = f"{root_path}/mcp/" if root_path else "/mcp/"
scope["path"] = new_path
if "raw_path" in scope:
scope["raw_path"] = new_path.encode("latin-1")
await self.application(scope, receive, send)
return
if (app_path.endswith("/mcp") and app_path != "/mcp") or (app_path.endswith("/mcp/") and app_path != "/mcp/"):
# SECURITY: Only rewrite recognised MCP paths — /servers/{id}/mcp.
# Arbitrary prefixes (e.g. /foo/mcp) must NOT be rewritten to
Expand All @@ -3033,8 +3045,13 @@ async def _call_streamable_http(self, scope, receive, send):
await self.application(scope, receive, send)
return
# Rewrite to /mcp/ and continue through middleware (lets CORSMiddleware handle preflight)
# Preserve root_path prefix when rewriting
scope["path"] = f"{root_path}/mcp/" if root_path else "/mcp/"
# Preserve root_path prefix when rewriting. Keep raw_path aligned so
# downstream URL reconstruction (e.g. Starlette RedirectResponse) stays
# consistent with the rewritten path (#4275).
new_path = f"{root_path}/mcp/" if root_path else "/mcp/"
scope["path"] = new_path
if "raw_path" in scope:
scope["raw_path"] = new_path.encode("latin-1")
await self.application(scope, receive, send)
return
await self.application(scope, receive, send)
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/mcpgateway/test_main_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -2989,6 +2989,79 @@ async def mock_send(msg):
# ORJSONResponse sends http.response.start + http.response.body
assert any(m.get("status") == 404 for m in sent if m.get("type") == "http.response.start")

@pytest.mark.asyncio
async def test_rewrite_bare_mcp_to_trailing_slash(self):
"""Bare /mcp rewrites to /mcp/ so Starlette's Mount matches without a 307 redirect (#4275)."""
app_mock = AsyncMock()
middleware = MCPPathRewriteMiddleware(app_mock)
scope = {"type": "http", "path": "/mcp", "headers": []}
receive, send = AsyncMock(), AsyncMock()

with patch("mcpgateway.main.streamable_http_auth", new=AsyncMock(return_value=True)):
await middleware._call_streamable_http(scope, receive, send)

assert scope["path"] == "/mcp/"
app_mock.assert_called_once_with(scope, receive, send)

@pytest.mark.asyncio
async def test_rewrite_bare_mcp_with_root_path(self):
"""Bare /mcp with reverse-proxy prefix rewrites to /<prefix>/mcp/ (#4275)."""
app_mock = AsyncMock()
middleware = MCPPathRewriteMiddleware(app_mock)
scope = {
"type": "http",
"path": "/gateway/mcp",
"root_path": "/gateway",
"headers": [],
}
receive, send = AsyncMock(), AsyncMock()

with patch("mcpgateway.main.streamable_http_auth", new=AsyncMock(return_value=True)):
await middleware._call_streamable_http(scope, receive, send)

assert scope["path"] == "/gateway/mcp/"
app_mock.assert_called_once()

@pytest.mark.asyncio
async def test_rewrite_updates_raw_path(self):
"""Rewriting keeps scope['raw_path'] aligned with scope['path'] (#4275)."""
app_mock = AsyncMock()
middleware = MCPPathRewriteMiddleware(app_mock)
scope = {
"type": "http",
"path": "/servers/123/mcp",
"raw_path": b"/servers/123/mcp",
"headers": [],
}
receive, send = AsyncMock(), AsyncMock()

with patch("mcpgateway.main.streamable_http_auth", new=AsyncMock(return_value=True)):
await middleware._call_streamable_http(scope, receive, send)

assert scope["path"] == "/mcp/"
assert scope["raw_path"] == b"/mcp/"
app_mock.assert_called_once()

@pytest.mark.asyncio
async def test_bare_mcp_updates_raw_path(self):
"""Bare /mcp rewrite also syncs raw_path so URL reconstruction stays correct (#4275)."""
app_mock = AsyncMock()
middleware = MCPPathRewriteMiddleware(app_mock)
scope = {
"type": "http",
"path": "/mcp",
"raw_path": b"/mcp",
"headers": [],
}
receive, send = AsyncMock(), AsyncMock()

with patch("mcpgateway.main.streamable_http_auth", new=AsyncMock(return_value=True)):
await middleware._call_streamable_http(scope, receive, send)

assert scope["path"] == "/mcp/"
assert scope["raw_path"] == b"/mcp/"
app_mock.assert_called_once()


class TestServerEndpointCoverage:
"""Exercise server endpoints and SSE coverage."""
Expand Down
Loading