831 lines
30 KiB
Python
831 lines
30 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
线程安全的全局状态管理(P0 / O-01)
|
||
|
||
约束:
|
||
- 业务代码禁止直接读写底层 dict;必须通过本模块 safe_* API 访问
|
||
- 读:要么持锁并返回副本,要么以“快照”的方式返回可迭代列表
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import threading
|
||
import time
|
||
import random
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
from app_config import get_config
|
||
|
||
config = get_config()
|
||
|
||
|
||
# ==================== Active tasks(运行中的任务句柄) ====================
|
||
|
||
_active_tasks: Dict[str, Any] = {}
|
||
_active_tasks_lock = threading.RLock()
|
||
|
||
|
||
def safe_set_task(account_id: str, handle: Any) -> None:
|
||
with _active_tasks_lock:
|
||
_active_tasks[account_id] = handle
|
||
|
||
|
||
def safe_get_task(account_id: str) -> Any:
|
||
with _active_tasks_lock:
|
||
return _active_tasks.get(account_id)
|
||
|
||
|
||
def safe_remove_task(account_id: str) -> Any:
|
||
with _active_tasks_lock:
|
||
return _active_tasks.pop(account_id, None)
|
||
|
||
|
||
def safe_get_active_task_ids() -> List[str]:
|
||
with _active_tasks_lock:
|
||
return list(_active_tasks.keys())
|
||
|
||
|
||
# ==================== Task status(前端展示状态) ====================
|
||
|
||
_task_status: Dict[str, Dict[str, Any]] = {}
|
||
_task_status_lock = threading.RLock()
|
||
|
||
|
||
def safe_set_task_status(account_id: str, status_dict: Dict[str, Any]) -> None:
|
||
with _task_status_lock:
|
||
_task_status[account_id] = dict(status_dict or {})
|
||
|
||
|
||
def safe_update_task_status(account_id: str, updates: Dict[str, Any]) -> bool:
|
||
with _task_status_lock:
|
||
if account_id not in _task_status:
|
||
return False
|
||
_task_status[account_id].update(updates or {})
|
||
return True
|
||
|
||
|
||
def safe_get_task_status(account_id: str) -> Dict[str, Any]:
|
||
with _task_status_lock:
|
||
value = _task_status.get(account_id)
|
||
return dict(value) if value else {}
|
||
|
||
|
||
def safe_remove_task_status(account_id: str) -> Dict[str, Any]:
|
||
with _task_status_lock:
|
||
return _task_status.pop(account_id, None)
|
||
|
||
|
||
def safe_get_all_task_status() -> Dict[str, Dict[str, Any]]:
|
||
with _task_status_lock:
|
||
return {k: dict(v) for k, v in _task_status.items()}
|
||
|
||
|
||
def safe_iter_task_status_items() -> List[Tuple[str, Dict[str, Any]]]:
|
||
with _task_status_lock:
|
||
return [(k, dict(v)) for k, v in _task_status.items()]
|
||
|
||
|
||
# ==================== User accounts cache(账号对象缓存) ====================
|
||
|
||
_user_accounts: Dict[int, Dict[str, Any]] = {}
|
||
_user_accounts_last_access: Dict[int, float] = {}
|
||
_user_accounts_lock = threading.RLock()
|
||
|
||
|
||
def safe_touch_user_accounts(user_id: int) -> None:
|
||
now_ts = time.time()
|
||
with _user_accounts_lock:
|
||
_user_accounts_last_access[int(user_id)] = now_ts
|
||
|
||
|
||
def safe_get_user_accounts_last_access_items() -> List[Tuple[int, float]]:
|
||
with _user_accounts_lock:
|
||
return list(_user_accounts_last_access.items())
|
||
|
||
|
||
def safe_get_user_accounts_snapshot(user_id: int) -> Dict[str, Any]:
|
||
with _user_accounts_lock:
|
||
return dict(_user_accounts.get(int(user_id), {}))
|
||
|
||
|
||
def safe_set_user_accounts(user_id: int, accounts_by_id: Dict[str, Any]) -> None:
|
||
with _user_accounts_lock:
|
||
_user_accounts[int(user_id)] = dict(accounts_by_id or {})
|
||
_user_accounts_last_access[int(user_id)] = time.time()
|
||
|
||
|
||
def safe_get_account(user_id: int, account_id: str) -> Any:
|
||
with _user_accounts_lock:
|
||
return _user_accounts.get(int(user_id), {}).get(account_id)
|
||
|
||
|
||
def safe_set_account(user_id: int, account_id: str, account_obj: Any) -> None:
|
||
with _user_accounts_lock:
|
||
uid = int(user_id)
|
||
if uid not in _user_accounts:
|
||
_user_accounts[uid] = {}
|
||
_user_accounts[uid][account_id] = account_obj
|
||
_user_accounts_last_access[uid] = time.time()
|
||
|
||
|
||
def safe_remove_account(user_id: int, account_id: str) -> Any:
|
||
with _user_accounts_lock:
|
||
uid = int(user_id)
|
||
if uid not in _user_accounts:
|
||
return None
|
||
return _user_accounts[uid].pop(account_id, None)
|
||
|
||
|
||
def safe_remove_user_accounts(user_id: int) -> None:
|
||
with _user_accounts_lock:
|
||
uid = int(user_id)
|
||
_user_accounts.pop(uid, None)
|
||
_user_accounts_last_access.pop(uid, None)
|
||
|
||
|
||
def safe_iter_user_accounts_items() -> List[Tuple[int, Dict[str, Any]]]:
|
||
with _user_accounts_lock:
|
||
return [(uid, dict(accounts)) for uid, accounts in _user_accounts.items()]
|
||
|
||
|
||
def safe_has_user(user_id: int) -> bool:
|
||
with _user_accounts_lock:
|
||
return int(user_id) in _user_accounts
|
||
|
||
|
||
# ==================== Log cache(用户维度日志缓存) ====================
|
||
|
||
_log_cache: Dict[int, List[Dict[str, Any]]] = {}
|
||
_log_cache_lock = threading.RLock()
|
||
_log_cache_total_count = 0
|
||
|
||
|
||
def safe_add_log(
|
||
user_id: int,
|
||
log_entry: Dict[str, Any],
|
||
*,
|
||
max_logs_per_user: Optional[int] = None,
|
||
max_total_logs: Optional[int] = None,
|
||
) -> None:
|
||
global _log_cache_total_count
|
||
|
||
uid = int(user_id)
|
||
max_logs_per_user = int(max_logs_per_user or config.MAX_LOGS_PER_USER)
|
||
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
|
||
|
||
with _log_cache_lock:
|
||
if uid not in _log_cache:
|
||
_log_cache[uid] = []
|
||
|
||
if len(_log_cache[uid]) >= max_logs_per_user:
|
||
_log_cache[uid].pop(0)
|
||
_log_cache_total_count = max(0, _log_cache_total_count - 1)
|
||
|
||
_log_cache[uid].append(dict(log_entry or {}))
|
||
_log_cache_total_count += 1
|
||
|
||
while _log_cache_total_count > max_total_logs:
|
||
if not _log_cache:
|
||
break
|
||
max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
|
||
if _log_cache.get(max_user):
|
||
_log_cache[max_user].pop(0)
|
||
_log_cache_total_count -= 1
|
||
else:
|
||
break
|
||
|
||
|
||
def safe_get_user_logs(user_id: int) -> List[Dict[str, Any]]:
|
||
uid = int(user_id)
|
||
with _log_cache_lock:
|
||
return list(_log_cache.get(uid, []))
|
||
|
||
|
||
def safe_clear_user_logs(user_id: int) -> None:
|
||
global _log_cache_total_count
|
||
|
||
uid = int(user_id)
|
||
with _log_cache_lock:
|
||
removed = len(_log_cache.get(uid, []))
|
||
_log_cache.pop(uid, None)
|
||
_log_cache_total_count = max(0, _log_cache_total_count - removed)
|
||
|
||
|
||
def safe_get_log_cache_total_count() -> int:
|
||
with _log_cache_lock:
|
||
return int(_log_cache_total_count)
|
||
|
||
|
||
# ==================== Captcha storage(验证码存储) ====================
|
||
|
||
_captcha_storage: Dict[str, Dict[str, Any]] = {}
|
||
_captcha_storage_lock = threading.RLock()
|
||
|
||
|
||
def safe_set_captcha(session_id: str, captcha_data: Dict[str, Any]) -> None:
|
||
with _captcha_storage_lock:
|
||
_captcha_storage[str(session_id)] = dict(captcha_data or {})
|
||
|
||
|
||
def safe_cleanup_expired_captcha(now_ts: Optional[float] = None) -> int:
|
||
now_ts = float(now_ts if now_ts is not None else time.time())
|
||
with _captcha_storage_lock:
|
||
expired = [k for k, v in _captcha_storage.items() if float(v.get("expire_time", 0) or 0) < now_ts]
|
||
for k in expired:
|
||
_captcha_storage.pop(k, None)
|
||
return len(expired)
|
||
|
||
|
||
def safe_delete_captcha(session_id: str) -> None:
|
||
with _captcha_storage_lock:
|
||
_captcha_storage.pop(str(session_id), None)
|
||
|
||
|
||
def safe_verify_and_consume_captcha(session_id: str, code: str, *, max_attempts: Optional[int] = None) -> Tuple[bool, str]:
|
||
max_attempts = int(max_attempts or config.MAX_CAPTCHA_ATTEMPTS)
|
||
with _captcha_storage_lock:
|
||
captcha_data = _captcha_storage.pop(str(session_id), None)
|
||
|
||
if captcha_data is None:
|
||
return False, "验证码已过期或不存在,请重新获取"
|
||
|
||
try:
|
||
if float(captcha_data.get("expire_time", 0) or 0) < time.time():
|
||
return False, "验证码已过期,请重新获取"
|
||
|
||
failed_attempts = int(captcha_data.get("failed_attempts", 0) or 0)
|
||
if failed_attempts >= max_attempts:
|
||
return False, f"验证码错误次数过多({max_attempts}次),请重新获取"
|
||
|
||
expected = str(captcha_data.get("code", "") or "").lower()
|
||
actual = str(code or "").lower()
|
||
if expected != actual:
|
||
failed_attempts += 1
|
||
captcha_data["failed_attempts"] = failed_attempts
|
||
if failed_attempts < max_attempts:
|
||
_captcha_storage[str(session_id)] = captcha_data
|
||
return False, "验证码错误"
|
||
|
||
return True, "验证成功"
|
||
except Exception:
|
||
return False, "验证码验证失败,请重新获取"
|
||
|
||
|
||
# ==================== IP rate limit(验证码失败限流) ====================
|
||
|
||
_ip_rate_limit: Dict[str, Dict[str, Any]] = {}
|
||
_ip_rate_limit_lock = threading.RLock()
|
||
|
||
|
||
def check_ip_rate_limit(
|
||
ip_address: str,
|
||
*,
|
||
max_attempts_per_hour: Optional[int] = None,
|
||
lock_duration_seconds: Optional[int] = None,
|
||
) -> Tuple[bool, Optional[str]]:
|
||
current_time = time.time()
|
||
max_attempts_per_hour = int(max_attempts_per_hour or config.MAX_IP_ATTEMPTS_PER_HOUR)
|
||
lock_duration_seconds = int(lock_duration_seconds or config.IP_LOCK_DURATION)
|
||
|
||
with _ip_rate_limit_lock:
|
||
expired_ips = []
|
||
for ip, data in _ip_rate_limit.items():
|
||
lock_expired = float(data.get("lock_until", 0) or 0) < current_time
|
||
first_attempt = data.get("first_attempt")
|
||
attempt_expired = first_attempt is None or (current_time - float(first_attempt)) > 3600
|
||
if lock_expired and attempt_expired:
|
||
expired_ips.append(ip)
|
||
|
||
for ip in expired_ips:
|
||
_ip_rate_limit.pop(ip, None)
|
||
|
||
ip_key = str(ip_address)
|
||
if ip_key in _ip_rate_limit:
|
||
ip_data = _ip_rate_limit[ip_key]
|
||
if float(ip_data.get("lock_until", 0) or 0) > current_time:
|
||
remaining_time = int(float(ip_data["lock_until"]) - current_time)
|
||
return False, f"IP已被锁定,请{remaining_time // 60 + 1}分钟后再试"
|
||
|
||
first_attempt = ip_data.get("first_attempt")
|
||
if first_attempt is None or current_time - float(first_attempt) > 3600:
|
||
_ip_rate_limit[ip_key] = {"attempts": 0, "first_attempt": current_time}
|
||
|
||
return True, None
|
||
|
||
|
||
def record_failed_captcha(
|
||
ip_address: str,
|
||
*,
|
||
max_attempts_per_hour: Optional[int] = None,
|
||
lock_duration_seconds: Optional[int] = None,
|
||
) -> bool:
|
||
current_time = time.time()
|
||
max_attempts_per_hour = int(max_attempts_per_hour or config.MAX_IP_ATTEMPTS_PER_HOUR)
|
||
lock_duration_seconds = int(lock_duration_seconds or config.IP_LOCK_DURATION)
|
||
|
||
with _ip_rate_limit_lock:
|
||
ip_key = str(ip_address)
|
||
if ip_key not in _ip_rate_limit:
|
||
_ip_rate_limit[ip_key] = {"attempts": 1, "first_attempt": current_time}
|
||
else:
|
||
_ip_rate_limit[ip_key]["attempts"] = int(_ip_rate_limit[ip_key].get("attempts", 0) or 0) + 1
|
||
|
||
if int(_ip_rate_limit[ip_key].get("attempts", 0) or 0) >= max_attempts_per_hour:
|
||
_ip_rate_limit[ip_key]["lock_until"] = current_time + lock_duration_seconds
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def cleanup_expired_ip_rate_limits(now_ts: Optional[float] = None) -> int:
|
||
now_ts = float(now_ts if now_ts is not None else time.time())
|
||
with _ip_rate_limit_lock:
|
||
expired_ips = []
|
||
for ip, data in _ip_rate_limit.items():
|
||
lock_until = float(data.get("lock_until", 0) or 0)
|
||
first_attempt = float(data.get("first_attempt", 0) or 0)
|
||
if lock_until < now_ts and (now_ts - first_attempt) > 3600:
|
||
expired_ips.append(ip)
|
||
for ip in expired_ips:
|
||
_ip_rate_limit.pop(ip, None)
|
||
return len(expired_ips)
|
||
|
||
|
||
def safe_get_ip_lock_until(ip_address: str) -> float:
|
||
"""获取指定 IP 的锁定截至时间戳(未锁定返回 0)。"""
|
||
ip_key = str(ip_address)
|
||
with _ip_rate_limit_lock:
|
||
data = _ip_rate_limit.get(ip_key) or {}
|
||
try:
|
||
return float(data.get("lock_until", 0) or 0)
|
||
except Exception:
|
||
return 0.0
|
||
|
||
|
||
# ==================== IP request rate limit(接口频率限制) ====================
|
||
|
||
_ip_request_rate: Dict[str, Dict[str, Any]] = {}
|
||
_ip_request_rate_lock = threading.RLock()
|
||
|
||
|
||
def _get_action_rate_limit(action: str) -> Tuple[int, int]:
|
||
action = str(action or "").lower()
|
||
if action == "register":
|
||
return int(config.IP_RATE_LIMIT_REGISTER_MAX), int(config.IP_RATE_LIMIT_REGISTER_WINDOW_SECONDS)
|
||
if action == "email":
|
||
return int(config.IP_RATE_LIMIT_EMAIL_MAX), int(config.IP_RATE_LIMIT_EMAIL_WINDOW_SECONDS)
|
||
return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS)
|
||
|
||
|
||
def check_ip_request_rate(
|
||
ip_address: str,
|
||
action: str,
|
||
*,
|
||
max_requests: Optional[int] = None,
|
||
window_seconds: Optional[int] = None,
|
||
) -> Tuple[bool, Optional[str]]:
|
||
now_ts = time.time()
|
||
default_max, default_window = _get_action_rate_limit(action)
|
||
max_requests = int(max_requests or default_max)
|
||
window_seconds = int(window_seconds or default_window)
|
||
|
||
key = f"{action}:{ip_address}"
|
||
with _ip_request_rate_lock:
|
||
data = _ip_request_rate.get(key)
|
||
if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds:
|
||
data = {"window_start": now_ts, "count": 0}
|
||
_ip_request_rate[key] = data
|
||
|
||
if int(data.get("count", 0) or 0) >= max_requests:
|
||
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
|
||
if remaining >= 60:
|
||
wait_hint = f"{remaining // 60 + 1}分钟"
|
||
else:
|
||
wait_hint = f"{remaining}秒"
|
||
return False, f"请求过于频繁,请{wait_hint}后再试"
|
||
|
||
data["count"] = int(data.get("count", 0) or 0) + 1
|
||
return True, None
|
||
|
||
|
||
def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
|
||
now_ts = float(now_ts if now_ts is not None else time.time())
|
||
removed = 0
|
||
with _ip_request_rate_lock:
|
||
for key in list(_ip_request_rate.keys()):
|
||
data = _ip_request_rate.get(key) or {}
|
||
action = key.split(":", 1)[0]
|
||
_, window_seconds = _get_action_rate_limit(action)
|
||
window_start = float(data.get("window_start", 0) or 0)
|
||
if now_ts - window_start >= window_seconds:
|
||
_ip_request_rate.pop(key, None)
|
||
removed += 1
|
||
return removed
|
||
|
||
|
||
# ==================== 登录风控(验证码/限流/延迟/锁定) ====================
|
||
|
||
_login_failures: Dict[str, Dict[str, Any]] = {}
|
||
_login_failures_lock = threading.RLock()
|
||
|
||
_login_rate_limits: Dict[str, Dict[str, Any]] = {}
|
||
_login_rate_limits_lock = threading.RLock()
|
||
|
||
_login_scan_state: Dict[str, Dict[str, Any]] = {}
|
||
_login_scan_lock = threading.RLock()
|
||
|
||
_login_ip_user_locks: Dict[str, Dict[str, Any]] = {}
|
||
_login_ip_user_lock = threading.RLock()
|
||
|
||
_login_alert_state: Dict[int, Dict[str, Any]] = {}
|
||
_login_alert_lock = threading.RLock()
|
||
|
||
|
||
def _normalize_login_key(kind: str, ip_address: str, username: Optional[str] = None) -> str:
|
||
ip_key = str(ip_address or "")
|
||
user_key = str(username or "").strip().lower()
|
||
if kind == "ip":
|
||
return f"ip:{ip_key}"
|
||
if kind == "user":
|
||
return f"user:{user_key}" if user_key else ""
|
||
return f"ipuser:{ip_key}:{user_key}" if user_key else ""
|
||
|
||
|
||
def _get_login_captcha_config() -> Tuple[int, int]:
|
||
return int(config.LOGIN_CAPTCHA_AFTER_FAILURES), int(config.LOGIN_CAPTCHA_WINDOW_SECONDS)
|
||
|
||
|
||
def _get_login_rate_limit_config() -> Tuple[int, int, int, int]:
|
||
return (
|
||
int(config.LOGIN_IP_MAX_ATTEMPTS),
|
||
int(config.LOGIN_USERNAME_MAX_ATTEMPTS),
|
||
int(config.LOGIN_IP_USERNAME_MAX_ATTEMPTS),
|
||
int(config.LOGIN_RATE_LIMIT_WINDOW_SECONDS),
|
||
)
|
||
|
||
|
||
def _get_login_lock_config() -> Tuple[int, int, int]:
|
||
return (
|
||
int(config.LOGIN_ACCOUNT_LOCK_FAILURES),
|
||
int(config.LOGIN_ACCOUNT_LOCK_WINDOW_SECONDS),
|
||
int(config.LOGIN_ACCOUNT_LOCK_SECONDS),
|
||
)
|
||
|
||
|
||
def _get_login_scan_config() -> Tuple[int, int, int]:
|
||
return (
|
||
int(config.LOGIN_SCAN_UNIQUE_USERNAME_THRESHOLD),
|
||
int(config.LOGIN_SCAN_WINDOW_SECONDS),
|
||
int(config.LOGIN_SCAN_COOLDOWN_SECONDS),
|
||
)
|
||
|
||
|
||
def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_seconds: int) -> Dict[str, Any]:
|
||
if not data or (now_ts - float(data.get("window_start", 0) or 0)) > window_seconds:
|
||
return {"window_start": now_ts, "count": 0}
|
||
return data
|
||
|
||
|
||
def record_login_username_attempt(ip_address: str, username: str) -> bool:
|
||
now_ts = time.time()
|
||
threshold, window_seconds, cooldown_seconds = _get_login_scan_config()
|
||
ip_key = str(ip_address or "")
|
||
user_key = str(username or "").strip().lower()
|
||
if not ip_key or not user_key:
|
||
return False
|
||
|
||
with _login_scan_lock:
|
||
data = _login_scan_state.get(ip_key)
|
||
if not data or (now_ts - float(data.get("first_seen", 0) or 0)) > window_seconds:
|
||
data = {"first_seen": now_ts, "usernames": set(), "scan_until": 0}
|
||
_login_scan_state[ip_key] = data
|
||
|
||
data["usernames"].add(user_key)
|
||
if len(data["usernames"]) >= threshold:
|
||
data["scan_until"] = max(float(data.get("scan_until", 0) or 0), now_ts + cooldown_seconds)
|
||
|
||
return now_ts < float(data.get("scan_until", 0) or 0)
|
||
|
||
|
||
def is_login_scan_locked(ip_address: str) -> bool:
|
||
now_ts = time.time()
|
||
ip_key = str(ip_address or "")
|
||
with _login_scan_lock:
|
||
data = _login_scan_state.get(ip_key)
|
||
if not data:
|
||
return False
|
||
if now_ts >= float(data.get("scan_until", 0) or 0):
|
||
return False
|
||
return True
|
||
|
||
|
||
def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optional[str]]:
|
||
now_ts = time.time()
|
||
ip_max, user_max, ip_user_max, window_seconds = _get_login_rate_limit_config()
|
||
ip_key = _normalize_login_key("ip", ip_address)
|
||
user_key = _normalize_login_key("user", "", username)
|
||
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
|
||
|
||
def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]:
|
||
if not key or max_requests <= 0:
|
||
return True, None
|
||
data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds)
|
||
if int(data.get("count", 0) or 0) >= max_requests:
|
||
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
|
||
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}秒"
|
||
return False, f"请求过于频繁,请{wait_hint}后再试"
|
||
data["count"] = int(data.get("count", 0) or 0) + 1
|
||
_login_rate_limits[key] = data
|
||
return True, None
|
||
|
||
with _login_rate_limits_lock:
|
||
allowed, msg = _check(ip_key, ip_max)
|
||
if not allowed:
|
||
return False, msg
|
||
allowed, msg = _check(ip_user_key, ip_user_max)
|
||
if not allowed:
|
||
return False, msg
|
||
allowed, msg = _check(user_key, user_max)
|
||
if not allowed:
|
||
return False, msg
|
||
|
||
return True, None
|
||
|
||
|
||
def _update_login_failure(key: str, now_ts: float, window_seconds: int) -> int:
|
||
data = _login_failures.get(key)
|
||
if not data or (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
|
||
data = {"first_failed": now_ts, "count": 0}
|
||
_login_failures[key] = data
|
||
data["count"] = int(data.get("count", 0) or 0) + 1
|
||
return int(data["count"])
|
||
|
||
|
||
def record_login_failure(ip_address: str, username: Optional[str] = None) -> None:
|
||
now_ts = time.time()
|
||
max_failures, window_seconds = _get_login_captcha_config()
|
||
lock_failures, lock_window, lock_seconds = _get_login_lock_config()
|
||
ip_key = _normalize_login_key("ip", ip_address)
|
||
user_key = _normalize_login_key("user", "", username or "")
|
||
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
|
||
|
||
with _login_failures_lock:
|
||
ip_count = _update_login_failure(ip_key, now_ts, window_seconds)
|
||
user_count = _update_login_failure(user_key, now_ts, window_seconds)
|
||
ip_user_count = _update_login_failure(ip_user_key, now_ts, window_seconds)
|
||
|
||
for key in (ip_key, user_key, ip_user_key):
|
||
data = _login_failures.get(key)
|
||
if data and int(data.get("count", 0) or 0) > max_failures * 5:
|
||
data["count"] = max_failures * 5
|
||
|
||
if username:
|
||
ip_user_lock_key = _normalize_login_key("ipuser", ip_address, username)
|
||
with _login_ip_user_lock:
|
||
if ip_user_count >= lock_failures:
|
||
_login_ip_user_locks[ip_user_lock_key] = {
|
||
"lock_until": now_ts + lock_seconds,
|
||
"first_failed": now_ts - lock_window,
|
||
}
|
||
|
||
|
||
def clear_login_failures(ip_address: str, username: Optional[str] = None) -> None:
|
||
ip_key = _normalize_login_key("ip", ip_address)
|
||
user_key = _normalize_login_key("user", "", username or "")
|
||
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
|
||
with _login_failures_lock:
|
||
_login_failures.pop(ip_key, None)
|
||
_login_failures.pop(user_key, None)
|
||
_login_failures.pop(ip_user_key, None)
|
||
with _login_ip_user_lock:
|
||
_login_ip_user_locks.pop(ip_user_key, None)
|
||
|
||
|
||
def _get_login_failure_count(ip_address: str, username: Optional[str] = None) -> int:
|
||
now_ts = time.time()
|
||
_, window_seconds = _get_login_captcha_config()
|
||
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
|
||
with _login_failures_lock:
|
||
data = _login_failures.get(ip_user_key)
|
||
if not data:
|
||
return 0
|
||
if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
|
||
_login_failures.pop(ip_user_key, None)
|
||
return 0
|
||
return int(data.get("count", 0) or 0)
|
||
|
||
|
||
def check_login_captcha_required(ip_address: str, username: Optional[str] = None) -> bool:
|
||
now_ts = time.time()
|
||
max_failures, window_seconds = _get_login_captcha_config()
|
||
ip_key = _normalize_login_key("ip", ip_address)
|
||
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
|
||
|
||
with _login_failures_lock:
|
||
ip_data = _login_failures.get(ip_key)
|
||
if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds:
|
||
if int(ip_data.get("count", 0) or 0) >= max_failures:
|
||
return True
|
||
ip_user_data = _login_failures.get(ip_user_key)
|
||
if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds:
|
||
if int(ip_user_data.get("count", 0) or 0) >= max_failures:
|
||
return True
|
||
|
||
if is_login_scan_locked(ip_address):
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def check_login_ip_user_locked(ip_address: str, username: Optional[str]) -> Tuple[bool, int]:
|
||
now_ts = time.time()
|
||
if not username:
|
||
return False, 0
|
||
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
|
||
with _login_ip_user_lock:
|
||
data = _login_ip_user_locks.get(ip_user_key)
|
||
if not data:
|
||
return False, 0
|
||
lock_until = float(data.get("lock_until", 0) or 0)
|
||
if now_ts >= lock_until:
|
||
_login_ip_user_locks.pop(ip_user_key, None)
|
||
return False, 0
|
||
remaining = int(lock_until - now_ts)
|
||
return True, max(1, remaining)
|
||
|
||
|
||
def get_login_failure_delay_seconds(ip_address: str, username: Optional[str]) -> float:
|
||
fail_count = _get_login_failure_count(ip_address, username)
|
||
if fail_count <= 0:
|
||
return 0.0
|
||
base_ms = max(0, int(config.LOGIN_FAIL_DELAY_BASE_MS))
|
||
max_ms = max(base_ms, int(config.LOGIN_FAIL_DELAY_MAX_MS))
|
||
delay_ms = min(max_ms, int(base_ms * (1.6 ** max(0, fail_count - 1))))
|
||
jitter = random.randint(0, max(50, int(base_ms * 0.3)))
|
||
return float(delay_ms + jitter) / 1000.0
|
||
|
||
|
||
def should_send_login_alert(user_id: int, ip_address: str) -> bool:
|
||
now_ts = time.time()
|
||
min_interval = int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS)
|
||
with _login_alert_lock:
|
||
data = _login_alert_state.get(int(user_id))
|
||
if not data:
|
||
_login_alert_state[int(user_id)] = {"last_sent": now_ts, "last_ip": ip_address}
|
||
return True
|
||
last_sent = float(data.get("last_sent", 0) or 0)
|
||
last_ip = str(data.get("last_ip", "") or "")
|
||
if ip_address and ip_address != last_ip:
|
||
_login_alert_state[int(user_id)] = {"last_sent": now_ts, "last_ip": ip_address}
|
||
return True
|
||
if (now_ts - last_sent) >= min_interval:
|
||
_login_alert_state[int(user_id)] = {"last_sent": now_ts, "last_ip": ip_address}
|
||
return True
|
||
return False
|
||
|
||
|
||
# ==================== 邮箱维度限流 ====================
|
||
|
||
_email_rate_limit: Dict[str, Dict[str, Any]] = {}
|
||
_email_rate_limit_lock = threading.RLock()
|
||
|
||
|
||
def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str]]:
|
||
now_ts = time.time()
|
||
max_requests = int(config.EMAIL_RATE_LIMIT_MAX)
|
||
window_seconds = int(config.EMAIL_RATE_LIMIT_WINDOW_SECONDS)
|
||
email_key = str(email or "").strip().lower()
|
||
if not email_key:
|
||
return True, None
|
||
key = f"{action}:{email_key}"
|
||
|
||
with _email_rate_limit_lock:
|
||
data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds)
|
||
if int(data.get("count", 0) or 0) >= max_requests:
|
||
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
|
||
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}秒"
|
||
return False, f"请求过于频繁,请{wait_hint}后再试"
|
||
data["count"] = int(data.get("count", 0) or 0) + 1
|
||
_email_rate_limit[key] = data
|
||
return True, None
|
||
|
||
|
||
# ==================== Batch screenshots(批次任务截图收集) ====================
|
||
|
||
_batch_task_screenshots: Dict[str, Dict[str, Any]] = {}
|
||
_batch_task_lock = threading.RLock()
|
||
|
||
|
||
def safe_create_batch(batch_id: str, batch_info: Dict[str, Any]) -> None:
|
||
with _batch_task_lock:
|
||
_batch_task_screenshots[str(batch_id)] = dict(batch_info or {})
|
||
|
||
|
||
def safe_get_batch(batch_id: str) -> Optional[Dict[str, Any]]:
|
||
with _batch_task_lock:
|
||
info = _batch_task_screenshots.get(str(batch_id))
|
||
return dict(info) if info else None
|
||
|
||
|
||
def safe_update_batch(batch_id: str, updates: Dict[str, Any]) -> bool:
|
||
with _batch_task_lock:
|
||
if str(batch_id) not in _batch_task_screenshots:
|
||
return False
|
||
_batch_task_screenshots[str(batch_id)].update(dict(updates or {}))
|
||
return True
|
||
|
||
|
||
def safe_pop_batch(batch_id: str) -> Optional[Dict[str, Any]]:
|
||
with _batch_task_lock:
|
||
return _batch_task_screenshots.pop(str(batch_id), None)
|
||
|
||
|
||
def safe_batch_append_result(batch_id: str, result: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
now_ts = time.time()
|
||
with _batch_task_lock:
|
||
info = _batch_task_screenshots.get(str(batch_id))
|
||
if not info:
|
||
return None
|
||
info.setdefault("screenshots", []).append(dict(result or {}))
|
||
info["completed"] = int(info.get("completed", 0) or 0) + 1
|
||
info["updated_at"] = now_ts
|
||
total = int(info.get("total_accounts", 0) or 0)
|
||
if total > 0 and int(info.get("completed", 0) or 0) >= total:
|
||
return _batch_task_screenshots.pop(str(batch_id), None)
|
||
return None
|
||
|
||
|
||
def safe_cleanup_expired_batches(expire_seconds: int, now_ts: Optional[float] = None) -> int:
|
||
now_ts = float(now_ts if now_ts is not None else time.time())
|
||
expire_seconds = max(1, int(expire_seconds))
|
||
with _batch_task_lock:
|
||
expired = []
|
||
for batch_id, info in list(_batch_task_screenshots.items()):
|
||
last_ts = info.get("updated_at") or info.get("created_at") or info.get("created_time") or now_ts
|
||
if (now_ts - float(last_ts)) > expire_seconds:
|
||
expired.append(batch_id)
|
||
for batch_id in expired:
|
||
_batch_task_screenshots.pop(batch_id, None)
|
||
return len(expired)
|
||
|
||
|
||
def safe_finalize_batch_after_dispatch(batch_id: str, total_accounts: int, *, now_ts: Optional[float] = None) -> Optional[Dict[str, Any]]:
|
||
"""定时批次任务:更新总账号数,并在“已完成>=总数”时弹出批次数据用于发邮件。"""
|
||
now_ts = float(now_ts if now_ts is not None else time.time())
|
||
with _batch_task_lock:
|
||
info = _batch_task_screenshots.get(str(batch_id))
|
||
if not info:
|
||
return None
|
||
info["total_accounts"] = int(total_accounts or 0)
|
||
info["updated_at"] = now_ts
|
||
|
||
if int(total_accounts or 0) <= 0:
|
||
_batch_task_screenshots.pop(str(batch_id), None)
|
||
return None
|
||
|
||
if int(info.get("completed", 0) or 0) >= int(total_accounts):
|
||
return _batch_task_screenshots.pop(str(batch_id), None)
|
||
return None
|
||
|
||
|
||
# ==================== Pending random schedules(兼容旧随机延迟逻辑) ====================
|
||
|
||
_pending_random_schedules: Dict[int, Dict[str, Any]] = {}
|
||
_pending_random_lock = threading.RLock()
|
||
|
||
|
||
def safe_set_pending_random_schedule(schedule_id: int, info: Dict[str, Any]) -> None:
|
||
with _pending_random_lock:
|
||
_pending_random_schedules[int(schedule_id)] = dict(info or {})
|
||
|
||
|
||
def safe_get_pending_random_schedule(schedule_id: int) -> Optional[Dict[str, Any]]:
|
||
with _pending_random_lock:
|
||
value = _pending_random_schedules.get(int(schedule_id))
|
||
return dict(value) if value else None
|
||
|
||
|
||
def safe_pop_pending_random_schedule(schedule_id: int) -> Optional[Dict[str, Any]]:
|
||
with _pending_random_lock:
|
||
return _pending_random_schedules.pop(int(schedule_id), None)
|
||
|
||
|
||
def safe_iter_pending_random_schedules_items() -> List[Tuple[int, Dict[str, Any]]]:
|
||
with _pending_random_lock:
|
||
return [(sid, dict(info)) for sid, info in _pending_random_schedules.items()]
|
||
|
||
|
||
def safe_cleanup_expired_pending_random(expire_seconds: int, now_ts: Optional[float] = None) -> int:
|
||
now_ts = float(now_ts if now_ts is not None else time.time())
|
||
expire_seconds = max(1, int(expire_seconds))
|
||
with _pending_random_lock:
|
||
expired = []
|
||
for schedule_id, info in list(_pending_random_schedules.items()):
|
||
created_at = info.get("created_at") or info.get("created_time") or now_ts
|
||
if (now_ts - float(created_at)) > expire_seconds:
|
||
expired.append(schedule_id)
|
||
for schedule_id in expired:
|
||
_pending_random_schedules.pop(int(schedule_id), None)
|
||
return len(expired)
|