package auth import ( "context" "strings" "testing" "time" ) func TestInitiateOAuth(t *testing.T) { stateStore := NewMemoryStateStore() service := NewOAuthService( "test-client-id", "test-client-secret", "http://localhost:8080/auth/callback/google", stateStore, ) t.Run("successful OAuth initiation", func(t *testing.T) { redirectURL, err := service.InitiateOAuth("google") if err != nil { t.Fatalf("Expected no error, got %v", err) } if redirectURL == "" { t.Fatal("Expected redirect URL, got empty string") } // Check that URL contains expected components if !strings.Contains(redirectURL, "accounts.google.com") { t.Errorf("Expected Google OAuth URL, got %s", redirectURL) } if !strings.Contains(redirectURL, "client_id=test-client-id") { t.Errorf("Expected client_id in URL, got %s", redirectURL) } if !strings.Contains(redirectURL, "redirect_uri=") { t.Errorf("Expected redirect_uri in URL, got %s", redirectURL) } if !strings.Contains(redirectURL, "state=") { t.Errorf("Expected state parameter in URL, got %s", redirectURL) } }) t.Run("unsupported provider", func(t *testing.T) { _, err := service.InitiateOAuth("unsupported") if err == nil { t.Fatal("Expected error for unsupported provider, got nil") } if !strings.Contains(err.Error(), "unsupported OAuth provider") { t.Errorf("Expected unsupported provider error, got %v", err) } }) } func TestHandleOAuthCallback(t *testing.T) { stateStore := NewMemoryStateStore() service := NewOAuthService( "test-client-id", "test-client-secret", "http://localhost:8080/auth/callback/google", stateStore, ) t.Run("invalid state token", func(t *testing.T) { ctx := context.Background() _, err := service.HandleOAuthCallback(ctx, "google", "test-code", "invalid-state") if err == nil { t.Fatal("Expected error for invalid state, got nil") } if !strings.Contains(err.Error(), "invalid or expired state token") { t.Errorf("Expected invalid state error, got %v", err) } }) t.Run("valid state token but invalid code", func(t *testing.T) { // Store a valid state token state := "valid-state-token" expiry := time.Now().Add(10 * time.Minute) if err := stateStore.Set(state, expiry); err != nil { t.Fatalf("Failed to set state: %v", err) } ctx := context.Background() _, err := service.HandleOAuthCallback(ctx, "google", "invalid-code", state) // This should fail because the code is invalid (can't exchange with Google) if err == nil { t.Fatal("Expected error for invalid code, got nil") } // The state should be deleted even on error valid, _ := stateStore.Validate(state) if valid { t.Error("Expected state to be deleted after callback") } }) t.Run("unsupported provider", func(t *testing.T) { ctx := context.Background() _, err := service.HandleOAuthCallback(ctx, "unsupported", "test-code", "test-state") if err == nil { t.Fatal("Expected error for unsupported provider, got nil") } if !strings.Contains(err.Error(), "unsupported OAuth provider") { t.Errorf("Expected unsupported provider error, got %v", err) } }) } func TestGenerateStateToken(t *testing.T) { t.Run("generates unique tokens", func(t *testing.T) { token1, err := generateStateToken() if err != nil { t.Fatalf("Failed to generate token: %v", err) } token2, err := generateStateToken() if err != nil { t.Fatalf("Failed to generate token: %v", err) } if token1 == token2 { t.Error("Expected unique tokens, got duplicates") } if len(token1) == 0 { t.Error("Expected non-empty token") } }) }