package middleware import ( "net/http" "net/http/httptest" "testing" ) // mockSessionStore is a mock implementation of SessionStore for testing type mockSessionStore struct { valid bool userID string } func (m *mockSessionStore) ValidateSession(r *http.Request) bool { return m.valid } func (m *mockSessionStore) GetUserID(r *http.Request) (string, error) { if !m.valid { return "", http.ErrNoCookie } return m.userID, nil } func TestRequireAuth(t *testing.T) { t.Run("allows authenticated requests", func(t *testing.T) { mockStore := &mockSessionStore{ valid: true, userID: "test-user-123", } handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID, ok := GetUserIDFromContext(r.Context()) if !ok { t.Error("Expected user ID in context") return } if userID != "test-user-123" { t.Errorf("Expected user ID test-user-123, got %s", userID) } w.WriteHeader(http.StatusOK) w.Write([]byte("success")) })) w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/dashboard", nil) handler.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } if w.Body.String() != "success" { t.Errorf("Expected body 'success', got %s", w.Body.String()) } }) t.Run("redirects unauthenticated requests", func(t *testing.T) { mockStore := &mockSessionStore{ valid: false, } handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("Handler should not be called for unauthenticated request") })) w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/dashboard", nil) handler.ServeHTTP(w, r) if w.Code != http.StatusSeeOther { t.Errorf("Expected status 303, got %d", w.Code) } location := w.Header().Get("Location") if location != "/login" { t.Errorf("Expected redirect to /login, got %s", location) } }) t.Run("redirects when GetUserID fails", func(t *testing.T) { mockStore := &mockSessionStore{ valid: true, // ValidateSession returns true userID: "", // But GetUserID will fail } // Override GetUserID to return error mockStore.valid = false handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("Handler should not be called when GetUserID fails") })) w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/dashboard", nil) handler.ServeHTTP(w, r) if w.Code != http.StatusSeeOther { t.Errorf("Expected status 303, got %d", w.Code) } }) } func TestGetUserIDFromContext(t *testing.T) { t.Run("retrieves user ID from context", func(t *testing.T) { mockStore := &mockSessionStore{ valid: true, userID: "context-user-456", } handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID, ok := GetUserIDFromContext(r.Context()) if !ok { t.Error("Expected user ID in context") return } if userID != "context-user-456" { t.Errorf("Expected user ID context-user-456, got %s", userID) } w.WriteHeader(http.StatusOK) })) w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/test", nil) handler.ServeHTTP(w, r) if w.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", w.Code) } }) t.Run("returns false when user ID not in context", func(t *testing.T) { r := httptest.NewRequest("GET", "/test", nil) _, ok := GetUserIDFromContext(r.Context()) if ok { t.Error("Expected ok to be false when user ID not in context") } }) }