package auth import ( "context" "crypto/rand" "encoding/base64" "fmt" "time" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) // OAuthService handles OAuth authentication flows type OAuthService struct { googleConfig *oauth2.Config stateStore StateStore } // StateStore manages OAuth state tokens for CSRF protection type StateStore interface { Set(state string, expiry time.Time) error Validate(state string) (bool, error) Delete(state string) error } // NewOAuthService creates a new OAuth service func NewOAuthService(googleClientID, googleClientSecret, googleRedirectURL string, stateStore StateStore) *OAuthService { googleConfig := &oauth2.Config{ ClientID: googleClientID, ClientSecret: googleClientSecret, RedirectURL: googleRedirectURL, Scopes: []string{ "https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile", }, Endpoint: google.Endpoint, } return &OAuthService{ googleConfig: googleConfig, stateStore: stateStore, } } // InitiateOAuth starts the OAuth flow and returns the redirect URL func (s *OAuthService) InitiateOAuth(provider string) (string, error) { if provider != "google" { return "", fmt.Errorf("unsupported OAuth provider: %s", provider) } // Generate random state token for CSRF protection state, err := generateStateToken() if err != nil { return "", fmt.Errorf("failed to generate state token: %w", err) } // Store state with 10 minute expiry expiry := time.Now().Add(10 * time.Minute) if err := s.stateStore.Set(state, expiry); err != nil { return "", fmt.Errorf("failed to store state token: %w", err) } // Generate authorization URL url := s.googleConfig.AuthCodeURL(state, oauth2.AccessTypeOffline) return url, nil } // HandleOAuthCallback processes the OAuth callback and exchanges the code for a token func (s *OAuthService) HandleOAuthCallback(ctx context.Context, provider, code, state string) (*oauth2.Token, error) { if provider != "google" { return nil, fmt.Errorf("unsupported OAuth provider: %s", provider) } // Validate state token valid, err := s.stateStore.Validate(state) if err != nil { return nil, fmt.Errorf("failed to validate state token: %w", err) } if !valid { return nil, fmt.Errorf("invalid or expired state token") } // Delete state token after validation _ = s.stateStore.Delete(state) // Exchange authorization code for token token, err := s.googleConfig.Exchange(ctx, code) if err != nil { return nil, fmt.Errorf("failed to exchange code for token: %w", err) } return token, nil } // GetGoogleConfig returns the Google OAuth config for accessing user info func (s *OAuthService) GetGoogleConfig() *oauth2.Config { return s.googleConfig } // generateStateToken generates a cryptographically secure random state token func generateStateToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil }