Files
Kiro/internal/auth/oauth_test.go

137 lines
3.5 KiB
Go

package auth
import (
"context"
"strings"
"testing"
"time"
)
func TestInitiateOAuth(t *testing.T) {
stateStore := NewMemoryStateStore()
service := NewOAuthService(
"test-client-id",
"test-client-secret",
"http://localhost:8080/auth/callback/google",
stateStore,
)
t.Run("successful OAuth initiation", func(t *testing.T) {
redirectURL, err := service.InitiateOAuth("google")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if redirectURL == "" {
t.Fatal("Expected redirect URL, got empty string")
}
// Check that URL contains expected components
if !strings.Contains(redirectURL, "accounts.google.com") {
t.Errorf("Expected Google OAuth URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "client_id=test-client-id") {
t.Errorf("Expected client_id in URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "redirect_uri=") {
t.Errorf("Expected redirect_uri in URL, got %s", redirectURL)
}
if !strings.Contains(redirectURL, "state=") {
t.Errorf("Expected state parameter in URL, got %s", redirectURL)
}
})
t.Run("unsupported provider", func(t *testing.T) {
_, err := service.InitiateOAuth("unsupported")
if err == nil {
t.Fatal("Expected error for unsupported provider, got nil")
}
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
t.Errorf("Expected unsupported provider error, got %v", err)
}
})
}
func TestHandleOAuthCallback(t *testing.T) {
stateStore := NewMemoryStateStore()
service := NewOAuthService(
"test-client-id",
"test-client-secret",
"http://localhost:8080/auth/callback/google",
stateStore,
)
t.Run("invalid state token", func(t *testing.T) {
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "google", "test-code", "invalid-state")
if err == nil {
t.Fatal("Expected error for invalid state, got nil")
}
if !strings.Contains(err.Error(), "invalid or expired state token") {
t.Errorf("Expected invalid state error, got %v", err)
}
})
t.Run("valid state token but invalid code", func(t *testing.T) {
// Store a valid state token
state := "valid-state-token"
expiry := time.Now().Add(10 * time.Minute)
if err := stateStore.Set(state, expiry); err != nil {
t.Fatalf("Failed to set state: %v", err)
}
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "google", "invalid-code", state)
// This should fail because the code is invalid (can't exchange with Google)
if err == nil {
t.Fatal("Expected error for invalid code, got nil")
}
// The state should be deleted even on error
valid, _ := stateStore.Validate(state)
if valid {
t.Error("Expected state to be deleted after callback")
}
})
t.Run("unsupported provider", func(t *testing.T) {
ctx := context.Background()
_, err := service.HandleOAuthCallback(ctx, "unsupported", "test-code", "test-state")
if err == nil {
t.Fatal("Expected error for unsupported provider, got nil")
}
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
t.Errorf("Expected unsupported provider error, got %v", err)
}
})
}
func TestGenerateStateToken(t *testing.T) {
t.Run("generates unique tokens", func(t *testing.T) {
token1, err := generateStateToken()
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
token2, err := generateStateToken()
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
if token1 == token2 {
t.Error("Expected unique tokens, got duplicates")
}
if len(token1) == 0 {
t.Error("Expected non-empty token")
}
})
}