146 lines
3.5 KiB
Go
146 lines
3.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
// mockSessionStore is a mock implementation of SessionStore for testing
|
|
type mockSessionStore struct {
|
|
valid bool
|
|
userID string
|
|
}
|
|
|
|
func (m *mockSessionStore) ValidateSession(r *http.Request) bool {
|
|
return m.valid
|
|
}
|
|
|
|
func (m *mockSessionStore) GetUserID(r *http.Request) (string, error) {
|
|
if !m.valid {
|
|
return "", http.ErrNoCookie
|
|
}
|
|
return m.userID, nil
|
|
}
|
|
|
|
func TestRequireAuth(t *testing.T) {
|
|
t.Run("allows authenticated requests", func(t *testing.T) {
|
|
mockStore := &mockSessionStore{
|
|
valid: true,
|
|
userID: "test-user-123",
|
|
}
|
|
|
|
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userID, ok := GetUserIDFromContext(r.Context())
|
|
if !ok {
|
|
t.Error("Expected user ID in context")
|
|
return
|
|
}
|
|
if userID != "test-user-123" {
|
|
t.Errorf("Expected user ID test-user-123, got %s", userID)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("success"))
|
|
}))
|
|
|
|
w := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/dashboard", nil)
|
|
|
|
handler.ServeHTTP(w, r)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if w.Body.String() != "success" {
|
|
t.Errorf("Expected body 'success', got %s", w.Body.String())
|
|
}
|
|
})
|
|
|
|
t.Run("redirects unauthenticated requests", func(t *testing.T) {
|
|
mockStore := &mockSessionStore{
|
|
valid: false,
|
|
}
|
|
|
|
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("Handler should not be called for unauthenticated request")
|
|
}))
|
|
|
|
w := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/dashboard", nil)
|
|
|
|
handler.ServeHTTP(w, r)
|
|
|
|
if w.Code != http.StatusSeeOther {
|
|
t.Errorf("Expected status 303, got %d", w.Code)
|
|
}
|
|
|
|
location := w.Header().Get("Location")
|
|
if location != "/login" {
|
|
t.Errorf("Expected redirect to /login, got %s", location)
|
|
}
|
|
})
|
|
|
|
t.Run("redirects when GetUserID fails", func(t *testing.T) {
|
|
mockStore := &mockSessionStore{
|
|
valid: true, // ValidateSession returns true
|
|
userID: "", // But GetUserID will fail
|
|
}
|
|
|
|
// Override GetUserID to return error
|
|
mockStore.valid = false
|
|
|
|
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("Handler should not be called when GetUserID fails")
|
|
}))
|
|
|
|
w := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/dashboard", nil)
|
|
|
|
handler.ServeHTTP(w, r)
|
|
|
|
if w.Code != http.StatusSeeOther {
|
|
t.Errorf("Expected status 303, got %d", w.Code)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGetUserIDFromContext(t *testing.T) {
|
|
t.Run("retrieves user ID from context", func(t *testing.T) {
|
|
mockStore := &mockSessionStore{
|
|
valid: true,
|
|
userID: "context-user-456",
|
|
}
|
|
|
|
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userID, ok := GetUserIDFromContext(r.Context())
|
|
if !ok {
|
|
t.Error("Expected user ID in context")
|
|
return
|
|
}
|
|
if userID != "context-user-456" {
|
|
t.Errorf("Expected user ID context-user-456, got %s", userID)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
w := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/test", nil)
|
|
|
|
handler.ServeHTTP(w, r)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d", w.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("returns false when user ID not in context", func(t *testing.T) {
|
|
r := httptest.NewRequest("GET", "/test", nil)
|
|
|
|
_, ok := GetUserIDFromContext(r.Context())
|
|
if ok {
|
|
t.Error("Expected ok to be false when user ID not in context")
|
|
}
|
|
})
|
|
}
|