Initial commit: Custom Start Page application with authentication and DynamoDB storage

This commit is contained in:
2026-02-18 22:06:43 -05:00
commit 7175ff14ba
47 changed files with 7592 additions and 0 deletions

109
internal/auth/oauth.go Normal file
View File

@@ -0,0 +1,109 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
// OAuthService handles OAuth authentication flows
type OAuthService struct {
googleConfig *oauth2.Config
stateStore StateStore
}
// StateStore manages OAuth state tokens for CSRF protection
type StateStore interface {
Set(state string, expiry time.Time) error
Validate(state string) (bool, error)
Delete(state string) error
}
// NewOAuthService creates a new OAuth service
func NewOAuthService(googleClientID, googleClientSecret, googleRedirectURL string, stateStore StateStore) *OAuthService {
googleConfig := &oauth2.Config{
ClientID: googleClientID,
ClientSecret: googleClientSecret,
RedirectURL: googleRedirectURL,
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
return &OAuthService{
googleConfig: googleConfig,
stateStore: stateStore,
}
}
// InitiateOAuth starts the OAuth flow and returns the redirect URL
func (s *OAuthService) InitiateOAuth(provider string) (string, error) {
if provider != "google" {
return "", fmt.Errorf("unsupported OAuth provider: %s", provider)
}
// Generate random state token for CSRF protection
state, err := generateStateToken()
if err != nil {
return "", fmt.Errorf("failed to generate state token: %w", err)
}
// Store state with 10 minute expiry
expiry := time.Now().Add(10 * time.Minute)
if err := s.stateStore.Set(state, expiry); err != nil {
return "", fmt.Errorf("failed to store state token: %w", err)
}
// Generate authorization URL
url := s.googleConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
return url, nil
}
// HandleOAuthCallback processes the OAuth callback and exchanges the code for a token
func (s *OAuthService) HandleOAuthCallback(ctx context.Context, provider, code, state string) (*oauth2.Token, error) {
if provider != "google" {
return nil, fmt.Errorf("unsupported OAuth provider: %s", provider)
}
// Validate state token
valid, err := s.stateStore.Validate(state)
if err != nil {
return nil, fmt.Errorf("failed to validate state token: %w", err)
}
if !valid {
return nil, fmt.Errorf("invalid or expired state token")
}
// Delete state token after validation
_ = s.stateStore.Delete(state)
// Exchange authorization code for token
token, err := s.googleConfig.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
return token, nil
}
// GetGoogleConfig returns the Google OAuth config for accessing user info
func (s *OAuthService) GetGoogleConfig() *oauth2.Config {
return s.googleConfig
}
// generateStateToken generates a cryptographically secure random state token
func generateStateToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}

136
internal/auth/oauth_test.go Normal file
View File

@@ -0,0 +1,136 @@
package auth
import (
"context"
"strings"
"testing"
"time"
)
func TestInitiateOAuth(t *testing.T) {
stateStore := NewMemoryStateStore()
service := NewOAuthService(
"test-client-id",
"test-client-secret",
"http://localhost:8080/auth/callback/google",
stateStore,
)
t.Run("successful OAuth initiation", func(t *testing.T) {
redirectURL, err := service.InitiateOAuth("google")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if redirectURL == "" {
t.Fatal("Expected redirect URL, got empty string")
}
// Check that URL contains expected components
if !strings.Contains(redirectURL, "accounts.google.com") {
t.Errorf("Expected Google OAuth URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "client_id=test-client-id") {
t.Errorf("Expected client_id in URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "redirect_uri=") {
t.Errorf("Expected redirect_uri in URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "state=") {
t.Errorf("Expected state parameter in URL, got %s", redirectURL)
}
})
t.Run("unsupported provider", func(t *testing.T) {
_, err := service.InitiateOAuth("unsupported")
if err == nil {
t.Fatal("Expected error for unsupported provider, got nil")
}
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
t.Errorf("Expected unsupported provider error, got %v", err)
}
})
}
func TestHandleOAuthCallback(t *testing.T) {
stateStore := NewMemoryStateStore()
service := NewOAuthService(
"test-client-id",
"test-client-secret",
"http://localhost:8080/auth/callback/google",
stateStore,
)
t.Run("invalid state token", func(t *testing.T) {
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "google", "test-code", "invalid-state")
if err == nil {
t.Fatal("Expected error for invalid state, got nil")
}
if !strings.Contains(err.Error(), "invalid or expired state token") {
t.Errorf("Expected invalid state error, got %v", err)
}
})
t.Run("valid state token but invalid code", func(t *testing.T) {
// Store a valid state token
state := "valid-state-token"
expiry := time.Now().Add(10 * time.Minute)
if err := stateStore.Set(state, expiry); err != nil {
t.Fatalf("Failed to set state: %v", err)
}
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "google", "invalid-code", state)
// This should fail because the code is invalid (can't exchange with Google)
if err == nil {
t.Fatal("Expected error for invalid code, got nil")
}
// The state should be deleted even on error
valid, _ := stateStore.Validate(state)
if valid {
t.Error("Expected state to be deleted after callback")
}
})
t.Run("unsupported provider", func(t *testing.T) {
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "unsupported", "test-code", "test-state")
if err == nil {
t.Fatal("Expected error for unsupported provider, got nil")
}
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
t.Errorf("Expected unsupported provider error, got %v", err)
}
})
}
func TestGenerateStateToken(t *testing.T) {
t.Run("generates unique tokens", func(t *testing.T) {
token1, err := generateStateToken()
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
token2, err := generateStateToken()
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
if token1 == token2 {
t.Error("Expected unique tokens, got duplicates")
}
if len(token1) == 0 {
t.Error("Expected non-empty token")
}
})
}

View File

@@ -0,0 +1,88 @@
package auth
import (
"fmt"
"net/http"
"github.com/gorilla/sessions"
)
const sessionName = "startpage_session"
const userIDKey = "user_id"
// CookieSessionStore implements SessionStore using gorilla/sessions
type CookieSessionStore struct {
store *sessions.CookieStore
}
// NewCookieSessionStore creates a new cookie-based session store
func NewCookieSessionStore(secretKey string, maxAge int) *CookieSessionStore {
store := sessions.NewCookieStore([]byte(secretKey))
store.Options = &sessions.Options{
Path: "/",
MaxAge: maxAge,
HttpOnly: true,
Secure: false, // Set to true in production with HTTPS
SameSite: http.SameSiteLaxMode,
}
return &CookieSessionStore{
store: store,
}
}
// CreateSession creates a new session for the user
func (s *CookieSessionStore) CreateSession(w http.ResponseWriter, r *http.Request, userID string) error {
session, err := s.store.Get(r, sessionName)
if err != nil {
// If there's an error getting the session, create a new one
session, _ = s.store.New(r, sessionName)
}
session.Values[userIDKey] = userID
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save session: %w", err)
}
return nil
}
// GetUserID retrieves the user ID from the session
func (s *CookieSessionStore) GetUserID(r *http.Request) (string, error) {
session, err := s.store.Get(r, sessionName)
if err != nil {
return "", fmt.Errorf("failed to get session: %w", err)
}
userID, ok := session.Values[userIDKey].(string)
if !ok || userID == "" {
return "", fmt.Errorf("user ID not found in session")
}
return userID, nil
}
// ValidateSession checks if a valid session exists for the request
func (s *CookieSessionStore) ValidateSession(r *http.Request) bool {
_, err := s.GetUserID(r)
return err == nil
}
// DestroySession destroys the user's session
func (s *CookieSessionStore) DestroySession(w http.ResponseWriter, r *http.Request) error {
session, err := s.store.Get(r, sessionName)
if err != nil {
// Session doesn't exist or is invalid, nothing to destroy
return nil
}
// Set MaxAge to -1 to delete the cookie
session.Options.MaxAge = -1
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to destroy session: %w", err)
}
return nil
}

View File

@@ -0,0 +1,237 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestCookieSessionStore(t *testing.T) {
store := NewCookieSessionStore("test-secret-key", 3600)
t.Run("create and retrieve session", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-123"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
// Get the cookie from the response
cookies := w.Result().Cookies()
if len(cookies) == 0 {
t.Fatal("Expected session cookie, got none")
}
// Create a new request with the cookie
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
retrievedUserID, err := store.GetUserID(r2)
if err != nil {
t.Fatalf("Failed to get user ID: %v", err)
}
if retrievedUserID != userID {
t.Errorf("Expected user ID %s, got %s", userID, retrievedUserID)
}
})
t.Run("get user ID without session", func(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
_, err := store.GetUserID(r)
if err == nil {
t.Error("Expected error when getting user ID without session")
}
})
t.Run("destroy session", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-456"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
// Get the cookie
cookies := w.Result().Cookies()
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
// Destroy the session
w2 := httptest.NewRecorder()
err = store.DestroySession(w2, r2)
if err != nil {
t.Fatalf("Failed to destroy session: %v", err)
}
// Check that the cookie has MaxAge set to -1 (deletion marker)
destroyCookies := w2.Result().Cookies()
if len(destroyCookies) == 0 {
t.Fatal("Expected cookie with MaxAge=-1 for deletion")
}
foundDeleteCookie := false
for _, cookie := range destroyCookies {
if cookie.MaxAge == -1 {
foundDeleteCookie = true
break
}
}
if !foundDeleteCookie {
t.Error("Expected cookie with MaxAge=-1 to indicate deletion")
}
})
t.Run("update existing session", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
// Create initial session
userID1 := "user-1"
store.CreateSession(w, r, userID1)
cookies := w.Result().Cookies()
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
// Update session with new user ID
w2 := httptest.NewRecorder()
userID2 := "user-2"
err := store.CreateSession(w2, r2, userID2)
if err != nil {
t.Fatalf("Failed to update session: %v", err)
}
// Verify new user ID
cookies2 := w2.Result().Cookies()
r3 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies2 {
r3.AddCookie(cookie)
}
retrievedUserID, err := store.GetUserID(r3)
if err != nil {
t.Fatalf("Failed to get user ID: %v", err)
}
if retrievedUserID != userID2 {
t.Errorf("Expected user ID %s, got %s", userID2, retrievedUserID)
}
})
t.Run("validate session with valid session", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-789"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
// Create request with session cookie
cookies := w.Result().Cookies()
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
// Validate session
if !store.ValidateSession(r2) {
t.Error("Expected ValidateSession to return true for valid session")
}
})
t.Run("validate session without session", func(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
// Validate session without any cookies
if store.ValidateSession(r) {
t.Error("Expected ValidateSession to return false for missing session")
}
})
t.Run("validate session after logout", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-logout"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
// Get the cookie
cookies := w.Result().Cookies()
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
// Destroy the session
w2 := httptest.NewRecorder()
err = store.DestroySession(w2, r2)
if err != nil {
t.Fatalf("Failed to destroy session: %v", err)
}
// Create a new request without any cookies (simulating browser behavior after logout)
r3 := httptest.NewRequest("GET", "/", nil)
// Validate session should return false
if store.ValidateSession(r3) {
t.Error("Expected ValidateSession to return false after logout")
}
})
t.Run("session cookie has security settings", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-security"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
cookies := w.Result().Cookies()
if len(cookies) == 0 {
t.Fatal("Expected session cookie, got none")
}
cookie := cookies[0]
// Verify HttpOnly flag is set
if !cookie.HttpOnly {
t.Error("Expected HttpOnly flag to be true")
}
// Verify SameSite is set
if cookie.SameSite != http.SameSiteLaxMode {
t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite)
}
// Verify Path is set
if cookie.Path != "/" {
t.Errorf("Expected Path to be /, got %s", cookie.Path)
}
// Verify MaxAge is set
if cookie.MaxAge != 3600 {
t.Errorf("Expected MaxAge to be 3600, got %d", cookie.MaxAge)
}
})
}

View File

@@ -0,0 +1,78 @@
package auth
import (
"sync"
"time"
)
// MemoryStateStore is an in-memory implementation of StateStore
// Note: This is suitable for development but should be replaced with
// a distributed store (Redis, DynamoDB) for production with multiple servers
type MemoryStateStore struct {
mu sync.RWMutex
states map[string]time.Time
}
// NewMemoryStateStore creates a new in-memory state store
func NewMemoryStateStore() *MemoryStateStore {
store := &MemoryStateStore{
states: make(map[string]time.Time),
}
// Start cleanup goroutine to remove expired states
go store.cleanupExpired()
return store
}
// Set stores a state token with an expiry time
func (s *MemoryStateStore) Set(state string, expiry time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
s.states[state] = expiry
return nil
}
// Validate checks if a state token is valid and not expired
func (s *MemoryStateStore) Validate(state string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
expiry, exists := s.states[state]
if !exists {
return false, nil
}
if time.Now().After(expiry) {
return false, nil
}
return true, nil
}
// Delete removes a state token from the store
func (s *MemoryStateStore) Delete(state string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.states, state)
return nil
}
// cleanupExpired periodically removes expired state tokens
func (s *MemoryStateStore) cleanupExpired() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now()
for state, expiry := range s.states {
if now.After(expiry) {
delete(s.states, state)
}
}
s.mu.Unlock()
}
}

View File

@@ -0,0 +1,112 @@
package auth
import (
"testing"
"time"
)
func TestMemoryStateStore(t *testing.T) {
store := NewMemoryStateStore()
t.Run("set and validate state", func(t *testing.T) {
state := "test-state-123"
expiry := time.Now().Add(10 * time.Minute)
err := store.Set(state, expiry)
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
valid, err := store.Validate(state)
if err != nil {
t.Fatalf("Failed to validate state: %v", err)
}
if !valid {
t.Error("Expected state to be valid")
}
})
t.Run("validate non-existent state", func(t *testing.T) {
valid, err := store.Validate("non-existent")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if valid {
t.Error("Expected non-existent state to be invalid")
}
})
t.Run("validate expired state", func(t *testing.T) {
state := "expired-state"
expiry := time.Now().Add(-1 * time.Minute) // Already expired
err := store.Set(state, expiry)
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
valid, err := store.Validate(state)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if valid {
t.Error("Expected expired state to be invalid")
}
})
t.Run("delete state", func(t *testing.T) {
state := "delete-test"
expiry := time.Now().Add(10 * time.Minute)
err := store.Set(state, expiry)
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
err = store.Delete(state)
if err != nil {
t.Fatalf("Failed to delete state: %v", err)
}
valid, err := store.Validate(state)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if valid {
t.Error("Expected deleted state to be invalid")
}
})
t.Run("multiple states", func(t *testing.T) {
state1 := "state-1"
state2 := "state-2"
expiry := time.Now().Add(10 * time.Minute)
store.Set(state1, expiry)
store.Set(state2, expiry)
valid1, _ := store.Validate(state1)
valid2, _ := store.Validate(state2)
if !valid1 || !valid2 {
t.Error("Expected both states to be valid")
}
store.Delete(state1)
valid1, _ = store.Validate(state1)
valid2, _ = store.Validate(state2)
if valid1 {
t.Error("Expected state1 to be invalid after deletion")
}
if !valid2 {
t.Error("Expected state2 to still be valid")
}
})
}

View File

@@ -0,0 +1,100 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/google/uuid"
"golang.org/x/oauth2"
"custom-start-page/internal/models"
"custom-start-page/internal/storage"
)
// GoogleUserInfo represents the user info returned by Google
type GoogleUserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
VerifiedEmail bool `json:"verified_email"`
Name string `json:"name"`
Picture string `json:"picture"`
}
// UserService handles user creation and retrieval
type UserService struct {
userRepo *storage.UserRepository
}
// NewUserService creates a new user service
func NewUserService(userRepo *storage.UserRepository) *UserService {
return &UserService{
userRepo: userRepo,
}
}
// GetOrCreateUserFromGoogle fetches user info from Google and creates or retrieves the user
func (s *UserService) GetOrCreateUserFromGoogle(ctx context.Context, token *oauth2.Token, oauthConfig *oauth2.Config) (*models.User, error) {
// Fetch user info from Google
userInfo, err := s.fetchGoogleUserInfo(ctx, token, oauthConfig)
if err != nil {
return nil, fmt.Errorf("failed to fetch Google user info: %w", err)
}
// Check if user already exists
user, err := s.userRepo.GetByOAuthID(ctx, "google", userInfo.ID)
if err == nil {
// User exists, update last login time
user.UpdatedAt = time.Now()
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
}
return user, nil
}
// User doesn't exist, create new user
user = &models.User{
ID: uuid.New().String(),
Email: userInfo.Email,
OAuthProvider: "google",
OAuthID: userInfo.ID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return user, nil
}
// fetchGoogleUserInfo fetches user information from Google's userinfo endpoint
func (s *UserService) fetchGoogleUserInfo(ctx context.Context, token *oauth2.Token, oauthConfig *oauth2.Config) (*GoogleUserInfo, error) {
client := oauthConfig.Client(ctx, token)
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to get user info: status %d, body: %s", resp.StatusCode, string(body))
}
var userInfo GoogleUserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
if !userInfo.VerifiedEmail {
return nil, fmt.Errorf("email not verified")
}
return &userInfo, nil
}