Initial commit: Custom Start Page application with authentication and DynamoDB storage
This commit is contained in:
109
internal/auth/oauth.go
Normal file
109
internal/auth/oauth.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// OAuthService handles OAuth authentication flows
|
||||
type OAuthService struct {
|
||||
googleConfig *oauth2.Config
|
||||
stateStore StateStore
|
||||
}
|
||||
|
||||
// StateStore manages OAuth state tokens for CSRF protection
|
||||
type StateStore interface {
|
||||
Set(state string, expiry time.Time) error
|
||||
Validate(state string) (bool, error)
|
||||
Delete(state string) error
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(googleClientID, googleClientSecret, googleRedirectURL string, stateStore StateStore) *OAuthService {
|
||||
googleConfig := &oauth2.Config{
|
||||
ClientID: googleClientID,
|
||||
ClientSecret: googleClientSecret,
|
||||
RedirectURL: googleRedirectURL,
|
||||
Scopes: []string{
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
return &OAuthService{
|
||||
googleConfig: googleConfig,
|
||||
stateStore: stateStore,
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateOAuth starts the OAuth flow and returns the redirect URL
|
||||
func (s *OAuthService) InitiateOAuth(provider string) (string, error) {
|
||||
if provider != "google" {
|
||||
return "", fmt.Errorf("unsupported OAuth provider: %s", provider)
|
||||
}
|
||||
|
||||
// Generate random state token for CSRF protection
|
||||
state, err := generateStateToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate state token: %w", err)
|
||||
}
|
||||
|
||||
// Store state with 10 minute expiry
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
if err := s.stateStore.Set(state, expiry); err != nil {
|
||||
return "", fmt.Errorf("failed to store state token: %w", err)
|
||||
}
|
||||
|
||||
// Generate authorization URL
|
||||
url := s.googleConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// HandleOAuthCallback processes the OAuth callback and exchanges the code for a token
|
||||
func (s *OAuthService) HandleOAuthCallback(ctx context.Context, provider, code, state string) (*oauth2.Token, error) {
|
||||
if provider != "google" {
|
||||
return nil, fmt.Errorf("unsupported OAuth provider: %s", provider)
|
||||
}
|
||||
|
||||
// Validate state token
|
||||
valid, err := s.stateStore.Validate(state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate state token: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("invalid or expired state token")
|
||||
}
|
||||
|
||||
// Delete state token after validation
|
||||
_ = s.stateStore.Delete(state)
|
||||
|
||||
// Exchange authorization code for token
|
||||
token, err := s.googleConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetGoogleConfig returns the Google OAuth config for accessing user info
|
||||
func (s *OAuthService) GetGoogleConfig() *oauth2.Config {
|
||||
return s.googleConfig
|
||||
}
|
||||
|
||||
// generateStateToken generates a cryptographically secure random state token
|
||||
func generateStateToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
136
internal/auth/oauth_test.go
Normal file
136
internal/auth/oauth_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInitiateOAuth(t *testing.T) {
|
||||
stateStore := NewMemoryStateStore()
|
||||
service := NewOAuthService(
|
||||
"test-client-id",
|
||||
"test-client-secret",
|
||||
"http://localhost:8080/auth/callback/google",
|
||||
stateStore,
|
||||
)
|
||||
|
||||
t.Run("successful OAuth initiation", func(t *testing.T) {
|
||||
redirectURL, err := service.InitiateOAuth("google")
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if redirectURL == "" {
|
||||
t.Fatal("Expected redirect URL, got empty string")
|
||||
}
|
||||
|
||||
// Check that URL contains expected components
|
||||
if !strings.Contains(redirectURL, "accounts.google.com") {
|
||||
t.Errorf("Expected Google OAuth URL, got %s", redirectURL)
|
||||
}
|
||||
|
||||
if !strings.Contains(redirectURL, "client_id=test-client-id") {
|
||||
t.Errorf("Expected client_id in URL, got %s", redirectURL)
|
||||
}
|
||||
|
||||
if !strings.Contains(redirectURL, "redirect_uri=") {
|
||||
t.Errorf("Expected redirect_uri in URL, got %s", redirectURL)
|
||||
}
|
||||
|
||||
if !strings.Contains(redirectURL, "state=") {
|
||||
t.Errorf("Expected state parameter in URL, got %s", redirectURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported provider", func(t *testing.T) {
|
||||
_, err := service.InitiateOAuth("unsupported")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for unsupported provider, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
|
||||
t.Errorf("Expected unsupported provider error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleOAuthCallback(t *testing.T) {
|
||||
stateStore := NewMemoryStateStore()
|
||||
service := NewOAuthService(
|
||||
"test-client-id",
|
||||
"test-client-secret",
|
||||
"http://localhost:8080/auth/callback/google",
|
||||
stateStore,
|
||||
)
|
||||
|
||||
t.Run("invalid state token", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := service.HandleOAuthCallback(ctx, "google", "test-code", "invalid-state")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid state, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid or expired state token") {
|
||||
t.Errorf("Expected invalid state error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid state token but invalid code", func(t *testing.T) {
|
||||
// Store a valid state token
|
||||
state := "valid-state-token"
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
if err := stateStore.Set(state, expiry); err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := service.HandleOAuthCallback(ctx, "google", "invalid-code", state)
|
||||
|
||||
// This should fail because the code is invalid (can't exchange with Google)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid code, got nil")
|
||||
}
|
||||
|
||||
// The state should be deleted even on error
|
||||
valid, _ := stateStore.Validate(state)
|
||||
if valid {
|
||||
t.Error("Expected state to be deleted after callback")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported provider", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := service.HandleOAuthCallback(ctx, "unsupported", "test-code", "test-state")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for unsupported provider, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
|
||||
t.Errorf("Expected unsupported provider error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateStateToken(t *testing.T) {
|
||||
t.Run("generates unique tokens", func(t *testing.T) {
|
||||
token1, err := generateStateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
token2, err := generateStateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
if token1 == token2 {
|
||||
t.Error("Expected unique tokens, got duplicates")
|
||||
}
|
||||
|
||||
if len(token1) == 0 {
|
||||
t.Error("Expected non-empty token")
|
||||
}
|
||||
})
|
||||
}
|
||||
88
internal/auth/session_store.go
Normal file
88
internal/auth/session_store.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const sessionName = "startpage_session"
|
||||
const userIDKey = "user_id"
|
||||
|
||||
// CookieSessionStore implements SessionStore using gorilla/sessions
|
||||
type CookieSessionStore struct {
|
||||
store *sessions.CookieStore
|
||||
}
|
||||
|
||||
// NewCookieSessionStore creates a new cookie-based session store
|
||||
func NewCookieSessionStore(secretKey string, maxAge int) *CookieSessionStore {
|
||||
store := sessions.NewCookieStore([]byte(secretKey))
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: maxAge,
|
||||
HttpOnly: true,
|
||||
Secure: false, // Set to true in production with HTTPS
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
return &CookieSessionStore{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSession creates a new session for the user
|
||||
func (s *CookieSessionStore) CreateSession(w http.ResponseWriter, r *http.Request, userID string) error {
|
||||
session, err := s.store.Get(r, sessionName)
|
||||
if err != nil {
|
||||
// If there's an error getting the session, create a new one
|
||||
session, _ = s.store.New(r, sessionName)
|
||||
}
|
||||
|
||||
session.Values[userIDKey] = userID
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserID retrieves the user ID from the session
|
||||
func (s *CookieSessionStore) GetUserID(r *http.Request) (string, error) {
|
||||
session, err := s.store.Get(r, sessionName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get session: %w", err)
|
||||
}
|
||||
|
||||
userID, ok := session.Values[userIDKey].(string)
|
||||
if !ok || userID == "" {
|
||||
return "", fmt.Errorf("user ID not found in session")
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// ValidateSession checks if a valid session exists for the request
|
||||
func (s *CookieSessionStore) ValidateSession(r *http.Request) bool {
|
||||
_, err := s.GetUserID(r)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// DestroySession destroys the user's session
|
||||
func (s *CookieSessionStore) DestroySession(w http.ResponseWriter, r *http.Request) error {
|
||||
session, err := s.store.Get(r, sessionName)
|
||||
if err != nil {
|
||||
// Session doesn't exist or is invalid, nothing to destroy
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set MaxAge to -1 to delete the cookie
|
||||
session.Options.MaxAge = -1
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to destroy session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
237
internal/auth/session_store_test.go
Normal file
237
internal/auth/session_store_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCookieSessionStore(t *testing.T) {
|
||||
store := NewCookieSessionStore("test-secret-key", 3600)
|
||||
|
||||
t.Run("create and retrieve session", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-123"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Get the cookie from the response
|
||||
cookies := w.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatal("Expected session cookie, got none")
|
||||
}
|
||||
|
||||
// Create a new request with the cookie
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
retrievedUserID, err := store.GetUserID(r2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user ID: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUserID != userID {
|
||||
t.Errorf("Expected user ID %s, got %s", userID, retrievedUserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("get user ID without session", func(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
_, err := store.GetUserID(r)
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting user ID without session")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("destroy session", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-456"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Get the cookie
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Destroy the session
|
||||
w2 := httptest.NewRecorder()
|
||||
err = store.DestroySession(w2, r2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to destroy session: %v", err)
|
||||
}
|
||||
|
||||
// Check that the cookie has MaxAge set to -1 (deletion marker)
|
||||
destroyCookies := w2.Result().Cookies()
|
||||
if len(destroyCookies) == 0 {
|
||||
t.Fatal("Expected cookie with MaxAge=-1 for deletion")
|
||||
}
|
||||
|
||||
foundDeleteCookie := false
|
||||
for _, cookie := range destroyCookies {
|
||||
if cookie.MaxAge == -1 {
|
||||
foundDeleteCookie = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundDeleteCookie {
|
||||
t.Error("Expected cookie with MaxAge=-1 to indicate deletion")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("update existing session", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Create initial session
|
||||
userID1 := "user-1"
|
||||
store.CreateSession(w, r, userID1)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Update session with new user ID
|
||||
w2 := httptest.NewRecorder()
|
||||
userID2 := "user-2"
|
||||
err := store.CreateSession(w2, r2, userID2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update session: %v", err)
|
||||
}
|
||||
|
||||
// Verify new user ID
|
||||
cookies2 := w2.Result().Cookies()
|
||||
r3 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies2 {
|
||||
r3.AddCookie(cookie)
|
||||
}
|
||||
|
||||
retrievedUserID, err := store.GetUserID(r3)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user ID: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUserID != userID2 {
|
||||
t.Errorf("Expected user ID %s, got %s", userID2, retrievedUserID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate session with valid session", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-789"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Create request with session cookie
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Validate session
|
||||
if !store.ValidateSession(r2) {
|
||||
t.Error("Expected ValidateSession to return true for valid session")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate session without session", func(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Validate session without any cookies
|
||||
if store.ValidateSession(r) {
|
||||
t.Error("Expected ValidateSession to return false for missing session")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate session after logout", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-logout"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Get the cookie
|
||||
cookies := w.Result().Cookies()
|
||||
r2 := httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range cookies {
|
||||
r2.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Destroy the session
|
||||
w2 := httptest.NewRecorder()
|
||||
err = store.DestroySession(w2, r2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to destroy session: %v", err)
|
||||
}
|
||||
|
||||
// Create a new request without any cookies (simulating browser behavior after logout)
|
||||
r3 := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
// Validate session should return false
|
||||
if store.ValidateSession(r3) {
|
||||
t.Error("Expected ValidateSession to return false after logout")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("session cookie has security settings", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
userID := "test-user-security"
|
||||
err := store.CreateSession(w, r, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatal("Expected session cookie, got none")
|
||||
}
|
||||
|
||||
cookie := cookies[0]
|
||||
|
||||
// Verify HttpOnly flag is set
|
||||
if !cookie.HttpOnly {
|
||||
t.Error("Expected HttpOnly flag to be true")
|
||||
}
|
||||
|
||||
// Verify SameSite is set
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Expected SameSite to be Lax, got %v", cookie.SameSite)
|
||||
}
|
||||
|
||||
// Verify Path is set
|
||||
if cookie.Path != "/" {
|
||||
t.Errorf("Expected Path to be /, got %s", cookie.Path)
|
||||
}
|
||||
|
||||
// Verify MaxAge is set
|
||||
if cookie.MaxAge != 3600 {
|
||||
t.Errorf("Expected MaxAge to be 3600, got %d", cookie.MaxAge)
|
||||
}
|
||||
})
|
||||
}
|
||||
78
internal/auth/state_store.go
Normal file
78
internal/auth/state_store.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryStateStore is an in-memory implementation of StateStore
|
||||
// Note: This is suitable for development but should be replaced with
|
||||
// a distributed store (Redis, DynamoDB) for production with multiple servers
|
||||
type MemoryStateStore struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]time.Time
|
||||
}
|
||||
|
||||
// NewMemoryStateStore creates a new in-memory state store
|
||||
func NewMemoryStateStore() *MemoryStateStore {
|
||||
store := &MemoryStateStore{
|
||||
states: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine to remove expired states
|
||||
go store.cleanupExpired()
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// Set stores a state token with an expiry time
|
||||
func (s *MemoryStateStore) Set(state string, expiry time.Time) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.states[state] = expiry
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if a state token is valid and not expired
|
||||
func (s *MemoryStateStore) Validate(state string) (bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
expiry, exists := s.states[state]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(expiry) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Delete removes a state token from the store
|
||||
func (s *MemoryStateStore) Delete(state string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.states, state)
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpired periodically removes expired state tokens
|
||||
func (s *MemoryStateStore) cleanupExpired() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
for state, expiry := range s.states {
|
||||
if now.After(expiry) {
|
||||
delete(s.states, state)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
112
internal/auth/state_store_test.go
Normal file
112
internal/auth/state_store_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMemoryStateStore(t *testing.T) {
|
||||
store := NewMemoryStateStore()
|
||||
|
||||
t.Run("set and validate state", func(t *testing.T) {
|
||||
state := "test-state-123"
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
|
||||
err := store.Set(state, expiry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
|
||||
valid, err := store.Validate(state)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to validate state: %v", err)
|
||||
}
|
||||
|
||||
if !valid {
|
||||
t.Error("Expected state to be valid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate non-existent state", func(t *testing.T) {
|
||||
valid, err := store.Validate("non-existent")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("Expected non-existent state to be invalid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("validate expired state", func(t *testing.T) {
|
||||
state := "expired-state"
|
||||
expiry := time.Now().Add(-1 * time.Minute) // Already expired
|
||||
|
||||
err := store.Set(state, expiry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
|
||||
valid, err := store.Validate(state)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("Expected expired state to be invalid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete state", func(t *testing.T) {
|
||||
state := "delete-test"
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
|
||||
err := store.Set(state, expiry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
|
||||
err = store.Delete(state)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete state: %v", err)
|
||||
}
|
||||
|
||||
valid, err := store.Validate(state)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if valid {
|
||||
t.Error("Expected deleted state to be invalid")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple states", func(t *testing.T) {
|
||||
state1 := "state-1"
|
||||
state2 := "state-2"
|
||||
expiry := time.Now().Add(10 * time.Minute)
|
||||
|
||||
store.Set(state1, expiry)
|
||||
store.Set(state2, expiry)
|
||||
|
||||
valid1, _ := store.Validate(state1)
|
||||
valid2, _ := store.Validate(state2)
|
||||
|
||||
if !valid1 || !valid2 {
|
||||
t.Error("Expected both states to be valid")
|
||||
}
|
||||
|
||||
store.Delete(state1)
|
||||
|
||||
valid1, _ = store.Validate(state1)
|
||||
valid2, _ = store.Validate(state2)
|
||||
|
||||
if valid1 {
|
||||
t.Error("Expected state1 to be invalid after deletion")
|
||||
}
|
||||
|
||||
if !valid2 {
|
||||
t.Error("Expected state2 to still be valid")
|
||||
}
|
||||
})
|
||||
}
|
||||
100
internal/auth/user_service.go
Normal file
100
internal/auth/user_service.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"custom-start-page/internal/models"
|
||||
"custom-start-page/internal/storage"
|
||||
)
|
||||
|
||||
// GoogleUserInfo represents the user info returned by Google
|
||||
type GoogleUserInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
VerifiedEmail bool `json:"verified_email"`
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
// UserService handles user creation and retrieval
|
||||
type UserService struct {
|
||||
userRepo *storage.UserRepository
|
||||
}
|
||||
|
||||
// NewUserService creates a new user service
|
||||
func NewUserService(userRepo *storage.UserRepository) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreateUserFromGoogle fetches user info from Google and creates or retrieves the user
|
||||
func (s *UserService) GetOrCreateUserFromGoogle(ctx context.Context, token *oauth2.Token, oauthConfig *oauth2.Config) (*models.User, error) {
|
||||
// Fetch user info from Google
|
||||
userInfo, err := s.fetchGoogleUserInfo(ctx, token, oauthConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Google user info: %w", err)
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
user, err := s.userRepo.GetByOAuthID(ctx, "google", userInfo.ID)
|
||||
if err == nil {
|
||||
// User exists, update last login time
|
||||
user.UpdatedAt = time.Now()
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// User doesn't exist, create new user
|
||||
user = &models.User{
|
||||
ID: uuid.New().String(),
|
||||
Email: userInfo.Email,
|
||||
OAuthProvider: "google",
|
||||
OAuthID: userInfo.ID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// fetchGoogleUserInfo fetches user information from Google's userinfo endpoint
|
||||
func (s *UserService) fetchGoogleUserInfo(ctx context.Context, token *oauth2.Token, oauthConfig *oauth2.Config) (*GoogleUserInfo, error) {
|
||||
client := oauthConfig.Client(ctx, token)
|
||||
|
||||
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to get user info: status %d, body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var userInfo GoogleUserInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode user info: %w", err)
|
||||
}
|
||||
|
||||
if !userInfo.VerifiedEmail {
|
||||
return nil, fmt.Errorf("email not verified")
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
Reference in New Issue
Block a user