144 lines
3.9 KiB
Go
144 lines
3.9 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
|
|
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
|
|
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
|
|
"github.com/google/uuid"
|
|
|
|
"custom-start-page/internal/models"
|
|
)
|
|
|
|
const usersTableName = "Users"
|
|
|
|
// UserStorage handles user data operations
|
|
type UserStorage struct {
|
|
db *DynamoDBClient
|
|
}
|
|
|
|
// NewUserStorage creates a new UserStorage instance
|
|
func NewUserStorage(db *DynamoDBClient) *UserStorage {
|
|
return &UserStorage{db: db}
|
|
}
|
|
|
|
// CreateUser creates a new user in DynamoDB
|
|
func (s *UserStorage) CreateUser(ctx context.Context, email, oauthProvider, oauthID string) (*models.User, error) {
|
|
user := &models.User{
|
|
ID: uuid.New().String(),
|
|
Email: email,
|
|
OAuthProvider: oauthProvider,
|
|
OAuthID: oauthID,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
item, err := attributevalue.MarshalMap(user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal user: %w", err)
|
|
}
|
|
|
|
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
|
|
TableName: aws.String(usersTableName),
|
|
Item: item,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// GetUserByID retrieves a user by their ID
|
|
func (s *UserStorage) GetUserByID(ctx context.Context, userID string) (*models.User, error) {
|
|
result, err := s.db.client.GetItem(ctx, &dynamodb.GetItemInput{
|
|
TableName: aws.String(usersTableName),
|
|
Key: map[string]types.AttributeValue{
|
|
"user_id": &types.AttributeValueMemberS{Value: userID},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
|
|
if result.Item == nil {
|
|
return nil, fmt.Errorf("user not found")
|
|
}
|
|
|
|
var user models.User
|
|
err = attributevalue.UnmarshalMap(result.Item, &user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// GetUserByOAuth retrieves a user by their OAuth provider and ID
|
|
func (s *UserStorage) GetUserByOAuth(ctx context.Context, oauthProvider, oauthID string) (*models.User, error) {
|
|
// Since we don't have a GSI for oauth_provider + oauth_id, we need to scan
|
|
// In production, you might want to add a GSI for this access pattern
|
|
result, err := s.db.client.Scan(ctx, &dynamodb.ScanInput{
|
|
TableName: aws.String(usersTableName),
|
|
FilterExpression: aws.String("oauth_provider = :provider AND oauth_id = :id"),
|
|
ExpressionAttributeValues: map[string]types.AttributeValue{
|
|
":provider": &types.AttributeValueMemberS{Value: oauthProvider},
|
|
":id": &types.AttributeValueMemberS{Value: oauthID},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan for user: %w", err)
|
|
}
|
|
|
|
if len(result.Items) == 0 {
|
|
return nil, fmt.Errorf("user not found")
|
|
}
|
|
|
|
var user models.User
|
|
err = attributevalue.UnmarshalMap(result.Items[0], &user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// UpdateUser updates an existing user
|
|
func (s *UserStorage) UpdateUser(ctx context.Context, user *models.User) error {
|
|
user.UpdatedAt = time.Now()
|
|
|
|
item, err := attributevalue.MarshalMap(user)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal user: %w", err)
|
|
}
|
|
|
|
_, err = s.db.client.PutItem(ctx, &dynamodb.PutItemInput{
|
|
TableName: aws.String(usersTableName),
|
|
Item: item,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteUser deletes a user by their ID
|
|
func (s *UserStorage) DeleteUser(ctx context.Context, userID string) error {
|
|
_, err := s.db.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
|
|
TableName: aws.String(usersTableName),
|
|
Key: map[string]types.AttributeValue{
|
|
"user_id": &types.AttributeValueMemberS{Value: userID},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete user: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|