Files
zsglpt/services/state.py

482 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
线程安全的全局状态管理P0 / O-01
约束:
- 业务代码禁止直接读写底层 dict必须通过本模块 safe_* API 访问
- 读:要么持锁并返回副本,要么以“快照”的方式返回可迭代列表
"""
from __future__ import annotations
import threading
import time
from typing import Any, Dict, List, Optional, Tuple
from app_config import get_config
config = get_config()
# ==================== Active tasks运行中的任务句柄 ====================
_active_tasks: Dict[str, Any] = {}
_active_tasks_lock = threading.RLock()
def safe_set_task(account_id: str, handle: Any) -> None:
with _active_tasks_lock:
_active_tasks[account_id] = handle
def safe_get_task(account_id: str) -> Any:
with _active_tasks_lock:
return _active_tasks.get(account_id)
def safe_remove_task(account_id: str) -> Any:
with _active_tasks_lock:
return _active_tasks.pop(account_id, None)
def safe_get_active_task_ids() -> List[str]:
with _active_tasks_lock:
return list(_active_tasks.keys())
# ==================== Task status前端展示状态 ====================
_task_status: Dict[str, Dict[str, Any]] = {}
_task_status_lock = threading.RLock()
def safe_set_task_status(account_id: str, status_dict: Dict[str, Any]) -> None:
with _task_status_lock:
_task_status[account_id] = dict(status_dict or {})
def safe_update_task_status(account_id: str, updates: Dict[str, Any]) -> bool:
with _task_status_lock:
if account_id not in _task_status:
return False
_task_status[account_id].update(updates or {})
return True
def safe_get_task_status(account_id: str) -> Dict[str, Any]:
with _task_status_lock:
value = _task_status.get(account_id)
return dict(value) if value else {}
def safe_remove_task_status(account_id: str) -> Dict[str, Any]:
with _task_status_lock:
return _task_status.pop(account_id, None)
def safe_get_all_task_status() -> Dict[str, Dict[str, Any]]:
with _task_status_lock:
return {k: dict(v) for k, v in _task_status.items()}
def safe_iter_task_status_items() -> List[Tuple[str, Dict[str, Any]]]:
with _task_status_lock:
return [(k, dict(v)) for k, v in _task_status.items()]
# ==================== User accounts cache账号对象缓存 ====================
_user_accounts: Dict[int, Dict[str, Any]] = {}
_user_accounts_last_access: Dict[int, float] = {}
_user_accounts_lock = threading.RLock()
def safe_touch_user_accounts(user_id: int) -> None:
now_ts = time.time()
with _user_accounts_lock:
_user_accounts_last_access[int(user_id)] = now_ts
def safe_get_user_accounts_last_access_items() -> List[Tuple[int, float]]:
with _user_accounts_lock:
return list(_user_accounts_last_access.items())
def safe_get_user_accounts_snapshot(user_id: int) -> Dict[str, Any]:
with _user_accounts_lock:
return dict(_user_accounts.get(int(user_id), {}))
def safe_set_user_accounts(user_id: int, accounts_by_id: Dict[str, Any]) -> None:
with _user_accounts_lock:
_user_accounts[int(user_id)] = dict(accounts_by_id or {})
_user_accounts_last_access[int(user_id)] = time.time()
def safe_get_account(user_id: int, account_id: str) -> Any:
with _user_accounts_lock:
return _user_accounts.get(int(user_id), {}).get(account_id)
def safe_set_account(user_id: int, account_id: str, account_obj: Any) -> None:
with _user_accounts_lock:
uid = int(user_id)
if uid not in _user_accounts:
_user_accounts[uid] = {}
_user_accounts[uid][account_id] = account_obj
_user_accounts_last_access[uid] = time.time()
def safe_remove_account(user_id: int, account_id: str) -> Any:
with _user_accounts_lock:
uid = int(user_id)
if uid not in _user_accounts:
return None
return _user_accounts[uid].pop(account_id, None)
def safe_remove_user_accounts(user_id: int) -> None:
with _user_accounts_lock:
uid = int(user_id)
_user_accounts.pop(uid, None)
_user_accounts_last_access.pop(uid, None)
def safe_iter_user_accounts_items() -> List[Tuple[int, Dict[str, Any]]]:
with _user_accounts_lock:
return [(uid, dict(accounts)) for uid, accounts in _user_accounts.items()]
def safe_has_user(user_id: int) -> bool:
with _user_accounts_lock:
return int(user_id) in _user_accounts
# ==================== Log cache用户维度日志缓存 ====================
_log_cache: Dict[int, List[Dict[str, Any]]] = {}
_log_cache_lock = threading.RLock()
_log_cache_total_count = 0
def safe_add_log(
user_id: int,
log_entry: Dict[str, Any],
*,
max_logs_per_user: Optional[int] = None,
max_total_logs: Optional[int] = None,
) -> None:
global _log_cache_total_count
uid = int(user_id)
max_logs_per_user = int(max_logs_per_user or config.MAX_LOGS_PER_USER)
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
with _log_cache_lock:
if uid not in _log_cache:
_log_cache[uid] = []
if len(_log_cache[uid]) >= max_logs_per_user:
_log_cache[uid].pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
_log_cache[uid].append(dict(log_entry or {}))
_log_cache_total_count += 1
while _log_cache_total_count > max_total_logs:
if not _log_cache:
break
max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
if _log_cache.get(max_user):
_log_cache[max_user].pop(0)
_log_cache_total_count -= 1
else:
break
def safe_get_user_logs(user_id: int) -> List[Dict[str, Any]]:
uid = int(user_id)
with _log_cache_lock:
return list(_log_cache.get(uid, []))
def safe_clear_user_logs(user_id: int) -> None:
global _log_cache_total_count
uid = int(user_id)
with _log_cache_lock:
removed = len(_log_cache.get(uid, []))
_log_cache.pop(uid, None)
_log_cache_total_count = max(0, _log_cache_total_count - removed)
def safe_get_log_cache_total_count() -> int:
with _log_cache_lock:
return int(_log_cache_total_count)
# ==================== Captcha storage验证码存储 ====================
_captcha_storage: Dict[str, Dict[str, Any]] = {}
_captcha_storage_lock = threading.RLock()
def safe_set_captcha(session_id: str, captcha_data: Dict[str, Any]) -> None:
with _captcha_storage_lock:
_captcha_storage[str(session_id)] = dict(captcha_data or {})
def safe_cleanup_expired_captcha(now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
with _captcha_storage_lock:
expired = [k for k, v in _captcha_storage.items() if float(v.get("expire_time", 0) or 0) < now_ts]
for k in expired:
_captcha_storage.pop(k, None)
return len(expired)
def safe_delete_captcha(session_id: str) -> None:
with _captcha_storage_lock:
_captcha_storage.pop(str(session_id), None)
def safe_verify_and_consume_captcha(session_id: str, code: str, *, max_attempts: Optional[int] = None) -> Tuple[bool, str]:
max_attempts = int(max_attempts or config.MAX_CAPTCHA_ATTEMPTS)
with _captcha_storage_lock:
captcha_data = _captcha_storage.pop(str(session_id), None)
if captcha_data is None:
return False, "验证码已过期或不存在,请重新获取"
try:
if float(captcha_data.get("expire_time", 0) or 0) < time.time():
return False, "验证码已过期,请重新获取"
failed_attempts = int(captcha_data.get("failed_attempts", 0) or 0)
if failed_attempts >= max_attempts:
return False, f"验证码错误次数过多({max_attempts}次),请重新获取"
expected = str(captcha_data.get("code", "") or "").lower()
actual = str(code or "").lower()
if expected != actual:
failed_attempts += 1
captcha_data["failed_attempts"] = failed_attempts
if failed_attempts < max_attempts:
_captcha_storage[str(session_id)] = captcha_data
return False, "验证码错误"
return True, "验证成功"
except Exception:
return False, "验证码验证失败,请重新获取"
# ==================== IP rate limit验证码失败限流 ====================
_ip_rate_limit: Dict[str, Dict[str, Any]] = {}
_ip_rate_limit_lock = threading.RLock()
def check_ip_rate_limit(
ip_address: str,
*,
max_attempts_per_hour: Optional[int] = None,
lock_duration_seconds: Optional[int] = None,
) -> Tuple[bool, Optional[str]]:
current_time = time.time()
max_attempts_per_hour = int(max_attempts_per_hour or config.MAX_IP_ATTEMPTS_PER_HOUR)
lock_duration_seconds = int(lock_duration_seconds or config.IP_LOCK_DURATION)
with _ip_rate_limit_lock:
expired_ips = []
for ip, data in _ip_rate_limit.items():
lock_expired = float(data.get("lock_until", 0) or 0) < current_time
first_attempt = data.get("first_attempt")
attempt_expired = first_attempt is None or (current_time - float(first_attempt)) > 3600
if lock_expired and attempt_expired:
expired_ips.append(ip)
for ip in expired_ips:
_ip_rate_limit.pop(ip, None)
ip_key = str(ip_address)
if ip_key in _ip_rate_limit:
ip_data = _ip_rate_limit[ip_key]
if float(ip_data.get("lock_until", 0) or 0) > current_time:
remaining_time = int(float(ip_data["lock_until"]) - current_time)
return False, f"IP已被锁定,请{remaining_time // 60 + 1}分钟后再试"
first_attempt = ip_data.get("first_attempt")
if first_attempt is None or current_time - float(first_attempt) > 3600:
_ip_rate_limit[ip_key] = {"attempts": 0, "first_attempt": current_time}
return True, None
def record_failed_captcha(
ip_address: str,
*,
max_attempts_per_hour: Optional[int] = None,
lock_duration_seconds: Optional[int] = None,
) -> bool:
current_time = time.time()
max_attempts_per_hour = int(max_attempts_per_hour or config.MAX_IP_ATTEMPTS_PER_HOUR)
lock_duration_seconds = int(lock_duration_seconds or config.IP_LOCK_DURATION)
with _ip_rate_limit_lock:
ip_key = str(ip_address)
if ip_key not in _ip_rate_limit:
_ip_rate_limit[ip_key] = {"attempts": 1, "first_attempt": current_time}
else:
_ip_rate_limit[ip_key]["attempts"] = int(_ip_rate_limit[ip_key].get("attempts", 0) or 0) + 1
if int(_ip_rate_limit[ip_key].get("attempts", 0) or 0) >= max_attempts_per_hour:
_ip_rate_limit[ip_key]["lock_until"] = current_time + lock_duration_seconds
return True
return False
def cleanup_expired_ip_rate_limits(now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
with _ip_rate_limit_lock:
expired_ips = []
for ip, data in _ip_rate_limit.items():
lock_until = float(data.get("lock_until", 0) or 0)
first_attempt = float(data.get("first_attempt", 0) or 0)
if lock_until < now_ts and (now_ts - first_attempt) > 3600:
expired_ips.append(ip)
for ip in expired_ips:
_ip_rate_limit.pop(ip, None)
return len(expired_ips)
def safe_get_ip_lock_until(ip_address: str) -> float:
"""获取指定 IP 的锁定截至时间戳(未锁定返回 0"""
ip_key = str(ip_address)
with _ip_rate_limit_lock:
data = _ip_rate_limit.get(ip_key) or {}
try:
return float(data.get("lock_until", 0) or 0)
except Exception:
return 0.0
# ==================== Batch screenshots批次任务截图收集 ====================
_batch_task_screenshots: Dict[str, Dict[str, Any]] = {}
_batch_task_lock = threading.RLock()
def safe_create_batch(batch_id: str, batch_info: Dict[str, Any]) -> None:
with _batch_task_lock:
_batch_task_screenshots[str(batch_id)] = dict(batch_info or {})
def safe_get_batch(batch_id: str) -> Optional[Dict[str, Any]]:
with _batch_task_lock:
info = _batch_task_screenshots.get(str(batch_id))
return dict(info) if info else None
def safe_update_batch(batch_id: str, updates: Dict[str, Any]) -> bool:
with _batch_task_lock:
if str(batch_id) not in _batch_task_screenshots:
return False
_batch_task_screenshots[str(batch_id)].update(dict(updates or {}))
return True
def safe_pop_batch(batch_id: str) -> Optional[Dict[str, Any]]:
with _batch_task_lock:
return _batch_task_screenshots.pop(str(batch_id), None)
def safe_batch_append_result(batch_id: str, result: Dict[str, Any]) -> Optional[Dict[str, Any]]:
now_ts = time.time()
with _batch_task_lock:
info = _batch_task_screenshots.get(str(batch_id))
if not info:
return None
info.setdefault("screenshots", []).append(dict(result or {}))
info["completed"] = int(info.get("completed", 0) or 0) + 1
info["updated_at"] = now_ts
total = int(info.get("total_accounts", 0) or 0)
if total > 0 and int(info.get("completed", 0) or 0) >= total:
return _batch_task_screenshots.pop(str(batch_id), None)
return None
def safe_cleanup_expired_batches(expire_seconds: int, now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
expire_seconds = max(1, int(expire_seconds))
with _batch_task_lock:
expired = []
for batch_id, info in list(_batch_task_screenshots.items()):
last_ts = info.get("updated_at") or info.get("created_at") or info.get("created_time") or now_ts
if (now_ts - float(last_ts)) > expire_seconds:
expired.append(batch_id)
for batch_id in expired:
_batch_task_screenshots.pop(batch_id, None)
return len(expired)
def safe_finalize_batch_after_dispatch(batch_id: str, total_accounts: int, *, now_ts: Optional[float] = None) -> Optional[Dict[str, Any]]:
"""定时批次任务:更新总账号数,并在“已完成>=总数”时弹出批次数据用于发邮件。"""
now_ts = float(now_ts if now_ts is not None else time.time())
with _batch_task_lock:
info = _batch_task_screenshots.get(str(batch_id))
if not info:
return None
info["total_accounts"] = int(total_accounts or 0)
info["updated_at"] = now_ts
if int(total_accounts or 0) <= 0:
_batch_task_screenshots.pop(str(batch_id), None)
return None
if int(info.get("completed", 0) or 0) >= int(total_accounts):
return _batch_task_screenshots.pop(str(batch_id), None)
return None
# ==================== Pending random schedules兼容旧随机延迟逻辑 ====================
_pending_random_schedules: Dict[int, Dict[str, Any]] = {}
_pending_random_lock = threading.RLock()
def safe_set_pending_random_schedule(schedule_id: int, info: Dict[str, Any]) -> None:
with _pending_random_lock:
_pending_random_schedules[int(schedule_id)] = dict(info or {})
def safe_get_pending_random_schedule(schedule_id: int) -> Optional[Dict[str, Any]]:
with _pending_random_lock:
value = _pending_random_schedules.get(int(schedule_id))
return dict(value) if value else None
def safe_pop_pending_random_schedule(schedule_id: int) -> Optional[Dict[str, Any]]:
with _pending_random_lock:
return _pending_random_schedules.pop(int(schedule_id), None)
def safe_iter_pending_random_schedules_items() -> List[Tuple[int, Dict[str, Any]]]:
with _pending_random_lock:
return [(sid, dict(info)) for sid, info in _pending_random_schedules.items()]
def safe_cleanup_expired_pending_random(expire_seconds: int, now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
expire_seconds = max(1, int(expire_seconds))
with _pending_random_lock:
expired = []
for schedule_id, info in list(_pending_random_schedules.items()):
created_at = info.get("created_at") or info.get("created_time") or now_ts
if (now_ts - float(created_at)) > expire_seconds:
expired.append(schedule_id)
for schedule_id in expired:
_pending_random_schedules.pop(int(schedule_id), None)
return len(expired)