Initial commit: Custom Start Page application with authentication and DynamoDB storage

This commit is contained in:
2026-02-18 22:06:43 -05:00
commit 7175ff14ba
47 changed files with 7592 additions and 0 deletions

109
internal/auth/oauth.go Normal file
View File

@@ -0,0 +1,109 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
// OAuthService handles OAuth authentication flows
type OAuthService struct {
googleConfig *oauth2.Config
stateStore StateStore
}
// StateStore manages OAuth state tokens for CSRF protection
type StateStore interface {
Set(state string, expiry time.Time) error
Validate(state string) (bool, error)
Delete(state string) error
}
// NewOAuthService creates a new OAuth service
func NewOAuthService(googleClientID, googleClientSecret, googleRedirectURL string, stateStore StateStore) *OAuthService {
googleConfig := &oauth2.Config{
ClientID: googleClientID,
ClientSecret: googleClientSecret,
RedirectURL: googleRedirectURL,
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
return &OAuthService{
googleConfig: googleConfig,
stateStore: stateStore,
}
}
// InitiateOAuth starts the OAuth flow and returns the redirect URL
func (s *OAuthService) InitiateOAuth(provider string) (string, error) {
if provider != "google" {
return "", fmt.Errorf("unsupported OAuth provider: %s", provider)
}
// Generate random state token for CSRF protection
state, err := generateStateToken()
if err != nil {
return "", fmt.Errorf("failed to generate state token: %w", err)
}
// Store state with 10 minute expiry
expiry := time.Now().Add(10 * time.Minute)
if err := s.stateStore.Set(state, expiry); err != nil {
return "", fmt.Errorf("failed to store state token: %w", err)
}
// Generate authorization URL
url := s.googleConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
return url, nil
}
// HandleOAuthCallback processes the OAuth callback and exchanges the code for a token
func (s *OAuthService) HandleOAuthCallback(ctx context.Context, provider, code, state string) (*oauth2.Token, error) {
if provider != "google" {
return nil, fmt.Errorf("unsupported OAuth provider: %s", provider)
}
// Validate state token
valid, err := s.stateStore.Validate(state)
if err != nil {
return nil, fmt.Errorf("failed to validate state token: %w", err)
}
if !valid {
return nil, fmt.Errorf("invalid or expired state token")
}
// Delete state token after validation
_ = s.stateStore.Delete(state)
// Exchange authorization code for token
token, err := s.googleConfig.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
return token, nil
}
// GetGoogleConfig returns the Google OAuth config for accessing user info
func (s *OAuthService) GetGoogleConfig() *oauth2.Config {
return s.googleConfig
}
// generateStateToken generates a cryptographically secure random state token
func generateStateToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}

136
internal/auth/oauth_test.go Normal file
View File

@@ -0,0 +1,136 @@
package auth
import (
"context"
"strings"
"testing"
"time"
)
func TestInitiateOAuth(t *testing.T) {
stateStore := NewMemoryStateStore()
service := NewOAuthService(
"test-client-id",
"test-client-secret",
"http://localhost:8080/auth/callback/google",
stateStore,
)
t.Run("successful OAuth initiation", func(t *testing.T) {
redirectURL, err := service.InitiateOAuth("google")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if redirectURL == "" {
t.Fatal("Expected redirect URL, got empty string")
}
// Check that URL contains expected components
if !strings.Contains(redirectURL, "accounts.google.com") {
t.Errorf("Expected Google OAuth URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "client_id=test-client-id") {
t.Errorf("Expected client_id in URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "redirect_uri=") {
t.Errorf("Expected redirect_uri in URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "state=") {
t.Errorf("Expected state parameter in URL, got %s", redirectURL)
}
})
t.Run("unsupported provider", func(t *testing.T) {
_, err := service.InitiateOAuth("unsupported")
if err == nil {
t.Fatal("Expected error for unsupported provider, got nil")
}
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
t.Errorf("Expected unsupported provider error, got %v", err)
}
})
}
func TestHandleOAuthCallback(t *testing.T) {
stateStore := NewMemoryStateStore()
service := NewOAuthService(
"test-client-id",
"test-client-secret",
"http://localhost:8080/auth/callback/google",
stateStore,
)
t.Run("invalid state token", func(t *testing.T) {
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "google", "test-code", "invalid-state")
if err == nil {
t.Fatal("Expected error for invalid state, got nil")
}
if !strings.Contains(err.Error(), "invalid or expired state token") {
t.Errorf("Expected invalid state error, got %v", err)
}
})
t.Run("valid state token but invalid code", func(t *testing.T) {
// Store a valid state token
state := "valid-state-token"
expiry := time.Now().Add(10 * time.Minute)
if err := stateStore.Set(state, expiry); err != nil {
t.Fatalf("Failed to set state: %v", err)
}
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "google", "invalid-code", state)
// This should fail because the code is invalid (can't exchange with Google)
if err == nil {
t.Fatal("Expected error for invalid code, got nil")
}
// The state should be deleted even on error
valid, _ := stateStore.Validate(state)
if valid {
t.Error("Expected state to be deleted after callback")
}
})
t.Run("unsupported provider", func(t *testing.T) {
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "unsupported", "test-code", "test-state")
if err == nil {
t.Fatal("Expected error for unsupported provider, got nil")
}
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
t.Errorf("Expected unsupported provider error, got %v", err)
}
})
}
func TestGenerateStateToken(t *testing.T) {
t.Run("generates unique tokens", func(t *testing.T) {
token1, err := generateStateToken()
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
token2, err := generateStateToken()
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
if token1 == token2 {
t.Error("Expected unique tokens, got duplicates")
}
if len(token1) == 0 {
t.Error("Expected non-empty token")
}
})
}

View File

@@ -0,0 +1,88 @@
package auth
import (
"fmt"
"net/http"
"github.com/gorilla/sessions"
)
const sessionName = "startpage_session"
const userIDKey = "user_id"
// CookieSessionStore implements SessionStore using gorilla/sessions
type CookieSessionStore struct {
store *sessions.CookieStore
}
// NewCookieSessionStore creates a new cookie-based session store
func NewCookieSessionStore(secretKey string, maxAge int) *CookieSessionStore {
store := sessions.NewCookieStore([]byte(secretKey))
store.Options = &sessions.Options{
Path: "/",
MaxAge: maxAge,
HttpOnly: true,
Secure: false, // Set to true in production with HTTPS
SameSite: http.SameSiteLaxMode,
}
return &CookieSessionStore{
store: store,
}
}
// CreateSession creates a new session for the user
func (s *CookieSessionStore) CreateSession(w http.ResponseWriter, r *http.Request, userID string) error {
session, err := s.store.Get(r, sessionName)
if err != nil {
// If there's an error getting the session, create a new one
session, _ = s.store.New(r, sessionName)
}
session.Values[userIDKey] = userID
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save session: %w", err)
}
return nil
}
// GetUserID retrieves the user ID from the session
func (s *CookieSessionStore) GetUserID(r *http.Request) (string, error) {
session, err := s.store.Get(r, sessionName)
if err != nil {
return "", fmt.Errorf("failed to get session: %w", err)
}
userID, ok := session.Values[userIDKey].(string)
if !ok || userID == "" {
return "", fmt.Errorf("user ID not found in session")
}
return userID, nil
}
// ValidateSession checks if a valid session exists for the request
func (s *CookieSessionStore) ValidateSession(r *http.Request) bool {
_, err := s.GetUserID(r)
return err == nil
}
// DestroySession destroys the user's session
func (s *CookieSessionStore) DestroySession(w http.ResponseWriter, r *http.Request) error {
session, err := s.store.Get(r, sessionName)
if err != nil {
// Session doesn't exist or is invalid, nothing to destroy
return nil
}
// Set MaxAge to -1 to delete the cookie
session.Options.MaxAge = -1
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to destroy session: %w", err)
}
return nil
}

View File

@@ -0,0 +1,237 @@
package auth
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestCookieSessionStore(t *testing.T) {
store := NewCookieSessionStore("test-secret-key", 3600)
t.Run("create and retrieve session", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-123"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
// Get the cookie from the response
cookies := w.Result().Cookies()
if len(cookies) == 0 {
t.Fatal("Expected session cookie, got none")
}
// Create a new request with the cookie
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
retrievedUserID, err := store.GetUserID(r2)
if err != nil {
t.Fatalf("Failed to get user ID: %v", err)
}
if retrievedUserID != userID {
t.Errorf("Expected user ID %s, got %s", userID, retrievedUserID)
}
})
t.Run("get user ID without session", func(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
_, err := store.GetUserID(r)
if err == nil {
t.Error("Expected error when getting user ID without session")
}
})
t.Run("destroy session", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-456"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
// Get the cookie
cookies := w.Result().Cookies()
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
// Destroy the session
w2 := httptest.NewRecorder()
err = store.DestroySession(w2, r2)
if err != nil {
t.Fatalf("Failed to destroy session: %v", err)
}
// Check that the cookie has MaxAge set to -1 (deletion marker)
destroyCookies := w2.Result().Cookies()
if len(destroyCookies) == 0 {
t.Fatal("Expected cookie with MaxAge=-1 for deletion")
}
foundDeleteCookie := false
for _, cookie := range destroyCookies {
if cookie.MaxAge == -1 {
foundDeleteCookie = true
break
}
}
if !foundDeleteCookie {
t.Error("Expected cookie with MaxAge=-1 to indicate deletion")
}
})
t.Run("update existing session", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
// Create initial session
userID1 := "user-1"
store.CreateSession(w, r, userID1)
cookies := w.Result().Cookies()
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
// Update session with new user ID
w2 := httptest.NewRecorder()
userID2 := "user-2"
err := store.CreateSession(w2, r2, userID2)
if err != nil {
t.Fatalf("Failed to update session: %v", err)
}
// Verify new user ID
cookies2 := w2.Result().Cookies()
r3 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies2 {
r3.AddCookie(cookie)
}
retrievedUserID, err := store.GetUserID(r3)
if err != nil {
t.Fatalf("Failed to get user ID: %v", err)
}
if retrievedUserID != userID2 {
t.Errorf("Expected user ID %s, got %s", userID2, retrievedUserID)
}
})
t.Run("validate session with valid session", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-789"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
// Create request with session cookie
cookies := w.Result().Cookies()
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
// Validate session
if !store.ValidateSession(r2) {
t.Error("Expected ValidateSession to return true for valid session")
}
})
t.Run("validate session without session", func(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
// Validate session without any cookies
if store.ValidateSession(r) {
t.Error("Expected ValidateSession to return false for missing session")
}
})
t.Run("validate session after logout", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-logout"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
// Get the cookie
cookies := w.Result().Cookies()
r2 := httptest.NewRequest("GET", "/", nil)
for _, cookie := range cookies {
r2.AddCookie(cookie)
}
// Destroy the session
w2 := httptest.NewRecorder()
err = store.DestroySession(w2, r2)
if err != nil {
t.Fatalf("Failed to destroy session: %v", err)
}
// Create a new request without any cookies (simulating browser behavior after logout)
r3 := httptest.NewRequest("GET", "/", nil)
// Validate session should return false
if store.ValidateSession(r3) {
t.Error("Expected ValidateSession to return false after logout")
}
})
t.Run("session cookie has security settings", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
userID := "test-user-security"
err := store.CreateSession(w, r, userID)
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
cookies := w.Result().Cookies()
if len(cookies) == 0 {
t.Fatal("Expected session cookie, got none")
}
cookie := cookies[0]
// Verify HttpOnly flag is set
if !cookie.HttpOnly {
t.Error("Expected HttpOnly flag to be true")
}
// Verify SameSite is set
if cookie.SameSite != http.SameSiteLaxMode {
t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite)
}
// Verify Path is set
if cookie.Path != "/" {
t.Errorf("Expected Path to be /, got %s", cookie.Path)
}
// Verify MaxAge is set
if cookie.MaxAge != 3600 {
t.Errorf("Expected MaxAge to be 3600, got %d", cookie.MaxAge)
}
})
}

View File

@@ -0,0 +1,78 @@
package auth
import (
"sync"
"time"
)
// MemoryStateStore is an in-memory implementation of StateStore
// Note: This is suitable for development but should be replaced with
// a distributed store (Redis, DynamoDB) for production with multiple servers
type MemoryStateStore struct {
mu sync.RWMutex
states map[string]time.Time
}
// NewMemoryStateStore creates a new in-memory state store
func NewMemoryStateStore() *MemoryStateStore {
store := &MemoryStateStore{
states: make(map[string]time.Time),
}
// Start cleanup goroutine to remove expired states
go store.cleanupExpired()
return store
}
// Set stores a state token with an expiry time
func (s *MemoryStateStore) Set(state string, expiry time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
s.states[state] = expiry
return nil
}
// Validate checks if a state token is valid and not expired
func (s *MemoryStateStore) Validate(state string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
expiry, exists := s.states[state]
if !exists {
return false, nil
}
if time.Now().After(expiry) {
return false, nil
}
return true, nil
}
// Delete removes a state token from the store
func (s *MemoryStateStore) Delete(state string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.states, state)
return nil
}
// cleanupExpired periodically removes expired state tokens
func (s *MemoryStateStore) cleanupExpired() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
s.mu.Lock()
now := time.Now()
for state, expiry := range s.states {
if now.After(expiry) {
delete(s.states, state)
}
}
s.mu.Unlock()
}
}

View File

@@ -0,0 +1,112 @@
package auth
import (
"testing"
"time"
)
func TestMemoryStateStore(t *testing.T) {
store := NewMemoryStateStore()
t.Run("set and validate state", func(t *testing.T) {
state := "test-state-123"
expiry := time.Now().Add(10 * time.Minute)
err := store.Set(state, expiry)
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
valid, err := store.Validate(state)
if err != nil {
t.Fatalf("Failed to validate state: %v", err)
}
if !valid {
t.Error("Expected state to be valid")
}
})
t.Run("validate non-existent state", func(t *testing.T) {
valid, err := store.Validate("non-existent")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if valid {
t.Error("Expected non-existent state to be invalid")
}
})
t.Run("validate expired state", func(t *testing.T) {
state := "expired-state"
expiry := time.Now().Add(-1 * time.Minute) // Already expired
err := store.Set(state, expiry)
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
valid, err := store.Validate(state)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if valid {
t.Error("Expected expired state to be invalid")
}
})
t.Run("delete state", func(t *testing.T) {
state := "delete-test"
expiry := time.Now().Add(10 * time.Minute)
err := store.Set(state, expiry)
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
err = store.Delete(state)
if err != nil {
t.Fatalf("Failed to delete state: %v", err)
}
valid, err := store.Validate(state)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if valid {
t.Error("Expected deleted state to be invalid")
}
})
t.Run("multiple states", func(t *testing.T) {
state1 := "state-1"
state2 := "state-2"
expiry := time.Now().Add(10 * time.Minute)
store.Set(state1, expiry)
store.Set(state2, expiry)
valid1, _ := store.Validate(state1)
valid2, _ := store.Validate(state2)
if !valid1 || !valid2 {
t.Error("Expected both states to be valid")
}
store.Delete(state1)
valid1, _ = store.Validate(state1)
valid2, _ = store.Validate(state2)
if valid1 {
t.Error("Expected state1 to be invalid after deletion")
}
if !valid2 {
t.Error("Expected state2 to still be valid")
}
})
}

View File

@@ -0,0 +1,100 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/google/uuid"
"golang.org/x/oauth2"
"custom-start-page/internal/models"
"custom-start-page/internal/storage"
)
// GoogleUserInfo represents the user info returned by Google
type GoogleUserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
VerifiedEmail bool `json:"verified_email"`
Name string `json:"name"`
Picture string `json:"picture"`
}
// UserService handles user creation and retrieval
type UserService struct {
userRepo *storage.UserRepository
}
// NewUserService creates a new user service
func NewUserService(userRepo *storage.UserRepository) *UserService {
return &UserService{
userRepo: userRepo,
}
}
// GetOrCreateUserFromGoogle fetches user info from Google and creates or retrieves the user
func (s *UserService) GetOrCreateUserFromGoogle(ctx context.Context, token *oauth2.Token, oauthConfig *oauth2.Config) (*models.User, error) {
// Fetch user info from Google
userInfo, err := s.fetchGoogleUserInfo(ctx, token, oauthConfig)
if err != nil {
return nil, fmt.Errorf("failed to fetch Google user info: %w", err)
}
// Check if user already exists
user, err := s.userRepo.GetByOAuthID(ctx, "google", userInfo.ID)
if err == nil {
// User exists, update last login time
user.UpdatedAt = time.Now()
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
}
return user, nil
}
// User doesn't exist, create new user
user = &models.User{
ID: uuid.New().String(),
Email: userInfo.Email,
OAuthProvider: "google",
OAuthID: userInfo.ID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return user, nil
}
// fetchGoogleUserInfo fetches user information from Google's userinfo endpoint
func (s *UserService) fetchGoogleUserInfo(ctx context.Context, token *oauth2.Token, oauthConfig *oauth2.Config) (*GoogleUserInfo, error) {
client := oauthConfig.Client(ctx, token)
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to get user info: status %d, body: %s", resp.StatusCode, string(body))
}
var userInfo GoogleUserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, fmt.Errorf("failed to decode user info: %w", err)
}
if !userInfo.VerifiedEmail {
return nil, fmt.Errorf("email not verified")
}
return &userInfo, nil
}

View File

@@ -0,0 +1,170 @@
package handlers
import (
"html/template"
"log"
"net/http"
"path/filepath"
"custom-start-page/internal/auth"
)
// AuthHandler handles authentication-related HTTP requests
type AuthHandler struct {
oauthService *auth.OAuthService
userService *auth.UserService
sessionStore SessionStore
templates *template.Template
}
// SessionStore manages user sessions
type SessionStore interface {
CreateSession(w http.ResponseWriter, r *http.Request, userID string) error
GetUserID(r *http.Request) (string, error)
DestroySession(w http.ResponseWriter, r *http.Request) error
}
// NewAuthHandler creates a new auth handler
func NewAuthHandler(oauthService *auth.OAuthService, userService *auth.UserService, sessionStore SessionStore) *AuthHandler {
return NewAuthHandlerWithTemplates(oauthService, userService, sessionStore, nil)
}
// NewAuthHandlerWithTemplates creates a new auth handler with custom templates
func NewAuthHandlerWithTemplates(oauthService *auth.OAuthService, userService *auth.UserService, sessionStore SessionStore, templates *template.Template) *AuthHandler {
if templates == nil {
// Parse templates
templates = template.Must(template.ParseGlob(filepath.Join("templates", "*.html")))
template.Must(templates.ParseGlob(filepath.Join("templates", "layouts", "*.html")))
}
return &AuthHandler{
oauthService: oauthService,
userService: userService,
sessionStore: sessionStore,
templates: templates,
}
}
// HandleOAuthInitiate initiates the OAuth flow
// GET /auth/oauth/:provider
func (h *AuthHandler) HandleOAuthInitiate(w http.ResponseWriter, r *http.Request) {
// Extract provider from URL path
provider := r.PathValue("provider")
if provider == "" {
http.Error(w, "Provider not specified", http.StatusBadRequest)
return
}
// Generate OAuth redirect URL
redirectURL, err := h.oauthService.InitiateOAuth(provider)
if err != nil {
log.Printf("Failed to initiate OAuth: %v", err)
http.Error(w, "Failed to initiate OAuth", http.StatusInternalServerError)
return
}
// Redirect to OAuth provider
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
}
// HandleOAuthCallback handles the OAuth callback
// GET /auth/callback/:provider
func (h *AuthHandler) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
// Extract provider from URL path
provider := r.PathValue("provider")
if provider == "" {
http.Error(w, "Provider not specified", http.StatusBadRequest)
return
}
// Get code and state from query parameters
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" {
// Check for error from OAuth provider
if errMsg := r.URL.Query().Get("error"); errMsg != "" {
log.Printf("OAuth error: %s", errMsg)
http.Redirect(w, r, "/login?error=oauth_failed", http.StatusTemporaryRedirect)
return
}
http.Error(w, "Authorization code not provided", http.StatusBadRequest)
return
}
if state == "" {
http.Error(w, "State parameter not provided", http.StatusBadRequest)
return
}
// Exchange code for token
token, err := h.oauthService.HandleOAuthCallback(r.Context(), provider, code, state)
if err != nil {
log.Printf("Failed to handle OAuth callback: %v", err)
http.Redirect(w, r, "/login?error=oauth_failed", http.StatusTemporaryRedirect)
return
}
// Get or create user from OAuth provider
user, err := h.userService.GetOrCreateUserFromGoogle(r.Context(), token, h.oauthService.GetGoogleConfig())
if err != nil {
log.Printf("Failed to get or create user: %v", err)
http.Redirect(w, r, "/login?error=user_creation_failed", http.StatusTemporaryRedirect)
return
}
// Create session
if err := h.sessionStore.CreateSession(w, r, user.ID); err != nil {
log.Printf("Failed to create session: %v", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
// Redirect to dashboard
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
}
// HandleLogout logs out the user
// POST /logout
func (h *AuthHandler) HandleLogout(w http.ResponseWriter, r *http.Request) {
if err := h.sessionStore.DestroySession(w, r); err != nil {
log.Printf("Failed to destroy session: %v", err)
}
// Redirect to login page
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
}
// HandleLogin displays the login page
// GET /login
func (h *AuthHandler) HandleLogin(w http.ResponseWriter, r *http.Request) {
// Check if user is already logged in
if userID, err := h.sessionStore.GetUserID(r); err == nil && userID != "" {
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
return
}
// Get error message if any
errorMsg := ""
if errParam := r.URL.Query().Get("error"); errParam != "" {
switch errParam {
case "oauth_failed":
errorMsg = "Authentication failed. Please try again."
case "user_creation_failed":
errorMsg = "Failed to create user account. Please try again."
default:
errorMsg = "An error occurred. Please try again."
}
}
// Render login template
data := map[string]interface{}{
"Error": errorMsg,
"OAuthProviders": []map[string]string{}, // Empty for now, can be extended
}
if err := h.templates.ExecuteTemplate(w, "login.html", data); err != nil {
log.Printf("Failed to render login template: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,159 @@
package handlers
import (
"html/template"
"net/http"
"net/http/httptest"
"testing"
"custom-start-page/internal/auth"
)
// MockSessionStore is a mock implementation of SessionStore for testing
type MockSessionStore struct {
userID string
shouldError bool
}
func (m *MockSessionStore) CreateSession(w http.ResponseWriter, r *http.Request, userID string) error {
m.userID = userID
return nil
}
func (m *MockSessionStore) GetUserID(r *http.Request) (string, error) {
if m.shouldError {
return "", http.ErrNoCookie
}
return m.userID, nil
}
func (m *MockSessionStore) DestroySession(w http.ResponseWriter, r *http.Request) error {
m.userID = ""
return nil
}
func (m *MockSessionStore) ValidateSession(r *http.Request) bool {
return m.userID != ""
}
// createMockTemplate creates a simple mock template for testing
func createMockTemplate() *template.Template {
tmpl := template.New("login.html")
template.Must(tmpl.Parse(`<!DOCTYPE html><html><body>{{if .Error}}<div>{{.Error}}</div>{{end}}<a href="/auth/oauth/google">Login</a></body></html>`))
return tmpl
}
// TestHandleLogin_UnauthenticatedUser tests that unauthenticated users see the login page
func TestHandleLogin_UnauthenticatedUser(t *testing.T) {
// Setup
mockSessionStore := &MockSessionStore{shouldError: true}
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
userService := auth.NewUserService(nil) // nil repo for this test
mockTemplate := createMockTemplate()
handler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
// Create request
req := httptest.NewRequest(http.MethodGet, "/login", nil)
w := httptest.NewRecorder()
// Execute
handler.HandleLogin(w, req)
// Assert
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Check that response contains login page elements
body := w.Body.String()
if body == "" {
t.Error("Expected non-empty response body")
}
}
// TestHandleLogin_AuthenticatedUser tests that authenticated users are redirected to dashboard
func TestHandleLogin_AuthenticatedUser(t *testing.T) {
// Setup
mockSessionStore := &MockSessionStore{userID: "test-user-123"}
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
userService := auth.NewUserService(nil)
mockTemplate := createMockTemplate()
handler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
// Create request
req := httptest.NewRequest(http.MethodGet, "/login", nil)
w := httptest.NewRecorder()
// Execute
handler.HandleLogin(w, req)
// Assert
if w.Code != http.StatusTemporaryRedirect {
t.Errorf("Expected status 307, got %d", w.Code)
}
location := w.Header().Get("Location")
if location != "/dashboard" {
t.Errorf("Expected redirect to /dashboard, got %s", location)
}
}
// TestHandleLogin_WithError tests that error messages are displayed
func TestHandleLogin_WithError(t *testing.T) {
// Setup
mockSessionStore := &MockSessionStore{shouldError: true}
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
userService := auth.NewUserService(nil)
mockTemplate := createMockTemplate()
handler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
// Create request with error parameter
req := httptest.NewRequest(http.MethodGet, "/login?error=oauth_failed", nil)
w := httptest.NewRecorder()
// Execute
handler.HandleLogin(w, req)
// Assert
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Check that response contains error message
body := w.Body.String()
if body == "" {
t.Error("Expected non-empty response body")
}
}
// TestHandleLogout tests that logout destroys session and redirects to login
func TestHandleLogout(t *testing.T) {
// Setup
mockSessionStore := &MockSessionStore{userID: "test-user-123"}
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
userService := auth.NewUserService(nil)
mockTemplate := createMockTemplate()
handler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
// Create request
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
w := httptest.NewRecorder()
// Execute
handler.HandleLogout(w, req)
// Assert
if w.Code != http.StatusTemporaryRedirect {
t.Errorf("Expected status 307, got %d", w.Code)
}
location := w.Header().Get("Location")
if location != "/login" {
t.Errorf("Expected redirect to /login, got %s", location)
}
// Verify session was destroyed
if mockSessionStore.userID != "" {
t.Error("Expected session to be destroyed")
}
}

View File

@@ -0,0 +1,58 @@
package handlers
import (
"html/template"
"log"
"net/http"
"path/filepath"
"custom-start-page/internal/middleware"
)
// DashboardHandler handles dashboard-related HTTP requests
type DashboardHandler struct {
templates *template.Template
}
// NewDashboardHandler creates a new dashboard handler
func NewDashboardHandler() *DashboardHandler {
// Parse templates
templates := template.Must(template.ParseGlob(filepath.Join("templates", "*.html")))
template.Must(templates.ParseGlob(filepath.Join("templates", "layouts", "*.html")))
return &DashboardHandler{
templates: templates,
}
}
// HandleDashboard displays the dashboard page
// GET /dashboard
func (h *DashboardHandler) HandleDashboard(w http.ResponseWriter, r *http.Request) {
// Get user ID from context (set by auth middleware)
userID, ok := middleware.GetUserIDFromContext(r.Context())
if !ok {
http.Error(w, "User ID not found in context", http.StatusInternalServerError)
return
}
// TODO: Fetch user's pages from database
// For now, we'll use mock data
pages := []map[string]interface{}{
{
"ID": "default-page",
"Name": "Home",
"Active": true,
},
}
// Render dashboard template
data := map[string]interface{}{
"UserID": userID,
"Pages": pages,
}
if err := h.templates.ExecuteTemplate(w, "dashboard.html", data); err != nil {
log.Printf("Failed to render dashboard template: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,68 @@
package handlers
import (
"context"
"html/template"
"net/http"
"net/http/httptest"
"testing"
"custom-start-page/internal/middleware"
)
// createMockDashboardTemplate creates a simple mock template for testing
func createMockDashboardTemplate() *template.Template {
tmpl := template.New("dashboard.html")
template.Must(tmpl.Parse(`<!DOCTYPE html><html><body><h1>Dashboard</h1><div>User: {{.UserID}}</div></body></html>`))
return tmpl
}
// TestHandleDashboard_WithAuthenticatedUser tests that authenticated users see the dashboard
func TestHandleDashboard_WithAuthenticatedUser(t *testing.T) {
// Setup
mockTemplate := createMockDashboardTemplate()
handler := &DashboardHandler{
templates: mockTemplate,
}
// Create request with user ID in context
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
ctx := context.WithValue(req.Context(), middleware.GetUserIDContextKey(), "test-user-123")
req = req.WithContext(ctx)
w := httptest.NewRecorder()
// Execute
handler.HandleDashboard(w, req)
// Assert
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
// Check that response contains dashboard content
body := w.Body.String()
if body == "" {
t.Error("Expected non-empty response body")
}
}
// TestHandleDashboard_WithoutUserID tests that requests without user ID fail
func TestHandleDashboard_WithoutUserID(t *testing.T) {
// Setup
mockTemplate := createMockDashboardTemplate()
handler := &DashboardHandler{
templates: mockTemplate,
}
// Create request without user ID in context
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
w := httptest.NewRecorder()
// Execute
handler.HandleDashboard(w, req)
// Assert
if w.Code != http.StatusInternalServerError {
t.Errorf("Expected status 500, got %d", w.Code)
}
}

View File

@@ -0,0 +1,102 @@
package handlers
import (
"net/http"
"net/http/httptest"
"testing"
"custom-start-page/internal/auth"
"custom-start-page/internal/middleware"
)
// TestRedirectFlow_UnauthenticatedToLogin tests that unauthenticated users are redirected to login
func TestRedirectFlow_UnauthenticatedToLogin(t *testing.T) {
// Setup
mockSessionStore := &MockSessionStore{shouldError: true}
// Create middleware
requireAuth := middleware.RequireAuth(mockSessionStore)
// Create dashboard handler
mockDashboardTemplate := createMockDashboardTemplate()
dashboardHandler := &DashboardHandler{templates: mockDashboardTemplate}
// Wrap dashboard handler with auth middleware
protectedHandler := requireAuth(http.HandlerFunc(dashboardHandler.HandleDashboard))
// Create request to dashboard
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
w := httptest.NewRecorder()
// Execute
protectedHandler.ServeHTTP(w, req)
// Assert - should redirect to login
if w.Code != http.StatusSeeOther {
t.Errorf("Expected status 303, got %d", w.Code)
}
location := w.Header().Get("Location")
if location != "/login" {
t.Errorf("Expected redirect to /login, got %s", location)
}
}
// TestRedirectFlow_AuthenticatedToDashboard tests that authenticated users accessing login are redirected to dashboard
func TestRedirectFlow_AuthenticatedToDashboard(t *testing.T) {
// Setup
mockSessionStore := &MockSessionStore{userID: "test-user-123"}
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
userService := auth.NewUserService(nil)
mockTemplate := createMockTemplate()
authHandler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
// Create request to login page
req := httptest.NewRequest(http.MethodGet, "/login", nil)
w := httptest.NewRecorder()
// Execute
authHandler.HandleLogin(w, req)
// Assert - should redirect to dashboard
if w.Code != http.StatusTemporaryRedirect {
t.Errorf("Expected status 307, got %d", w.Code)
}
location := w.Header().Get("Location")
if location != "/dashboard" {
t.Errorf("Expected redirect to /dashboard, got %s", location)
}
}
// TestRedirectFlow_LogoutToLogin tests that logout redirects to login
func TestRedirectFlow_LogoutToLogin(t *testing.T) {
// Setup
mockSessionStore := &MockSessionStore{userID: "test-user-123"}
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
userService := auth.NewUserService(nil)
mockTemplate := createMockTemplate()
authHandler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
// Create logout request
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
w := httptest.NewRecorder()
// Execute
authHandler.HandleLogout(w, req)
// Assert - should redirect to login
if w.Code != http.StatusTemporaryRedirect {
t.Errorf("Expected status 307, got %d", w.Code)
}
location := w.Header().Get("Location")
if location != "/login" {
t.Errorf("Expected redirect to /login, got %s", location)
}
// Verify session was destroyed
if mockSessionStore.userID != "" {
t.Error("Expected session to be destroyed after logout")
}
}

View File

@@ -0,0 +1,52 @@
package middleware
import (
"context"
"net/http"
)
// SessionStore defines the interface for session validation
type SessionStore interface {
ValidateSession(r *http.Request) bool
GetUserID(r *http.Request) (string, error)
}
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const userIDContextKey contextKey = "user_id"
// GetUserIDContextKey returns the context key for user ID (for testing)
func GetUserIDContextKey() contextKey {
return userIDContextKey
}
// RequireAuth is a middleware that ensures the user is authenticated
func RequireAuth(sessionStore SessionStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Validate session
if !sessionStore.ValidateSession(r) {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
// Get user ID and add to context
userID, err := sessionStore.GetUserID(r)
if err != nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
// Add user ID to request context
ctx := context.WithValue(r.Context(), userIDContextKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// GetUserIDFromContext retrieves the user ID from the request context
func GetUserIDFromContext(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userIDContextKey).(string)
return userID, ok
}

View File

@@ -0,0 +1,145 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
// mockSessionStore is a mock implementation of SessionStore for testing
type mockSessionStore struct {
valid bool
userID string
}
func (m *mockSessionStore) ValidateSession(r *http.Request) bool {
return m.valid
}
func (m *mockSessionStore) GetUserID(r *http.Request) (string, error) {
if !m.valid {
return "", http.ErrNoCookie
}
return m.userID, nil
}
func TestRequireAuth(t *testing.T) {
t.Run("allows authenticated requests", func(t *testing.T) {
mockStore := &mockSessionStore{
valid: true,
userID: "test-user-123",
}
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID, ok := GetUserIDFromContext(r.Context())
if !ok {
t.Error("Expected user ID in context")
return
}
if userID != "test-user-123" {
t.Errorf("Expected user ID test-user-123, got %s", userID)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/dashboard", nil)
handler.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
if w.Body.String() != "success" {
t.Errorf("Expected body 'success', got %s", w.Body.String())
}
})
t.Run("redirects unauthenticated requests", func(t *testing.T) {
mockStore := &mockSessionStore{
valid: false,
}
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Handler should not be called for unauthenticated request")
}))
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/dashboard", nil)
handler.ServeHTTP(w, r)
if w.Code != http.StatusSeeOther {
t.Errorf("Expected status 303, got %d", w.Code)
}
location := w.Header().Get("Location")
if location != "/login" {
t.Errorf("Expected redirect to /login, got %s", location)
}
})
t.Run("redirects when GetUserID fails", func(t *testing.T) {
mockStore := &mockSessionStore{
valid: true, // ValidateSession returns true
userID: "", // But GetUserID will fail
}
// Override GetUserID to return error
mockStore.valid = false
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Handler should not be called when GetUserID fails")
}))
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/dashboard", nil)
handler.ServeHTTP(w, r)
if w.Code != http.StatusSeeOther {
t.Errorf("Expected status 303, got %d", w.Code)
}
})
}
func TestGetUserIDFromContext(t *testing.T) {
t.Run("retrieves user ID from context", func(t *testing.T) {
mockStore := &mockSessionStore{
valid: true,
userID: "context-user-456",
}
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID, ok := GetUserIDFromContext(r.Context())
if !ok {
t.Error("Expected user ID in context")
return
}
if userID != "context-user-456" {
t.Errorf("Expected user ID context-user-456, got %s", userID)
}
w.WriteHeader(http.StatusOK)
}))
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/test", nil)
handler.ServeHTTP(w, r)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
})
t.Run("returns false when user ID not in context", func(t *testing.T) {
r := httptest.NewRequest("GET", "/test", nil)
_, ok := GetUserIDFromContext(r.Context())
if ok {
t.Error("Expected ok to be false when user ID not in context")
}
})
}

15
internal/models/user.go Normal file
View File

@@ -0,0 +1,15 @@
package models
import (
"time"
)
// User represents a user in the system
type User struct {
ID string `dynamodbav:"user_id" json:"id"`
Email string `dynamodbav:"email" json:"email"`
OAuthProvider string `dynamodbav:"oauth_provider" json:"oauth_provider"`
OAuthID string `dynamodbav:"oauth_id" json:"oauth_id"`
CreatedAt time.Time `dynamodbav:"created_at" json:"created_at"`
UpdatedAt time.Time `dynamodbav:"updated_at" json:"updated_at"`
}

124
internal/storage/README.md Normal file
View File

@@ -0,0 +1,124 @@
# DynamoDB Storage Service
This package provides an enhanced DynamoDB client wrapper with the following features:
## Features
### 1. Connection Pooling
The client uses the AWS SDK's default HTTP client which includes connection pooling automatically. This ensures efficient reuse of TCP connections to DynamoDB.
### 2. Retry Logic with Exponential Backoff
The client is configured with automatic retry logic:
- **Max Attempts**: 5 retries
- **Max Backoff**: 20 seconds
- **Strategy**: Exponential backoff with jitter to prevent thundering herd
This handles transient failures gracefully and improves reliability.
### 3. Transaction Support
The `TransactWriteItems` method provides ACID transaction support for multiple write operations:
```go
err := client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{
TransactItems: []types.TransactWriteItem{
{
Put: &types.Put{
TableName: aws.String("MyTable"),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: "item1"},
},
},
},
// More items...
},
})
```
### 4. Batch Operations
The client provides batch read and write operations with automatic retry of unprocessed items:
#### BatchGetItems
Retrieves multiple items in a single request:
```go
output, err := client.BatchGetItems(ctx, &dynamodb.BatchGetItemInput{
RequestItems: map[string]types.KeysAndAttributes{
"MyTable": {
Keys: []map[string]types.AttributeValue{
{"id": &types.AttributeValueMemberS{Value: "item1"}},
{"id": &types.AttributeValueMemberS{Value: "item2"}},
},
},
},
})
```
#### BatchWriteItems
Writes multiple items in a single request:
```go
err := client.BatchWriteItems(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: map[string][]types.WriteRequest{
"MyTable": {
{
PutRequest: &types.PutRequest{
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: "item1"},
},
},
},
},
},
})
```
Both batch operations automatically handle unprocessed items with exponential backoff retry logic.
## Standard Operations
The client also provides wrapped versions of standard DynamoDB operations with automatic retry:
- `PutItem` - Put a single item
- `GetItem` - Get a single item
- `UpdateItem` - Update a single item
- `DeleteItem` - Delete a single item
- `Query` - Query items
## Usage
### Creating a Client
```go
ctx := context.Background()
client, err := storage.NewDynamoDBClient(ctx, "http://localhost:8000")
if err != nil {
log.Fatal(err)
}
```
For production (AWS DynamoDB), pass an empty string for the endpoint:
```go
client, err := storage.NewDynamoDBClient(ctx, "")
```
### Testing
The package includes comprehensive tests that can be run against DynamoDB Local:
1. Start DynamoDB Local:
```bash
docker-compose up -d
```
2. Run tests:
```bash
go test -v ./internal/storage
```
Tests will automatically skip if DynamoDB is not available.
## Requirements Addressed
This implementation addresses the following requirements from the spec:
- **Requirement 8.1**: Immediate persistence of all changes
- **Requirement 8.8**: Efficient scaling for 10,000+ items per user
- **Design requirement**: Retry logic with exponential backoff
- **Design requirement**: Transaction support for atomic operations
- **Design requirement**: Batch operations for efficient bulk reads/writes
- **Design requirement**: Connection pooling for performance

View File

@@ -0,0 +1,248 @@
package storage
import (
"context"
"fmt"
"math"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)
// DynamoDBClient wraps the AWS DynamoDB client with enhanced features
type DynamoDBClient struct {
client *dynamodb.Client
}
// NewDynamoDBClient creates a new DynamoDB client with connection pooling and retry logic
func NewDynamoDBClient(ctx context.Context, endpoint string) (*DynamoDBClient, error) {
// Configure retry strategy with exponential backoff
retryer := retry.NewStandard(func(o *retry.StandardOptions) {
o.MaxAttempts = 5
o.MaxBackoff = 20 * time.Second
// Use exponential backoff with jitter
o.Backoff = retry.NewExponentialJitterBackoff(20 * time.Second)
})
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRetryer(func() aws.Retryer {
return retryer
}),
// Connection pooling is handled by the HTTP client
config.WithHTTPClient(nil), // Uses default HTTP client with connection pooling
)
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %w", err)
}
// Override endpoint if provided (for local DynamoDB)
if endpoint != "" {
cfg.BaseEndpoint = aws.String(endpoint)
}
client := dynamodb.NewFromConfig(cfg)
return &DynamoDBClient{
client: client,
}, nil
}
// CreateUsersTable creates the Users table in DynamoDB
func (db *DynamoDBClient) CreateUsersTable(ctx context.Context) error {
tableName := "Users"
// Check if table already exists
_, err := db.client.DescribeTable(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
})
if err == nil {
// Table already exists
return nil
}
// Create table
_, err = db.client.CreateTable(ctx, &dynamodb.CreateTableInput{
TableName: aws.String(tableName),
AttributeDefinitions: []types.AttributeDefinition{
{
AttributeName: aws.String("user_id"),
AttributeType: types.ScalarAttributeTypeS,
},
},
KeySchema: []types.KeySchemaElement{
{
AttributeName: aws.String("user_id"),
KeyType: types.KeyTypeHash,
},
},
BillingMode: types.BillingModePayPerRequest,
})
if err != nil {
return fmt.Errorf("failed to create Users table: %w", err)
}
// Wait for table to be active
waiter := dynamodb.NewTableExistsWaiter(db.client)
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
}, 5*60) // 5 minutes timeout
if err != nil {
return fmt.Errorf("failed waiting for Users table to be active: %w", err)
}
return nil
}
// GetClient returns the underlying DynamoDB client
func (db *DynamoDBClient) GetClient() *dynamodb.Client {
return db.client
}
// TransactWriteItems executes a transactional write operation with automatic retry
func (db *DynamoDBClient) TransactWriteItems(ctx context.Context, input *dynamodb.TransactWriteItemsInput) error {
_, err := db.client.TransactWriteItems(ctx, input)
if err != nil {
return fmt.Errorf("transaction write failed: %w", err)
}
return nil
}
// BatchGetItems retrieves multiple items in a single batch operation
func (db *DynamoDBClient) BatchGetItems(ctx context.Context, input *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error) {
output, err := db.client.BatchGetItem(ctx, input)
if err != nil {
return nil, fmt.Errorf("batch get failed: %w", err)
}
// Handle unprocessed keys with exponential backoff
if len(output.UnprocessedKeys) > 0 {
return db.retryUnprocessedKeys(ctx, output)
}
return output, nil
}
// BatchWriteItems writes multiple items in a single batch operation
func (db *DynamoDBClient) BatchWriteItems(ctx context.Context, input *dynamodb.BatchWriteItemInput) error {
output, err := db.client.BatchWriteItem(ctx, input)
if err != nil {
return fmt.Errorf("batch write failed: %w", err)
}
// Handle unprocessed items with exponential backoff
if len(output.UnprocessedItems) > 0 {
return db.retryUnprocessedWrites(ctx, output.UnprocessedItems)
}
return nil
}
// retryUnprocessedKeys retries unprocessed keys from BatchGetItem with exponential backoff
func (db *DynamoDBClient) retryUnprocessedKeys(ctx context.Context, output *dynamodb.BatchGetItemOutput) (*dynamodb.BatchGetItemOutput, error) {
maxRetries := 5
backoff := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries && len(output.UnprocessedKeys) > 0; attempt++ {
// Wait with exponential backoff
time.Sleep(backoff)
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
// Retry unprocessed keys
retryOutput, err := db.client.BatchGetItem(ctx, &dynamodb.BatchGetItemInput{
RequestItems: output.UnprocessedKeys,
})
if err != nil {
return nil, fmt.Errorf("retry batch get failed: %w", err)
}
// Merge responses
for table, items := range retryOutput.Responses {
output.Responses[table] = append(output.Responses[table], items...)
}
output.UnprocessedKeys = retryOutput.UnprocessedKeys
}
if len(output.UnprocessedKeys) > 0 {
return output, fmt.Errorf("failed to process all keys after %d retries", maxRetries)
}
return output, nil
}
// retryUnprocessedWrites retries unprocessed items from BatchWriteItem with exponential backoff
func (db *DynamoDBClient) retryUnprocessedWrites(ctx context.Context, unprocessedItems map[string][]types.WriteRequest) error {
maxRetries := 5
backoff := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries && len(unprocessedItems) > 0; attempt++ {
// Wait with exponential backoff
time.Sleep(backoff)
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
// Retry unprocessed items
output, err := db.client.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: unprocessedItems,
})
if err != nil {
return fmt.Errorf("retry batch write failed: %w", err)
}
unprocessedItems = output.UnprocessedItems
}
if len(unprocessedItems) > 0 {
return fmt.Errorf("failed to process all items after %d retries", maxRetries)
}
return nil
}
// PutItem puts a single item with automatic retry
func (db *DynamoDBClient) PutItem(ctx context.Context, input *dynamodb.PutItemInput) error {
_, err := db.client.PutItem(ctx, input)
if err != nil {
return fmt.Errorf("put item failed: %w", err)
}
return nil
}
// GetItem retrieves a single item with automatic retry
func (db *DynamoDBClient) GetItem(ctx context.Context, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
output, err := db.client.GetItem(ctx, input)
if err != nil {
return nil, fmt.Errorf("get item failed: %w", err)
}
return output, nil
}
// Query executes a query operation with automatic retry
func (db *DynamoDBClient) Query(ctx context.Context, input *dynamodb.QueryInput) (*dynamodb.QueryOutput, error) {
output, err := db.client.Query(ctx, input)
if err != nil {
return nil, fmt.Errorf("query failed: %w", err)
}
return output, nil
}
// UpdateItem updates a single item with automatic retry
func (db *DynamoDBClient) UpdateItem(ctx context.Context, input *dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) {
output, err := db.client.UpdateItem(ctx, input)
if err != nil {
return nil, fmt.Errorf("update item failed: %w", err)
}
return output, nil
}
// DeleteItem deletes a single item with automatic retry
func (db *DynamoDBClient) DeleteItem(ctx context.Context, input *dynamodb.DeleteItemInput) error {
_, err := db.client.DeleteItem(ctx, input)
if err != nil {
return fmt.Errorf("delete item failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,452 @@
package storage
import (
"context"
"os"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)
func setupTestClient(t *testing.T) (*DynamoDBClient, context.Context) {
t.Helper()
// Set dummy AWS credentials for local DynamoDB
os.Setenv("AWS_ACCESS_KEY_ID", "dummy")
os.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
os.Setenv("AWS_REGION", "us-east-1")
endpoint := os.Getenv("DYNAMODB_ENDPOINT")
if endpoint == "" {
endpoint = "http://localhost:8000"
}
ctx := context.Background()
client, err := NewDynamoDBClient(ctx, endpoint)
if err != nil {
t.Skipf("Skipping test: DynamoDB not available: %v", err)
}
// Test connection by listing tables
_, err = client.client.ListTables(ctx, &dynamodb.ListTablesInput{})
if err != nil {
t.Skipf("Skipping test: Cannot connect to DynamoDB: %v", err)
}
return client, ctx
}
func TestNewDynamoDBClient(t *testing.T) {
client, _ := setupTestClient(t)
if client == nil {
t.Fatal("Expected non-nil client")
}
if client.client == nil {
t.Fatal("Expected non-nil underlying client")
}
}
func TestTransactWriteItems(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestTransactions"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Test transaction with two puts
testID1 := "test-txn-1"
testID2 := "test-txn-2"
input := &dynamodb.TransactWriteItemsInput{
TransactItems: []types.TransactWriteItem{
{
Put: &types.Put{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID1},
"value": &types.AttributeValueMemberS{Value: "value1"},
},
},
},
{
Put: &types.Put{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID2},
"value": &types.AttributeValueMemberS{Value: "value2"},
},
},
},
},
}
err := client.TransactWriteItems(ctx, input)
if err != nil {
t.Fatalf("TransactWriteItems failed: %v", err)
}
// Verify both items were written
output1, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID1},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output1.Item) == 0 {
t.Fatal("Expected item to exist after transaction")
}
output2, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID2},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output2.Item) == 0 {
t.Fatal("Expected item to exist after transaction")
}
}
func TestBatchGetItems(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestBatchGet"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put test items
testIDs := []string{"batch-1", "batch-2", "batch-3"}
for _, id := range testIDs {
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
"value": &types.AttributeValueMemberS{Value: "test-value"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
}
// Batch get items
keys := make([]map[string]types.AttributeValue, len(testIDs))
for i, id := range testIDs {
keys[i] = map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
}
}
output, err := client.BatchGetItems(ctx, &dynamodb.BatchGetItemInput{
RequestItems: map[string]types.KeysAndAttributes{
tableName: {
Keys: keys,
},
},
})
if err != nil {
t.Fatalf("BatchGetItems failed: %v", err)
}
if len(output.Responses[tableName]) != len(testIDs) {
t.Fatalf("Expected %d items, got %d", len(testIDs), len(output.Responses[tableName]))
}
}
func TestBatchWriteItems(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestBatchWrite"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Batch write items
testIDs := []string{"write-1", "write-2", "write-3"}
writeRequests := make([]types.WriteRequest, len(testIDs))
for i, id := range testIDs {
writeRequests[i] = types.WriteRequest{
PutRequest: &types.PutRequest{
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
"value": &types.AttributeValueMemberS{Value: "batch-value"},
},
},
}
}
err := client.BatchWriteItems(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: map[string][]types.WriteRequest{
tableName: writeRequests,
},
})
if err != nil {
t.Fatalf("BatchWriteItems failed: %v", err)
}
// Verify items were written
for _, id := range testIDs {
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output.Item) == 0 {
t.Fatalf("Expected item %s to exist after batch write", id)
}
}
}
func TestPutAndGetItem(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestPutGet"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put item
testID := "put-get-test"
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
"value": &types.AttributeValueMemberS{Value: "test-value"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
// Get item
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output.Item) == 0 {
t.Fatal("Expected item to exist")
}
valueAttr, ok := output.Item["value"]
if !ok {
t.Fatal("Expected 'value' attribute")
}
value := valueAttr.(*types.AttributeValueMemberS).Value
if value != "test-value" {
t.Fatalf("Expected value 'test-value', got '%s'", value)
}
}
func TestUpdateItem(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestUpdate"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put initial item
testID := "update-test"
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
"value": &types.AttributeValueMemberS{Value: "initial"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
// Update item
_, err = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
UpdateExpression: aws.String("SET #v = :val"),
ExpressionAttributeNames: map[string]string{
"#v": "value",
},
ExpressionAttributeValues: map[string]types.AttributeValue{
":val": &types.AttributeValueMemberS{Value: "updated"},
},
})
if err != nil {
t.Fatalf("UpdateItem failed: %v", err)
}
// Verify update
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
value := output.Item["value"].(*types.AttributeValueMemberS).Value
if value != "updated" {
t.Fatalf("Expected value 'updated', got '%s'", value)
}
}
func TestDeleteItem(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestDelete"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put item
testID := "delete-test"
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
"value": &types.AttributeValueMemberS{Value: "test"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
// Delete item
err = client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
})
if err != nil {
t.Fatalf("DeleteItem failed: %v", err)
}
// Verify deletion
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output.Item) != 0 {
t.Fatal("Expected item to be deleted")
}
}
func TestQuery(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestQuery"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put test items
testIDs := []string{"query-1", "query-2", "query-3"}
for _, id := range testIDs {
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
"value": &types.AttributeValueMemberS{Value: "test"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
}
// Query for specific item
output, err := client.Query(ctx, &dynamodb.QueryInput{
TableName: aws.String(tableName),
KeyConditionExpression: aws.String("id = :id"),
ExpressionAttributeValues: map[string]types.AttributeValue{
":id": &types.AttributeValueMemberS{Value: "query-1"},
},
})
if err != nil {
t.Fatalf("Query failed: %v", err)
}
if len(output.Items) != 1 {
t.Fatalf("Expected 1 item, got %d", len(output.Items))
}
}
// Helper functions
func createTestTable(t *testing.T, ctx context.Context, client *DynamoDBClient, tableName string) {
t.Helper()
_, err := client.client.CreateTable(ctx, &dynamodb.CreateTableInput{
TableName: aws.String(tableName),
AttributeDefinitions: []types.AttributeDefinition{
{
AttributeName: aws.String("id"),
AttributeType: types.ScalarAttributeTypeS,
},
},
KeySchema: []types.KeySchemaElement{
{
AttributeName: aws.String("id"),
KeyType: types.KeyTypeHash,
},
},
BillingMode: types.BillingModePayPerRequest,
})
if err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
// Wait for table to be active
waiter := dynamodb.NewTableExistsWaiter(client.client)
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
}, 30*time.Second)
if err != nil {
t.Fatalf("Failed waiting for table to be active: %v", err)
}
}
func deleteTestTable(t *testing.T, ctx context.Context, client *DynamoDBClient, tableName string) {
t.Helper()
_, err := client.client.DeleteTable(ctx, &dynamodb.DeleteTableInput{
TableName: aws.String(tableName),
})
if err != nil {
t.Logf("Warning: Failed to delete test table: %v", err)
}
}

View File

@@ -0,0 +1,115 @@
package storage
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"custom-start-page/internal/models"
)
// UserRepository handles user data operations
type UserRepository struct {
client *DynamoDBClient
tableName string
}
// NewUserRepository creates a new user repository
func NewUserRepository(client *DynamoDBClient, tableName string) *UserRepository {
return &UserRepository{
client: client,
tableName: tableName,
}
}
// Create creates a new user in the database
func (r *UserRepository) Create(ctx context.Context, user *models.User) error {
item, err := attributevalue.MarshalMap(user)
if err != nil {
return fmt.Errorf("failed to marshal user: %w", err)
}
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(r.tableName),
Item: item,
})
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// GetByID retrieves a user by their ID
func (r *UserRepository) GetByID(ctx context.Context, userID string) (*models.User, error) {
result, err := r.client.GetClient().GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(r.tableName),
Key: map[string]types.AttributeValue{
"user_id": &types.AttributeValueMemberS{Value: userID},
},
})
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if result.Item == nil {
return nil, fmt.Errorf("user not found")
}
var user models.User
if err := attributevalue.UnmarshalMap(result.Item, &user); err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
return &user, nil
}
// GetByOAuthID retrieves a user by their OAuth provider and OAuth ID
func (r *UserRepository) GetByOAuthID(ctx context.Context, provider, oauthID string) (*models.User, error) {
// Use a scan with filter for now
// In production, consider adding a GSI on oauth_provider + oauth_id
result, err := r.client.GetClient().Scan(ctx, &dynamodb.ScanInput{
TableName: aws.String(r.tableName),
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :oauth_id"),
ExpressionAttributeValues: map[string]types.AttributeValue{
":provider": &types.AttributeValueMemberS{Value: provider},
":oauth_id": &types.AttributeValueMemberS{Value: oauthID},
},
})
if err != nil {
return nil, fmt.Errorf("failed to scan users: %w", err)
}
if len(result.Items) == 0 {
return nil, fmt.Errorf("user not found")
}
var user models.User
if err := attributevalue.UnmarshalMap(result.Items[0], &user); err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
return &user, nil
}
// Update updates an existing user
func (r *UserRepository) Update(ctx context.Context, user *models.User) error {
item, err := attributevalue.MarshalMap(user)
if err != nil {
return fmt.Errorf("failed to marshal user: %w", err)
}
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(r.tableName),
Item: item,
})
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}

View File

@@ -0,0 +1,143 @@
package storage
import (
"context"
"fmt"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/google/uuid"
"custom-start-page/internal/models"
)
const usersTableName = "Users"
// UserStorage handles user data operations
type UserStorage struct {
db *DynamoDBClient
}
// NewUserStorage creates a new UserStorage instance
func NewUserStorage(db *DynamoDBClient) *UserStorage {
return &UserStorage{db: db}
}
// CreateUser creates a new user in DynamoDB
func (s *UserStorage) CreateUser(ctx context.Context, email, oauthProvider, oauthID string) (*models.User, error) {
user := &models.User{
ID: uuid.New().String(),
Email: email,
OAuthProvider: oauthProvider,
OAuthID: oauthID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
item, err := attributevalue.MarshalMap(user)
if err != nil {
return nil, fmt.Errorf("failed to marshal user: %w", err)
}
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(usersTableName),
Item: item,
})
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return user, nil
}
// GetUserByID retrieves a user by their ID
func (s *UserStorage) GetUserByID(ctx context.Context, userID string) (*models.User, error) {
result, err := s.db.client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(usersTableName),
Key: map[string]types.AttributeValue{
"user_id": &types.AttributeValueMemberS{Value: userID},
},
})
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if result.Item == nil {
return nil, fmt.Errorf("user not found")
}
var user models.User
err = attributevalue.UnmarshalMap(result.Item, &user)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
return &user, nil
}
// GetUserByOAuth retrieves a user by their OAuth provider and ID
func (s *UserStorage) GetUserByOAuth(ctx context.Context, oauthProvider, oauthID string) (*models.User, error) {
// Since we don't have a GSI for oauth_provider + oauth_id, we need to scan
// In production, you might want to add a GSI for this access pattern
result, err := s.db.client.Scan(ctx, &dynamodb.ScanInput{
TableName: aws.String(usersTableName),
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :id"),
ExpressionAttributeValues: map[string]types.AttributeValue{
":provider": &types.AttributeValueMemberS{Value: oauthProvider},
":id": &types.AttributeValueMemberS{Value: oauthID},
},
})
if err != nil {
return nil, fmt.Errorf("failed to scan for user: %w", err)
}
if len(result.Items) == 0 {
return nil, fmt.Errorf("user not found")
}
var user models.User
err = attributevalue.UnmarshalMap(result.Items[0], &user)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
return &user, nil
}
// UpdateUser updates an existing user
func (s *UserStorage) UpdateUser(ctx context.Context, user *models.User) error {
user.UpdatedAt = time.Now()
item, err := attributevalue.MarshalMap(user)
if err != nil {
return fmt.Errorf("failed to marshal user: %w", err)
}
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(usersTableName),
Item: item,
})
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// DeleteUser deletes a user by their ID
func (s *UserStorage) DeleteUser(ctx context.Context, userID string) error {
_, err := s.db.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
TableName: aws.String(usersTableName),
Key: map[string]types.AttributeValue{
"user_id": &types.AttributeValueMemberS{Value: userID},
},
})
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}

View File

@@ -0,0 +1,228 @@
package storage
import (
"context"
"os"
"testing"
"time"
)
func setupTestDB(t *testing.T) *DynamoDBClient {
ctx := context.Background()
endpoint := os.Getenv("DYNAMODB_ENDPOINT")
if endpoint == "" {
endpoint = "http://localhost:8000"
}
db, err := NewDynamoDBClient(ctx, endpoint)
if err != nil {
t.Fatalf("Failed to create DynamoDB client: %v", err)
}
// Create Users table
if err := db.CreateUsersTable(ctx); err != nil {
t.Fatalf("Failed to create Users table: %v", err)
}
return db
}
func TestCreateUser(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
if user.ID == "" {
t.Error("User ID should not be empty")
}
if user.Email != "test@example.com" {
t.Errorf("Expected email 'test@example.com', got '%s'", user.Email)
}
if user.OAuthProvider != "google" {
t.Errorf("Expected oauth_provider 'google', got '%s'", user.OAuthProvider)
}
if user.OAuthID != "google123" {
t.Errorf("Expected oauth_id 'google123', got '%s'", user.OAuthID)
}
if user.CreatedAt.IsZero() {
t.Error("CreatedAt should not be zero")
}
if user.UpdatedAt.IsZero() {
t.Error("UpdatedAt should not be zero")
}
}
func TestGetUserByID(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create a user first
createdUser, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Retrieve the user
retrievedUser, err := storage.GetUserByID(ctx, createdUser.ID)
if err != nil {
t.Fatalf("Failed to get user: %v", err)
}
if retrievedUser.ID != createdUser.ID {
t.Errorf("Expected ID '%s', got '%s'", createdUser.ID, retrievedUser.ID)
}
if retrievedUser.Email != createdUser.Email {
t.Errorf("Expected email '%s', got '%s'", createdUser.Email, retrievedUser.Email)
}
}
func TestGetUserByID_NotFound(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
_, err := storage.GetUserByID(ctx, "nonexistent-id")
if err == nil {
t.Error("Expected error when getting nonexistent user")
}
}
func TestGetUserByOAuth(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create a user first
createdUser, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Retrieve the user by OAuth
retrievedUser, err := storage.GetUserByOAuth(ctx, "google", "google123")
if err != nil {
t.Fatalf("Failed to get user by OAuth: %v", err)
}
if retrievedUser.ID != createdUser.ID {
t.Errorf("Expected ID '%s', got '%s'", createdUser.ID, retrievedUser.ID)
}
if retrievedUser.Email != createdUser.Email {
t.Errorf("Expected email '%s', got '%s'", createdUser.Email, retrievedUser.Email)
}
}
func TestGetUserByOAuth_NotFound(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
_, err := storage.GetUserByOAuth(ctx, "google", "nonexistent")
if err == nil {
t.Error("Expected error when getting nonexistent user")
}
}
func TestUpdateUser(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create a user first
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
originalUpdatedAt := user.UpdatedAt
time.Sleep(10 * time.Millisecond) // Ensure time difference
// Update the user
user.Email = "updated@example.com"
err = storage.UpdateUser(ctx, user)
if err != nil {
t.Fatalf("Failed to update user: %v", err)
}
// Retrieve and verify
updatedUser, err := storage.GetUserByID(ctx, user.ID)
if err != nil {
t.Fatalf("Failed to get updated user: %v", err)
}
if updatedUser.Email != "updated@example.com" {
t.Errorf("Expected email 'updated@example.com', got '%s'", updatedUser.Email)
}
if !updatedUser.UpdatedAt.After(originalUpdatedAt) {
t.Error("UpdatedAt should be updated")
}
}
func TestDeleteUser(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create a user first
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Delete the user
err = storage.DeleteUser(ctx, user.ID)
if err != nil {
t.Fatalf("Failed to delete user: %v", err)
}
// Verify deletion
_, err = storage.GetUserByID(ctx, user.ID)
if err == nil {
t.Error("Expected error when getting deleted user")
}
}
func TestCreateUser_MultipleUsers(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create multiple users
user1, err := storage.CreateUser(ctx, "user1@example.com", "google", "google1")
if err != nil {
t.Fatalf("Failed to create user1: %v", err)
}
user2, err := storage.CreateUser(ctx, "user2@example.com", "github", "github1")
if err != nil {
t.Fatalf("Failed to create user2: %v", err)
}
// Verify both users exist and are different
if user1.ID == user2.ID {
t.Error("User IDs should be unique")
}
retrievedUser1, err := storage.GetUserByID(ctx, user1.ID)
if err != nil {
t.Fatalf("Failed to get user1: %v", err)
}
retrievedUser2, err := storage.GetUserByID(ctx, user2.ID)
if err != nil {
t.Fatalf("Failed to get user2: %v", err)
}
if retrievedUser1.Email != "user1@example.com" {
t.Errorf("Expected user1 email 'user1@example.com', got '%s'", retrievedUser1.Email)
}
if retrievedUser2.Email != "user2@example.com" {
t.Errorf("Expected user2 email 'user2@example.com', got '%s'", retrievedUser2.Email)
}
}

View File

@@ -0,0 +1,43 @@
package testing
import (
"testing"
"github.com/leanovate/gopter"
)
// PropertyTestConfig holds configuration for property-based tests
type PropertyTestConfig struct {
MinSuccessfulTests int
MaxSize int
Workers int
}
// DefaultPropertyTestConfig returns default configuration for property tests
func DefaultPropertyTestConfig() *PropertyTestConfig {
return &PropertyTestConfig{
MinSuccessfulTests: 100,
MaxSize: 100,
Workers: 4,
}
}
// RunPropertyTest runs a property-based test with the given configuration
func RunPropertyTest(t *testing.T, config *PropertyTestConfig, testFunc func(*gopter.Properties)) {
if config == nil {
config = DefaultPropertyTestConfig()
}
parameters := gopter.DefaultTestParameters()
parameters.MinSuccessfulTests = config.MinSuccessfulTests
parameters.MaxSize = config.MaxSize
parameters.Workers = config.Workers
properties := gopter.NewProperties(parameters)
// Call the test function to add properties
testFunc(properties)
// Run the tests
properties.TestingRun(t)
}

View File

@@ -0,0 +1,55 @@
package testing
import (
"testing"
"github.com/leanovate/gopter"
"github.com/leanovate/gopter/gen"
"github.com/leanovate/gopter/prop"
)
func TestDefaultPropertyTestConfig(t *testing.T) {
cfg := DefaultPropertyTestConfig()
if cfg.MinSuccessfulTests != 100 {
t.Errorf("Expected MinSuccessfulTests to be 100, got %d", cfg.MinSuccessfulTests)
}
if cfg.MaxSize != 100 {
t.Errorf("Expected MaxSize to be 100, got %d", cfg.MaxSize)
}
if cfg.Workers != 4 {
t.Errorf("Expected Workers to be 4, got %d", cfg.Workers)
}
}
func TestRunPropertyTest(t *testing.T) {
// Simple property: for all integers, x + 0 = x
RunPropertyTest(t, nil, func(properties *gopter.Properties) {
properties.Property("addition identity", prop.ForAll(
func(x int) bool {
return x+0 == x
},
gen.Int(),
))
})
}
func TestRunPropertyTestWithCustomConfig(t *testing.T) {
cfg := &PropertyTestConfig{
MinSuccessfulTests: 50,
MaxSize: 50,
Workers: 2,
}
// Simple property: for all strings, len(s) >= 0
RunPropertyTest(t, cfg, func(properties *gopter.Properties) {
properties.Property("string length non-negative", prop.ForAll(
func(s string) bool {
return len(s) >= 0
},
gen.AnyString(),
))
})
}