Files
zsglpt/services/state.py

831 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
线程安全的全局状态管理P0 / O-01
约束:
- 业务代码禁止直接读写底层 dict必须通过本模块 safe_* API 访问
- 读:要么持锁并返回副本,要么以“快照”的方式返回可迭代列表
"""
from __future__ import annotations
import threading
import time
import random
from typing import Any, Dict, List, Optional, Tuple
from app_config import get_config
config = get_config()
# ==================== Active tasks运行中的任务句柄 ====================
_active_tasks: Dict[str, Any] = {}
_active_tasks_lock = threading.RLock()
def safe_set_task(account_id: str, handle: Any) -> None:
with _active_tasks_lock:
_active_tasks[account_id] = handle
def safe_get_task(account_id: str) -> Any:
with _active_tasks_lock:
return _active_tasks.get(account_id)
def safe_remove_task(account_id: str) -> Any:
with _active_tasks_lock:
return _active_tasks.pop(account_id, None)
def safe_get_active_task_ids() -> List[str]:
with _active_tasks_lock:
return list(_active_tasks.keys())
# ==================== Task status前端展示状态 ====================
_task_status: Dict[str, Dict[str, Any]] = {}
_task_status_lock = threading.RLock()
def safe_set_task_status(account_id: str, status_dict: Dict[str, Any]) -> None:
with _task_status_lock:
_task_status[account_id] = dict(status_dict or {})
def safe_update_task_status(account_id: str, updates: Dict[str, Any]) -> bool:
with _task_status_lock:
if account_id not in _task_status:
return False
_task_status[account_id].update(updates or {})
return True
def safe_get_task_status(account_id: str) -> Dict[str, Any]:
with _task_status_lock:
value = _task_status.get(account_id)
return dict(value) if value else {}
def safe_remove_task_status(account_id: str) -> Dict[str, Any]:
with _task_status_lock:
return _task_status.pop(account_id, None)
def safe_get_all_task_status() -> Dict[str, Dict[str, Any]]:
with _task_status_lock:
return {k: dict(v) for k, v in _task_status.items()}
def safe_iter_task_status_items() -> List[Tuple[str, Dict[str, Any]]]:
with _task_status_lock:
return [(k, dict(v)) for k, v in _task_status.items()]
# ==================== User accounts cache账号对象缓存 ====================
_user_accounts: Dict[int, Dict[str, Any]] = {}
_user_accounts_last_access: Dict[int, float] = {}
_user_accounts_lock = threading.RLock()
def safe_touch_user_accounts(user_id: int) -> None:
now_ts = time.time()
with _user_accounts_lock:
_user_accounts_last_access[int(user_id)] = now_ts
def safe_get_user_accounts_last_access_items() -> List[Tuple[int, float]]:
with _user_accounts_lock:
return list(_user_accounts_last_access.items())
def safe_get_user_accounts_snapshot(user_id: int) -> Dict[str, Any]:
with _user_accounts_lock:
return dict(_user_accounts.get(int(user_id), {}))
def safe_set_user_accounts(user_id: int, accounts_by_id: Dict[str, Any]) -> None:
with _user_accounts_lock:
_user_accounts[int(user_id)] = dict(accounts_by_id or {})
_user_accounts_last_access[int(user_id)] = time.time()
def safe_get_account(user_id: int, account_id: str) -> Any:
with _user_accounts_lock:
return _user_accounts.get(int(user_id), {}).get(account_id)
def safe_set_account(user_id: int, account_id: str, account_obj: Any) -> None:
with _user_accounts_lock:
uid = int(user_id)
if uid not in _user_accounts:
_user_accounts[uid] = {}
_user_accounts[uid][account_id] = account_obj
_user_accounts_last_access[uid] = time.time()
def safe_remove_account(user_id: int, account_id: str) -> Any:
with _user_accounts_lock:
uid = int(user_id)
if uid not in _user_accounts:
return None
return _user_accounts[uid].pop(account_id, None)
def safe_remove_user_accounts(user_id: int) -> None:
with _user_accounts_lock:
uid = int(user_id)
_user_accounts.pop(uid, None)
_user_accounts_last_access.pop(uid, None)
def safe_iter_user_accounts_items() -> List[Tuple[int, Dict[str, Any]]]:
with _user_accounts_lock:
return [(uid, dict(accounts)) for uid, accounts in _user_accounts.items()]
def safe_has_user(user_id: int) -> bool:
with _user_accounts_lock:
return int(user_id) in _user_accounts
# ==================== Log cache用户维度日志缓存 ====================
_log_cache: Dict[int, List[Dict[str, Any]]] = {}
_log_cache_lock = threading.RLock()
_log_cache_total_count = 0
def safe_add_log(
user_id: int,
log_entry: Dict[str, Any],
*,
max_logs_per_user: Optional[int] = None,
max_total_logs: Optional[int] = None,
) -> None:
global _log_cache_total_count
uid = int(user_id)
max_logs_per_user = int(max_logs_per_user or config.MAX_LOGS_PER_USER)
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
with _log_cache_lock:
if uid not in _log_cache:
_log_cache[uid] = []
if len(_log_cache[uid]) >= max_logs_per_user:
_log_cache[uid].pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
_log_cache[uid].append(dict(log_entry or {}))
_log_cache_total_count += 1
while _log_cache_total_count > max_total_logs:
if not _log_cache:
break
max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
if _log_cache.get(max_user):
_log_cache[max_user].pop(0)
_log_cache_total_count -= 1
else:
break
def safe_get_user_logs(user_id: int) -> List[Dict[str, Any]]:
uid = int(user_id)
with _log_cache_lock:
return list(_log_cache.get(uid, []))
def safe_clear_user_logs(user_id: int) -> None:
global _log_cache_total_count
uid = int(user_id)
with _log_cache_lock:
removed = len(_log_cache.get(uid, []))
_log_cache.pop(uid, None)
_log_cache_total_count = max(0, _log_cache_total_count - removed)
def safe_get_log_cache_total_count() -> int:
with _log_cache_lock:
return int(_log_cache_total_count)
# ==================== Captcha storage验证码存储 ====================
_captcha_storage: Dict[str, Dict[str, Any]] = {}
_captcha_storage_lock = threading.RLock()
def safe_set_captcha(session_id: str, captcha_data: Dict[str, Any]) -> None:
with _captcha_storage_lock:
_captcha_storage[str(session_id)] = dict(captcha_data or {})
def safe_cleanup_expired_captcha(now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
with _captcha_storage_lock:
expired = [k for k, v in _captcha_storage.items() if float(v.get("expire_time", 0) or 0) < now_ts]
for k in expired:
_captcha_storage.pop(k, None)
return len(expired)
def safe_delete_captcha(session_id: str) -> None:
with _captcha_storage_lock:
_captcha_storage.pop(str(session_id), None)
def safe_verify_and_consume_captcha(session_id: str, code: str, *, max_attempts: Optional[int] = None) -> Tuple[bool, str]:
max_attempts = int(max_attempts or config.MAX_CAPTCHA_ATTEMPTS)
with _captcha_storage_lock:
captcha_data = _captcha_storage.pop(str(session_id), None)
if captcha_data is None:
return False, "验证码已过期或不存在,请重新获取"
try:
if float(captcha_data.get("expire_time", 0) or 0) < time.time():
return False, "验证码已过期,请重新获取"
failed_attempts = int(captcha_data.get("failed_attempts", 0) or 0)
if failed_attempts >= max_attempts:
return False, f"验证码错误次数过多({max_attempts}次),请重新获取"
expected = str(captcha_data.get("code", "") or "").lower()
actual = str(code or "").lower()
if expected != actual:
failed_attempts += 1
captcha_data["failed_attempts"] = failed_attempts
if failed_attempts < max_attempts:
_captcha_storage[str(session_id)] = captcha_data
return False, "验证码错误"
return True, "验证成功"
except Exception:
return False, "验证码验证失败,请重新获取"
# ==================== IP rate limit验证码失败限流 ====================
_ip_rate_limit: Dict[str, Dict[str, Any]] = {}
_ip_rate_limit_lock = threading.RLock()
def check_ip_rate_limit(
ip_address: str,
*,
max_attempts_per_hour: Optional[int] = None,
lock_duration_seconds: Optional[int] = None,
) -> Tuple[bool, Optional[str]]:
current_time = time.time()
max_attempts_per_hour = int(max_attempts_per_hour or config.MAX_IP_ATTEMPTS_PER_HOUR)
lock_duration_seconds = int(lock_duration_seconds or config.IP_LOCK_DURATION)
with _ip_rate_limit_lock:
expired_ips = []
for ip, data in _ip_rate_limit.items():
lock_expired = float(data.get("lock_until", 0) or 0) < current_time
first_attempt = data.get("first_attempt")
attempt_expired = first_attempt is None or (current_time - float(first_attempt)) > 3600
if lock_expired and attempt_expired:
expired_ips.append(ip)
for ip in expired_ips:
_ip_rate_limit.pop(ip, None)
ip_key = str(ip_address)
if ip_key in _ip_rate_limit:
ip_data = _ip_rate_limit[ip_key]
if float(ip_data.get("lock_until", 0) or 0) > current_time:
remaining_time = int(float(ip_data["lock_until"]) - current_time)
return False, f"IP已被锁定,请{remaining_time // 60 + 1}分钟后再试"
first_attempt = ip_data.get("first_attempt")
if first_attempt is None or current_time - float(first_attempt) > 3600:
_ip_rate_limit[ip_key] = {"attempts": 0, "first_attempt": current_time}
return True, None
def record_failed_captcha(
ip_address: str,
*,
max_attempts_per_hour: Optional[int] = None,
lock_duration_seconds: Optional[int] = None,
) -> bool:
current_time = time.time()
max_attempts_per_hour = int(max_attempts_per_hour or config.MAX_IP_ATTEMPTS_PER_HOUR)
lock_duration_seconds = int(lock_duration_seconds or config.IP_LOCK_DURATION)
with _ip_rate_limit_lock:
ip_key = str(ip_address)
if ip_key not in _ip_rate_limit:
_ip_rate_limit[ip_key] = {"attempts": 1, "first_attempt": current_time}
else:
_ip_rate_limit[ip_key]["attempts"] = int(_ip_rate_limit[ip_key].get("attempts", 0) or 0) + 1
if int(_ip_rate_limit[ip_key].get("attempts", 0) or 0) >= max_attempts_per_hour:
_ip_rate_limit[ip_key]["lock_until"] = current_time + lock_duration_seconds
return True
return False
def cleanup_expired_ip_rate_limits(now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
with _ip_rate_limit_lock:
expired_ips = []
for ip, data in _ip_rate_limit.items():
lock_until = float(data.get("lock_until", 0) or 0)
first_attempt = float(data.get("first_attempt", 0) or 0)
if lock_until < now_ts and (now_ts - first_attempt) > 3600:
expired_ips.append(ip)
for ip in expired_ips:
_ip_rate_limit.pop(ip, None)
return len(expired_ips)
def safe_get_ip_lock_until(ip_address: str) -> float:
"""获取指定 IP 的锁定截至时间戳(未锁定返回 0"""
ip_key = str(ip_address)
with _ip_rate_limit_lock:
data = _ip_rate_limit.get(ip_key) or {}
try:
return float(data.get("lock_until", 0) or 0)
except Exception:
return 0.0
# ==================== IP request rate limit接口频率限制 ====================
_ip_request_rate: Dict[str, Dict[str, Any]] = {}
_ip_request_rate_lock = threading.RLock()
def _get_action_rate_limit(action: str) -> Tuple[int, int]:
action = str(action or "").lower()
if action == "register":
return int(config.IP_RATE_LIMIT_REGISTER_MAX), int(config.IP_RATE_LIMIT_REGISTER_WINDOW_SECONDS)
if action == "email":
return int(config.IP_RATE_LIMIT_EMAIL_MAX), int(config.IP_RATE_LIMIT_EMAIL_WINDOW_SECONDS)
return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS)
def check_ip_request_rate(
ip_address: str,
action: str,
*,
max_requests: Optional[int] = None,
window_seconds: Optional[int] = None,
) -> Tuple[bool, Optional[str]]:
now_ts = time.time()
default_max, default_window = _get_action_rate_limit(action)
max_requests = int(max_requests or default_max)
window_seconds = int(window_seconds or default_window)
key = f"{action}:{ip_address}"
with _ip_request_rate_lock:
data = _ip_request_rate.get(key)
if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds:
data = {"window_start": now_ts, "count": 0}
_ip_request_rate[key] = data
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
if remaining >= 60:
wait_hint = f"{remaining // 60 + 1}分钟"
else:
wait_hint = f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
return True, None
def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
removed = 0
with _ip_request_rate_lock:
for key in list(_ip_request_rate.keys()):
data = _ip_request_rate.get(key) or {}
action = key.split(":", 1)[0]
_, window_seconds = _get_action_rate_limit(action)
window_start = float(data.get("window_start", 0) or 0)
if now_ts - window_start >= window_seconds:
_ip_request_rate.pop(key, None)
removed += 1
return removed
# ==================== 登录风控(验证码/限流/延迟/锁定) ====================
_login_failures: Dict[str, Dict[str, Any]] = {}
_login_failures_lock = threading.RLock()
_login_rate_limits: Dict[str, Dict[str, Any]] = {}
_login_rate_limits_lock = threading.RLock()
_login_scan_state: Dict[str, Dict[str, Any]] = {}
_login_scan_lock = threading.RLock()
_login_ip_user_locks: Dict[str, Dict[str, Any]] = {}
_login_ip_user_lock = threading.RLock()
_login_alert_state: Dict[int, Dict[str, Any]] = {}
_login_alert_lock = threading.RLock()
def _normalize_login_key(kind: str, ip_address: str, username: Optional[str] = None) -> str:
ip_key = str(ip_address or "")
user_key = str(username or "").strip().lower()
if kind == "ip":
return f"ip:{ip_key}"
if kind == "user":
return f"user:{user_key}" if user_key else ""
return f"ipuser:{ip_key}:{user_key}" if user_key else ""
def _get_login_captcha_config() -> Tuple[int, int]:
return int(config.LOGIN_CAPTCHA_AFTER_FAILURES), int(config.LOGIN_CAPTCHA_WINDOW_SECONDS)
def _get_login_rate_limit_config() -> Tuple[int, int, int, int]:
return (
int(config.LOGIN_IP_MAX_ATTEMPTS),
int(config.LOGIN_USERNAME_MAX_ATTEMPTS),
int(config.LOGIN_IP_USERNAME_MAX_ATTEMPTS),
int(config.LOGIN_RATE_LIMIT_WINDOW_SECONDS),
)
def _get_login_lock_config() -> Tuple[int, int, int]:
return (
int(config.LOGIN_ACCOUNT_LOCK_FAILURES),
int(config.LOGIN_ACCOUNT_LOCK_WINDOW_SECONDS),
int(config.LOGIN_ACCOUNT_LOCK_SECONDS),
)
def _get_login_scan_config() -> Tuple[int, int, int]:
return (
int(config.LOGIN_SCAN_UNIQUE_USERNAME_THRESHOLD),
int(config.LOGIN_SCAN_WINDOW_SECONDS),
int(config.LOGIN_SCAN_COOLDOWN_SECONDS),
)
def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_seconds: int) -> Dict[str, Any]:
if not data or (now_ts - float(data.get("window_start", 0) or 0)) > window_seconds:
return {"window_start": now_ts, "count": 0}
return data
def record_login_username_attempt(ip_address: str, username: str) -> bool:
now_ts = time.time()
threshold, window_seconds, cooldown_seconds = _get_login_scan_config()
ip_key = str(ip_address or "")
user_key = str(username or "").strip().lower()
if not ip_key or not user_key:
return False
with _login_scan_lock:
data = _login_scan_state.get(ip_key)
if not data or (now_ts - float(data.get("first_seen", 0) or 0)) > window_seconds:
data = {"first_seen": now_ts, "usernames": set(), "scan_until": 0}
_login_scan_state[ip_key] = data
data["usernames"].add(user_key)
if len(data["usernames"]) >= threshold:
data["scan_until"] = max(float(data.get("scan_until", 0) or 0), now_ts + cooldown_seconds)
return now_ts < float(data.get("scan_until", 0) or 0)
def is_login_scan_locked(ip_address: str) -> bool:
now_ts = time.time()
ip_key = str(ip_address or "")
with _login_scan_lock:
data = _login_scan_state.get(ip_key)
if not data:
return False
if now_ts >= float(data.get("scan_until", 0) or 0):
return False
return True
def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optional[str]]:
now_ts = time.time()
ip_max, user_max, ip_user_max, window_seconds = _get_login_rate_limit_config()
ip_key = _normalize_login_key("ip", ip_address)
user_key = _normalize_login_key("user", "", username)
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]:
if not key or max_requests <= 0:
return True, None
data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_login_rate_limits[key] = data
return True, None
with _login_rate_limits_lock:
allowed, msg = _check(ip_key, ip_max)
if not allowed:
return False, msg
allowed, msg = _check(ip_user_key, ip_user_max)
if not allowed:
return False, msg
allowed, msg = _check(user_key, user_max)
if not allowed:
return False, msg
return True, None
def _update_login_failure(key: str, now_ts: float, window_seconds: int) -> int:
data = _login_failures.get(key)
if not data or (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
data = {"first_failed": now_ts, "count": 0}
_login_failures[key] = data
data["count"] = int(data.get("count", 0) or 0) + 1
return int(data["count"])
def record_login_failure(ip_address: str, username: Optional[str] = None) -> None:
now_ts = time.time()
max_failures, window_seconds = _get_login_captcha_config()
lock_failures, lock_window, lock_seconds = _get_login_lock_config()
ip_key = _normalize_login_key("ip", ip_address)
user_key = _normalize_login_key("user", "", username or "")
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
with _login_failures_lock:
ip_count = _update_login_failure(ip_key, now_ts, window_seconds)
user_count = _update_login_failure(user_key, now_ts, window_seconds)
ip_user_count = _update_login_failure(ip_user_key, now_ts, window_seconds)
for key in (ip_key, user_key, ip_user_key):
data = _login_failures.get(key)
if data and int(data.get("count", 0) or 0) > max_failures * 5:
data["count"] = max_failures * 5
if username:
ip_user_lock_key = _normalize_login_key("ipuser", ip_address, username)
with _login_ip_user_lock:
if ip_user_count >= lock_failures:
_login_ip_user_locks[ip_user_lock_key] = {
"lock_until": now_ts + lock_seconds,
"first_failed": now_ts - lock_window,
}
def clear_login_failures(ip_address: str, username: Optional[str] = None) -> None:
ip_key = _normalize_login_key("ip", ip_address)
user_key = _normalize_login_key("user", "", username or "")
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
with _login_failures_lock:
_login_failures.pop(ip_key, None)
_login_failures.pop(user_key, None)
_login_failures.pop(ip_user_key, None)
with _login_ip_user_lock:
_login_ip_user_locks.pop(ip_user_key, None)
def _get_login_failure_count(ip_address: str, username: Optional[str] = None) -> int:
now_ts = time.time()
_, window_seconds = _get_login_captcha_config()
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
with _login_failures_lock:
data = _login_failures.get(ip_user_key)
if not data:
return 0
if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
_login_failures.pop(ip_user_key, None)
return 0
return int(data.get("count", 0) or 0)
def check_login_captcha_required(ip_address: str, username: Optional[str] = None) -> bool:
now_ts = time.time()
max_failures, window_seconds = _get_login_captcha_config()
ip_key = _normalize_login_key("ip", ip_address)
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
with _login_failures_lock:
ip_data = _login_failures.get(ip_key)
if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_data.get("count", 0) or 0) >= max_failures:
return True
ip_user_data = _login_failures.get(ip_user_key)
if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_user_data.get("count", 0) or 0) >= max_failures:
return True
if is_login_scan_locked(ip_address):
return True
return False
def check_login_ip_user_locked(ip_address: str, username: Optional[str]) -> Tuple[bool, int]:
now_ts = time.time()
if not username:
return False, 0
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
with _login_ip_user_lock:
data = _login_ip_user_locks.get(ip_user_key)
if not data:
return False, 0
lock_until = float(data.get("lock_until", 0) or 0)
if now_ts >= lock_until:
_login_ip_user_locks.pop(ip_user_key, None)
return False, 0
remaining = int(lock_until - now_ts)
return True, max(1, remaining)
def get_login_failure_delay_seconds(ip_address: str, username: Optional[str]) -> float:
fail_count = _get_login_failure_count(ip_address, username)
if fail_count <= 0:
return 0.0
base_ms = max(0, int(config.LOGIN_FAIL_DELAY_BASE_MS))
max_ms = max(base_ms, int(config.LOGIN_FAIL_DELAY_MAX_MS))
delay_ms = min(max_ms, int(base_ms * (1.6 ** max(0, fail_count - 1))))
jitter = random.randint(0, max(50, int(base_ms * 0.3)))
return float(delay_ms + jitter) / 1000.0
def should_send_login_alert(user_id: int, ip_address: str) -> bool:
now_ts = time.time()
min_interval = int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS)
with _login_alert_lock:
data = _login_alert_state.get(int(user_id))
if not data:
_login_alert_state[int(user_id)] = {"last_sent": now_ts, "last_ip": ip_address}
return True
last_sent = float(data.get("last_sent", 0) or 0)
last_ip = str(data.get("last_ip", "") or "")
if ip_address and ip_address != last_ip:
_login_alert_state[int(user_id)] = {"last_sent": now_ts, "last_ip": ip_address}
return True
if (now_ts - last_sent) >= min_interval:
_login_alert_state[int(user_id)] = {"last_sent": now_ts, "last_ip": ip_address}
return True
return False
# ==================== 邮箱维度限流 ====================
_email_rate_limit: Dict[str, Dict[str, Any]] = {}
_email_rate_limit_lock = threading.RLock()
def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str]]:
now_ts = time.time()
max_requests = int(config.EMAIL_RATE_LIMIT_MAX)
window_seconds = int(config.EMAIL_RATE_LIMIT_WINDOW_SECONDS)
email_key = str(email or "").strip().lower()
if not email_key:
return True, None
key = f"{action}:{email_key}"
with _email_rate_limit_lock:
data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_email_rate_limit[key] = data
return True, None
# ==================== Batch screenshots批次任务截图收集 ====================
_batch_task_screenshots: Dict[str, Dict[str, Any]] = {}
_batch_task_lock = threading.RLock()
def safe_create_batch(batch_id: str, batch_info: Dict[str, Any]) -> None:
with _batch_task_lock:
_batch_task_screenshots[str(batch_id)] = dict(batch_info or {})
def safe_get_batch(batch_id: str) -> Optional[Dict[str, Any]]:
with _batch_task_lock:
info = _batch_task_screenshots.get(str(batch_id))
return dict(info) if info else None
def safe_update_batch(batch_id: str, updates: Dict[str, Any]) -> bool:
with _batch_task_lock:
if str(batch_id) not in _batch_task_screenshots:
return False
_batch_task_screenshots[str(batch_id)].update(dict(updates or {}))
return True
def safe_pop_batch(batch_id: str) -> Optional[Dict[str, Any]]:
with _batch_task_lock:
return _batch_task_screenshots.pop(str(batch_id), None)
def safe_batch_append_result(batch_id: str, result: Dict[str, Any]) -> Optional[Dict[str, Any]]:
now_ts = time.time()
with _batch_task_lock:
info = _batch_task_screenshots.get(str(batch_id))
if not info:
return None
info.setdefault("screenshots", []).append(dict(result or {}))
info["completed"] = int(info.get("completed", 0) or 0) + 1
info["updated_at"] = now_ts
total = int(info.get("total_accounts", 0) or 0)
if total > 0 and int(info.get("completed", 0) or 0) >= total:
return _batch_task_screenshots.pop(str(batch_id), None)
return None
def safe_cleanup_expired_batches(expire_seconds: int, now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
expire_seconds = max(1, int(expire_seconds))
with _batch_task_lock:
expired = []
for batch_id, info in list(_batch_task_screenshots.items()):
last_ts = info.get("updated_at") or info.get("created_at") or info.get("created_time") or now_ts
if (now_ts - float(last_ts)) > expire_seconds:
expired.append(batch_id)
for batch_id in expired:
_batch_task_screenshots.pop(batch_id, None)
return len(expired)
def safe_finalize_batch_after_dispatch(batch_id: str, total_accounts: int, *, now_ts: Optional[float] = None) -> Optional[Dict[str, Any]]:
"""定时批次任务:更新总账号数,并在“已完成>=总数”时弹出批次数据用于发邮件。"""
now_ts = float(now_ts if now_ts is not None else time.time())
with _batch_task_lock:
info = _batch_task_screenshots.get(str(batch_id))
if not info:
return None
info["total_accounts"] = int(total_accounts or 0)
info["updated_at"] = now_ts
if int(total_accounts or 0) <= 0:
_batch_task_screenshots.pop(str(batch_id), None)
return None
if int(info.get("completed", 0) or 0) >= int(total_accounts):
return _batch_task_screenshots.pop(str(batch_id), None)
return None
# ==================== Pending random schedules兼容旧随机延迟逻辑 ====================
_pending_random_schedules: Dict[int, Dict[str, Any]] = {}
_pending_random_lock = threading.RLock()
def safe_set_pending_random_schedule(schedule_id: int, info: Dict[str, Any]) -> None:
with _pending_random_lock:
_pending_random_schedules[int(schedule_id)] = dict(info or {})
def safe_get_pending_random_schedule(schedule_id: int) -> Optional[Dict[str, Any]]:
with _pending_random_lock:
value = _pending_random_schedules.get(int(schedule_id))
return dict(value) if value else None
def safe_pop_pending_random_schedule(schedule_id: int) -> Optional[Dict[str, Any]]:
with _pending_random_lock:
return _pending_random_schedules.pop(int(schedule_id), None)
def safe_iter_pending_random_schedules_items() -> List[Tuple[int, Dict[str, Any]]]:
with _pending_random_lock:
return [(sid, dict(info)) for sid, info in _pending_random_schedules.items()]
def safe_cleanup_expired_pending_random(expire_seconds: int, now_ts: Optional[float] = None) -> int:
now_ts = float(now_ts if now_ts is not None else time.time())
expire_seconds = max(1, int(expire_seconds))
with _pending_random_lock:
expired = []
for schedule_id, info in list(_pending_random_schedules.items()):
created_at = info.get("created_at") or info.get("created_time") or now_ts
if (now_ts - float(created_at)) > expire_seconds:
expired.append(schedule_id)
for schedule_id in expired:
_pending_random_schedules.pop(int(schedule_id), None)
return len(expired)