Files
Kiro/internal/middleware/auth_test.go

146 lines
3.5 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
// mockSessionStore is a mock implementation of SessionStore for testing
type mockSessionStore struct {
valid bool
userID string
}
func (m *mockSessionStore) ValidateSession(r *http.Request) bool {
return m.valid
}
func (m *mockSessionStore) GetUserID(r *http.Request) (string, error) {
if !m.valid {
return "", http.ErrNoCookie
}
return m.userID, nil
}
func TestRequireAuth(t *testing.T) {
t.Run("allows authenticated requests", func(t *testing.T) {
mockStore := &mockSessionStore{
valid: true,
userID: "test-user-123",
}
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID, ok := GetUserIDFromContext(r.Context())
if !ok {
t.Error("Expected user ID in context")
return
}
if userID != "test-user-123" {
t.Errorf("Expected user ID test-user-123, got %s", userID)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/dashboard", nil)
handler.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Body.String() != "success" {
t.Errorf("Expected body 'success', got %s", w.Body.String())
}
})
t.Run("redirects unauthenticated requests", func(t *testing.T) {
mockStore := &mockSessionStore{
valid: false,
}
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Handler should not be called for unauthenticated request")
}))
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/dashboard", nil)
handler.ServeHTTP(w, r)
if w.Code != http.StatusSeeOther {
t.Errorf("Expected status 303, got %d", w.Code)
}
location := w.Header().Get("Location")
if location != "/login" {
t.Errorf("Expected redirect to /login, got %s", location)
}
})
t.Run("redirects when GetUserID fails", func(t *testing.T) {
mockStore := &mockSessionStore{
valid: true, // ValidateSession returns true
userID: "", // But GetUserID will fail
}
// Override GetUserID to return error
mockStore.valid = false
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Handler should not be called when GetUserID fails")
}))
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/dashboard", nil)
handler.ServeHTTP(w, r)
if w.Code != http.StatusSeeOther {
t.Errorf("Expected status 303, got %d", w.Code)
}
})
}
func TestGetUserIDFromContext(t *testing.T) {
t.Run("retrieves user ID from context", func(t *testing.T) {
mockStore := &mockSessionStore{
valid: true,
userID: "context-user-456",
}
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID, ok := GetUserIDFromContext(r.Context())
if !ok {
t.Error("Expected user ID in context")
return
}
if userID != "context-user-456" {
t.Errorf("Expected user ID context-user-456, got %s", userID)
}
w.WriteHeader(http.StatusOK)
}))
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/test", nil)
handler.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
})
t.Run("returns false when user ID not in context", func(t *testing.T) {
r := httptest.NewRequest("GET", "/test", nil)
_, ok := GetUserIDFromContext(r.Context())
if ok {
t.Error("Expected ok to be false when user ID not in context")
}
})
}