// user_service.go package services import ( "errors" "fmt" "log/slog" "github.com/theorift/white-zone-bot/models" "github.com/theorift/white-zone-bot/repositories" ) var ( ErrInvalidReferrer = errors.New("invalid referrer ID") ErrSelfReferral = errors.New("self-referral not allowed") ErrAlreadyReferred = errors.New("existing referral cannot be changed") ErrLateReferralAttempt = errors.New("cannot add referrer after initial registration") ErrCircularReference = errors.New("circular referral chain detected") ErrDuplicateUser = errors.New("user already exists") ErrReferralChainTooLong = errors.New("referral chain too deep") ) type UserService struct { ID int64 repo repositories.UserRepository } func NewUserService(userID int64, repo repositories.UserRepository) *UserService { return &UserService{ ID: userID, repo: repo, } } // ProcessStart handles user initialization with referral validation func (s *UserService) ProcessStart(referrerID int64) error { tx, err := s.repo.BeginTx() if err != nil { return fmt.Errorf("transaction start failed: %w", err) } var committed bool defer func() { if !committed { s.repo.Rollback(tx) } }() txRepo := s.repo.WithTx(tx) txService := NewUserService(s.ID, txRepo) // Check existing user within transaction existingUser, err := txRepo.GetUser(s.ID) if err != nil && !errors.Is(err, repositories.ErrRecordNotFound) { return fmt.Errorf("failed to check user existence: %w", err) } if existingUser.ID != 0 { return txService.handleExistingUser(referrerID, existingUser) } if err := txService.handleNewUser(referrerID); err != nil { return err } if err := txRepo.Commit(tx); err != nil { return fmt.Errorf("transaction commit failed: %w", err) } committed = true // Log only after successful commit slog.Info("New user registered", "userID", s.ID, "referrerID", referrerID) if referrerID != 0 { slog.Info("Referral points awarded", "userID", s.ID, "referrerID", referrerID) } return nil } func (s *UserService) handleExistingUser(referrerID int64, user models.User) error { if referrerID == 0 { return nil // No referral attempt } switch { case user.ReferrerID == referrerID: slog.Info("Duplicate referral attempt ignored", "userID", s.ID, "referrerID", referrerID) return nil case user.ReferrerID != 0: return ErrAlreadyReferred default: return ErrLateReferralAttempt } } func (s *UserService) handleNewUser(referrerID int64) error { if referrerID != 0 { if err := s.validateNewReferral(referrerID); err != nil { return err } } err := s.repo.CreateUser(s.ID, referrerID) if err != nil { switch { case errors.Is(err, repositories.ErrForeignKeyViolation): return ErrInvalidReferrer case errors.Is(err, repositories.ErrDuplicateEntry): return ErrDuplicateUser default: return fmt.Errorf("failed to create user: %w", err) } } if referrerID != 0 { if err := s.repo.AddPointsToUser(referrerID, 1); err != nil { return fmt.Errorf("failed to award referral points: %w", err) } } return nil } func (s *UserService) validateNewReferral(referrerID int64) error { const maxChainDepth = 20 depth := 0 if referrerID == s.ID { return ErrSelfReferral } visited := map[int64]struct{}{} currentID := referrerID for depth < maxChainDepth { if currentID == 0 { break // Valid chain with no loop } if currentID == s.ID { return ErrCircularReference } if _, exists := visited[currentID]; exists { return ErrCircularReference } visited[currentID] = struct{}{} referrer, err := s.repo.GetUser(currentID) if err != nil { if errors.Is(err, repositories.ErrRecordNotFound) { return ErrInvalidReferrer } return fmt.Errorf("failed to verify referrer %d: %w", currentID, err) } if referrer.ID == 0 { return ErrInvalidReferrer } currentID = referrer.ReferrerID depth++ } if depth >= maxChainDepth { return ErrReferralChainTooLong } return nil } // GetPoints retrieves the user's current points balance func (s *UserService) GetPoints() (int64, error) { user, err := s.repo.GetUser(s.ID) if err != nil { if errors.Is(err, repositories.ErrRecordNotFound) { return 0, nil } return 0, fmt.Errorf("failed to get user points: %w", err) } return user.Points, nil } func (s *UserService) SpendPoints(amount int64) error { newPoints, err := s.repo.ReducePointsFromUser(s.ID, amount) if err != nil { return fmt.Errorf("failed to spend points: %w", err) } slog.Info("Points updated", "userID", s.ID, "newPoints", newPoints) return nil }