Initial commit: Custom Start Page application with authentication and DynamoDB storage
This commit is contained in:
109
internal/auth/oauth.go
Normal file
109
internal/auth/oauth.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// OAuthService handles OAuth authentication flows
|
||||
type OAuthService struct {
|
||||
googleConfig *oauth2.Config
|
||||
stateStore StateStore
|
||||
}
|
||||
|
||||
// StateStore manages OAuth state tokens for CSRF protection
|
||||
type StateStore interface {
|
||||
Set(state string, expiry time.Time) error
|
||||
Validate(state string) (bool, error)
|
||||
Delete(state string) error
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(googleClientID, googleClientSecret, googleRedirectURL string, stateStore StateStore) *OAuthService {
|
||||
googleConfig := &oauth2.Config{
|
||||
ClientID: googleClientID,
|
||||
ClientSecret: googleClientSecret,
|
||||
RedirectURL: googleRedirectURL,
|
||||
Scopes: []string{
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
return &OAuthService{
|
||||
googleConfig: googleConfig,
|
||||
stateStore: stateStore,
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateOAuth starts the OAuth flow and returns the redirect URL
|
||||
func (s *OAuthService) InitiateOAuth(provider string) (string, error) {
|
||||
if provider != "google" {
|
||||
return "", fmt.Errorf("unsupported OAuth provider: %s", provider)
|
||||
}
|
||||
|
||||
// Generate random state token for CSRF protection
|
||||
state, err := generateStateToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate state token: %w", err)
|
||||
}
|
||||
|
||||
// Store state with 10 minute expiry
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
if err := s.stateStore.Set(state, expiry); err != nil {
|
||||
return "", fmt.Errorf("failed to store state token: %w", err)
|
||||
}
|
||||
|
||||
// Generate authorization URL
|
||||
url := s.googleConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// HandleOAuthCallback processes the OAuth callback and exchanges the code for a token
|
||||
func (s *OAuthService) HandleOAuthCallback(ctx context.Context, provider, code, state string) (*oauth2.Token, error) {
|
||||
if provider != "google" {
|
||||
return nil, fmt.Errorf("unsupported OAuth provider: %s", provider)
|
||||
}
|
||||
|
||||
// Validate state token
|
||||
valid, err := s.stateStore.Validate(state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate state token: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("invalid or expired state token")
|
||||
}
|
||||
|
||||
// Delete state token after validation
|
||||
_ = s.stateStore.Delete(state)
|
||||
|
||||
// Exchange authorization code for token
|
||||
token, err := s.googleConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetGoogleConfig returns the Google OAuth config for accessing user info
|
||||
func (s *OAuthService) GetGoogleConfig() *oauth2.Config {
|
||||
return s.googleConfig
|
||||
}
|
||||
|
||||
// generateStateToken generates a cryptographically secure random state token
|
||||
func generateStateToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
136
internal/auth/oauth_test.go
Normal file
136
internal/auth/oauth_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInitiateOAuth(t *testing.T) {
|
||||
stateStore := NewMemoryStateStore()
|
||||
service := NewOAuthService(
|
||||
"test-client-id",
|
||||
"test-client-secret",
|
||||
"http://localhost:8080/auth/callback/google",
|
||||
stateStore,
|
||||
)
|
||||
|
||||
t.Run("successful OAuth initiation", func(t *testing.T) {
|
||||
redirectURL, err := service.InitiateOAuth("google")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if redirectURL == "" {
|
||||
t.Fatal("Expected redirect URL, got empty string")
|
||||
}
|
||||
|
||||
// Check that URL contains expected components
|
||||
if !strings.Contains(redirectURL, "accounts.google.com") {
|
||||
t.Errorf("Expected Google OAuth URL, got %s", redirectURL)
|
||||
}
|
||||
|
||||
if !strings.Contains(redirectURL, "client_id=test-client-id") {
|
||||
t.Errorf("Expected client_id in URL, got %s", redirectURL)
|
||||
}
|
||||
|
||||
if !strings.Contains(redirectURL, "redirect_uri=") {
|
||||
t.Errorf("Expected redirect_uri in URL, got %s", redirectURL)
|
||||
}
|
||||
|
||||
if !strings.Contains(redirectURL, "state=") {
|
||||
t.Errorf("Expected state parameter in URL, got %s", redirectURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported provider", func(t *testing.T) {
|
||||
_, err := service.InitiateOAuth("unsupported")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for unsupported provider, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
|
||||
t.Errorf("Expected unsupported provider error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleOAuthCallback(t *testing.T) {
|
||||
stateStore := NewMemoryStateStore()
|
||||
service := NewOAuthService(
|
||||
"test-client-id",
|
||||
"test-client-secret",
|
||||
"http://localhost:8080/auth/callback/google",
|
||||
stateStore,
|
||||
)
|
||||
|
||||
t.Run("invalid state token", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := service.HandleOAuthCallback(ctx, "google", "test-code", "invalid-state")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid state, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid or expired state token") {
|
||||
t.Errorf("Expected invalid state error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid state token but invalid code", func(t *testing.T) {
|
||||
// Store a valid state token
|
||||
state := "valid-state-token"
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
if err := stateStore.Set(state, expiry); err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := service.HandleOAuthCallback(ctx, "google", "invalid-code", state)
|
||||
|
||||
// This should fail because the code is invalid (can't exchange with Google)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid code, got nil")
|
||||
}
|
||||
|
||||
// The state should be deleted even on error
|
||||
valid, _ := stateStore.Validate(state)
|
||||
if valid {
|
||||
t.Error("Expected state to be deleted after callback")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported provider", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := service.HandleOAuthCallback(ctx, "unsupported", "test-code", "test-state")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for unsupported provider, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
|
||||
t.Errorf("Expected unsupported provider error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateStateToken(t *testing.T) {
|
||||
t.Run("generates unique tokens", func(t *testing.T) {
|
||||
token1, err := generateStateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
token2, err := generateStateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
if token1 == token2 {
|
||||
t.Error("Expected unique tokens, got duplicates")
|
||||
}
|
||||
|
||||
if len(token1) == 0 {
|
||||
t.Error("Expected non-empty token")
|
||||
}
|
||||
})
|
||||
}
|
||||
88
internal/auth/session_store.go
Normal file
88
internal/auth/session_store.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const sessionName = "startpage_session"
|
||||
const userIDKey = "user_id"
|
||||
|
||||
// CookieSessionStore implements SessionStore using gorilla/sessions
|
||||
type CookieSessionStore struct {
|
||||
store *sessions.CookieStore
|
||||
}
|
||||
|
||||
// NewCookieSessionStore creates a new cookie-based session store
|
||||
func NewCookieSessionStore(secretKey string, maxAge int) *CookieSessionStore {
|
||||
store := sessions.NewCookieStore([]byte(secretKey))
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: maxAge,
|
||||
HttpOnly: true,
|
||||
Secure: false, // Set to true in production with HTTPS
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
return &CookieSessionStore{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession creates a new session for the user
|
||||
func (s *CookieSessionStore) CreateSession(w http.ResponseWriter, r *http.Request, userID string) error {
|
||||
session, err := s.store.Get(r, sessionName)
|
||||
if err != nil {
|
||||
// If there's an error getting the session, create a new one
|
||||
session, _ = s.store.New(r, sessionName)
|
||||
}
|
||||
|
||||
session.Values[userIDKey] = userID
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserID retrieves the user ID from the session
|
||||
func (s *CookieSessionStore) GetUserID(r *http.Request) (string, error) {
|
||||
session, err := s.store.Get(r, sessionName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
userID, ok := session.Values[userIDKey].(string)
|
||||
if !ok || userID == "" {
|
||||
return "", fmt.Errorf("user ID not found in session")
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// ValidateSession checks if a valid session exists for the request
|
||||
func (s *CookieSessionStore) ValidateSession(r *http.Request) bool {
|
||||
_, err := s.GetUserID(r)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// DestroySession destroys the user's session
|
||||
func (s *CookieSessionStore) DestroySession(w http.ResponseWriter, r *http.Request) error {
|
||||
session, err := s.store.Get(r, sessionName)
|
||||
if err != nil {
|
||||
// Session doesn't exist or is invalid, nothing to destroy
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set MaxAge to -1 to delete the cookie
|
||||
session.Options.MaxAge = -1
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to destroy session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
237
internal/auth/session_store_test.go
Normal file
237
internal/auth/session_store_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCookieSessionStore(t *testing.T) {
|
||||
store := NewCookieSessionStore("test-secret-key", 3600)
|
||||
|
||||
t.Run("create and retrieve session", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-123"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Get the cookie from the response
|
||||
cookies := w.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatal("Expected session cookie, got none")
|
||||
}
|
||||
|
||||
// Create a new request with the cookie
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
retrievedUserID, err := store.GetUserID(r2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user ID: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUserID != userID {
|
||||
t.Errorf("Expected user ID %s, got %s", userID, retrievedUserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("get user ID without session", func(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
_, err := store.GetUserID(r)
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting user ID without session")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("destroy session", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-456"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Get the cookie
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Destroy the session
|
||||
w2 := httptest.NewRecorder()
|
||||
err = store.DestroySession(w2, r2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to destroy session: %v", err)
|
||||
}
|
||||
|
||||
// Check that the cookie has MaxAge set to -1 (deletion marker)
|
||||
destroyCookies := w2.Result().Cookies()
|
||||
if len(destroyCookies) == 0 {
|
||||
t.Fatal("Expected cookie with MaxAge=-1 for deletion")
|
||||
}
|
||||
|
||||
foundDeleteCookie := false
|
||||
for _, cookie := range destroyCookies {
|
||||
if cookie.MaxAge == -1 {
|
||||
foundDeleteCookie = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundDeleteCookie {
|
||||
t.Error("Expected cookie with MaxAge=-1 to indicate deletion")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("update existing session", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Create initial session
|
||||
userID1 := "user-1"
|
||||
store.CreateSession(w, r, userID1)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Update session with new user ID
|
||||
w2 := httptest.NewRecorder()
|
||||
userID2 := "user-2"
|
||||
err := store.CreateSession(w2, r2, userID2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update session: %v", err)
|
||||
}
|
||||
|
||||
// Verify new user ID
|
||||
cookies2 := w2.Result().Cookies()
|
||||
r3 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies2 {
|
||||
r3.AddCookie(cookie)
|
||||
}
|
||||
|
||||
retrievedUserID, err := store.GetUserID(r3)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user ID: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUserID != userID2 {
|
||||
t.Errorf("Expected user ID %s, got %s", userID2, retrievedUserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate session with valid session", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-789"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Create request with session cookie
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Validate session
|
||||
if !store.ValidateSession(r2) {
|
||||
t.Error("Expected ValidateSession to return true for valid session")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate session without session", func(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Validate session without any cookies
|
||||
if store.ValidateSession(r) {
|
||||
t.Error("Expected ValidateSession to return false for missing session")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate session after logout", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-logout"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Get the cookie
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Destroy the session
|
||||
w2 := httptest.NewRecorder()
|
||||
err = store.DestroySession(w2, r2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to destroy session: %v", err)
|
||||
}
|
||||
|
||||
// Create a new request without any cookies (simulating browser behavior after logout)
|
||||
r3 := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Validate session should return false
|
||||
if store.ValidateSession(r3) {
|
||||
t.Error("Expected ValidateSession to return false after logout")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session cookie has security settings", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-security"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatal("Expected session cookie, got none")
|
||||
}
|
||||
|
||||
cookie := cookies[0]
|
||||
|
||||
// Verify HttpOnly flag is set
|
||||
if !cookie.HttpOnly {
|
||||
t.Error("Expected HttpOnly flag to be true")
|
||||
}
|
||||
|
||||
// Verify SameSite is set
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite)
|
||||
}
|
||||
|
||||
// Verify Path is set
|
||||
if cookie.Path != "/" {
|
||||
t.Errorf("Expected Path to be /, got %s", cookie.Path)
|
||||
}
|
||||
|
||||
// Verify MaxAge is set
|
||||
if cookie.MaxAge != 3600 {
|
||||
t.Errorf("Expected MaxAge to be 3600, got %d", cookie.MaxAge)
|
||||
}
|
||||
})
|
||||
}
|
||||
78
internal/auth/state_store.go
Normal file
78
internal/auth/state_store.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryStateStore is an in-memory implementation of StateStore
|
||||
// Note: This is suitable for development but should be replaced with
|
||||
// a distributed store (Redis, DynamoDB) for production with multiple servers
|
||||
type MemoryStateStore struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]time.Time
|
||||
}
|
||||
|
||||
// NewMemoryStateStore creates a new in-memory state store
|
||||
func NewMemoryStateStore() *MemoryStateStore {
|
||||
store := &MemoryStateStore{
|
||||
states: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine to remove expired states
|
||||
go store.cleanupExpired()
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// Set stores a state token with an expiry time
|
||||
func (s *MemoryStateStore) Set(state string, expiry time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.states[state] = expiry
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if a state token is valid and not expired
|
||||
func (s *MemoryStateStore) Validate(state string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
expiry, exists := s.states[state]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(expiry) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Delete removes a state token from the store
|
||||
func (s *MemoryStateStore) Delete(state string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.states, state)
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpired periodically removes expired state tokens
|
||||
func (s *MemoryStateStore) cleanupExpired() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
for state, expiry := range s.states {
|
||||
if now.After(expiry) {
|
||||
delete(s.states, state)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
112
internal/auth/state_store_test.go
Normal file
112
internal/auth/state_store_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMemoryStateStore(t *testing.T) {
|
||||
store := NewMemoryStateStore()
|
||||
|
||||
t.Run("set and validate state", func(t *testing.T) {
|
||||
state := "test-state-123"
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
|
||||
err := store.Set(state, expiry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
|
||||
valid, err := store.Validate(state)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to validate state: %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("Expected state to be valid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate non-existent state", func(t *testing.T) {
|
||||
valid, err := store.Validate("non-existent")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("Expected non-existent state to be invalid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate expired state", func(t *testing.T) {
|
||||
state := "expired-state"
|
||||
expiry := time.Now().Add(-1 * time.Minute) // Already expired
|
||||
|
||||
err := store.Set(state, expiry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
|
||||
valid, err := store.Validate(state)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("Expected expired state to be invalid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete state", func(t *testing.T) {
|
||||
state := "delete-test"
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
|
||||
err := store.Set(state, expiry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
|
||||
err = store.Delete(state)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete state: %v", err)
|
||||
}
|
||||
|
||||
valid, err := store.Validate(state)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("Expected deleted state to be invalid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple states", func(t *testing.T) {
|
||||
state1 := "state-1"
|
||||
state2 := "state-2"
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
|
||||
store.Set(state1, expiry)
|
||||
store.Set(state2, expiry)
|
||||
|
||||
valid1, _ := store.Validate(state1)
|
||||
valid2, _ := store.Validate(state2)
|
||||
|
||||
if !valid1 || !valid2 {
|
||||
t.Error("Expected both states to be valid")
|
||||
}
|
||||
|
||||
store.Delete(state1)
|
||||
|
||||
valid1, _ = store.Validate(state1)
|
||||
valid2, _ = store.Validate(state2)
|
||||
|
||||
if valid1 {
|
||||
t.Error("Expected state1 to be invalid after deletion")
|
||||
}
|
||||
|
||||
if !valid2 {
|
||||
t.Error("Expected state2 to still be valid")
|
||||
}
|
||||
})
|
||||
}
|
||||
100
internal/auth/user_service.go
Normal file
100
internal/auth/user_service.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"custom-start-page/internal/models"
|
||||
"custom-start-page/internal/storage"
|
||||
)
|
||||
|
||||
// GoogleUserInfo represents the user info returned by Google
|
||||
type GoogleUserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
VerifiedEmail bool `json:"verified_email"`
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
// UserService handles user creation and retrieval
|
||||
type UserService struct {
|
||||
userRepo *storage.UserRepository
|
||||
}
|
||||
|
||||
// NewUserService creates a new user service
|
||||
func NewUserService(userRepo *storage.UserRepository) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreateUserFromGoogle fetches user info from Google and creates or retrieves the user
|
||||
func (s *UserService) GetOrCreateUserFromGoogle(ctx context.Context, token *oauth2.Token, oauthConfig *oauth2.Config) (*models.User, error) {
|
||||
// Fetch user info from Google
|
||||
userInfo, err := s.fetchGoogleUserInfo(ctx, token, oauthConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Google user info: %w", err)
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
user, err := s.userRepo.GetByOAuthID(ctx, "google", userInfo.ID)
|
||||
if err == nil {
|
||||
// User exists, update last login time
|
||||
user.UpdatedAt = time.Now()
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// User doesn't exist, create new user
|
||||
user = &models.User{
|
||||
ID: uuid.New().String(),
|
||||
Email: userInfo.Email,
|
||||
OAuthProvider: "google",
|
||||
OAuthID: userInfo.ID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// fetchGoogleUserInfo fetches user information from Google's userinfo endpoint
|
||||
func (s *UserService) fetchGoogleUserInfo(ctx context.Context, token *oauth2.Token, oauthConfig *oauth2.Config) (*GoogleUserInfo, error) {
|
||||
client := oauthConfig.Client(ctx, token)
|
||||
|
||||
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to get user info: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var userInfo GoogleUserInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
if !userInfo.VerifiedEmail {
|
||||
return nil, fmt.Errorf("email not verified")
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
170
internal/handlers/auth_handler.go
Normal file
170
internal/handlers/auth_handler.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"custom-start-page/internal/auth"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication-related HTTP requests
|
||||
type AuthHandler struct {
|
||||
oauthService *auth.OAuthService
|
||||
userService *auth.UserService
|
||||
sessionStore SessionStore
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
// SessionStore manages user sessions
|
||||
type SessionStore interface {
|
||||
CreateSession(w http.ResponseWriter, r *http.Request, userID string) error
|
||||
GetUserID(r *http.Request) (string, error)
|
||||
DestroySession(w http.ResponseWriter, r *http.Request) error
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler
|
||||
func NewAuthHandler(oauthService *auth.OAuthService, userService *auth.UserService, sessionStore SessionStore) *AuthHandler {
|
||||
return NewAuthHandlerWithTemplates(oauthService, userService, sessionStore, nil)
|
||||
}
|
||||
|
||||
// NewAuthHandlerWithTemplates creates a new auth handler with custom templates
|
||||
func NewAuthHandlerWithTemplates(oauthService *auth.OAuthService, userService *auth.UserService, sessionStore SessionStore, templates *template.Template) *AuthHandler {
|
||||
if templates == nil {
|
||||
// Parse templates
|
||||
templates = template.Must(template.ParseGlob(filepath.Join("templates", "*.html")))
|
||||
template.Must(templates.ParseGlob(filepath.Join("templates", "layouts", "*.html")))
|
||||
}
|
||||
|
||||
return &AuthHandler{
|
||||
oauthService: oauthService,
|
||||
userService: userService,
|
||||
sessionStore: sessionStore,
|
||||
templates: templates,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleOAuthInitiate initiates the OAuth flow
|
||||
// GET /auth/oauth/:provider
|
||||
func (h *AuthHandler) HandleOAuthInitiate(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract provider from URL path
|
||||
provider := r.PathValue("provider")
|
||||
if provider == "" {
|
||||
http.Error(w, "Provider not specified", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate OAuth redirect URL
|
||||
redirectURL, err := h.oauthService.InitiateOAuth(provider)
|
||||
if err != nil {
|
||||
log.Printf("Failed to initiate OAuth: %v", err)
|
||||
http.Error(w, "Failed to initiate OAuth", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to OAuth provider
|
||||
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// HandleOAuthCallback handles the OAuth callback
|
||||
// GET /auth/callback/:provider
|
||||
func (h *AuthHandler) HandleOAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract provider from URL path
|
||||
provider := r.PathValue("provider")
|
||||
if provider == "" {
|
||||
http.Error(w, "Provider not specified", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Get code and state from query parameters
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
// Check for error from OAuth provider
|
||||
if errMsg := r.URL.Query().Get("error"); errMsg != "" {
|
||||
log.Printf("OAuth error: %s", errMsg)
|
||||
http.Redirect(w, r, "/login?error=oauth_failed", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
http.Error(w, "Authorization code not provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
http.Error(w, "State parameter not provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := h.oauthService.HandleOAuthCallback(r.Context(), provider, code, state)
|
||||
if err != nil {
|
||||
log.Printf("Failed to handle OAuth callback: %v", err)
|
||||
http.Redirect(w, r, "/login?error=oauth_failed", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
// Get or create user from OAuth provider
|
||||
user, err := h.userService.GetOrCreateUserFromGoogle(r.Context(), token, h.oauthService.GetGoogleConfig())
|
||||
if err != nil {
|
||||
log.Printf("Failed to get or create user: %v", err)
|
||||
http.Redirect(w, r, "/login?error=user_creation_failed", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
// Create session
|
||||
if err := h.sessionStore.CreateSession(w, r, user.ID); err != nil {
|
||||
log.Printf("Failed to create session: %v", err)
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to dashboard
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// HandleLogout logs out the user
|
||||
// POST /logout
|
||||
func (h *AuthHandler) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if err := h.sessionStore.DestroySession(w, r); err != nil {
|
||||
log.Printf("Failed to destroy session: %v", err)
|
||||
}
|
||||
|
||||
// Redirect to login page
|
||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// HandleLogin displays the login page
|
||||
// GET /login
|
||||
func (h *AuthHandler) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if user is already logged in
|
||||
if userID, err := h.sessionStore.GetUserID(r); err == nil && userID != "" {
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
// Get error message if any
|
||||
errorMsg := ""
|
||||
if errParam := r.URL.Query().Get("error"); errParam != "" {
|
||||
switch errParam {
|
||||
case "oauth_failed":
|
||||
errorMsg = "Authentication failed. Please try again."
|
||||
case "user_creation_failed":
|
||||
errorMsg = "Failed to create user account. Please try again."
|
||||
default:
|
||||
errorMsg = "An error occurred. Please try again."
|
||||
}
|
||||
}
|
||||
|
||||
// Render login template
|
||||
data := map[string]interface{}{
|
||||
"Error": errorMsg,
|
||||
"OAuthProviders": []map[string]string{}, // Empty for now, can be extended
|
||||
}
|
||||
|
||||
if err := h.templates.ExecuteTemplate(w, "login.html", data); err != nil {
|
||||
log.Printf("Failed to render login template: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
159
internal/handlers/auth_handler_test.go
Normal file
159
internal/handlers/auth_handler_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"custom-start-page/internal/auth"
|
||||
)
|
||||
|
||||
// MockSessionStore is a mock implementation of SessionStore for testing
|
||||
type MockSessionStore struct {
|
||||
userID string
|
||||
shouldError bool
|
||||
}
|
||||
|
||||
func (m *MockSessionStore) CreateSession(w http.ResponseWriter, r *http.Request, userID string) error {
|
||||
m.userID = userID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSessionStore) GetUserID(r *http.Request) (string, error) {
|
||||
if m.shouldError {
|
||||
return "", http.ErrNoCookie
|
||||
}
|
||||
return m.userID, nil
|
||||
}
|
||||
|
||||
func (m *MockSessionStore) DestroySession(w http.ResponseWriter, r *http.Request) error {
|
||||
m.userID = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSessionStore) ValidateSession(r *http.Request) bool {
|
||||
return m.userID != ""
|
||||
}
|
||||
|
||||
// createMockTemplate creates a simple mock template for testing
|
||||
func createMockTemplate() *template.Template {
|
||||
tmpl := template.New("login.html")
|
||||
template.Must(tmpl.Parse(`<!DOCTYPE html><html><body>{{if .Error}}<div>{{.Error}}</div>{{end}}<a href="/auth/oauth/google">Login</a></body></html>`))
|
||||
return tmpl
|
||||
}
|
||||
|
||||
// TestHandleLogin_UnauthenticatedUser tests that unauthenticated users see the login page
|
||||
func TestHandleLogin_UnauthenticatedUser(t *testing.T) {
|
||||
// Setup
|
||||
mockSessionStore := &MockSessionStore{shouldError: true}
|
||||
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
|
||||
userService := auth.NewUserService(nil) // nil repo for this test
|
||||
mockTemplate := createMockTemplate()
|
||||
handler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleLogin(w, req)
|
||||
|
||||
// Assert
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Check that response contains login page elements
|
||||
body := w.Body.String()
|
||||
if body == "" {
|
||||
t.Error("Expected non-empty response body")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleLogin_AuthenticatedUser tests that authenticated users are redirected to dashboard
|
||||
func TestHandleLogin_AuthenticatedUser(t *testing.T) {
|
||||
// Setup
|
||||
mockSessionStore := &MockSessionStore{userID: "test-user-123"}
|
||||
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
|
||||
userService := auth.NewUserService(nil)
|
||||
mockTemplate := createMockTemplate()
|
||||
handler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleLogin(w, req)
|
||||
|
||||
// Assert
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Errorf("Expected status 307, got %d", w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != "/dashboard" {
|
||||
t.Errorf("Expected redirect to /dashboard, got %s", location)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleLogin_WithError tests that error messages are displayed
|
||||
func TestHandleLogin_WithError(t *testing.T) {
|
||||
// Setup
|
||||
mockSessionStore := &MockSessionStore{shouldError: true}
|
||||
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
|
||||
userService := auth.NewUserService(nil)
|
||||
mockTemplate := createMockTemplate()
|
||||
handler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
|
||||
|
||||
// Create request with error parameter
|
||||
req := httptest.NewRequest(http.MethodGet, "/login?error=oauth_failed", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleLogin(w, req)
|
||||
|
||||
// Assert
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Check that response contains error message
|
||||
body := w.Body.String()
|
||||
if body == "" {
|
||||
t.Error("Expected non-empty response body")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleLogout tests that logout destroys session and redirects to login
|
||||
func TestHandleLogout(t *testing.T) {
|
||||
// Setup
|
||||
mockSessionStore := &MockSessionStore{userID: "test-user-123"}
|
||||
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
|
||||
userService := auth.NewUserService(nil)
|
||||
mockTemplate := createMockTemplate()
|
||||
handler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleLogout(w, req)
|
||||
|
||||
// Assert
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Errorf("Expected status 307, got %d", w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != "/login" {
|
||||
t.Errorf("Expected redirect to /login, got %s", location)
|
||||
}
|
||||
|
||||
// Verify session was destroyed
|
||||
if mockSessionStore.userID != "" {
|
||||
t.Error("Expected session to be destroyed")
|
||||
}
|
||||
}
|
||||
58
internal/handlers/dashboard_handler.go
Normal file
58
internal/handlers/dashboard_handler.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"custom-start-page/internal/middleware"
|
||||
)
|
||||
|
||||
// DashboardHandler handles dashboard-related HTTP requests
|
||||
type DashboardHandler struct {
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
// NewDashboardHandler creates a new dashboard handler
|
||||
func NewDashboardHandler() *DashboardHandler {
|
||||
// Parse templates
|
||||
templates := template.Must(template.ParseGlob(filepath.Join("templates", "*.html")))
|
||||
template.Must(templates.ParseGlob(filepath.Join("templates", "layouts", "*.html")))
|
||||
|
||||
return &DashboardHandler{
|
||||
templates: templates,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDashboard displays the dashboard page
|
||||
// GET /dashboard
|
||||
func (h *DashboardHandler) HandleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user ID from context (set by auth middleware)
|
||||
userID, ok := middleware.GetUserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
http.Error(w, "User ID not found in context", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Fetch user's pages from database
|
||||
// For now, we'll use mock data
|
||||
pages := []map[string]interface{}{
|
||||
{
|
||||
"ID": "default-page",
|
||||
"Name": "Home",
|
||||
"Active": true,
|
||||
},
|
||||
}
|
||||
|
||||
// Render dashboard template
|
||||
data := map[string]interface{}{
|
||||
"UserID": userID,
|
||||
"Pages": pages,
|
||||
}
|
||||
|
||||
if err := h.templates.ExecuteTemplate(w, "dashboard.html", data); err != nil {
|
||||
log.Printf("Failed to render dashboard template: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
68
internal/handlers/dashboard_handler_test.go
Normal file
68
internal/handlers/dashboard_handler_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"custom-start-page/internal/middleware"
|
||||
)
|
||||
|
||||
// createMockDashboardTemplate creates a simple mock template for testing
|
||||
func createMockDashboardTemplate() *template.Template {
|
||||
tmpl := template.New("dashboard.html")
|
||||
template.Must(tmpl.Parse(`<!DOCTYPE html><html><body><h1>Dashboard</h1><div>User: {{.UserID}}</div></body></html>`))
|
||||
return tmpl
|
||||
}
|
||||
|
||||
// TestHandleDashboard_WithAuthenticatedUser tests that authenticated users see the dashboard
|
||||
func TestHandleDashboard_WithAuthenticatedUser(t *testing.T) {
|
||||
// Setup
|
||||
mockTemplate := createMockDashboardTemplate()
|
||||
handler := &DashboardHandler{
|
||||
templates: mockTemplate,
|
||||
}
|
||||
|
||||
// Create request with user ID in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||
ctx := context.WithValue(req.Context(), middleware.GetUserIDContextKey(), "test-user-123")
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleDashboard(w, req)
|
||||
|
||||
// Assert
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Check that response contains dashboard content
|
||||
body := w.Body.String()
|
||||
if body == "" {
|
||||
t.Error("Expected non-empty response body")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleDashboard_WithoutUserID tests that requests without user ID fail
|
||||
func TestHandleDashboard_WithoutUserID(t *testing.T) {
|
||||
// Setup
|
||||
mockTemplate := createMockDashboardTemplate()
|
||||
handler := &DashboardHandler{
|
||||
templates: mockTemplate,
|
||||
}
|
||||
|
||||
// Create request without user ID in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
handler.HandleDashboard(w, req)
|
||||
|
||||
// Assert
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status 500, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
102
internal/handlers/integration_test.go
Normal file
102
internal/handlers/integration_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"custom-start-page/internal/auth"
|
||||
"custom-start-page/internal/middleware"
|
||||
)
|
||||
|
||||
// TestRedirectFlow_UnauthenticatedToLogin tests that unauthenticated users are redirected to login
|
||||
func TestRedirectFlow_UnauthenticatedToLogin(t *testing.T) {
|
||||
// Setup
|
||||
mockSessionStore := &MockSessionStore{shouldError: true}
|
||||
|
||||
// Create middleware
|
||||
requireAuth := middleware.RequireAuth(mockSessionStore)
|
||||
|
||||
// Create dashboard handler
|
||||
mockDashboardTemplate := createMockDashboardTemplate()
|
||||
dashboardHandler := &DashboardHandler{templates: mockDashboardTemplate}
|
||||
|
||||
// Wrap dashboard handler with auth middleware
|
||||
protectedHandler := requireAuth(http.HandlerFunc(dashboardHandler.HandleDashboard))
|
||||
|
||||
// Create request to dashboard
|
||||
req := httptest.NewRequest(http.MethodGet, "/dashboard", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
protectedHandler.ServeHTTP(w, req)
|
||||
|
||||
// Assert - should redirect to login
|
||||
if w.Code != http.StatusSeeOther {
|
||||
t.Errorf("Expected status 303, got %d", w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != "/login" {
|
||||
t.Errorf("Expected redirect to /login, got %s", location)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectFlow_AuthenticatedToDashboard tests that authenticated users accessing login are redirected to dashboard
|
||||
func TestRedirectFlow_AuthenticatedToDashboard(t *testing.T) {
|
||||
// Setup
|
||||
mockSessionStore := &MockSessionStore{userID: "test-user-123"}
|
||||
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
|
||||
userService := auth.NewUserService(nil)
|
||||
mockTemplate := createMockTemplate()
|
||||
authHandler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
|
||||
|
||||
// Create request to login page
|
||||
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authHandler.HandleLogin(w, req)
|
||||
|
||||
// Assert - should redirect to dashboard
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Errorf("Expected status 307, got %d", w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != "/dashboard" {
|
||||
t.Errorf("Expected redirect to /dashboard, got %s", location)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedirectFlow_LogoutToLogin tests that logout redirects to login
|
||||
func TestRedirectFlow_LogoutToLogin(t *testing.T) {
|
||||
// Setup
|
||||
mockSessionStore := &MockSessionStore{userID: "test-user-123"}
|
||||
oauthService := auth.NewOAuthService("test-client-id", "test-secret", "http://localhost/callback", auth.NewMemoryStateStore())
|
||||
userService := auth.NewUserService(nil)
|
||||
mockTemplate := createMockTemplate()
|
||||
authHandler := NewAuthHandlerWithTemplates(oauthService, userService, mockSessionStore, mockTemplate)
|
||||
|
||||
// Create logout request
|
||||
req := httptest.NewRequest(http.MethodPost, "/logout", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authHandler.HandleLogout(w, req)
|
||||
|
||||
// Assert - should redirect to login
|
||||
if w.Code != http.StatusTemporaryRedirect {
|
||||
t.Errorf("Expected status 307, got %d", w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != "/login" {
|
||||
t.Errorf("Expected redirect to /login, got %s", location)
|
||||
}
|
||||
|
||||
// Verify session was destroyed
|
||||
if mockSessionStore.userID != "" {
|
||||
t.Error("Expected session to be destroyed after logout")
|
||||
}
|
||||
}
|
||||
52
internal/middleware/auth.go
Normal file
52
internal/middleware/auth.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SessionStore defines the interface for session validation
|
||||
type SessionStore interface {
|
||||
ValidateSession(r *http.Request) bool
|
||||
GetUserID(r *http.Request) (string, error)
|
||||
}
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const userIDContextKey contextKey = "user_id"
|
||||
|
||||
// GetUserIDContextKey returns the context key for user ID (for testing)
|
||||
func GetUserIDContextKey() contextKey {
|
||||
return userIDContextKey
|
||||
}
|
||||
|
||||
// RequireAuth is a middleware that ensures the user is authenticated
|
||||
func RequireAuth(sessionStore SessionStore) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Validate session
|
||||
if !sessionStore.ValidateSession(r) {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user ID and add to context
|
||||
userID, err := sessionStore.GetUserID(r)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// Add user ID to request context
|
||||
ctx := context.WithValue(r.Context(), userIDContextKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserIDFromContext retrieves the user ID from the request context
|
||||
func GetUserIDFromContext(ctx context.Context) (string, bool) {
|
||||
userID, ok := ctx.Value(userIDContextKey).(string)
|
||||
return userID, ok
|
||||
}
|
||||
145
internal/middleware/auth_test.go
Normal file
145
internal/middleware/auth_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockSessionStore is a mock implementation of SessionStore for testing
|
||||
type mockSessionStore struct {
|
||||
valid bool
|
||||
userID string
|
||||
}
|
||||
|
||||
func (m *mockSessionStore) ValidateSession(r *http.Request) bool {
|
||||
return m.valid
|
||||
}
|
||||
|
||||
func (m *mockSessionStore) GetUserID(r *http.Request) (string, error) {
|
||||
if !m.valid {
|
||||
return "", http.ErrNoCookie
|
||||
}
|
||||
return m.userID, nil
|
||||
}
|
||||
|
||||
func TestRequireAuth(t *testing.T) {
|
||||
t.Run("allows authenticated requests", func(t *testing.T) {
|
||||
mockStore := &mockSessionStore{
|
||||
valid: true,
|
||||
userID: "test-user-123",
|
||||
}
|
||||
|
||||
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := GetUserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
t.Error("Expected user ID in context")
|
||||
return
|
||||
}
|
||||
if userID != "test-user-123" {
|
||||
t.Errorf("Expected user ID test-user-123, got %s", userID)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/dashboard", nil)
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if w.Body.String() != "success" {
|
||||
t.Errorf("Expected body 'success', got %s", w.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("redirects unauthenticated requests", func(t *testing.T) {
|
||||
mockStore := &mockSessionStore{
|
||||
valid: false,
|
||||
}
|
||||
|
||||
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("Handler should not be called for unauthenticated request")
|
||||
}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/dashboard", nil)
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusSeeOther {
|
||||
t.Errorf("Expected status 303, got %d", w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != "/login" {
|
||||
t.Errorf("Expected redirect to /login, got %s", location)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("redirects when GetUserID fails", func(t *testing.T) {
|
||||
mockStore := &mockSessionStore{
|
||||
valid: true, // ValidateSession returns true
|
||||
userID: "", // But GetUserID will fail
|
||||
}
|
||||
|
||||
// Override GetUserID to return error
|
||||
mockStore.valid = false
|
||||
|
||||
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("Handler should not be called when GetUserID fails")
|
||||
}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/dashboard", nil)
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusSeeOther {
|
||||
t.Errorf("Expected status 303, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserIDFromContext(t *testing.T) {
|
||||
t.Run("retrieves user ID from context", func(t *testing.T) {
|
||||
mockStore := &mockSessionStore{
|
||||
valid: true,
|
||||
userID: "context-user-456",
|
||||
}
|
||||
|
||||
handler := RequireAuth(mockStore)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := GetUserIDFromContext(r.Context())
|
||||
if !ok {
|
||||
t.Error("Expected user ID in context")
|
||||
return
|
||||
}
|
||||
if userID != "context-user-456" {
|
||||
t.Errorf("Expected user ID context-user-456, got %s", userID)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns false when user ID not in context", func(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
_, ok := GetUserIDFromContext(r.Context())
|
||||
if ok {
|
||||
t.Error("Expected ok to be false when user ID not in context")
|
||||
}
|
||||
})
|
||||
}
|
||||
15
internal/models/user.go
Normal file
15
internal/models/user.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// User represents a user in the system
|
||||
type User struct {
|
||||
ID string `dynamodbav:"user_id" json:"id"`
|
||||
Email string `dynamodbav:"email" json:"email"`
|
||||
OAuthProvider string `dynamodbav:"oauth_provider" json:"oauth_provider"`
|
||||
OAuthID string `dynamodbav:"oauth_id" json:"oauth_id"`
|
||||
CreatedAt time.Time `dynamodbav:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `dynamodbav:"updated_at" json:"updated_at"`
|
||||
}
|
||||
124
internal/storage/README.md
Normal file
124
internal/storage/README.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# DynamoDB Storage Service
|
||||
|
||||
This package provides an enhanced DynamoDB client wrapper with the following features:
|
||||
|
||||
## Features
|
||||
|
||||
### 1. Connection Pooling
|
||||
The client uses the AWS SDK's default HTTP client which includes connection pooling automatically. This ensures efficient reuse of TCP connections to DynamoDB.
|
||||
|
||||
### 2. Retry Logic with Exponential Backoff
|
||||
The client is configured with automatic retry logic:
|
||||
- **Max Attempts**: 5 retries
|
||||
- **Max Backoff**: 20 seconds
|
||||
- **Strategy**: Exponential backoff with jitter to prevent thundering herd
|
||||
|
||||
This handles transient failures gracefully and improves reliability.
|
||||
|
||||
### 3. Transaction Support
|
||||
The `TransactWriteItems` method provides ACID transaction support for multiple write operations:
|
||||
```go
|
||||
err := client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{
|
||||
TransactItems: []types.TransactWriteItem{
|
||||
{
|
||||
Put: &types.Put{
|
||||
TableName: aws.String("MyTable"),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: "item1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
// More items...
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### 4. Batch Operations
|
||||
The client provides batch read and write operations with automatic retry of unprocessed items:
|
||||
|
||||
#### BatchGetItems
|
||||
Retrieves multiple items in a single request:
|
||||
```go
|
||||
output, err := client.BatchGetItems(ctx, &dynamodb.BatchGetItemInput{
|
||||
RequestItems: map[string]types.KeysAndAttributes{
|
||||
"MyTable": {
|
||||
Keys: []map[string]types.AttributeValue{
|
||||
{"id": &types.AttributeValueMemberS{Value: "item1"}},
|
||||
{"id": &types.AttributeValueMemberS{Value: "item2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### BatchWriteItems
|
||||
Writes multiple items in a single request:
|
||||
```go
|
||||
err := client.BatchWriteItems(ctx, &dynamodb.BatchWriteItemInput{
|
||||
RequestItems: map[string][]types.WriteRequest{
|
||||
"MyTable": {
|
||||
{
|
||||
PutRequest: &types.PutRequest{
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: "item1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Both batch operations automatically handle unprocessed items with exponential backoff retry logic.
|
||||
|
||||
## Standard Operations
|
||||
|
||||
The client also provides wrapped versions of standard DynamoDB operations with automatic retry:
|
||||
- `PutItem` - Put a single item
|
||||
- `GetItem` - Get a single item
|
||||
- `UpdateItem` - Update a single item
|
||||
- `DeleteItem` - Delete a single item
|
||||
- `Query` - Query items
|
||||
|
||||
## Usage
|
||||
|
||||
### Creating a Client
|
||||
|
||||
```go
|
||||
ctx := context.Background()
|
||||
client, err := storage.NewDynamoDBClient(ctx, "http://localhost:8000")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
For production (AWS DynamoDB), pass an empty string for the endpoint:
|
||||
```go
|
||||
client, err := storage.NewDynamoDBClient(ctx, "")
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
The package includes comprehensive tests that can be run against DynamoDB Local:
|
||||
|
||||
1. Start DynamoDB Local:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
2. Run tests:
|
||||
```bash
|
||||
go test -v ./internal/storage
|
||||
```
|
||||
|
||||
Tests will automatically skip if DynamoDB is not available.
|
||||
|
||||
## Requirements Addressed
|
||||
|
||||
This implementation addresses the following requirements from the spec:
|
||||
- **Requirement 8.1**: Immediate persistence of all changes
|
||||
- **Requirement 8.8**: Efficient scaling for 10,000+ items per user
|
||||
- **Design requirement**: Retry logic with exponential backoff
|
||||
- **Design requirement**: Transaction support for atomic operations
|
||||
- **Design requirement**: Batch operations for efficient bulk reads/writes
|
||||
- **Design requirement**: Connection pooling for performance
|
||||
248
internal/storage/dynamodb.go
Normal file
248
internal/storage/dynamodb.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/aws/retry"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
||||
)
|
||||
|
||||
// DynamoDBClient wraps the AWS DynamoDB client with enhanced features
|
||||
type DynamoDBClient struct {
|
||||
client *dynamodb.Client
|
||||
}
|
||||
|
||||
// NewDynamoDBClient creates a new DynamoDB client with connection pooling and retry logic
|
||||
func NewDynamoDBClient(ctx context.Context, endpoint string) (*DynamoDBClient, error) {
|
||||
// Configure retry strategy with exponential backoff
|
||||
retryer := retry.NewStandard(func(o *retry.StandardOptions) {
|
||||
o.MaxAttempts = 5
|
||||
o.MaxBackoff = 20 * time.Second
|
||||
// Use exponential backoff with jitter
|
||||
o.Backoff = retry.NewExponentialJitterBackoff(20 * time.Second)
|
||||
})
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(ctx,
|
||||
config.WithRetryer(func() aws.Retryer {
|
||||
return retryer
|
||||
}),
|
||||
// Connection pooling is handled by the HTTP client
|
||||
config.WithHTTPClient(nil), // Uses default HTTP client with connection pooling
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
||||
}
|
||||
|
||||
// Override endpoint if provided (for local DynamoDB)
|
||||
if endpoint != "" {
|
||||
cfg.BaseEndpoint = aws.String(endpoint)
|
||||
}
|
||||
|
||||
client := dynamodb.NewFromConfig(cfg)
|
||||
|
||||
return &DynamoDBClient{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateUsersTable creates the Users table in DynamoDB
|
||||
func (db *DynamoDBClient) CreateUsersTable(ctx context.Context) error {
|
||||
tableName := "Users"
|
||||
|
||||
// Check if table already exists
|
||||
_, err := db.client.DescribeTable(ctx, &dynamodb.DescribeTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
})
|
||||
if err == nil {
|
||||
// Table already exists
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create table
|
||||
_, err = db.client.CreateTable(ctx, &dynamodb.CreateTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
AttributeDefinitions: []types.AttributeDefinition{
|
||||
{
|
||||
AttributeName: aws.String("user_id"),
|
||||
AttributeType: types.ScalarAttributeTypeS,
|
||||
},
|
||||
},
|
||||
KeySchema: []types.KeySchemaElement{
|
||||
{
|
||||
AttributeName: aws.String("user_id"),
|
||||
KeyType: types.KeyTypeHash,
|
||||
},
|
||||
},
|
||||
BillingMode: types.BillingModePayPerRequest,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Users table: %w", err)
|
||||
}
|
||||
|
||||
// Wait for table to be active
|
||||
waiter := dynamodb.NewTableExistsWaiter(db.client)
|
||||
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
}, 5*60) // 5 minutes timeout
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed waiting for Users table to be active: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClient returns the underlying DynamoDB client
|
||||
func (db *DynamoDBClient) GetClient() *dynamodb.Client {
|
||||
return db.client
|
||||
}
|
||||
|
||||
// TransactWriteItems executes a transactional write operation with automatic retry
|
||||
func (db *DynamoDBClient) TransactWriteItems(ctx context.Context, input *dynamodb.TransactWriteItemsInput) error {
|
||||
_, err := db.client.TransactWriteItems(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("transaction write failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchGetItems retrieves multiple items in a single batch operation
|
||||
func (db *DynamoDBClient) BatchGetItems(ctx context.Context, input *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error) {
|
||||
output, err := db.client.BatchGetItem(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch get failed: %w", err)
|
||||
}
|
||||
|
||||
// Handle unprocessed keys with exponential backoff
|
||||
if len(output.UnprocessedKeys) > 0 {
|
||||
return db.retryUnprocessedKeys(ctx, output)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// BatchWriteItems writes multiple items in a single batch operation
|
||||
func (db *DynamoDBClient) BatchWriteItems(ctx context.Context, input *dynamodb.BatchWriteItemInput) error {
|
||||
output, err := db.client.BatchWriteItem(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("batch write failed: %w", err)
|
||||
}
|
||||
|
||||
// Handle unprocessed items with exponential backoff
|
||||
if len(output.UnprocessedItems) > 0 {
|
||||
return db.retryUnprocessedWrites(ctx, output.UnprocessedItems)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// retryUnprocessedKeys retries unprocessed keys from BatchGetItem with exponential backoff
|
||||
func (db *DynamoDBClient) retryUnprocessedKeys(ctx context.Context, output *dynamodb.BatchGetItemOutput) (*dynamodb.BatchGetItemOutput, error) {
|
||||
maxRetries := 5
|
||||
backoff := 100 * time.Millisecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries && len(output.UnprocessedKeys) > 0; attempt++ {
|
||||
// Wait with exponential backoff
|
||||
time.Sleep(backoff)
|
||||
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
|
||||
|
||||
// Retry unprocessed keys
|
||||
retryOutput, err := db.client.BatchGetItem(ctx, &dynamodb.BatchGetItemInput{
|
||||
RequestItems: output.UnprocessedKeys,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retry batch get failed: %w", err)
|
||||
}
|
||||
|
||||
// Merge responses
|
||||
for table, items := range retryOutput.Responses {
|
||||
output.Responses[table] = append(output.Responses[table], items...)
|
||||
}
|
||||
|
||||
output.UnprocessedKeys = retryOutput.UnprocessedKeys
|
||||
}
|
||||
|
||||
if len(output.UnprocessedKeys) > 0 {
|
||||
return output, fmt.Errorf("failed to process all keys after %d retries", maxRetries)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// retryUnprocessedWrites retries unprocessed items from BatchWriteItem with exponential backoff
|
||||
func (db *DynamoDBClient) retryUnprocessedWrites(ctx context.Context, unprocessedItems map[string][]types.WriteRequest) error {
|
||||
maxRetries := 5
|
||||
backoff := 100 * time.Millisecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries && len(unprocessedItems) > 0; attempt++ {
|
||||
// Wait with exponential backoff
|
||||
time.Sleep(backoff)
|
||||
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
|
||||
|
||||
// Retry unprocessed items
|
||||
output, err := db.client.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
|
||||
RequestItems: unprocessedItems,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("retry batch write failed: %w", err)
|
||||
}
|
||||
|
||||
unprocessedItems = output.UnprocessedItems
|
||||
}
|
||||
|
||||
if len(unprocessedItems) > 0 {
|
||||
return fmt.Errorf("failed to process all items after %d retries", maxRetries)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PutItem puts a single item with automatic retry
|
||||
func (db *DynamoDBClient) PutItem(ctx context.Context, input *dynamodb.PutItemInput) error {
|
||||
_, err := db.client.PutItem(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("put item failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetItem retrieves a single item with automatic retry
|
||||
func (db *DynamoDBClient) GetItem(ctx context.Context, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
|
||||
output, err := db.client.GetItem(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get item failed: %w", err)
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// Query executes a query operation with automatic retry
|
||||
func (db *DynamoDBClient) Query(ctx context.Context, input *dynamodb.QueryInput) (*dynamodb.QueryOutput, error) {
|
||||
output, err := db.client.Query(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// UpdateItem updates a single item with automatic retry
|
||||
func (db *DynamoDBClient) UpdateItem(ctx context.Context, input *dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) {
|
||||
output, err := db.client.UpdateItem(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update item failed: %w", err)
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// DeleteItem deletes a single item with automatic retry
|
||||
func (db *DynamoDBClient) DeleteItem(ctx context.Context, input *dynamodb.DeleteItemInput) error {
|
||||
_, err := db.client.DeleteItem(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete item failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
452
internal/storage/dynamodb_test.go
Normal file
452
internal/storage/dynamodb_test.go
Normal file
@@ -0,0 +1,452 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
||||
)
|
||||
|
||||
func setupTestClient(t *testing.T) (*DynamoDBClient, context.Context) {
|
||||
t.Helper()
|
||||
|
||||
// Set dummy AWS credentials for local DynamoDB
|
||||
os.Setenv("AWS_ACCESS_KEY_ID", "dummy")
|
||||
os.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
|
||||
os.Setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
endpoint := os.Getenv("DYNAMODB_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "http://localhost:8000"
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
client, err := NewDynamoDBClient(ctx, endpoint)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test: DynamoDB not available: %v", err)
|
||||
}
|
||||
|
||||
// Test connection by listing tables
|
||||
_, err = client.client.ListTables(ctx, &dynamodb.ListTablesInput{})
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test: Cannot connect to DynamoDB: %v", err)
|
||||
}
|
||||
|
||||
return client, ctx
|
||||
}
|
||||
|
||||
func TestNewDynamoDBClient(t *testing.T) {
|
||||
client, _ := setupTestClient(t)
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil client")
|
||||
}
|
||||
|
||||
if client.client == nil {
|
||||
t.Fatal("Expected non-nil underlying client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactWriteItems(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestTransactions"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Test transaction with two puts
|
||||
testID1 := "test-txn-1"
|
||||
testID2 := "test-txn-2"
|
||||
|
||||
input := &dynamodb.TransactWriteItemsInput{
|
||||
TransactItems: []types.TransactWriteItem{
|
||||
{
|
||||
Put: &types.Put{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID1},
|
||||
"value": &types.AttributeValueMemberS{Value: "value1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Put: &types.Put{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID2},
|
||||
"value": &types.AttributeValueMemberS{Value: "value2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := client.TransactWriteItems(ctx, input)
|
||||
if err != nil {
|
||||
t.Fatalf("TransactWriteItems failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify both items were written
|
||||
output1, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID1},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
if len(output1.Item) == 0 {
|
||||
t.Fatal("Expected item to exist after transaction")
|
||||
}
|
||||
|
||||
output2, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID2},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
if len(output2.Item) == 0 {
|
||||
t.Fatal("Expected item to exist after transaction")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchGetItems(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestBatchGet"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put test items
|
||||
testIDs := []string{"batch-1", "batch-2", "batch-3"}
|
||||
for _, id := range testIDs {
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
"value": &types.AttributeValueMemberS{Value: "test-value"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch get items
|
||||
keys := make([]map[string]types.AttributeValue, len(testIDs))
|
||||
for i, id := range testIDs {
|
||||
keys[i] = map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
}
|
||||
}
|
||||
|
||||
output, err := client.BatchGetItems(ctx, &dynamodb.BatchGetItemInput{
|
||||
RequestItems: map[string]types.KeysAndAttributes{
|
||||
tableName: {
|
||||
Keys: keys,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BatchGetItems failed: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Responses[tableName]) != len(testIDs) {
|
||||
t.Fatalf("Expected %d items, got %d", len(testIDs), len(output.Responses[tableName]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchWriteItems(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestBatchWrite"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Batch write items
|
||||
testIDs := []string{"write-1", "write-2", "write-3"}
|
||||
writeRequests := make([]types.WriteRequest, len(testIDs))
|
||||
for i, id := range testIDs {
|
||||
writeRequests[i] = types.WriteRequest{
|
||||
PutRequest: &types.PutRequest{
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
"value": &types.AttributeValueMemberS{Value: "batch-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
err := client.BatchWriteItems(ctx, &dynamodb.BatchWriteItemInput{
|
||||
RequestItems: map[string][]types.WriteRequest{
|
||||
tableName: writeRequests,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BatchWriteItems failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify items were written
|
||||
for _, id := range testIDs {
|
||||
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
if len(output.Item) == 0 {
|
||||
t.Fatalf("Expected item %s to exist after batch write", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutAndGetItem(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestPutGet"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put item
|
||||
testID := "put-get-test"
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
"value": &types.AttributeValueMemberS{Value: "test-value"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Get item
|
||||
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Item) == 0 {
|
||||
t.Fatal("Expected item to exist")
|
||||
}
|
||||
|
||||
valueAttr, ok := output.Item["value"]
|
||||
if !ok {
|
||||
t.Fatal("Expected 'value' attribute")
|
||||
}
|
||||
|
||||
value := valueAttr.(*types.AttributeValueMemberS).Value
|
||||
if value != "test-value" {
|
||||
t.Fatalf("Expected value 'test-value', got '%s'", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateItem(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestUpdate"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put initial item
|
||||
testID := "update-test"
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
"value": &types.AttributeValueMemberS{Value: "initial"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Update item
|
||||
_, err = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
UpdateExpression: aws.String("SET #v = :val"),
|
||||
ExpressionAttributeNames: map[string]string{
|
||||
"#v": "value",
|
||||
},
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":val": &types.AttributeValueMemberS{Value: "updated"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify update
|
||||
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
|
||||
value := output.Item["value"].(*types.AttributeValueMemberS).Value
|
||||
if value != "updated" {
|
||||
t.Fatalf("Expected value 'updated', got '%s'", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteItem(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestDelete"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put item
|
||||
testID := "delete-test"
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
"value": &types.AttributeValueMemberS{Value: "test"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Delete item
|
||||
err = client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Item) != 0 {
|
||||
t.Fatal("Expected item to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestQuery"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put test items
|
||||
testIDs := []string{"query-1", "query-2", "query-3"}
|
||||
for _, id := range testIDs {
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
"value": &types.AttributeValueMemberS{Value: "test"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Query for specific item
|
||||
output, err := client.Query(ctx, &dynamodb.QueryInput{
|
||||
TableName: aws.String(tableName),
|
||||
KeyConditionExpression: aws.String("id = :id"),
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":id": &types.AttributeValueMemberS{Value: "query-1"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Query failed: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Items) != 1 {
|
||||
t.Fatalf("Expected 1 item, got %d", len(output.Items))
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestTable(t *testing.T, ctx context.Context, client *DynamoDBClient, tableName string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := client.client.CreateTable(ctx, &dynamodb.CreateTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
AttributeDefinitions: []types.AttributeDefinition{
|
||||
{
|
||||
AttributeName: aws.String("id"),
|
||||
AttributeType: types.ScalarAttributeTypeS,
|
||||
},
|
||||
},
|
||||
KeySchema: []types.KeySchemaElement{
|
||||
{
|
||||
AttributeName: aws.String("id"),
|
||||
KeyType: types.KeyTypeHash,
|
||||
},
|
||||
},
|
||||
BillingMode: types.BillingModePayPerRequest,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
// Wait for table to be active
|
||||
waiter := dynamodb.NewTableExistsWaiter(client.client)
|
||||
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
}, 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed waiting for table to be active: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteTestTable(t *testing.T, ctx context.Context, client *DynamoDBClient, tableName string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := client.client.DeleteTable(ctx, &dynamodb.DeleteTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete test table: %v", err)
|
||||
}
|
||||
}
|
||||
115
internal/storage/user_repository.go
Normal file
115
internal/storage/user_repository.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
||||
|
||||
"custom-start-page/internal/models"
|
||||
)
|
||||
|
||||
// UserRepository handles user data operations
|
||||
type UserRepository struct {
|
||||
client *DynamoDBClient
|
||||
tableName string
|
||||
}
|
||||
|
||||
// NewUserRepository creates a new user repository
|
||||
func NewUserRepository(client *DynamoDBClient, tableName string) *UserRepository {
|
||||
return &UserRepository{
|
||||
client: client,
|
||||
tableName: tableName,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates a new user in the database
|
||||
func (r *UserRepository) Create(ctx context.Context, user *models.User) error {
|
||||
item, err := attributevalue.MarshalMap(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(r.tableName),
|
||||
Item: item,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a user by their ID
|
||||
func (r *UserRepository) GetByID(ctx context.Context, userID string) (*models.User, error) {
|
||||
result, err := r.client.GetClient().GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(r.tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"user_id": &types.AttributeValueMemberS{Value: userID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if result.Item == nil {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
var user models.User
|
||||
if err := attributevalue.UnmarshalMap(result.Item, &user); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByOAuthID retrieves a user by their OAuth provider and OAuth ID
|
||||
func (r *UserRepository) GetByOAuthID(ctx context.Context, provider, oauthID string) (*models.User, error) {
|
||||
// Use a scan with filter for now
|
||||
// In production, consider adding a GSI on oauth_provider + oauth_id
|
||||
result, err := r.client.GetClient().Scan(ctx, &dynamodb.ScanInput{
|
||||
TableName: aws.String(r.tableName),
|
||||
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :oauth_id"),
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":provider": &types.AttributeValueMemberS{Value: provider},
|
||||
":oauth_id": &types.AttributeValueMemberS{Value: oauthID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan users: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Items) == 0 {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
var user models.User
|
||||
if err := attributevalue.UnmarshalMap(result.Items[0], &user); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// Update updates an existing user
|
||||
func (r *UserRepository) Update(ctx context.Context, user *models.User) error {
|
||||
item, err := attributevalue.MarshalMap(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(r.tableName),
|
||||
Item: item,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
143
internal/storage/user_storage.go
Normal file
143
internal/storage/user_storage.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"custom-start-page/internal/models"
|
||||
)
|
||||
|
||||
const usersTableName = "Users"
|
||||
|
||||
// UserStorage handles user data operations
|
||||
type UserStorage struct {
|
||||
db *DynamoDBClient
|
||||
}
|
||||
|
||||
// NewUserStorage creates a new UserStorage instance
|
||||
func NewUserStorage(db *DynamoDBClient) *UserStorage {
|
||||
return &UserStorage{db: db}
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in DynamoDB
|
||||
func (s *UserStorage) CreateUser(ctx context.Context, email, oauthProvider, oauthID string) (*models.User, error) {
|
||||
user := &models.User{
|
||||
ID: uuid.New().String(),
|
||||
Email: email,
|
||||
OAuthProvider: oauthProvider,
|
||||
OAuthID: oauthID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
item, err := attributevalue.MarshalMap(user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
Item: item,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by their ID
|
||||
func (s *UserStorage) GetUserByID(ctx context.Context, userID string) (*models.User, error) {
|
||||
result, err := s.db.client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"user_id": &types.AttributeValueMemberS{Value: userID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if result.Item == nil {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
var user models.User
|
||||
err = attributevalue.UnmarshalMap(result.Item, &user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByOAuth retrieves a user by their OAuth provider and ID
|
||||
func (s *UserStorage) GetUserByOAuth(ctx context.Context, oauthProvider, oauthID string) (*models.User, error) {
|
||||
// Since we don't have a GSI for oauth_provider + oauth_id, we need to scan
|
||||
// In production, you might want to add a GSI for this access pattern
|
||||
result, err := s.db.client.Scan(ctx, &dynamodb.ScanInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :id"),
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":provider": &types.AttributeValueMemberS{Value: oauthProvider},
|
||||
":id": &types.AttributeValueMemberS{Value: oauthID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan for user: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Items) == 0 {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
var user models.User
|
||||
err = attributevalue.UnmarshalMap(result.Items[0], &user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates an existing user
|
||||
func (s *UserStorage) UpdateUser(ctx context.Context, user *models.User) error {
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
item, err := attributevalue.MarshalMap(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
Item: item,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user by their ID
|
||||
func (s *UserStorage) DeleteUser(ctx context.Context, userID string) error {
|
||||
_, err := s.db.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"user_id": &types.AttributeValueMemberS{Value: userID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
228
internal/storage/user_storage_test.go
Normal file
228
internal/storage/user_storage_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *DynamoDBClient {
|
||||
ctx := context.Background()
|
||||
endpoint := os.Getenv("DYNAMODB_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "http://localhost:8000"
|
||||
}
|
||||
|
||||
db, err := NewDynamoDBClient(ctx, endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create DynamoDB client: %v", err)
|
||||
}
|
||||
|
||||
// Create Users table
|
||||
if err := db.CreateUsersTable(ctx); err != nil {
|
||||
t.Fatalf("Failed to create Users table: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
if user.ID == "" {
|
||||
t.Error("User ID should not be empty")
|
||||
}
|
||||
if user.Email != "test@example.com" {
|
||||
t.Errorf("Expected email 'test@example.com', got '%s'", user.Email)
|
||||
}
|
||||
if user.OAuthProvider != "google" {
|
||||
t.Errorf("Expected oauth_provider 'google', got '%s'", user.OAuthProvider)
|
||||
}
|
||||
if user.OAuthID != "google123" {
|
||||
t.Errorf("Expected oauth_id 'google123', got '%s'", user.OAuthID)
|
||||
}
|
||||
if user.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should not be zero")
|
||||
}
|
||||
if user.UpdatedAt.IsZero() {
|
||||
t.Error("UpdatedAt should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a user first
|
||||
createdUser, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve the user
|
||||
retrievedUser, err := storage.GetUserByID(ctx, createdUser.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUser.ID != createdUser.ID {
|
||||
t.Errorf("Expected ID '%s', got '%s'", createdUser.ID, retrievedUser.ID)
|
||||
}
|
||||
if retrievedUser.Email != createdUser.Email {
|
||||
t.Errorf("Expected email '%s', got '%s'", createdUser.Email, retrievedUser.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByID_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := storage.GetUserByID(ctx, "nonexistent-id")
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting nonexistent user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByOAuth(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a user first
|
||||
createdUser, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve the user by OAuth
|
||||
retrievedUser, err := storage.GetUserByOAuth(ctx, "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user by OAuth: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUser.ID != createdUser.ID {
|
||||
t.Errorf("Expected ID '%s', got '%s'", createdUser.ID, retrievedUser.ID)
|
||||
}
|
||||
if retrievedUser.Email != createdUser.Email {
|
||||
t.Errorf("Expected email '%s', got '%s'", createdUser.Email, retrievedUser.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByOAuth_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := storage.GetUserByOAuth(ctx, "google", "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting nonexistent user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a user first
|
||||
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
originalUpdatedAt := user.UpdatedAt
|
||||
time.Sleep(10 * time.Millisecond) // Ensure time difference
|
||||
|
||||
// Update the user
|
||||
user.Email = "updated@example.com"
|
||||
err = storage.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update user: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify
|
||||
updatedUser, err := storage.GetUserByID(ctx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated user: %v", err)
|
||||
}
|
||||
|
||||
if updatedUser.Email != "updated@example.com" {
|
||||
t.Errorf("Expected email 'updated@example.com', got '%s'", updatedUser.Email)
|
||||
}
|
||||
if !updatedUser.UpdatedAt.After(originalUpdatedAt) {
|
||||
t.Error("UpdatedAt should be updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a user first
|
||||
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
// Delete the user
|
||||
err = storage.DeleteUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
_, err = storage.GetUserByID(ctx, user.ID)
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting deleted user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser_MultipleUsers(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple users
|
||||
user1, err := storage.CreateUser(ctx, "user1@example.com", "google", "google1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user1: %v", err)
|
||||
}
|
||||
|
||||
user2, err := storage.CreateUser(ctx, "user2@example.com", "github", "github1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user2: %v", err)
|
||||
}
|
||||
|
||||
// Verify both users exist and are different
|
||||
if user1.ID == user2.ID {
|
||||
t.Error("User IDs should be unique")
|
||||
}
|
||||
|
||||
retrievedUser1, err := storage.GetUserByID(ctx, user1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user1: %v", err)
|
||||
}
|
||||
|
||||
retrievedUser2, err := storage.GetUserByID(ctx, user2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user2: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUser1.Email != "user1@example.com" {
|
||||
t.Errorf("Expected user1 email 'user1@example.com', got '%s'", retrievedUser1.Email)
|
||||
}
|
||||
if retrievedUser2.Email != "user2@example.com" {
|
||||
t.Errorf("Expected user2 email 'user2@example.com', got '%s'", retrievedUser2.Email)
|
||||
}
|
||||
}
|
||||
43
internal/testing/helpers.go
Normal file
43
internal/testing/helpers.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package testing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/leanovate/gopter"
|
||||
)
|
||||
|
||||
// PropertyTestConfig holds configuration for property-based tests
|
||||
type PropertyTestConfig struct {
|
||||
MinSuccessfulTests int
|
||||
MaxSize int
|
||||
Workers int
|
||||
}
|
||||
|
||||
// DefaultPropertyTestConfig returns default configuration for property tests
|
||||
func DefaultPropertyTestConfig() *PropertyTestConfig {
|
||||
return &PropertyTestConfig{
|
||||
MinSuccessfulTests: 100,
|
||||
MaxSize: 100,
|
||||
Workers: 4,
|
||||
}
|
||||
}
|
||||
|
||||
// RunPropertyTest runs a property-based test with the given configuration
|
||||
func RunPropertyTest(t *testing.T, config *PropertyTestConfig, testFunc func(*gopter.Properties)) {
|
||||
if config == nil {
|
||||
config = DefaultPropertyTestConfig()
|
||||
}
|
||||
|
||||
parameters := gopter.DefaultTestParameters()
|
||||
parameters.MinSuccessfulTests = config.MinSuccessfulTests
|
||||
parameters.MaxSize = config.MaxSize
|
||||
parameters.Workers = config.Workers
|
||||
|
||||
properties := gopter.NewProperties(parameters)
|
||||
|
||||
// Call the test function to add properties
|
||||
testFunc(properties)
|
||||
|
||||
// Run the tests
|
||||
properties.TestingRun(t)
|
||||
}
|
||||
55
internal/testing/helpers_test.go
Normal file
55
internal/testing/helpers_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package testing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/leanovate/gopter"
|
||||
"github.com/leanovate/gopter/gen"
|
||||
"github.com/leanovate/gopter/prop"
|
||||
)
|
||||
|
||||
func TestDefaultPropertyTestConfig(t *testing.T) {
|
||||
cfg := DefaultPropertyTestConfig()
|
||||
|
||||
if cfg.MinSuccessfulTests != 100 {
|
||||
t.Errorf("Expected MinSuccessfulTests to be 100, got %d", cfg.MinSuccessfulTests)
|
||||
}
|
||||
|
||||
if cfg.MaxSize != 100 {
|
||||
t.Errorf("Expected MaxSize to be 100, got %d", cfg.MaxSize)
|
||||
}
|
||||
|
||||
if cfg.Workers != 4 {
|
||||
t.Errorf("Expected Workers to be 4, got %d", cfg.Workers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunPropertyTest(t *testing.T) {
|
||||
// Simple property: for all integers, x + 0 = x
|
||||
RunPropertyTest(t, nil, func(properties *gopter.Properties) {
|
||||
properties.Property("addition identity", prop.ForAll(
|
||||
func(x int) bool {
|
||||
return x+0 == x
|
||||
},
|
||||
gen.Int(),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRunPropertyTestWithCustomConfig(t *testing.T) {
|
||||
cfg := &PropertyTestConfig{
|
||||
MinSuccessfulTests: 50,
|
||||
MaxSize: 50,
|
||||
Workers: 2,
|
||||
}
|
||||
|
||||
// Simple property: for all strings, len(s) >= 0
|
||||
RunPropertyTest(t, cfg, func(properties *gopter.Properties) {
|
||||
properties.Property("string length non-negative", prop.ForAll(
|
||||
func(s string) bool {
|
||||
return len(s) >= 0
|
||||
},
|
||||
gen.AnyString(),
|
||||
))
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user