|
14 | 14 | from sqlmodel import Session, SQLModel, create_engine, delete, select |
15 | 15 |
|
16 | 16 | from bot.database.models import ( |
| 17 | + NewUserProbation, |
17 | 18 | PendingCaptchaValidation, |
18 | 19 | PhotoVerificationWhitelist, |
19 | 20 | UserWarning, |
@@ -462,6 +463,135 @@ def get_all_pending_captchas(self) -> list[PendingCaptchaValidation]: |
462 | 463 | statement = select(PendingCaptchaValidation) |
463 | 464 | return list(session.exec(statement).all()) |
464 | 465 |
|
| 466 | + def start_new_user_probation(self, user_id: int, group_id: int) -> NewUserProbation: |
| 467 | + """ |
| 468 | + Start or refresh probation for a new user. |
| 469 | +
|
| 470 | + Called when a user joins or passes captcha verification. |
| 471 | + If a record exists, refreshes joined_at to current time. |
| 472 | +
|
| 473 | + Args: |
| 474 | + user_id: Telegram user ID. |
| 475 | + group_id: Telegram group ID. |
| 476 | +
|
| 477 | + Returns: |
| 478 | + NewUserProbation: Created or updated probation record. |
| 479 | + """ |
| 480 | + with Session(self._engine) as session: |
| 481 | + statement = select(NewUserProbation).where( |
| 482 | + NewUserProbation.user_id == user_id, |
| 483 | + NewUserProbation.group_id == group_id, |
| 484 | + ) |
| 485 | + record = session.exec(statement).first() |
| 486 | + |
| 487 | + if record: |
| 488 | + record.joined_at = datetime.now(UTC) |
| 489 | + record.violation_count = 0 |
| 490 | + record.first_violation_at = None |
| 491 | + record.last_violation_at = None |
| 492 | + else: |
| 493 | + record = NewUserProbation( |
| 494 | + user_id=user_id, |
| 495 | + group_id=group_id, |
| 496 | + ) |
| 497 | + session.add(record) |
| 498 | + session.commit() |
| 499 | + session.refresh(record) |
| 500 | + logger.info(f"Started probation for user_id={user_id}, group_id={group_id}") |
| 501 | + return record |
| 502 | + |
| 503 | + def get_new_user_probation( |
| 504 | + self, user_id: int, group_id: int |
| 505 | + ) -> NewUserProbation | None: |
| 506 | + """ |
| 507 | + Get probation record for a user. |
| 508 | +
|
| 509 | + Args: |
| 510 | + user_id: Telegram user ID. |
| 511 | + group_id: Telegram group ID. |
| 512 | +
|
| 513 | + Returns: |
| 514 | + NewUserProbation | None: Probation record or None if not found. |
| 515 | + """ |
| 516 | + with Session(self._engine) as session: |
| 517 | + statement = select(NewUserProbation).where( |
| 518 | + NewUserProbation.user_id == user_id, |
| 519 | + NewUserProbation.group_id == group_id, |
| 520 | + ) |
| 521 | + return session.exec(statement).first() |
| 522 | + |
| 523 | + def increment_new_user_violation( |
| 524 | + self, user_id: int, group_id: int |
| 525 | + ) -> NewUserProbation: |
| 526 | + """ |
| 527 | + Increment violation count for a user on probation atomically. |
| 528 | +
|
| 529 | + Uses atomic SQL update to prevent race conditions when multiple |
| 530 | + violations occur simultaneously. |
| 531 | +
|
| 532 | + Args: |
| 533 | + user_id: Telegram user ID. |
| 534 | + group_id: Telegram group ID. |
| 535 | +
|
| 536 | + Returns: |
| 537 | + NewUserProbation: Updated probation record. |
| 538 | +
|
| 539 | + Raises: |
| 540 | + ValueError: If no probation record exists. |
| 541 | + """ |
| 542 | + from sqlalchemy import update as sql_update |
| 543 | + |
| 544 | + with Session(self._engine) as session: |
| 545 | + # First check if record exists |
| 546 | + select_stmt = select(NewUserProbation).where( |
| 547 | + NewUserProbation.user_id == user_id, |
| 548 | + NewUserProbation.group_id == group_id, |
| 549 | + ) |
| 550 | + record = session.exec(select_stmt).first() |
| 551 | + |
| 552 | + if not record: |
| 553 | + raise ValueError(f"No probation record for user {user_id} in group {group_id}") |
| 554 | + |
| 555 | + now = datetime.now(UTC) |
| 556 | + |
| 557 | + # Atomic update - increment directly in SQL |
| 558 | + update_stmt = ( |
| 559 | + sql_update(NewUserProbation) |
| 560 | + .where(NewUserProbation.id == record.id) |
| 561 | + .values( |
| 562 | + violation_count=NewUserProbation.violation_count + 1, |
| 563 | + first_violation_at=now if record.first_violation_at is None else record.first_violation_at, |
| 564 | + last_violation_at=now, |
| 565 | + ) |
| 566 | + ) |
| 567 | + session.exec(update_stmt) |
| 568 | + session.commit() |
| 569 | + |
| 570 | + # Refresh to get updated values |
| 571 | + session.refresh(record) |
| 572 | + logger.info( |
| 573 | + f"Incremented violation for user_id={user_id}, group_id={group_id}, " |
| 574 | + f"count={record.violation_count}" |
| 575 | + ) |
| 576 | + return record |
| 577 | + |
| 578 | + def clear_new_user_probation(self, user_id: int, group_id: int) -> None: |
| 579 | + """ |
| 580 | + Remove probation record for a user (when probation expires). |
| 581 | +
|
| 582 | + Args: |
| 583 | + user_id: Telegram user ID. |
| 584 | + group_id: Telegram group ID. |
| 585 | + """ |
| 586 | + with Session(self._engine) as session: |
| 587 | + statement = delete(NewUserProbation).where( |
| 588 | + NewUserProbation.user_id == user_id, |
| 589 | + NewUserProbation.group_id == group_id, |
| 590 | + ) |
| 591 | + session.exec(statement) |
| 592 | + session.commit() |
| 593 | + logger.info(f"Cleared probation for user_id={user_id}, group_id={group_id}") |
| 594 | + |
465 | 595 |
|
466 | 596 | # Module-level singleton for database service |
467 | 597 | _db_service: DatabaseService | None = None |
|
0 commit comments