116 lines
3.2 KiB
Go
116 lines
3.2 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
|
|
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
|
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
|
|
|
"custom-start-page/internal/models"
|
|
)
|
|
|
|
// UserRepository handles user data operations
|
|
type UserRepository struct {
|
|
client *DynamoDBClient
|
|
tableName string
|
|
}
|
|
|
|
// NewUserRepository creates a new user repository
|
|
func NewUserRepository(client *DynamoDBClient, tableName string) *UserRepository {
|
|
return &UserRepository{
|
|
client: client,
|
|
tableName: tableName,
|
|
}
|
|
}
|
|
|
|
// Create creates a new user in the database
|
|
func (r *UserRepository) Create(ctx context.Context, user *models.User) error {
|
|
item, err := attributevalue.MarshalMap(user)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal user: %w", err)
|
|
}
|
|
|
|
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
|
|
TableName: aws.String(r.tableName),
|
|
Item: item,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetByID retrieves a user by their ID
|
|
func (r *UserRepository) GetByID(ctx context.Context, userID string) (*models.User, error) {
|
|
result, err := r.client.GetClient().GetItem(ctx, &dynamodb.GetItemInput{
|
|
TableName: aws.String(r.tableName),
|
|
Key: map[string]types.AttributeValue{
|
|
"user_id": &types.AttributeValueMemberS{Value: userID},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
|
|
if result.Item == nil {
|
|
return nil, fmt.Errorf("user not found")
|
|
}
|
|
|
|
var user models.User
|
|
if err := attributevalue.UnmarshalMap(result.Item, &user); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// GetByOAuthID retrieves a user by their OAuth provider and OAuth ID
|
|
func (r *UserRepository) GetByOAuthID(ctx context.Context, provider, oauthID string) (*models.User, error) {
|
|
// Use a scan with filter for now
|
|
// In production, consider adding a GSI on oauth_provider + oauth_id
|
|
result, err := r.client.GetClient().Scan(ctx, &dynamodb.ScanInput{
|
|
TableName: aws.String(r.tableName),
|
|
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :oauth_id"),
|
|
ExpressionAttributeValues: map[string]types.AttributeValue{
|
|
":provider": &types.AttributeValueMemberS{Value: provider},
|
|
":oauth_id": &types.AttributeValueMemberS{Value: oauthID},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan users: %w", err)
|
|
}
|
|
|
|
if len(result.Items) == 0 {
|
|
return nil, fmt.Errorf("user not found")
|
|
}
|
|
|
|
var user models.User
|
|
if err := attributevalue.UnmarshalMap(result.Items[0], &user); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// Update updates an existing user
|
|
func (r *UserRepository) Update(ctx context.Context, user *models.User) error {
|
|
item, err := attributevalue.MarshalMap(user)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal user: %w", err)
|
|
}
|
|
|
|
_, err = r.client.GetClient().PutItem(ctx, &dynamodb.PutItemInput{
|
|
TableName: aws.String(r.tableName),
|
|
Item: item,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|