Initial commit: Custom Start Page application with authentication and DynamoDB storage

This commit is contained in:
2026-02-18 22:06:43 -05:00
commit 7175ff14ba
47 changed files with 7592 additions and 0 deletions

124
internal/storage/README.md Normal file
View File

@@ -0,0 +1,124 @@
# DynamoDB Storage Service
This package provides an enhanced DynamoDB client wrapper with the following features:
## Features
### 1. Connection Pooling
The client uses the AWS SDK's default HTTP client which includes connection pooling automatically. This ensures efficient reuse of TCP connections to DynamoDB.
### 2. Retry Logic with Exponential Backoff
The client is configured with automatic retry logic:
- **Max Attempts**: 5 retries
- **Max Backoff**: 20 seconds
- **Strategy**: Exponential backoff with jitter to prevent thundering herd
This handles transient failures gracefully and improves reliability.
### 3. Transaction Support
The `TransactWriteItems` method provides ACID transaction support for multiple write operations:
```go
err := client.TransactWriteItems(ctx, &dynamodb.TransactWriteItemsInput{
TransactItems: []types.TransactWriteItem{
{
Put: &types.Put{
TableName: aws.String("MyTable"),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: "item1"},
},
},
},
// More items...
},
})
```
### 4. Batch Operations
The client provides batch read and write operations with automatic retry of unprocessed items:
#### BatchGetItems
Retrieves multiple items in a single request:
```go
output, err := client.BatchGetItems(ctx, &dynamodb.BatchGetItemInput{
RequestItems: map[string]types.KeysAndAttributes{
"MyTable": {
Keys: []map[string]types.AttributeValue{
{"id": &types.AttributeValueMemberS{Value: "item1"}},
{"id": &types.AttributeValueMemberS{Value: "item2"}},
},
},
},
})
```
#### BatchWriteItems
Writes multiple items in a single request:
```go
err := client.BatchWriteItems(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: map[string][]types.WriteRequest{
"MyTable": {
{
PutRequest: &types.PutRequest{
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: "item1"},
},
},
},
},
},
})
```
Both batch operations automatically handle unprocessed items with exponential backoff retry logic.
## Standard Operations
The client also provides wrapped versions of standard DynamoDB operations with automatic retry:
- `PutItem` - Put a single item
- `GetItem` - Get a single item
- `UpdateItem` - Update a single item
- `DeleteItem` - Delete a single item
- `Query` - Query items
## Usage
### Creating a Client
```go
ctx := context.Background()
client, err := storage.NewDynamoDBClient(ctx, "http://localhost:8000")
if err != nil {
log.Fatal(err)
}
```
For production (AWS DynamoDB), pass an empty string for the endpoint:
```go
client, err := storage.NewDynamoDBClient(ctx, "")
```
### Testing
The package includes comprehensive tests that can be run against DynamoDB Local:
1. Start DynamoDB Local:
```bash
docker-compose up -d
```
2. Run tests:
```bash
go test -v ./internal/storage
```
Tests will automatically skip if DynamoDB is not available.
## Requirements Addressed
This implementation addresses the following requirements from the spec:
- **Requirement 8.1**: Immediate persistence of all changes
- **Requirement 8.8**: Efficient scaling for 10,000+ items per user
- **Design requirement**: Retry logic with exponential backoff
- **Design requirement**: Transaction support for atomic operations
- **Design requirement**: Batch operations for efficient bulk reads/writes
- **Design requirement**: Connection pooling for performance

View File

@@ -0,0 +1,248 @@
package storage
import (
"context"
"fmt"
"math"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)
// DynamoDBClient wraps the AWS DynamoDB client with enhanced features
type DynamoDBClient struct {
client *dynamodb.Client
}
// NewDynamoDBClient creates a new DynamoDB client with connection pooling and retry logic
func NewDynamoDBClient(ctx context.Context, endpoint string) (*DynamoDBClient, error) {
// Configure retry strategy with exponential backoff
retryer := retry.NewStandard(func(o *retry.StandardOptions) {
o.MaxAttempts = 5
o.MaxBackoff = 20 * time.Second
// Use exponential backoff with jitter
o.Backoff = retry.NewExponentialJitterBackoff(20 * time.Second)
})
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRetryer(func() aws.Retryer {
return retryer
}),
// Connection pooling is handled by the HTTP client
config.WithHTTPClient(nil), // Uses default HTTP client with connection pooling
)
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %w", err)
}
// Override endpoint if provided (for local DynamoDB)
if endpoint != "" {
cfg.BaseEndpoint = aws.String(endpoint)
}
client := dynamodb.NewFromConfig(cfg)
return &DynamoDBClient{
client: client,
}, nil
}
// CreateUsersTable creates the Users table in DynamoDB
func (db *DynamoDBClient) CreateUsersTable(ctx context.Context) error {
tableName := "Users"
// Check if table already exists
_, err := db.client.DescribeTable(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
})
if err == nil {
// Table already exists
return nil
}
// Create table
_, err = db.client.CreateTable(ctx, &dynamodb.CreateTableInput{
TableName: aws.String(tableName),
AttributeDefinitions: []types.AttributeDefinition{
{
AttributeName: aws.String("user_id"),
AttributeType: types.ScalarAttributeTypeS,
},
},
KeySchema: []types.KeySchemaElement{
{
AttributeName: aws.String("user_id"),
KeyType: types.KeyTypeHash,
},
},
BillingMode: types.BillingModePayPerRequest,
})
if err != nil {
return fmt.Errorf("failed to create Users table: %w", err)
}
// Wait for table to be active
waiter := dynamodb.NewTableExistsWaiter(db.client)
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
}, 5*60) // 5 minutes timeout
if err != nil {
return fmt.Errorf("failed waiting for Users table to be active: %w", err)
}
return nil
}
// GetClient returns the underlying DynamoDB client
func (db *DynamoDBClient) GetClient() *dynamodb.Client {
return db.client
}
// TransactWriteItems executes a transactional write operation with automatic retry
func (db *DynamoDBClient) TransactWriteItems(ctx context.Context, input *dynamodb.TransactWriteItemsInput) error {
_, err := db.client.TransactWriteItems(ctx, input)
if err != nil {
return fmt.Errorf("transaction write failed: %w", err)
}
return nil
}
// BatchGetItems retrieves multiple items in a single batch operation
func (db *DynamoDBClient) BatchGetItems(ctx context.Context, input *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error) {
output, err := db.client.BatchGetItem(ctx, input)
if err != nil {
return nil, fmt.Errorf("batch get failed: %w", err)
}
// Handle unprocessed keys with exponential backoff
if len(output.UnprocessedKeys) > 0 {
return db.retryUnprocessedKeys(ctx, output)
}
return output, nil
}
// BatchWriteItems writes multiple items in a single batch operation
func (db *DynamoDBClient) BatchWriteItems(ctx context.Context, input *dynamodb.BatchWriteItemInput) error {
output, err := db.client.BatchWriteItem(ctx, input)
if err != nil {
return fmt.Errorf("batch write failed: %w", err)
}
// Handle unprocessed items with exponential backoff
if len(output.UnprocessedItems) > 0 {
return db.retryUnprocessedWrites(ctx, output.UnprocessedItems)
}
return nil
}
// retryUnprocessedKeys retries unprocessed keys from BatchGetItem with exponential backoff
func (db *DynamoDBClient) retryUnprocessedKeys(ctx context.Context, output *dynamodb.BatchGetItemOutput) (*dynamodb.BatchGetItemOutput, error) {
maxRetries := 5
backoff := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries && len(output.UnprocessedKeys) > 0; attempt++ {
// Wait with exponential backoff
time.Sleep(backoff)
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
// Retry unprocessed keys
retryOutput, err := db.client.BatchGetItem(ctx, &dynamodb.BatchGetItemInput{
RequestItems: output.UnprocessedKeys,
})
if err != nil {
return nil, fmt.Errorf("retry batch get failed: %w", err)
}
// Merge responses
for table, items := range retryOutput.Responses {
output.Responses[table] = append(output.Responses[table], items...)
}
output.UnprocessedKeys = retryOutput.UnprocessedKeys
}
if len(output.UnprocessedKeys) > 0 {
return output, fmt.Errorf("failed to process all keys after %d retries", maxRetries)
}
return output, nil
}
// retryUnprocessedWrites retries unprocessed items from BatchWriteItem with exponential backoff
func (db *DynamoDBClient) retryUnprocessedWrites(ctx context.Context, unprocessedItems map[string][]types.WriteRequest) error {
maxRetries := 5
backoff := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries && len(unprocessedItems) > 0; attempt++ {
// Wait with exponential backoff
time.Sleep(backoff)
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
// Retry unprocessed items
output, err := db.client.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: unprocessedItems,
})
if err != nil {
return fmt.Errorf("retry batch write failed: %w", err)
}
unprocessedItems = output.UnprocessedItems
}
if len(unprocessedItems) > 0 {
return fmt.Errorf("failed to process all items after %d retries", maxRetries)
}
return nil
}
// PutItem puts a single item with automatic retry
func (db *DynamoDBClient) PutItem(ctx context.Context, input *dynamodb.PutItemInput) error {
_, err := db.client.PutItem(ctx, input)
if err != nil {
return fmt.Errorf("put item failed: %w", err)
}
return nil
}
// GetItem retrieves a single item with automatic retry
func (db *DynamoDBClient) GetItem(ctx context.Context, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
output, err := db.client.GetItem(ctx, input)
if err != nil {
return nil, fmt.Errorf("get item failed: %w", err)
}
return output, nil
}
// Query executes a query operation with automatic retry
func (db *DynamoDBClient) Query(ctx context.Context, input *dynamodb.QueryInput) (*dynamodb.QueryOutput, error) {
output, err := db.client.Query(ctx, input)
if err != nil {
return nil, fmt.Errorf("query failed: %w", err)
}
return output, nil
}
// UpdateItem updates a single item with automatic retry
func (db *DynamoDBClient) UpdateItem(ctx context.Context, input *dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) {
output, err := db.client.UpdateItem(ctx, input)
if err != nil {
return nil, fmt.Errorf("update item failed: %w", err)
}
return output, nil
}
// DeleteItem deletes a single item with automatic retry
func (db *DynamoDBClient) DeleteItem(ctx context.Context, input *dynamodb.DeleteItemInput) error {
_, err := db.client.DeleteItem(ctx, input)
if err != nil {
return fmt.Errorf("delete item failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,452 @@
package storage
import (
"context"
"os"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)
func setupTestClient(t *testing.T) (*DynamoDBClient, context.Context) {
t.Helper()
// Set dummy AWS credentials for local DynamoDB
os.Setenv("AWS_ACCESS_KEY_ID", "dummy")
os.Setenv("AWS_SECRET_ACCESS_KEY", "dummy")
os.Setenv("AWS_REGION", "us-east-1")
endpoint := os.Getenv("DYNAMODB_ENDPOINT")
if endpoint == "" {
endpoint = "http://localhost:8000"
}
ctx := context.Background()
client, err := NewDynamoDBClient(ctx, endpoint)
if err != nil {
t.Skipf("Skipping test: DynamoDB not available: %v", err)
}
// Test connection by listing tables
_, err = client.client.ListTables(ctx, &dynamodb.ListTablesInput{})
if err != nil {
t.Skipf("Skipping test: Cannot connect to DynamoDB: %v", err)
}
return client, ctx
}
func TestNewDynamoDBClient(t *testing.T) {
client, _ := setupTestClient(t)
if client == nil {
t.Fatal("Expected non-nil client")
}
if client.client == nil {
t.Fatal("Expected non-nil underlying client")
}
}
func TestTransactWriteItems(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestTransactions"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Test transaction with two puts
testID1 := "test-txn-1"
testID2 := "test-txn-2"
input := &dynamodb.TransactWriteItemsInput{
TransactItems: []types.TransactWriteItem{
{
Put: &types.Put{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID1},
"value": &types.AttributeValueMemberS{Value: "value1"},
},
},
},
{
Put: &types.Put{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID2},
"value": &types.AttributeValueMemberS{Value: "value2"},
},
},
},
},
}
err := client.TransactWriteItems(ctx, input)
if err != nil {
t.Fatalf("TransactWriteItems failed: %v", err)
}
// Verify both items were written
output1, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID1},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output1.Item) == 0 {
t.Fatal("Expected item to exist after transaction")
}
output2, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID2},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output2.Item) == 0 {
t.Fatal("Expected item to exist after transaction")
}
}
func TestBatchGetItems(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestBatchGet"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put test items
testIDs := []string{"batch-1", "batch-2", "batch-3"}
for _, id := range testIDs {
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
"value": &types.AttributeValueMemberS{Value: "test-value"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
}
// Batch get items
keys := make([]map[string]types.AttributeValue, len(testIDs))
for i, id := range testIDs {
keys[i] = map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
}
}
output, err := client.BatchGetItems(ctx, &dynamodb.BatchGetItemInput{
RequestItems: map[string]types.KeysAndAttributes{
tableName: {
Keys: keys,
},
},
})
if err != nil {
t.Fatalf("BatchGetItems failed: %v", err)
}
if len(output.Responses[tableName]) != len(testIDs) {
t.Fatalf("Expected %d items, got %d", len(testIDs), len(output.Responses[tableName]))
}
}
func TestBatchWriteItems(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestBatchWrite"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Batch write items
testIDs := []string{"write-1", "write-2", "write-3"}
writeRequests := make([]types.WriteRequest, len(testIDs))
for i, id := range testIDs {
writeRequests[i] = types.WriteRequest{
PutRequest: &types.PutRequest{
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
"value": &types.AttributeValueMemberS{Value: "batch-value"},
},
},
}
}
err := client.BatchWriteItems(ctx, &dynamodb.BatchWriteItemInput{
RequestItems: map[string][]types.WriteRequest{
tableName: writeRequests,
},
})
if err != nil {
t.Fatalf("BatchWriteItems failed: %v", err)
}
// Verify items were written
for _, id := range testIDs {
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output.Item) == 0 {
t.Fatalf("Expected item %s to exist after batch write", id)
}
}
}
func TestPutAndGetItem(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestPutGet"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put item
testID := "put-get-test"
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
"value": &types.AttributeValueMemberS{Value: "test-value"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
// Get item
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output.Item) == 0 {
t.Fatal("Expected item to exist")
}
valueAttr, ok := output.Item["value"]
if !ok {
t.Fatal("Expected 'value' attribute")
}
value := valueAttr.(*types.AttributeValueMemberS).Value
if value != "test-value" {
t.Fatalf("Expected value 'test-value', got '%s'", value)
}
}
func TestUpdateItem(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestUpdate"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put initial item
testID := "update-test"
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
"value": &types.AttributeValueMemberS{Value: "initial"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
// Update item
_, err = client.UpdateItem(ctx, &dynamodb.UpdateItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
UpdateExpression: aws.String("SET #v = :val"),
ExpressionAttributeNames: map[string]string{
"#v": "value",
},
ExpressionAttributeValues: map[string]types.AttributeValue{
":val": &types.AttributeValueMemberS{Value: "updated"},
},
})
if err != nil {
t.Fatalf("UpdateItem failed: %v", err)
}
// Verify update
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
value := output.Item["value"].(*types.AttributeValueMemberS).Value
if value != "updated" {
t.Fatalf("Expected value 'updated', got '%s'", value)
}
}
func TestDeleteItem(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestDelete"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put item
testID := "delete-test"
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
"value": &types.AttributeValueMemberS{Value: "test"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
// Delete item
err = client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
})
if err != nil {
t.Fatalf("DeleteItem failed: %v", err)
}
// Verify deletion
output, err := client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(tableName),
Key: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: testID},
},
})
if err != nil {
t.Fatalf("GetItem failed: %v", err)
}
if len(output.Item) != 0 {
t.Fatal("Expected item to be deleted")
}
}
func TestQuery(t *testing.T) {
client, ctx := setupTestClient(t)
// Create test table
tableName := "TestQuery"
createTestTable(t, ctx, client, tableName)
defer deleteTestTable(t, ctx, client, tableName)
// Put test items
testIDs := []string{"query-1", "query-2", "query-3"}
for _, id := range testIDs {
err := client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(tableName),
Item: map[string]types.AttributeValue{
"id": &types.AttributeValueMemberS{Value: id},
"value": &types.AttributeValueMemberS{Value: "test"},
},
})
if err != nil {
t.Fatalf("PutItem failed: %v", err)
}
}
// Query for specific item
output, err := client.Query(ctx, &dynamodb.QueryInput{
TableName: aws.String(tableName),
KeyConditionExpression: aws.String("id = :id"),
ExpressionAttributeValues: map[string]types.AttributeValue{
":id": &types.AttributeValueMemberS{Value: "query-1"},
},
})
if err != nil {
t.Fatalf("Query failed: %v", err)
}
if len(output.Items) != 1 {
t.Fatalf("Expected 1 item, got %d", len(output.Items))
}
}
// Helper functions
func createTestTable(t *testing.T, ctx context.Context, client *DynamoDBClient, tableName string) {
t.Helper()
_, err := client.client.CreateTable(ctx, &dynamodb.CreateTableInput{
TableName: aws.String(tableName),
AttributeDefinitions: []types.AttributeDefinition{
{
AttributeName: aws.String("id"),
AttributeType: types.ScalarAttributeTypeS,
},
},
KeySchema: []types.KeySchemaElement{
{
AttributeName: aws.String("id"),
KeyType: types.KeyTypeHash,
},
},
BillingMode: types.BillingModePayPerRequest,
})
if err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
// Wait for table to be active
waiter := dynamodb.NewTableExistsWaiter(client.client)
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
TableName: aws.String(tableName),
}, 30*time.Second)
if err != nil {
t.Fatalf("Failed waiting for table to be active: %v", err)
}
}
func deleteTestTable(t *testing.T, ctx context.Context, client *DynamoDBClient, tableName string) {
t.Helper()
_, err := client.client.DeleteTable(ctx, &dynamodb.DeleteTableInput{
TableName: aws.String(tableName),
})
if err != nil {
t.Logf("Warning: Failed to delete test table: %v", err)
}
}

View File

@@ -0,0 +1,115 @@
package storage
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"custom-start-page/internal/models"
)
// UserRepository handles user data operations
type UserRepository struct {
client *DynamoDBClient
tableName string
}
// NewUserRepository creates a new user repository
func NewUserRepository(client *DynamoDBClient, tableName string) *UserRepository {
return &UserRepository{
client: client,
tableName: tableName,
}
}
// Create creates a new user in the database
func (r *UserRepository) Create(ctx context.Context, user *models.User) error {
item, err := attributevalue.MarshalMap(user)
if err != nil {
return fmt.Errorf("failed to marshal user: %w", err)
}
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(r.tableName),
Item: item,
})
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// GetByID retrieves a user by their ID
func (r *UserRepository) GetByID(ctx context.Context, userID string) (*models.User, error) {
result, err := r.client.GetClient().GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(r.tableName),
Key: map[string]types.AttributeValue{
"user_id": &types.AttributeValueMemberS{Value: userID},
},
})
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if result.Item == nil {
return nil, fmt.Errorf("user not found")
}
var user models.User
if err := attributevalue.UnmarshalMap(result.Item, &user); err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
return &user, nil
}
// GetByOAuthID retrieves a user by their OAuth provider and OAuth ID
func (r *UserRepository) GetByOAuthID(ctx context.Context, provider, oauthID string) (*models.User, error) {
// Use a scan with filter for now
// In production, consider adding a GSI on oauth_provider + oauth_id
result, err := r.client.GetClient().Scan(ctx, &dynamodb.ScanInput{
TableName: aws.String(r.tableName),
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :oauth_id"),
ExpressionAttributeValues: map[string]types.AttributeValue{
":provider": &types.AttributeValueMemberS{Value: provider},
":oauth_id": &types.AttributeValueMemberS{Value: oauthID},
},
})
if err != nil {
return nil, fmt.Errorf("failed to scan users: %w", err)
}
if len(result.Items) == 0 {
return nil, fmt.Errorf("user not found")
}
var user models.User
if err := attributevalue.UnmarshalMap(result.Items[0], &user); err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
return &user, nil
}
// Update updates an existing user
func (r *UserRepository) Update(ctx context.Context, user *models.User) error {
item, err := attributevalue.MarshalMap(user)
if err != nil {
return fmt.Errorf("failed to marshal user: %w", err)
}
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(r.tableName),
Item: item,
})
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}

View File

@@ -0,0 +1,143 @@
package storage
import (
"context"
"fmt"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/google/uuid"
"custom-start-page/internal/models"
)
const usersTableName = "Users"
// UserStorage handles user data operations
type UserStorage struct {
db *DynamoDBClient
}
// NewUserStorage creates a new UserStorage instance
func NewUserStorage(db *DynamoDBClient) *UserStorage {
return &UserStorage{db: db}
}
// CreateUser creates a new user in DynamoDB
func (s *UserStorage) CreateUser(ctx context.Context, email, oauthProvider, oauthID string) (*models.User, error) {
user := &models.User{
ID: uuid.New().String(),
Email: email,
OAuthProvider: oauthProvider,
OAuthID: oauthID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
item, err := attributevalue.MarshalMap(user)
if err != nil {
return nil, fmt.Errorf("failed to marshal user: %w", err)
}
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(usersTableName),
Item: item,
})
if err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
}
return user, nil
}
// GetUserByID retrieves a user by their ID
func (s *UserStorage) GetUserByID(ctx context.Context, userID string) (*models.User, error) {
result, err := s.db.client.GetItem(ctx, &dynamodb.GetItemInput{
TableName: aws.String(usersTableName),
Key: map[string]types.AttributeValue{
"user_id": &types.AttributeValueMemberS{Value: userID},
},
})
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if result.Item == nil {
return nil, fmt.Errorf("user not found")
}
var user models.User
err = attributevalue.UnmarshalMap(result.Item, &user)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
return &user, nil
}
// GetUserByOAuth retrieves a user by their OAuth provider and ID
func (s *UserStorage) GetUserByOAuth(ctx context.Context, oauthProvider, oauthID string) (*models.User, error) {
// Since we don't have a GSI for oauth_provider + oauth_id, we need to scan
// In production, you might want to add a GSI for this access pattern
result, err := s.db.client.Scan(ctx, &dynamodb.ScanInput{
TableName: aws.String(usersTableName),
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :id"),
ExpressionAttributeValues: map[string]types.AttributeValue{
":provider": &types.AttributeValueMemberS{Value: oauthProvider},
":id": &types.AttributeValueMemberS{Value: oauthID},
},
})
if err != nil {
return nil, fmt.Errorf("failed to scan for user: %w", err)
}
if len(result.Items) == 0 {
return nil, fmt.Errorf("user not found")
}
var user models.User
err = attributevalue.UnmarshalMap(result.Items[0], &user)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
}
return &user, nil
}
// UpdateUser updates an existing user
func (s *UserStorage) UpdateUser(ctx context.Context, user *models.User) error {
user.UpdatedAt = time.Now()
item, err := attributevalue.MarshalMap(user)
if err != nil {
return fmt.Errorf("failed to marshal user: %w", err)
}
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
TableName: aws.String(usersTableName),
Item: item,
})
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// DeleteUser deletes a user by their ID
func (s *UserStorage) DeleteUser(ctx context.Context, userID string) error {
_, err := s.db.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
TableName: aws.String(usersTableName),
Key: map[string]types.AttributeValue{
"user_id": &types.AttributeValueMemberS{Value: userID},
},
})
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
}

View File

@@ -0,0 +1,228 @@
package storage
import (
"context"
"os"
"testing"
"time"
)
func setupTestDB(t *testing.T) *DynamoDBClient {
ctx := context.Background()
endpoint := os.Getenv("DYNAMODB_ENDPOINT")
if endpoint == "" {
endpoint = "http://localhost:8000"
}
db, err := NewDynamoDBClient(ctx, endpoint)
if err != nil {
t.Fatalf("Failed to create DynamoDB client: %v", err)
}
// Create Users table
if err := db.CreateUsersTable(ctx); err != nil {
t.Fatalf("Failed to create Users table: %v", err)
}
return db
}
func TestCreateUser(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
if user.ID == "" {
t.Error("User ID should not be empty")
}
if user.Email != "test@example.com" {
t.Errorf("Expected email 'test@example.com', got '%s'", user.Email)
}
if user.OAuthProvider != "google" {
t.Errorf("Expected oauth_provider 'google', got '%s'", user.OAuthProvider)
}
if user.OAuthID != "google123" {
t.Errorf("Expected oauth_id 'google123', got '%s'", user.OAuthID)
}
if user.CreatedAt.IsZero() {
t.Error("CreatedAt should not be zero")
}
if user.UpdatedAt.IsZero() {
t.Error("UpdatedAt should not be zero")
}
}
func TestGetUserByID(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create a user first
createdUser, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Retrieve the user
retrievedUser, err := storage.GetUserByID(ctx, createdUser.ID)
if err != nil {
t.Fatalf("Failed to get user: %v", err)
}
if retrievedUser.ID != createdUser.ID {
t.Errorf("Expected ID '%s', got '%s'", createdUser.ID, retrievedUser.ID)
}
if retrievedUser.Email != createdUser.Email {
t.Errorf("Expected email '%s', got '%s'", createdUser.Email, retrievedUser.Email)
}
}
func TestGetUserByID_NotFound(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
_, err := storage.GetUserByID(ctx, "nonexistent-id")
if err == nil {
t.Error("Expected error when getting nonexistent user")
}
}
func TestGetUserByOAuth(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create a user first
createdUser, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Retrieve the user by OAuth
retrievedUser, err := storage.GetUserByOAuth(ctx, "google", "google123")
if err != nil {
t.Fatalf("Failed to get user by OAuth: %v", err)
}
if retrievedUser.ID != createdUser.ID {
t.Errorf("Expected ID '%s', got '%s'", createdUser.ID, retrievedUser.ID)
}
if retrievedUser.Email != createdUser.Email {
t.Errorf("Expected email '%s', got '%s'", createdUser.Email, retrievedUser.Email)
}
}
func TestGetUserByOAuth_NotFound(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
_, err := storage.GetUserByOAuth(ctx, "google", "nonexistent")
if err == nil {
t.Error("Expected error when getting nonexistent user")
}
}
func TestUpdateUser(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create a user first
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
originalUpdatedAt := user.UpdatedAt
time.Sleep(10 * time.Millisecond) // Ensure time difference
// Update the user
user.Email = "updated@example.com"
err = storage.UpdateUser(ctx, user)
if err != nil {
t.Fatalf("Failed to update user: %v", err)
}
// Retrieve and verify
updatedUser, err := storage.GetUserByID(ctx, user.ID)
if err != nil {
t.Fatalf("Failed to get updated user: %v", err)
}
if updatedUser.Email != "updated@example.com" {
t.Errorf("Expected email 'updated@example.com', got '%s'", updatedUser.Email)
}
if !updatedUser.UpdatedAt.After(originalUpdatedAt) {
t.Error("UpdatedAt should be updated")
}
}
func TestDeleteUser(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create a user first
user, err := storage.CreateUser(ctx, "test@example.com", "google", "google123")
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
// Delete the user
err = storage.DeleteUser(ctx, user.ID)
if err != nil {
t.Fatalf("Failed to delete user: %v", err)
}
// Verify deletion
_, err = storage.GetUserByID(ctx, user.ID)
if err == nil {
t.Error("Expected error when getting deleted user")
}
}
func TestCreateUser_MultipleUsers(t *testing.T) {
db := setupTestDB(t)
storage := NewUserStorage(db)
ctx := context.Background()
// Create multiple users
user1, err := storage.CreateUser(ctx, "user1@example.com", "google", "google1")
if err != nil {
t.Fatalf("Failed to create user1: %v", err)
}
user2, err := storage.CreateUser(ctx, "user2@example.com", "github", "github1")
if err != nil {
t.Fatalf("Failed to create user2: %v", err)
}
// Verify both users exist and are different
if user1.ID == user2.ID {
t.Error("User IDs should be unique")
}
retrievedUser1, err := storage.GetUserByID(ctx, user1.ID)
if err != nil {
t.Fatalf("Failed to get user1: %v", err)
}
retrievedUser2, err := storage.GetUserByID(ctx, user2.ID)
if err != nil {
t.Fatalf("Failed to get user2: %v", err)
}
if retrievedUser1.Email != "user1@example.com" {
t.Errorf("Expected user1 email 'user1@example.com', got '%s'", retrievedUser1.Email)
}
if retrievedUser2.Email != "user2@example.com" {
t.Errorf("Expected user2 email 'user2@example.com', got '%s'", retrievedUser2.Email)
}
}