249 lines
7.4 KiB
Go
249 lines
7.4 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/aws/retry"
|
|
"github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
|
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
|
)
|
|
|
|
// DynamoDBClient wraps the AWS DynamoDB client with enhanced features
|
|
type DynamoDBClient struct {
|
|
client *dynamodb.Client
|
|
}
|
|
|
|
// NewDynamoDBClient creates a new DynamoDB client with connection pooling and retry logic
|
|
func NewDynamoDBClient(ctx context.Context, endpoint string) (*DynamoDBClient, error) {
|
|
// Configure retry strategy with exponential backoff
|
|
retryer := retry.NewStandard(func(o *retry.StandardOptions) {
|
|
o.MaxAttempts = 5
|
|
o.MaxBackoff = 20 * time.Second
|
|
// Use exponential backoff with jitter
|
|
o.Backoff = retry.NewExponentialJitterBackoff(20 * time.Second)
|
|
})
|
|
|
|
cfg, err := config.LoadDefaultConfig(ctx,
|
|
config.WithRetryer(func() aws.Retryer {
|
|
return retryer
|
|
}),
|
|
// Connection pooling is handled by the HTTP client
|
|
config.WithHTTPClient(nil), // Uses default HTTP client with connection pooling
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
|
}
|
|
|
|
// Override endpoint if provided (for local DynamoDB)
|
|
if endpoint != "" {
|
|
cfg.BaseEndpoint = aws.String(endpoint)
|
|
}
|
|
|
|
client := dynamodb.NewFromConfig(cfg)
|
|
|
|
return &DynamoDBClient{
|
|
client: client,
|
|
}, nil
|
|
}
|
|
|
|
// CreateUsersTable creates the Users table in DynamoDB
|
|
func (db *DynamoDBClient) CreateUsersTable(ctx context.Context) error {
|
|
tableName := "Users"
|
|
|
|
// Check if table already exists
|
|
_, err := db.client.DescribeTable(ctx, &dynamodb.DescribeTableInput{
|
|
TableName: aws.String(tableName),
|
|
})
|
|
if err == nil {
|
|
// Table already exists
|
|
return nil
|
|
}
|
|
|
|
// Create table
|
|
_, err = db.client.CreateTable(ctx, &dynamodb.CreateTableInput{
|
|
TableName: aws.String(tableName),
|
|
AttributeDefinitions: []types.AttributeDefinition{
|
|
{
|
|
AttributeName: aws.String("user_id"),
|
|
AttributeType: types.ScalarAttributeTypeS,
|
|
},
|
|
},
|
|
KeySchema: []types.KeySchemaElement{
|
|
{
|
|
AttributeName: aws.String("user_id"),
|
|
KeyType: types.KeyTypeHash,
|
|
},
|
|
},
|
|
BillingMode: types.BillingModePayPerRequest,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create Users table: %w", err)
|
|
}
|
|
|
|
// Wait for table to be active
|
|
waiter := dynamodb.NewTableExistsWaiter(db.client)
|
|
err = waiter.Wait(ctx, &dynamodb.DescribeTableInput{
|
|
TableName: aws.String(tableName),
|
|
}, 5*60) // 5 minutes timeout
|
|
if err != nil {
|
|
return fmt.Errorf("failed waiting for Users table to be active: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetClient returns the underlying DynamoDB client
|
|
func (db *DynamoDBClient) GetClient() *dynamodb.Client {
|
|
return db.client
|
|
}
|
|
|
|
// TransactWriteItems executes a transactional write operation with automatic retry
|
|
func (db *DynamoDBClient) TransactWriteItems(ctx context.Context, input *dynamodb.TransactWriteItemsInput) error {
|
|
_, err := db.client.TransactWriteItems(ctx, input)
|
|
if err != nil {
|
|
return fmt.Errorf("transaction write failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// BatchGetItems retrieves multiple items in a single batch operation
|
|
func (db *DynamoDBClient) BatchGetItems(ctx context.Context, input *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error) {
|
|
output, err := db.client.BatchGetItem(ctx, input)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("batch get failed: %w", err)
|
|
}
|
|
|
|
// Handle unprocessed keys with exponential backoff
|
|
if len(output.UnprocessedKeys) > 0 {
|
|
return db.retryUnprocessedKeys(ctx, output)
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
// BatchWriteItems writes multiple items in a single batch operation
|
|
func (db *DynamoDBClient) BatchWriteItems(ctx context.Context, input *dynamodb.BatchWriteItemInput) error {
|
|
output, err := db.client.BatchWriteItem(ctx, input)
|
|
if err != nil {
|
|
return fmt.Errorf("batch write failed: %w", err)
|
|
}
|
|
|
|
// Handle unprocessed items with exponential backoff
|
|
if len(output.UnprocessedItems) > 0 {
|
|
return db.retryUnprocessedWrites(ctx, output.UnprocessedItems)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// retryUnprocessedKeys retries unprocessed keys from BatchGetItem with exponential backoff
|
|
func (db *DynamoDBClient) retryUnprocessedKeys(ctx context.Context, output *dynamodb.BatchGetItemOutput) (*dynamodb.BatchGetItemOutput, error) {
|
|
maxRetries := 5
|
|
backoff := 100 * time.Millisecond
|
|
|
|
for attempt := 0; attempt < maxRetries && len(output.UnprocessedKeys) > 0; attempt++ {
|
|
// Wait with exponential backoff
|
|
time.Sleep(backoff)
|
|
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
|
|
|
|
// Retry unprocessed keys
|
|
retryOutput, err := db.client.BatchGetItem(ctx, &dynamodb.BatchGetItemInput{
|
|
RequestItems: output.UnprocessedKeys,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("retry batch get failed: %w", err)
|
|
}
|
|
|
|
// Merge responses
|
|
for table, items := range retryOutput.Responses {
|
|
output.Responses[table] = append(output.Responses[table], items...)
|
|
}
|
|
|
|
output.UnprocessedKeys = retryOutput.UnprocessedKeys
|
|
}
|
|
|
|
if len(output.UnprocessedKeys) > 0 {
|
|
return output, fmt.Errorf("failed to process all keys after %d retries", maxRetries)
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
// retryUnprocessedWrites retries unprocessed items from BatchWriteItem with exponential backoff
|
|
func (db *DynamoDBClient) retryUnprocessedWrites(ctx context.Context, unprocessedItems map[string][]types.WriteRequest) error {
|
|
maxRetries := 5
|
|
backoff := 100 * time.Millisecond
|
|
|
|
for attempt := 0; attempt < maxRetries && len(unprocessedItems) > 0; attempt++ {
|
|
// Wait with exponential backoff
|
|
time.Sleep(backoff)
|
|
backoff = time.Duration(math.Min(float64(backoff*2), float64(20*time.Second)))
|
|
|
|
// Retry unprocessed items
|
|
output, err := db.client.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
|
|
RequestItems: unprocessedItems,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("retry batch write failed: %w", err)
|
|
}
|
|
|
|
unprocessedItems = output.UnprocessedItems
|
|
}
|
|
|
|
if len(unprocessedItems) > 0 {
|
|
return fmt.Errorf("failed to process all items after %d retries", maxRetries)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// PutItem puts a single item with automatic retry
|
|
func (db *DynamoDBClient) PutItem(ctx context.Context, input *dynamodb.PutItemInput) error {
|
|
_, err := db.client.PutItem(ctx, input)
|
|
if err != nil {
|
|
return fmt.Errorf("put item failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetItem retrieves a single item with automatic retry
|
|
func (db *DynamoDBClient) GetItem(ctx context.Context, input *dynamodb.GetItemInput) (*dynamodb.GetItemOutput, error) {
|
|
output, err := db.client.GetItem(ctx, input)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get item failed: %w", err)
|
|
}
|
|
return output, nil
|
|
}
|
|
|
|
// Query executes a query operation with automatic retry
|
|
func (db *DynamoDBClient) Query(ctx context.Context, input *dynamodb.QueryInput) (*dynamodb.QueryOutput, error) {
|
|
output, err := db.client.Query(ctx, input)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query failed: %w", err)
|
|
}
|
|
return output, nil
|
|
}
|
|
|
|
// UpdateItem updates a single item with automatic retry
|
|
func (db *DynamoDBClient) UpdateItem(ctx context.Context, input *dynamodb.UpdateItemInput) (*dynamodb.UpdateItemOutput, error) {
|
|
output, err := db.client.UpdateItem(ctx, input)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("update item failed: %w", err)
|
|
}
|
|
return output, nil
|
|
}
|
|
|
|
// DeleteItem deletes a single item with automatic retry
|
|
func (db *DynamoDBClient) DeleteItem(ctx context.Context, input *dynamodb.DeleteItemInput) error {
|
|
_, err := db.client.DeleteItem(ctx, input)
|
|
if err != nil {
|
|
return fmt.Errorf("delete item failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|