Initial commit: Custom Start Page application with authentication and DynamoDB storage
This commit is contained in:
124
internal/storage/README.md
Normal file
124
internal/storage/README.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# DynamoDB Storage Service
|
||||
|
||||
This package provides an enhanced DynamoDB client wrapper with the following features:
|
||||
|
||||
## Features
|
||||
|
||||
### 1. Connection Pooling
|
||||
The client uses the AWS SDK's default HTTP client which includes connection pooling automatically. This ensures efficient reuse of TCP connections to DynamoDB.
|
||||
|
||||
### 2. Retry Logic with Exponential Backoff
|
||||
The client is configured with automatic retry logic:
|
||||
- **Max Attempts**: 5 retries
|
||||
- **Max Backoff**: 20 seconds
|
||||
- **Strategy**: Exponential backoff with jitter to prevent thundering herd
|
||||
|
||||
This handles transient failures gracefully and improves reliability.
|
||||
|
||||
### 3. Transaction Support
|
||||
The `TransactWriteItems` method provides ACID transaction support for multiple write operations:
|
||||
```go
|
||||
err := client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{
|
||||
TransactItems: []types.TransactWriteItem{
|
||||
{
|
||||
Put: &types.Put{
|
||||
TableName: aws.String("MyTable"),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: "item1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
// More items...
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### 4. Batch Operations
|
||||
The client provides batch read and write operations with automatic retry of unprocessed items:
|
||||
|
||||
#### BatchGetItems
|
||||
Retrieves multiple items in a single request:
|
||||
```go
|
||||
output, err := client.BatchGetItems(ctx, &dynamodb.BatchGetItemInput{
|
||||
RequestItems: map[string]types.KeysAndAttributes{
|
||||
"MyTable": {
|
||||
Keys: []map[string]types.AttributeValue{
|
||||
{"id": &types.AttributeValueMemberS{Value: "item1"}},
|
||||
{"id": &types.AttributeValueMemberS{Value: "item2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### BatchWriteItems
|
||||
Writes multiple items in a single request:
|
||||
```go
|
||||
err := client.BatchWriteItems(ctx, &dynamodb.BatchWriteItemInput{
|
||||
RequestItems: map[string][]types.WriteRequest{
|
||||
"MyTable": {
|
||||
{
|
||||
PutRequest: &types.PutRequest{
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: "item1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Both batch operations automatically handle unprocessed items with exponential backoff retry logic.
|
||||
|
||||
## Standard Operations
|
||||
|
||||
The client also provides wrapped versions of standard DynamoDB operations with automatic retry:
|
||||
- `PutItem` - Put a single item
|
||||
- `GetItem` - Get a single item
|
||||
- `UpdateItem` - Update a single item
|
||||
- `DeleteItem` - Delete a single item
|
||||
- `Query` - Query items
|
||||
|
||||
## Usage
|
||||
|
||||
### Creating a Client
|
||||
|
||||
```go
|
||||
ctx := context.Background()
|
||||
client, err := storage.NewDynamoDBClient(ctx, "http://localhost:8000")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
For production (AWS DynamoDB), pass an empty string for the endpoint:
|
||||
```go
|
||||
client, err := storage.NewDynamoDBClient(ctx, "")
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
The package includes comprehensive tests that can be run against DynamoDB Local:
|
||||
|
||||
1. Start DynamoDB Local:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
2. Run tests:
|
||||
```bash
|
||||
go test -v ./internal/storage
|
||||
```
|
||||
|
||||
Tests will automatically skip if DynamoDB is not available.
|
||||
|
||||
## Requirements Addressed
|
||||
|
||||
This implementation addresses the following requirements from the spec:
|
||||
- **Requirement 8.1**: Immediate persistence of all changes
|
||||
- **Requirement 8.8**: Efficient scaling for 10,000+ items per user
|
||||
- **Design requirement**: Retry logic with exponential backoff
|
||||
- **Design requirement**: Transaction support for atomic operations
|
||||
- **Design requirement**: Batch operations for efficient bulk reads/writes
|
||||
- **Design requirement**: Connection pooling for performance
|
||||
248
internal/storage/dynamodb.go
Normal file
248
internal/storage/dynamodb.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/aws/retry"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
||||
)
|
||||
|
||||
// DynamoDBClient wraps the AWS DynamoDB client with enhanced features
|
||||
type DynamoDBClient struct {
|
||||
client *dynamodb.Client
|
||||
}
|
||||
|
||||
// NewDynamoDBClient creates a new DynamoDB client with connection pooling and retry logic
|
||||
func NewDynamoDBClient(ctx context.Context, endpoint string) (*DynamoDBClient, error) {
|
||||
// Configure retry strategy with exponential backoff
|
||||
retryer := retry.NewStandard(func(o *retry.StandardOptions) {
|
||||
o.MaxAttempts = 5
|
||||
o.MaxBackoff = 20 * time.Second
|
||||
// Use exponential backoff with jitter
|
||||
o.Backoff = retry.NewExponentialJitterBackoff(20 * time.Second)
|
||||
})
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(ctx,
|
||||
config.WithRetryer(func() aws.Retryer {
|
||||
return retryer
|
||||
}),
|
||||
// Connection pooling is handled by the HTTP client
|
||||
config.WithHTTPClient(nil), // Uses default HTTP client with connection pooling
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
||||
}
|
||||
|
||||
// Override endpoint if provided (for local DynamoDB)
|
||||
if endpoint != "" {
|
||||
cfg.BaseEndpoint = aws.String(endpoint)
|
||||
}
|
||||
|
||||
client := dynamodb.NewFromConfig(cfg)
|
||||
|
||||
return &DynamoDBClient{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateUsersTable creates the Users table in DynamoDB
|
||||
func (db *DynamoDBClient) CreateUsersTable(ctx context.Context) error {
|
||||
tableName := "Users"
|
||||
|
||||
// Check if table already exists
|
||||
_, err := db.client.DescribeTable(ctx, &dynamodb.DescribeTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
})
|
||||
if err == nil {
|
||||
// Table already exists
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create table
|
||||
_, err = db.client.CreateTable(ctx, &dynamodb.CreateTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
AttributeDefinitions: []types.AttributeDefinition{
|
||||
{
|
||||
AttributeName: aws.String("user_id"),
|
||||
AttributeType: types.ScalarAttributeTypeS,
|
||||
},
|
||||
},
|
||||
KeySchema: []types.KeySchemaElement{
|
||||
{
|
||||
AttributeName: aws.String("user_id"),
|
||||
KeyType: types.KeyTypeHash,
|
||||
},
|
||||
},
|
||||
BillingMode: types.BillingModePayPerRequest,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Users table: %w", err)
|
||||
}
|
||||
|
||||
// Wait for table to be active
|
||||
waiter := dynamodb.NewTableExistsWaiter(db.client)
|
||||
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
}, 5*60) // 5 minutes timeout
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed waiting for Users table to be active: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClient returns the underlying DynamoDB client
|
||||
func (db *DynamoDBClient) GetClient() *dynamodb.Client {
|
||||
return db.client
|
||||
}
|
||||
|
||||
// TransactWriteItems executes a transactional write operation with automatic retry
|
||||
func (db *DynamoDBClient) TransactWriteItems(ctx context.Context, input *dynamodb.TransactWriteItemsInput) error {
|
||||
_, err := db.client.TransactWriteItems(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("transaction write failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchGetItems retrieves multiple items in a single batch operation
|
||||
func (db *DynamoDBClient) BatchGetItems(ctx context.Context, input *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error) {
|
||||
output, err := db.client.BatchGetItem(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch get failed: %w", err)
|
||||
}
|
||||
|
||||
// Handle unprocessed keys with exponential backoff
|
||||
if len(output.UnprocessedKeys) > 0 {
|
||||
return db.retryUnprocessedKeys(ctx, output)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// BatchWriteItems writes multiple items in a single batch operation
|
||||
func (db *DynamoDBClient) BatchWriteItems(ctx context.Context, input *dynamodb.BatchWriteItemInput) error {
|
||||
output, err := db.client.BatchWriteItem(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("batch write failed: %w", err)
|
||||
}
|
||||
|
||||
// Handle unprocessed items with exponential backoff
|
||||
if len(output.UnprocessedItems) > 0 {
|
||||
return db.retryUnprocessedWrites(ctx, output.UnprocessedItems)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// retryUnprocessedKeys retries unprocessed keys from BatchGetItem with exponential backoff
|
||||
func (db *DynamoDBClient) retryUnprocessedKeys(ctx context.Context, output *dynamodb.BatchGetItemOutput) (*dynamodb.BatchGetItemOutput, error) {
|
||||
maxRetries := 5
|
||||
backoff := 100 * time.Millisecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries && len(output.UnprocessedKeys) > 0; attempt++ {
|
||||
// Wait with exponential backoff
|
||||
time.Sleep(backoff)
|
||||
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
|
||||
|
||||
// Retry unprocessed keys
|
||||
retryOutput, err := db.client.BatchGetItem(ctx, &dynamodb.BatchGetItemInput{
|
||||
RequestItems: output.UnprocessedKeys,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retry batch get failed: %w", err)
|
||||
}
|
||||
|
||||
// Merge responses
|
||||
for table, items := range retryOutput.Responses {
|
||||
output.Responses[table] = append(output.Responses[table], items...)
|
||||
}
|
||||
|
||||
output.UnprocessedKeys = retryOutput.UnprocessedKeys
|
||||
}
|
||||
|
||||
if len(output.UnprocessedKeys) > 0 {
|
||||
return output, fmt.Errorf("failed to process all keys after %d retries", maxRetries)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// retryUnprocessedWrites retries unprocessed items from BatchWriteItem with exponential backoff
|
||||
func (db *DynamoDBClient) retryUnprocessedWrites(ctx context.Context, unprocessedItems map[string][]types.WriteRequest) error {
|
||||
maxRetries := 5
|
||||
backoff := 100 * time.Millisecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries && len(unprocessedItems) > 0; attempt++ {
|
||||
// Wait with exponential backoff
|
||||
time.Sleep(backoff)
|
||||
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
|
||||
|
||||
// Retry unprocessed items
|
||||
output, err := db.client.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
|
||||
RequestItems: unprocessedItems,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("retry batch write failed: %w", err)
|
||||
}
|
||||
|
||||
unprocessedItems = output.UnprocessedItems
|
||||
}
|
||||
|
||||
if len(unprocessedItems) > 0 {
|
||||
return fmt.Errorf("failed to process all items after %d retries", maxRetries)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PutItem puts a single item with automatic retry
|
||||
func (db *DynamoDBClient) PutItem(ctx context.Context, input *dynamodb.PutItemInput) error {
|
||||
_, err := db.client.PutItem(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("put item failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetItem retrieves a single item with automatic retry
|
||||
func (db *DynamoDBClient) GetItem(ctx context.Context, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
|
||||
output, err := db.client.GetItem(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get item failed: %w", err)
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// Query executes a query operation with automatic retry
|
||||
func (db *DynamoDBClient) Query(ctx context.Context, input *dynamodb.QueryInput) (*dynamodb.QueryOutput, error) {
|
||||
output, err := db.client.Query(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// UpdateItem updates a single item with automatic retry
|
||||
func (db *DynamoDBClient) UpdateItem(ctx context.Context, input *dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) {
|
||||
output, err := db.client.UpdateItem(ctx, input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update item failed: %w", err)
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// DeleteItem deletes a single item with automatic retry
|
||||
func (db *DynamoDBClient) DeleteItem(ctx context.Context, input *dynamodb.DeleteItemInput) error {
|
||||
_, err := db.client.DeleteItem(ctx, input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete item failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
452
internal/storage/dynamodb_test.go
Normal file
452
internal/storage/dynamodb_test.go
Normal file
@@ -0,0 +1,452 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
||||
)
|
||||
|
||||
func setupTestClient(t *testing.T) (*DynamoDBClient, context.Context) {
|
||||
t.Helper()
|
||||
|
||||
// Set dummy AWS credentials for local DynamoDB
|
||||
os.Setenv("AWS_ACCESS_KEY_ID", "dummy")
|
||||
os.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
|
||||
os.Setenv("AWS_REGION", "us-east-1")
|
||||
|
||||
endpoint := os.Getenv("DYNAMODB_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "http://localhost:8000"
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
client, err := NewDynamoDBClient(ctx, endpoint)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test: DynamoDB not available: %v", err)
|
||||
}
|
||||
|
||||
// Test connection by listing tables
|
||||
_, err = client.client.ListTables(ctx, &dynamodb.ListTablesInput{})
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test: Cannot connect to DynamoDB: %v", err)
|
||||
}
|
||||
|
||||
return client, ctx
|
||||
}
|
||||
|
||||
func TestNewDynamoDBClient(t *testing.T) {
|
||||
client, _ := setupTestClient(t)
|
||||
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil client")
|
||||
}
|
||||
|
||||
if client.client == nil {
|
||||
t.Fatal("Expected non-nil underlying client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactWriteItems(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestTransactions"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Test transaction with two puts
|
||||
testID1 := "test-txn-1"
|
||||
testID2 := "test-txn-2"
|
||||
|
||||
input := &dynamodb.TransactWriteItemsInput{
|
||||
TransactItems: []types.TransactWriteItem{
|
||||
{
|
||||
Put: &types.Put{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID1},
|
||||
"value": &types.AttributeValueMemberS{Value: "value1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Put: &types.Put{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID2},
|
||||
"value": &types.AttributeValueMemberS{Value: "value2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := client.TransactWriteItems(ctx, input)
|
||||
if err != nil {
|
||||
t.Fatalf("TransactWriteItems failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify both items were written
|
||||
output1, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID1},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
if len(output1.Item) == 0 {
|
||||
t.Fatal("Expected item to exist after transaction")
|
||||
}
|
||||
|
||||
output2, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID2},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
if len(output2.Item) == 0 {
|
||||
t.Fatal("Expected item to exist after transaction")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchGetItems(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestBatchGet"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put test items
|
||||
testIDs := []string{"batch-1", "batch-2", "batch-3"}
|
||||
for _, id := range testIDs {
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
"value": &types.AttributeValueMemberS{Value: "test-value"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch get items
|
||||
keys := make([]map[string]types.AttributeValue, len(testIDs))
|
||||
for i, id := range testIDs {
|
||||
keys[i] = map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
}
|
||||
}
|
||||
|
||||
output, err := client.BatchGetItems(ctx, &dynamodb.BatchGetItemInput{
|
||||
RequestItems: map[string]types.KeysAndAttributes{
|
||||
tableName: {
|
||||
Keys: keys,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BatchGetItems failed: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Responses[tableName]) != len(testIDs) {
|
||||
t.Fatalf("Expected %d items, got %d", len(testIDs), len(output.Responses[tableName]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchWriteItems(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestBatchWrite"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Batch write items
|
||||
testIDs := []string{"write-1", "write-2", "write-3"}
|
||||
writeRequests := make([]types.WriteRequest, len(testIDs))
|
||||
for i, id := range testIDs {
|
||||
writeRequests[i] = types.WriteRequest{
|
||||
PutRequest: &types.PutRequest{
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
"value": &types.AttributeValueMemberS{Value: "batch-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
err := client.BatchWriteItems(ctx, &dynamodb.BatchWriteItemInput{
|
||||
RequestItems: map[string][]types.WriteRequest{
|
||||
tableName: writeRequests,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BatchWriteItems failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify items were written
|
||||
for _, id := range testIDs {
|
||||
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
if len(output.Item) == 0 {
|
||||
t.Fatalf("Expected item %s to exist after batch write", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutAndGetItem(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestPutGet"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put item
|
||||
testID := "put-get-test"
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
"value": &types.AttributeValueMemberS{Value: "test-value"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Get item
|
||||
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Item) == 0 {
|
||||
t.Fatal("Expected item to exist")
|
||||
}
|
||||
|
||||
valueAttr, ok := output.Item["value"]
|
||||
if !ok {
|
||||
t.Fatal("Expected 'value' attribute")
|
||||
}
|
||||
|
||||
value := valueAttr.(*types.AttributeValueMemberS).Value
|
||||
if value != "test-value" {
|
||||
t.Fatalf("Expected value 'test-value', got '%s'", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateItem(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestUpdate"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put initial item
|
||||
testID := "update-test"
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
"value": &types.AttributeValueMemberS{Value: "initial"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Update item
|
||||
_, err = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
UpdateExpression: aws.String("SET #v = :val"),
|
||||
ExpressionAttributeNames: map[string]string{
|
||||
"#v": "value",
|
||||
},
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":val": &types.AttributeValueMemberS{Value: "updated"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify update
|
||||
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
|
||||
value := output.Item["value"].(*types.AttributeValueMemberS).Value
|
||||
if value != "updated" {
|
||||
t.Fatalf("Expected value 'updated', got '%s'", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteItem(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestDelete"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put item
|
||||
testID := "delete-test"
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
"value": &types.AttributeValueMemberS{Value: "test"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Delete item
|
||||
err = client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteItem failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: testID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetItem failed: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Item) != 0 {
|
||||
t.Fatal("Expected item to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
client, ctx := setupTestClient(t)
|
||||
|
||||
// Create test table
|
||||
tableName := "TestQuery"
|
||||
createTestTable(t, ctx, client, tableName)
|
||||
defer deleteTestTable(t, ctx, client, tableName)
|
||||
|
||||
// Put test items
|
||||
testIDs := []string{"query-1", "query-2", "query-3"}
|
||||
for _, id := range testIDs {
|
||||
err := client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(tableName),
|
||||
Item: map[string]types.AttributeValue{
|
||||
"id": &types.AttributeValueMemberS{Value: id},
|
||||
"value": &types.AttributeValueMemberS{Value: "test"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PutItem failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Query for specific item
|
||||
output, err := client.Query(ctx, &dynamodb.QueryInput{
|
||||
TableName: aws.String(tableName),
|
||||
KeyConditionExpression: aws.String("id = :id"),
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":id": &types.AttributeValueMemberS{Value: "query-1"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Query failed: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Items) != 1 {
|
||||
t.Fatalf("Expected 1 item, got %d", len(output.Items))
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestTable(t *testing.T, ctx context.Context, client *DynamoDBClient, tableName string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := client.client.CreateTable(ctx, &dynamodb.CreateTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
AttributeDefinitions: []types.AttributeDefinition{
|
||||
{
|
||||
AttributeName: aws.String("id"),
|
||||
AttributeType: types.ScalarAttributeTypeS,
|
||||
},
|
||||
},
|
||||
KeySchema: []types.KeySchemaElement{
|
||||
{
|
||||
AttributeName: aws.String("id"),
|
||||
KeyType: types.KeyTypeHash,
|
||||
},
|
||||
},
|
||||
BillingMode: types.BillingModePayPerRequest,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
// Wait for table to be active
|
||||
waiter := dynamodb.NewTableExistsWaiter(client.client)
|
||||
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
}, 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed waiting for table to be active: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteTestTable(t *testing.T, ctx context.Context, client *DynamoDBClient, tableName string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := client.client.DeleteTable(ctx, &dynamodb.DeleteTableInput{
|
||||
TableName: aws.String(tableName),
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete test table: %v", err)
|
||||
}
|
||||
}
|
||||
115
internal/storage/user_repository.go
Normal file
115
internal/storage/user_repository.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
||||
|
||||
"custom-start-page/internal/models"
|
||||
)
|
||||
|
||||
// UserRepository handles user data operations
|
||||
type UserRepository struct {
|
||||
client *DynamoDBClient
|
||||
tableName string
|
||||
}
|
||||
|
||||
// NewUserRepository creates a new user repository
|
||||
func NewUserRepository(client *DynamoDBClient, tableName string) *UserRepository {
|
||||
return &UserRepository{
|
||||
client: client,
|
||||
tableName: tableName,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates a new user in the database
|
||||
func (r *UserRepository) Create(ctx context.Context, user *models.User) error {
|
||||
item, err := attributevalue.MarshalMap(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(r.tableName),
|
||||
Item: item,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a user by their ID
|
||||
func (r *UserRepository) GetByID(ctx context.Context, userID string) (*models.User, error) {
|
||||
result, err := r.client.GetClient().GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(r.tableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"user_id": &types.AttributeValueMemberS{Value: userID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if result.Item == nil {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
var user models.User
|
||||
if err := attributevalue.UnmarshalMap(result.Item, &user); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByOAuthID retrieves a user by their OAuth provider and OAuth ID
|
||||
func (r *UserRepository) GetByOAuthID(ctx context.Context, provider, oauthID string) (*models.User, error) {
|
||||
// Use a scan with filter for now
|
||||
// In production, consider adding a GSI on oauth_provider + oauth_id
|
||||
result, err := r.client.GetClient().Scan(ctx, &dynamodb.ScanInput{
|
||||
TableName: aws.String(r.tableName),
|
||||
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :oauth_id"),
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":provider": &types.AttributeValueMemberS{Value: provider},
|
||||
":oauth_id": &types.AttributeValueMemberS{Value: oauthID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan users: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Items) == 0 {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
var user models.User
|
||||
if err := attributevalue.UnmarshalMap(result.Items[0], &user); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// Update updates an existing user
|
||||
func (r *UserRepository) Update(ctx context.Context, user *models.User) error {
|
||||
item, err := attributevalue.MarshalMap(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(r.tableName),
|
||||
Item: item,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
143
internal/storage/user_storage.go
Normal file
143
internal/storage/user_storage.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"custom-start-page/internal/models"
|
||||
)
|
||||
|
||||
const usersTableName = "Users"
|
||||
|
||||
// UserStorage handles user data operations
|
||||
type UserStorage struct {
|
||||
db *DynamoDBClient
|
||||
}
|
||||
|
||||
// NewUserStorage creates a new UserStorage instance
|
||||
func NewUserStorage(db *DynamoDBClient) *UserStorage {
|
||||
return &UserStorage{db: db}
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in DynamoDB
|
||||
func (s *UserStorage) CreateUser(ctx context.Context, email, oauthProvider, oauthID string) (*models.User, error) {
|
||||
user := &models.User{
|
||||
ID: uuid.New().String(),
|
||||
Email: email,
|
||||
OAuthProvider: oauthProvider,
|
||||
OAuthID: oauthID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
item, err := attributevalue.MarshalMap(user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
Item: item,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by their ID
|
||||
func (s *UserStorage) GetUserByID(ctx context.Context, userID string) (*models.User, error) {
|
||||
result, err := s.db.client.GetItem(ctx, &dynamodb.GetItemInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"user_id": &types.AttributeValueMemberS{Value: userID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if result.Item == nil {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
var user models.User
|
||||
err = attributevalue.UnmarshalMap(result.Item, &user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByOAuth retrieves a user by their OAuth provider and ID
|
||||
func (s *UserStorage) GetUserByOAuth(ctx context.Context, oauthProvider, oauthID string) (*models.User, error) {
|
||||
// Since we don't have a GSI for oauth_provider + oauth_id, we need to scan
|
||||
// In production, you might want to add a GSI for this access pattern
|
||||
result, err := s.db.client.Scan(ctx, &dynamodb.ScanInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :id"),
|
||||
ExpressionAttributeValues: map[string]types.AttributeValue{
|
||||
":provider": &types.AttributeValueMemberS{Value: oauthProvider},
|
||||
":id": &types.AttributeValueMemberS{Value: oauthID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan for user: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Items) == 0 {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
var user models.User
|
||||
err = attributevalue.UnmarshalMap(result.Items[0], &user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates an existing user
|
||||
func (s *UserStorage) UpdateUser(ctx context.Context, user *models.User) error {
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
item, err := attributevalue.MarshalMap(user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal user: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
Item: item,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user by their ID
|
||||
func (s *UserStorage) DeleteUser(ctx context.Context, userID string) error {
|
||||
_, err := s.db.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
|
||||
TableName: aws.String(usersTableName),
|
||||
Key: map[string]types.AttributeValue{
|
||||
"user_id": &types.AttributeValueMemberS{Value: userID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
228
internal/storage/user_storage_test.go
Normal file
228
internal/storage/user_storage_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *DynamoDBClient {
|
||||
ctx := context.Background()
|
||||
endpoint := os.Getenv("DYNAMODB_ENDPOINT")
|
||||
if endpoint == "" {
|
||||
endpoint = "http://localhost:8000"
|
||||
}
|
||||
|
||||
db, err := NewDynamoDBClient(ctx, endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create DynamoDB client: %v", err)
|
||||
}
|
||||
|
||||
// Create Users table
|
||||
if err := db.CreateUsersTable(ctx); err != nil {
|
||||
t.Fatalf("Failed to create Users table: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
if user.ID == "" {
|
||||
t.Error("User ID should not be empty")
|
||||
}
|
||||
if user.Email != "test@example.com" {
|
||||
t.Errorf("Expected email 'test@example.com', got '%s'", user.Email)
|
||||
}
|
||||
if user.OAuthProvider != "google" {
|
||||
t.Errorf("Expected oauth_provider 'google', got '%s'", user.OAuthProvider)
|
||||
}
|
||||
if user.OAuthID != "google123" {
|
||||
t.Errorf("Expected oauth_id 'google123', got '%s'", user.OAuthID)
|
||||
}
|
||||
if user.CreatedAt.IsZero() {
|
||||
t.Error("CreatedAt should not be zero")
|
||||
}
|
||||
if user.UpdatedAt.IsZero() {
|
||||
t.Error("UpdatedAt should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a user first
|
||||
createdUser, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve the user
|
||||
retrievedUser, err := storage.GetUserByID(ctx, createdUser.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUser.ID != createdUser.ID {
|
||||
t.Errorf("Expected ID '%s', got '%s'", createdUser.ID, retrievedUser.ID)
|
||||
}
|
||||
if retrievedUser.Email != createdUser.Email {
|
||||
t.Errorf("Expected email '%s', got '%s'", createdUser.Email, retrievedUser.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByID_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := storage.GetUserByID(ctx, "nonexistent-id")
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting nonexistent user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByOAuth(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a user first
|
||||
createdUser, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve the user by OAuth
|
||||
retrievedUser, err := storage.GetUserByOAuth(ctx, "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user by OAuth: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUser.ID != createdUser.ID {
|
||||
t.Errorf("Expected ID '%s', got '%s'", createdUser.ID, retrievedUser.ID)
|
||||
}
|
||||
if retrievedUser.Email != createdUser.Email {
|
||||
t.Errorf("Expected email '%s', got '%s'", createdUser.Email, retrievedUser.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByOAuth_NotFound(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := storage.GetUserByOAuth(ctx, "google", "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting nonexistent user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a user first
|
||||
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
originalUpdatedAt := user.UpdatedAt
|
||||
time.Sleep(10 * time.Millisecond) // Ensure time difference
|
||||
|
||||
// Update the user
|
||||
user.Email = "updated@example.com"
|
||||
err = storage.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update user: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve and verify
|
||||
updatedUser, err := storage.GetUserByID(ctx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated user: %v", err)
|
||||
}
|
||||
|
||||
if updatedUser.Email != "updated@example.com" {
|
||||
t.Errorf("Expected email 'updated@example.com', got '%s'", updatedUser.Email)
|
||||
}
|
||||
if !updatedUser.UpdatedAt.After(originalUpdatedAt) {
|
||||
t.Error("UpdatedAt should be updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a user first
|
||||
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
|
||||
// Delete the user
|
||||
err = storage.DeleteUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
_, err = storage.GetUserByID(ctx, user.ID)
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting deleted user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser_MultipleUsers(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
storage := NewUserStorage(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple users
|
||||
user1, err := storage.CreateUser(ctx, "user1@example.com", "google", "google1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user1: %v", err)
|
||||
}
|
||||
|
||||
user2, err := storage.CreateUser(ctx, "user2@example.com", "github", "github1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user2: %v", err)
|
||||
}
|
||||
|
||||
// Verify both users exist and are different
|
||||
if user1.ID == user2.ID {
|
||||
t.Error("User IDs should be unique")
|
||||
}
|
||||
|
||||
retrievedUser1, err := storage.GetUserByID(ctx, user1.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user1: %v", err)
|
||||
}
|
||||
|
||||
retrievedUser2, err := storage.GetUserByID(ctx, user2.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user2: %v", err)
|
||||
}
|
||||
|
||||
if retrievedUser1.Email != "user1@example.com" {
|
||||
t.Errorf("Expected user1 email 'user1@example.com', got '%s'", retrievedUser1.Email)
|
||||
}
|
||||
if retrievedUser2.Email != "user2@example.com" {
|
||||
t.Errorf("Expected user2 email 'user2@example.com', got '%s'", retrievedUser2.Email)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user