Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions internal/auth/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,19 @@ func TruncateSessionID(sessionID string) string {
return strutil.Truncate(sessionID, 8)
}

// IsMalformedHeader returns true if the header value contains characters
// that are not valid in HTTP header values per RFC 7230: null bytes, control
// characters below 0x20 (except horizontal tab 0x09), or DEL (0x7F).
// Per spec 7.2 item 3, such headers must be rejected with HTTP 400.
func IsMalformedHeader(header string) bool {
for _, c := range header {
if c == 0x00 || (c < 0x20 && c != 0x09) || c == 0x7F {
return true
}
}
return false
}

// GenerateRandomAPIKey generates a cryptographically random API key.
// Per spec §7.3, the gateway SHOULD generate a random API key on startup
// if none is provided. Returns a 32-byte hex-encoded string (64 chars).
Expand Down
78 changes: 78 additions & 0 deletions internal/auth/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,84 @@ import (
"github.com/github/gh-aw-mcpg/internal/logger/sanitize"
)

func TestIsMalformedHeader(t *testing.T) {
assert := assert.New(t)

tests := []struct {
name string
header string
want bool
}{
{
name: "Empty string is valid",
header: "",
want: false,
},
{
name: "Normal API key is valid",
header: "my-secret-api-key",
want: false,
},
{
name: "Bearer token is valid",
header: "Bearer my-token-123",
want: false,
},
{
name: "Horizontal tab (0x09) is valid per RFC 7230",
header: "key\twith\ttabs",
want: false,
},
{
name: "Printable ASCII is valid",
header: "!#$%&'*+-.0123456789ABCDEFabcdef~",
want: false,
},
{
name: "Null byte (0x00) is malformed",
header: "key\x00value",
want: true,
},
{
name: "DEL (0x7F) is malformed",
header: "key\x7fvalue",
want: true,
},
{
name: "Control char 0x01 is malformed",
header: "key\x01value",
want: true,
},
{
name: "Newline (0x0A) is malformed",
header: "key\nvalue",
want: true,
},
{
name: "Carriage return (0x0D) is malformed",
header: "key\rvalue",
want: true,
},
{
name: "Leading null byte",
header: "\x00key",
want: true,
},
{
name: "Trailing null byte",
header: "key\x00",
want: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsMalformedHeader(tt.header)
assert.Equal(tt.want, got)
})
}
}

func TestTruncateSecret(t *testing.T) {
assert := assert.New(t)

Expand Down
16 changes: 2 additions & 14 deletions internal/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,12 @@ package server
import (
"net/http"

"github.com/github/gh-aw-mcpg/internal/auth"
"github.com/github/gh-aw-mcpg/internal/logger"
)

var logAuth = logger.New("server:auth")

// isMalformedAuthHeader returns true if the header value contains characters
// that are not valid in HTTP header values per RFC 7230: null bytes, control
// characters below 0x20 (except horizontal tab 0x09), or DEL (0x7F).
// Per spec 7.2 item 3, such headers must be rejected with HTTP 400.
func isMalformedAuthHeader(header string) bool {
for _, c := range header {
if c == 0x00 || (c < 0x20 && c != 0x09) || c == 0x7F {
return true
}
}
return false
}

// authMiddleware implements API key authentication per spec section 7.1
// Per spec: Authorization header MUST contain the API key directly (NOT Bearer scheme)
//
Expand All @@ -43,7 +31,7 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc {

// Spec 7.2 item 3: Malformed Authorization headers (null bytes, non-printable
// control characters) must return 400 Bad Request, not 401.
if isMalformedAuthHeader(authHeader) {
if auth.IsMalformedHeader(authHeader) {
rejectRequest(w, r, http.StatusBadRequest, "bad_request", "malformed Authorization header", "auth", "authentication_failed", "malformed_auth_header")
return
}
Expand Down
8 changes: 5 additions & 3 deletions internal/server/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/github/gh-aw-mcpg/internal/auth"
)

// TestAuthMiddleware tests the authMiddleware function with various scenarios
Expand Down Expand Up @@ -292,7 +294,7 @@ func TestAuthMiddleware_ConcurrentRequests(t *testing.T) {
}
}

// TestIsMalformedAuthHeader tests the isMalformedAuthHeader helper.
// TestIsMalformedAuthHeader tests auth.IsMalformedHeader via the server package.
func TestIsMalformedAuthHeader(t *testing.T) {
tests := []struct {
name string
Expand All @@ -313,8 +315,8 @@ func TestIsMalformedAuthHeader(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isMalformedAuthHeader(tt.header)
assert.Equal(t, tt.malformed, got, "isMalformedAuthHeader(%q) should return %v", tt.header, tt.malformed)
got := auth.IsMalformedHeader(tt.header)
assert.Equal(t, tt.malformed, got, "auth.IsMalformedHeader(%q) should return %v", tt.header, tt.malformed)
})
}
}
Loading