137 lines
3.5 KiB
Go
137 lines
3.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestInitiateOAuth(t *testing.T) {
|
|
stateStore := NewMemoryStateStore()
|
|
service := NewOAuthService(
|
|
"test-client-id",
|
|
"test-client-secret",
|
|
"http://localhost:8080/auth/callback/google",
|
|
stateStore,
|
|
)
|
|
|
|
t.Run("successful OAuth initiation", func(t *testing.T) {
|
|
redirectURL, err := service.InitiateOAuth("google")
|
|
if err != nil {
|
|
t.Fatalf("Expected no error, got %v", err)
|
|
}
|
|
|
|
if redirectURL == "" {
|
|
t.Fatal("Expected redirect URL, got empty string")
|
|
}
|
|
|
|
// Check that URL contains expected components
|
|
if !strings.Contains(redirectURL, "accounts.google.com") {
|
|
t.Errorf("Expected Google OAuth URL, got %s", redirectURL)
|
|
}
|
|
|
|
if !strings.Contains(redirectURL, "client_id=test-client-id") {
|
|
t.Errorf("Expected client_id in URL, got %s", redirectURL)
|
|
}
|
|
|
|
if !strings.Contains(redirectURL, "redirect_uri=") {
|
|
t.Errorf("Expected redirect_uri in URL, got %s", redirectURL)
|
|
}
|
|
|
|
if !strings.Contains(redirectURL, "state=") {
|
|
t.Errorf("Expected state parameter in URL, got %s", redirectURL)
|
|
}
|
|
})
|
|
|
|
t.Run("unsupported provider", func(t *testing.T) {
|
|
_, err := service.InitiateOAuth("unsupported")
|
|
if err == nil {
|
|
t.Fatal("Expected error for unsupported provider, got nil")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
|
|
t.Errorf("Expected unsupported provider error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHandleOAuthCallback(t *testing.T) {
|
|
stateStore := NewMemoryStateStore()
|
|
service := NewOAuthService(
|
|
"test-client-id",
|
|
"test-client-secret",
|
|
"http://localhost:8080/auth/callback/google",
|
|
stateStore,
|
|
)
|
|
|
|
t.Run("invalid state token", func(t *testing.T) {
|
|
ctx := context.Background()
|
|
_, err := service.HandleOAuthCallback(ctx, "google", "test-code", "invalid-state")
|
|
if err == nil {
|
|
t.Fatal("Expected error for invalid state, got nil")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "invalid or expired state token") {
|
|
t.Errorf("Expected invalid state error, got %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("valid state token but invalid code", func(t *testing.T) {
|
|
// Store a valid state token
|
|
state := "valid-state-token"
|
|
expiry := time.Now().Add(10 * time.Minute)
|
|
if err := stateStore.Set(state, expiry); err != nil {
|
|
t.Fatalf("Failed to set state: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
_, err := service.HandleOAuthCallback(ctx, "google", "invalid-code", state)
|
|
|
|
// This should fail because the code is invalid (can't exchange with Google)
|
|
if err == nil {
|
|
t.Fatal("Expected error for invalid code, got nil")
|
|
}
|
|
|
|
// The state should be deleted even on error
|
|
valid, _ := stateStore.Validate(state)
|
|
if valid {
|
|
t.Error("Expected state to be deleted after callback")
|
|
}
|
|
})
|
|
|
|
t.Run("unsupported provider", func(t *testing.T) {
|
|
ctx := context.Background()
|
|
_, err := service.HandleOAuthCallback(ctx, "unsupported", "test-code", "test-state")
|
|
if err == nil {
|
|
t.Fatal("Expected error for unsupported provider, got nil")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), "unsupported OAuth provider") {
|
|
t.Errorf("Expected unsupported provider error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGenerateStateToken(t *testing.T) {
|
|
t.Run("generates unique tokens", func(t *testing.T) {
|
|
token1, err := generateStateToken()
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate token: %v", err)
|
|
}
|
|
|
|
token2, err := generateStateToken()
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate token: %v", err)
|
|
}
|
|
|
|
if token1 == token2 {
|
|
t.Error("Expected unique tokens, got duplicates")
|
|
}
|
|
|
|
if len(token1) == 0 {
|
|
t.Error("Expected non-empty token")
|
|
}
|
|
})
|
|
}
|