Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions pkg/nodeauth/jwt/node_jwt_authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,18 @@ func (v *NodeJWTAuthenticator) AuthenticateJWT(ctx context.Context, tokenString
// Public Key Validation: Verify node's CSA pubkey against the whitelisted registry via NodeAuthProvider.
isValid, err := v.nodeAuthProvider.IsNodePubKeyTrusted(ctx, publicKey)
if err != nil {
v.logger.Error("Node validation failed",
"csaPubKey", hex.EncodeToString(publicKey),
"error", err,
)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
v.logger.Warn("Node validation skipped: context canceled or deadline exceeded",
"csaPubKey", hex.EncodeToString(publicKey),
"error", err,
"contextErr", ctx.Err(),
)
Comment on lines +81 to +86
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed: added TestNodeJWTAuthenticator_AuthenticateJWT_ProviderDeadlineExceededError (line 545) covering context.DeadlineExceeded path.

} else {
v.logger.Error("Node validation failed",
"csaPubKey", hex.EncodeToString(publicKey),
"error", err,
)
}
return false, claims, fmt.Errorf("node validation failed: %w", err)
}

Expand Down
117 changes: 117 additions & 0 deletions pkg/nodeauth/jwt/node_jwt_authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"errors"
"io"
"log/slog"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -450,3 +452,118 @@ func TestNewNodeJWTAuthenticator_WithAndWithoutLeeway(t *testing.T) {
assert.NotNil(t, authenticator.parser)
})
}

// captureHandler is a minimal slog.Handler that captures log records for test assertions.
// It satisfies the full slog.Handler contract: WithAttrs and WithGroup return a new
// handler so that logger.With(...) / logger.WithGroup(...) calls don't silently discard
// attributes.
type captureHandler struct {
mu sync.Mutex
records []slog.Record
attrs []slog.Attr
groups []string
}

func (h *captureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
func (h *captureHandler) Handle(_ context.Context, r slog.Record) error {
// Prepend any inherited attrs so they appear in captured records.
if len(h.attrs) > 0 {
r2 := slog.NewRecord(r.Time, r.Level, r.Message, r.PC)
r2.AddAttrs(h.attrs...)
r.Attrs(func(a slog.Attr) bool { r2.AddAttrs(a); return true })
r = r2
}
h.mu.Lock()
defer h.mu.Unlock()
h.records = append(h.records, r.Clone())
return nil
}
func (h *captureHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
h.mu.Lock()
defer h.mu.Unlock()
merged := make([]slog.Attr, len(h.attrs)+len(attrs))
copy(merged, h.attrs)
copy(merged[len(h.attrs):], attrs)
return &captureHandler{records: h.records, attrs: merged, groups: h.groups}
}
func (h *captureHandler) WithGroup(name string) slog.Handler {
h.mu.Lock()
defer h.mu.Unlock()
groups := append(append([]string{}, h.groups...), name)
return &captureHandler{records: h.records, attrs: h.attrs, groups: groups}
}

func TestNodeJWTAuthenticator_AuthenticateJWT_ProviderNonContextError(t *testing.T) {
// Non-context provider errors must be logged at ERROR level.
privateKey, csaPubKey := createValidatorTestKeys()
providerErr := errors.New("database unavailable")
mockProvider := &mocks.NodeAuthProvider{}
mockProvider.On("IsNodePubKeyTrusted", mock.Anything, csaPubKey).Return(false, providerErr)

h := &captureHandler{}
authenticator := NewNodeJWTAuthenticator(mockProvider, slog.New(h))

testRequest := testRequest{Field: "test-request"}
valid, claims, err := authenticator.AuthenticateJWT(context.Background(), createValidJWT(privateKey, csaPubKey), testRequest)

require.Error(t, err)
assert.False(t, valid)
assert.NotNil(t, claims)
assert.Contains(t, err.Error(), "node validation failed")

require.Len(t, h.records, 1)
assert.Equal(t, slog.LevelError, h.records[0].Level, "non-context provider errors should log at ERROR")
mockProvider.AssertExpectations(t)
}

func TestNodeJWTAuthenticator_AuthenticateJWT_ProviderContextCancelledError(t *testing.T) {
// Context-cancellation errors from the provider must be logged at WARN, not ERROR,
// because they are caused by the caller cancelling the request — not a system fault.
privateKey, csaPubKey := createValidatorTestKeys()
mockProvider := &mocks.NodeAuthProvider{}
mockProvider.On("IsNodePubKeyTrusted", mock.Anything, csaPubKey).Return(false, context.Canceled)

h := &captureHandler{}
authenticator := NewNodeJWTAuthenticator(mockProvider, slog.New(h))

ctx, cancel := context.WithCancel(context.Background())
cancel() // already cancelled

testRequest := testRequest{Field: "test-request"}
valid, claims, err := authenticator.AuthenticateJWT(ctx, createValidJWT(privateKey, csaPubKey), testRequest)

require.Error(t, err)
assert.False(t, valid)
assert.NotNil(t, claims)
assert.ErrorIs(t, err, context.Canceled)

require.Len(t, h.records, 1)
assert.Equal(t, slog.LevelWarn, h.records[0].Level, "context cancellation from provider should log at WARN not ERROR")
mockProvider.AssertExpectations(t)
}

func TestNodeJWTAuthenticator_AuthenticateJWT_ProviderDeadlineExceededError(t *testing.T) {
// context.DeadlineExceeded from the provider must also be logged at WARN, not ERROR,
// because it is an expected transient condition (e.g. slow upstream), not a system fault.
privateKey, csaPubKey := createValidatorTestKeys()
mockProvider := &mocks.NodeAuthProvider{}
mockProvider.On("IsNodePubKeyTrusted", mock.Anything, csaPubKey).Return(false, context.DeadlineExceeded)

h := &captureHandler{}
authenticator := NewNodeJWTAuthenticator(mockProvider, slog.New(h))

ctx, cancel := context.WithTimeout(context.Background(), 0) // immediately expired
defer cancel()

testRequest := testRequest{Field: "test-request"}
valid, claims, err := authenticator.AuthenticateJWT(ctx, createValidJWT(privateKey, csaPubKey), testRequest)

require.Error(t, err)
assert.False(t, valid)
assert.NotNil(t, claims)
assert.ErrorIs(t, err, context.DeadlineExceeded)

require.Len(t, h.records, 1)
assert.Equal(t, slog.LevelWarn, h.records[0].Level, "deadline exceeded from provider should log at WARN not ERROR")
mockProvider.AssertExpectations(t)
}
Loading