#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 线程安全的全局状态管理(P0 / O-01) 约束: - 业务代码禁止直接读写底层 dict;必须通过本模块 safe_* API 访问 - 读:要么持锁并返回副本,要么以“快照”的方式返回可迭代列表 """ from __future__ import annotations import threading import time from typing import Any, Dict, List, Optional, Tuple from app_config import get_config config = get_config() # ==================== Active tasks(运行中的任务句柄) ==================== _active_tasks: Dict[str, Any] = {} _active_tasks_lock = threading.RLock() def safe_set_task(account_id: str, handle: Any) -> None: with _active_tasks_lock: _active_tasks[account_id] = handle def safe_get_task(account_id: str) -> Any: with _active_tasks_lock: return _active_tasks.get(account_id) def safe_remove_task(account_id: str) -> Any: with _active_tasks_lock: return _active_tasks.pop(account_id, None) def safe_get_active_task_ids() -> List[str]: with _active_tasks_lock: return list(_active_tasks.keys()) # ==================== Task status(前端展示状态) ==================== _task_status: Dict[str, Dict[str, Any]] = {} _task_status_lock = threading.RLock() def safe_set_task_status(account_id: str, status_dict: Dict[str, Any]) -> None: with _task_status_lock: _task_status[account_id] = dict(status_dict or {}) def safe_update_task_status(account_id: str, updates: Dict[str, Any]) -> bool: with _task_status_lock: if account_id not in _task_status: return False _task_status[account_id].update(updates or {}) return True def safe_get_task_status(account_id: str) -> Dict[str, Any]: with _task_status_lock: value = _task_status.get(account_id) return dict(value) if value else {} def safe_remove_task_status(account_id: str) -> Dict[str, Any]: with _task_status_lock: return _task_status.pop(account_id, None) def safe_get_all_task_status() -> Dict[str, Dict[str, Any]]: with _task_status_lock: return {k: dict(v) for k, v in _task_status.items()} def safe_iter_task_status_items() -> List[Tuple[str, Dict[str, Any]]]: with _task_status_lock: return [(k, dict(v)) for k, v in _task_status.items()] # ==================== User accounts cache(账号对象缓存) ==================== _user_accounts: Dict[int, Dict[str, Any]] = {} _user_accounts_last_access: Dict[int, float] = {} _user_accounts_lock = threading.RLock() def safe_touch_user_accounts(user_id: int) -> None: now_ts = time.time() with _user_accounts_lock: _user_accounts_last_access[int(user_id)] = now_ts def safe_get_user_accounts_last_access_items() -> List[Tuple[int, float]]: with _user_accounts_lock: return list(_user_accounts_last_access.items()) def safe_get_user_accounts_snapshot(user_id: int) -> Dict[str, Any]: with _user_accounts_lock: return dict(_user_accounts.get(int(user_id), {})) def safe_set_user_accounts(user_id: int, accounts_by_id: Dict[str, Any]) -> None: with _user_accounts_lock: _user_accounts[int(user_id)] = dict(accounts_by_id or {}) _user_accounts_last_access[int(user_id)] = time.time() def safe_get_account(user_id: int, account_id: str) -> Any: with _user_accounts_lock: return _user_accounts.get(int(user_id), {}).get(account_id) def safe_set_account(user_id: int, account_id: str, account_obj: Any) -> None: with _user_accounts_lock: uid = int(user_id) if uid not in _user_accounts: _user_accounts[uid] = {} _user_accounts[uid][account_id] = account_obj _user_accounts_last_access[uid] = time.time() def safe_remove_account(user_id: int, account_id: str) -> Any: with _user_accounts_lock: uid = int(user_id) if uid not in _user_accounts: return None return _user_accounts[uid].pop(account_id, None) def safe_remove_user_accounts(user_id: int) -> None: with _user_accounts_lock: uid = int(user_id) _user_accounts.pop(uid, None) _user_accounts_last_access.pop(uid, None) def safe_iter_user_accounts_items() -> List[Tuple[int, Dict[str, Any]]]: with _user_accounts_lock: return [(uid, dict(accounts)) for uid, accounts in _user_accounts.items()] def safe_has_user(user_id: int) -> bool: with _user_accounts_lock: return int(user_id) in _user_accounts # ==================== Log cache(用户维度日志缓存) ==================== _log_cache: Dict[int, List[Dict[str, Any]]] = {} _log_cache_lock = threading.RLock() _log_cache_total_count = 0 def safe_add_log( user_id: int, log_entry: Dict[str, Any], *, max_logs_per_user: Optional[int] = None, max_total_logs: Optional[int] = None, ) -> None: global _log_cache_total_count uid = int(user_id) max_logs_per_user = int(max_logs_per_user or config.MAX_LOGS_PER_USER) max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS) with _log_cache_lock: if uid not in _log_cache: _log_cache[uid] = [] if len(_log_cache[uid]) >= max_logs_per_user: _log_cache[uid].pop(0) _log_cache_total_count = max(0, _log_cache_total_count - 1) _log_cache[uid].append(dict(log_entry or {})) _log_cache_total_count += 1 while _log_cache_total_count > max_total_logs: if not _log_cache: break max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u])) if _log_cache.get(max_user): _log_cache[max_user].pop(0) _log_cache_total_count -= 1 else: break def safe_get_user_logs(user_id: int) -> List[Dict[str, Any]]: uid = int(user_id) with _log_cache_lock: return list(_log_cache.get(uid, [])) def safe_clear_user_logs(user_id: int) -> None: global _log_cache_total_count uid = int(user_id) with _log_cache_lock: removed = len(_log_cache.get(uid, [])) _log_cache.pop(uid, None) _log_cache_total_count = max(0, _log_cache_total_count - removed) def safe_get_log_cache_total_count() -> int: with _log_cache_lock: return int(_log_cache_total_count) # ==================== Captcha storage(验证码存储) ==================== _captcha_storage: Dict[str, Dict[str, Any]] = {} _captcha_storage_lock = threading.RLock() def safe_set_captcha(session_id: str, captcha_data: Dict[str, Any]) -> None: with _captcha_storage_lock: _captcha_storage[str(session_id)] = dict(captcha_data or {}) def safe_cleanup_expired_captcha(now_ts: Optional[float] = None) -> int: now_ts = float(now_ts if now_ts is not None else time.time()) with _captcha_storage_lock: expired = [k for k, v in _captcha_storage.items() if float(v.get("expire_time", 0) or 0) < now_ts] for k in expired: _captcha_storage.pop(k, None) return len(expired) def safe_delete_captcha(session_id: str) -> None: with _captcha_storage_lock: _captcha_storage.pop(str(session_id), None) def safe_verify_and_consume_captcha(session_id: str, code: str, *, max_attempts: Optional[int] = None) -> Tuple[bool, str]: max_attempts = int(max_attempts or config.MAX_CAPTCHA_ATTEMPTS) with _captcha_storage_lock: captcha_data = _captcha_storage.pop(str(session_id), None) if captcha_data is None: return False, "验证码已过期或不存在,请重新获取" try: if float(captcha_data.get("expire_time", 0) or 0) < time.time(): return False, "验证码已过期,请重新获取" failed_attempts = int(captcha_data.get("failed_attempts", 0) or 0) if failed_attempts >= max_attempts: return False, f"验证码错误次数过多({max_attempts}次),请重新获取" expected = str(captcha_data.get("code", "") or "").lower() actual = str(code or "").lower() if expected != actual: failed_attempts += 1 captcha_data["failed_attempts"] = failed_attempts if failed_attempts < max_attempts: _captcha_storage[str(session_id)] = captcha_data return False, "验证码错误" return True, "验证成功" except Exception: return False, "验证码验证失败,请重新获取" # ==================== IP rate limit(验证码失败限流) ==================== _ip_rate_limit: Dict[str, Dict[str, Any]] = {} _ip_rate_limit_lock = threading.RLock() def check_ip_rate_limit( ip_address: str, *, max_attempts_per_hour: Optional[int] = None, lock_duration_seconds: Optional[int] = None, ) -> Tuple[bool, Optional[str]]: current_time = time.time() max_attempts_per_hour = int(max_attempts_per_hour or config.MAX_IP_ATTEMPTS_PER_HOUR) lock_duration_seconds = int(lock_duration_seconds or config.IP_LOCK_DURATION) with _ip_rate_limit_lock: expired_ips = [] for ip, data in _ip_rate_limit.items(): lock_expired = float(data.get("lock_until", 0) or 0) < current_time first_attempt = data.get("first_attempt") attempt_expired = first_attempt is None or (current_time - float(first_attempt)) > 3600 if lock_expired and attempt_expired: expired_ips.append(ip) for ip in expired_ips: _ip_rate_limit.pop(ip, None) ip_key = str(ip_address) if ip_key in _ip_rate_limit: ip_data = _ip_rate_limit[ip_key] if float(ip_data.get("lock_until", 0) or 0) > current_time: remaining_time = int(float(ip_data["lock_until"]) - current_time) return False, f"IP已被锁定,请{remaining_time // 60 + 1}分钟后再试" first_attempt = ip_data.get("first_attempt") if first_attempt is None or current_time - float(first_attempt) > 3600: _ip_rate_limit[ip_key] = {"attempts": 0, "first_attempt": current_time} return True, None def record_failed_captcha( ip_address: str, *, max_attempts_per_hour: Optional[int] = None, lock_duration_seconds: Optional[int] = None, ) -> bool: current_time = time.time() max_attempts_per_hour = int(max_attempts_per_hour or config.MAX_IP_ATTEMPTS_PER_HOUR) lock_duration_seconds = int(lock_duration_seconds or config.IP_LOCK_DURATION) with _ip_rate_limit_lock: ip_key = str(ip_address) if ip_key not in _ip_rate_limit: _ip_rate_limit[ip_key] = {"attempts": 1, "first_attempt": current_time} else: _ip_rate_limit[ip_key]["attempts"] = int(_ip_rate_limit[ip_key].get("attempts", 0) or 0) + 1 if int(_ip_rate_limit[ip_key].get("attempts", 0) or 0) >= max_attempts_per_hour: _ip_rate_limit[ip_key]["lock_until"] = current_time + lock_duration_seconds return True return False def cleanup_expired_ip_rate_limits(now_ts: Optional[float] = None) -> int: now_ts = float(now_ts if now_ts is not None else time.time()) with _ip_rate_limit_lock: expired_ips = [] for ip, data in _ip_rate_limit.items(): lock_until = float(data.get("lock_until", 0) or 0) first_attempt = float(data.get("first_attempt", 0) or 0) if lock_until < now_ts and (now_ts - first_attempt) > 3600: expired_ips.append(ip) for ip in expired_ips: _ip_rate_limit.pop(ip, None) return len(expired_ips) def safe_get_ip_lock_until(ip_address: str) -> float: """获取指定 IP 的锁定截至时间戳(未锁定返回 0)。""" ip_key = str(ip_address) with _ip_rate_limit_lock: data = _ip_rate_limit.get(ip_key) or {} try: return float(data.get("lock_until", 0) or 0) except Exception: return 0.0 # ==================== IP request rate limit(接口频率限制) ==================== _ip_request_rate: Dict[str, Dict[str, Any]] = {} _ip_request_rate_lock = threading.RLock() def _get_action_rate_limit(action: str) -> Tuple[int, int]: action = str(action or "").lower() if action == "register": return int(config.IP_RATE_LIMIT_REGISTER_MAX), int(config.IP_RATE_LIMIT_REGISTER_WINDOW_SECONDS) if action == "email": return int(config.IP_RATE_LIMIT_EMAIL_MAX), int(config.IP_RATE_LIMIT_EMAIL_WINDOW_SECONDS) return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS) def check_ip_request_rate( ip_address: str, action: str, *, max_requests: Optional[int] = None, window_seconds: Optional[int] = None, ) -> Tuple[bool, Optional[str]]: now_ts = time.time() default_max, default_window = _get_action_rate_limit(action) max_requests = int(max_requests or default_max) window_seconds = int(window_seconds or default_window) key = f"{action}:{ip_address}" with _ip_request_rate_lock: data = _ip_request_rate.get(key) if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds: data = {"window_start": now_ts, "count": 0} _ip_request_rate[key] = data if int(data.get("count", 0) or 0) >= max_requests: remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0)))) if remaining >= 60: wait_hint = f"{remaining // 60 + 1}分钟" else: wait_hint = f"{remaining}秒" return False, f"请求过于频繁,请{wait_hint}后再试" data["count"] = int(data.get("count", 0) or 0) + 1 return True, None def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int: now_ts = float(now_ts if now_ts is not None else time.time()) removed = 0 with _ip_request_rate_lock: for key in list(_ip_request_rate.keys()): data = _ip_request_rate.get(key) or {} action = key.split(":", 1)[0] _, window_seconds = _get_action_rate_limit(action) window_start = float(data.get("window_start", 0) or 0) if now_ts - window_start >= window_seconds: _ip_request_rate.pop(key, None) removed += 1 return removed # ==================== 登录失败追踪(触发验证码) ==================== _login_failures: Dict[str, Dict[str, Any]] = {} _login_failures_lock = threading.RLock() def _get_login_captcha_config() -> Tuple[int, int]: return int(config.LOGIN_CAPTCHA_AFTER_FAILURES), int(config.LOGIN_CAPTCHA_WINDOW_SECONDS) def record_login_failure(ip_address: str) -> None: now_ts = time.time() max_failures, window_seconds = _get_login_captcha_config() ip_key = str(ip_address or "") with _login_failures_lock: data = _login_failures.get(ip_key) if not data or (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds: data = {"first_failed": now_ts, "count": 0} _login_failures[ip_key] = data data["count"] = int(data.get("count", 0) or 0) + 1 if int(data["count"]) > max_failures * 5: data["count"] = max_failures * 5 def clear_login_failures(ip_address: str) -> None: ip_key = str(ip_address or "") with _login_failures_lock: _login_failures.pop(ip_key, None) def check_login_captcha_required(ip_address: str) -> bool: now_ts = time.time() max_failures, window_seconds = _get_login_captcha_config() ip_key = str(ip_address or "") with _login_failures_lock: data = _login_failures.get(ip_key) if not data: return False if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds: _login_failures.pop(ip_key, None) return False return int(data.get("count", 0) or 0) >= max_failures # ==================== Batch screenshots(批次任务截图收集) ==================== _batch_task_screenshots: Dict[str, Dict[str, Any]] = {} _batch_task_lock = threading.RLock() def safe_create_batch(batch_id: str, batch_info: Dict[str, Any]) -> None: with _batch_task_lock: _batch_task_screenshots[str(batch_id)] = dict(batch_info or {}) def safe_get_batch(batch_id: str) -> Optional[Dict[str, Any]]: with _batch_task_lock: info = _batch_task_screenshots.get(str(batch_id)) return dict(info) if info else None def safe_update_batch(batch_id: str, updates: Dict[str, Any]) -> bool: with _batch_task_lock: if str(batch_id) not in _batch_task_screenshots: return False _batch_task_screenshots[str(batch_id)].update(dict(updates or {})) return True def safe_pop_batch(batch_id: str) -> Optional[Dict[str, Any]]: with _batch_task_lock: return _batch_task_screenshots.pop(str(batch_id), None) def safe_batch_append_result(batch_id: str, result: Dict[str, Any]) -> Optional[Dict[str, Any]]: now_ts = time.time() with _batch_task_lock: info = _batch_task_screenshots.get(str(batch_id)) if not info: return None info.setdefault("screenshots", []).append(dict(result or {})) info["completed"] = int(info.get("completed", 0) or 0) + 1 info["updated_at"] = now_ts total = int(info.get("total_accounts", 0) or 0) if total > 0 and int(info.get("completed", 0) or 0) >= total: return _batch_task_screenshots.pop(str(batch_id), None) return None def safe_cleanup_expired_batches(expire_seconds: int, now_ts: Optional[float] = None) -> int: now_ts = float(now_ts if now_ts is not None else time.time()) expire_seconds = max(1, int(expire_seconds)) with _batch_task_lock: expired = [] for batch_id, info in list(_batch_task_screenshots.items()): last_ts = info.get("updated_at") or info.get("created_at") or info.get("created_time") or now_ts if (now_ts - float(last_ts)) > expire_seconds: expired.append(batch_id) for batch_id in expired: _batch_task_screenshots.pop(batch_id, None) return len(expired) def safe_finalize_batch_after_dispatch(batch_id: str, total_accounts: int, *, now_ts: Optional[float] = None) -> Optional[Dict[str, Any]]: """定时批次任务:更新总账号数,并在“已完成>=总数”时弹出批次数据用于发邮件。""" now_ts = float(now_ts if now_ts is not None else time.time()) with _batch_task_lock: info = _batch_task_screenshots.get(str(batch_id)) if not info: return None info["total_accounts"] = int(total_accounts or 0) info["updated_at"] = now_ts if int(total_accounts or 0) <= 0: _batch_task_screenshots.pop(str(batch_id), None) return None if int(info.get("completed", 0) or 0) >= int(total_accounts): return _batch_task_screenshots.pop(str(batch_id), None) return None # ==================== Pending random schedules(兼容旧随机延迟逻辑) ==================== _pending_random_schedules: Dict[int, Dict[str, Any]] = {} _pending_random_lock = threading.RLock() def safe_set_pending_random_schedule(schedule_id: int, info: Dict[str, Any]) -> None: with _pending_random_lock: _pending_random_schedules[int(schedule_id)] = dict(info or {}) def safe_get_pending_random_schedule(schedule_id: int) -> Optional[Dict[str, Any]]: with _pending_random_lock: value = _pending_random_schedules.get(int(schedule_id)) return dict(value) if value else None def safe_pop_pending_random_schedule(schedule_id: int) -> Optional[Dict[str, Any]]: with _pending_random_lock: return _pending_random_schedules.pop(int(schedule_id), None) def safe_iter_pending_random_schedules_items() -> List[Tuple[int, Dict[str, Any]]]: with _pending_random_lock: return [(sid, dict(info)) for sid, info in _pending_random_schedules.items()] def safe_cleanup_expired_pending_random(expire_seconds: int, now_ts: Optional[float] = None) -> int: now_ts = float(now_ts if now_ts is not None else time.time()) expire_seconds = max(1, int(expire_seconds)) with _pending_random_lock: expired = [] for schedule_id, info in list(_pending_random_schedules.items()): created_at = info.get("created_at") or info.get("created_time") or now_ts if (now_ts - float(created_at)) > expire_seconds: expired.append(schedule_id) for schedule_id in expired: _pending_random_schedules.pop(int(schedule_id), None) return len(expired)