refactor: optimize structure, stability and runtime performance

This commit is contained in:
2026-02-07 00:35:11 +08:00
parent fae21329d7
commit bf29ac1924
44 changed files with 6894 additions and 4792 deletions

View File

@@ -13,7 +13,7 @@ from __future__ import annotations
import threading
import time
import random
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from app_config import get_config
@@ -161,6 +161,36 @@ _log_cache_lock = threading.RLock()
_log_cache_total_count = 0
def _pop_oldest_log_for_user(uid: int) -> bool:
global _log_cache_total_count
logs = _log_cache.get(uid)
if not logs:
_log_cache.pop(uid, None)
return False
logs.pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
if not logs:
_log_cache.pop(uid, None)
return True
def _pop_oldest_log_from_largest_user() -> bool:
largest_uid = None
largest_size = 0
for uid, logs in _log_cache.items():
size = len(logs)
if size > largest_size:
largest_uid = uid
largest_size = size
if largest_uid is None or largest_size <= 0:
return False
return _pop_oldest_log_for_user(int(largest_uid))
def safe_add_log(
user_id: int,
log_entry: Dict[str, Any],
@@ -175,24 +205,17 @@ def safe_add_log(
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
with _log_cache_lock:
if uid not in _log_cache:
_log_cache[uid] = []
logs = _log_cache.setdefault(uid, [])
if len(_log_cache[uid]) >= max_logs_per_user:
_log_cache[uid].pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
if len(logs) >= max_logs_per_user:
_pop_oldest_log_for_user(uid)
logs = _log_cache.setdefault(uid, [])
_log_cache[uid].append(dict(log_entry or {}))
logs.append(dict(log_entry or {}))
_log_cache_total_count += 1
while _log_cache_total_count > max_total_logs:
if not _log_cache:
break
max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
if _log_cache.get(max_user):
_log_cache[max_user].pop(0)
_log_cache_total_count -= 1
else:
if not _pop_oldest_log_from_largest_user():
break
@@ -378,6 +401,34 @@ def _get_action_rate_limit(action: str) -> Tuple[int, int]:
return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS)
def _format_wait_hint(remaining_seconds: int) -> str:
remaining = max(1, int(remaining_seconds or 0))
if remaining >= 60:
return f"{remaining // 60 + 1}分钟"
return f"{remaining}"
def _check_and_increment_rate_bucket(
*,
buckets: Dict[str, Dict[str, Any]],
key: str,
now_ts: float,
max_requests: int,
window_seconds: int,
) -> Tuple[bool, Optional[str]]:
if not key or int(max_requests) <= 0:
return True, None
data = _get_or_reset_bucket(buckets.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= int(max_requests):
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
return False, f"请求过于频繁,请{_format_wait_hint(remaining)}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
buckets[key] = data
return True, None
def check_ip_request_rate(
ip_address: str,
action: str,
@@ -392,21 +443,13 @@ def check_ip_request_rate(
key = f"{action}:{ip_address}"
with _ip_request_rate_lock:
data = _ip_request_rate.get(key)
if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds:
data = {"window_start": now_ts, "count": 0}
_ip_request_rate[key] = data
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
if remaining >= 60:
wait_hint = f"{remaining // 60 + 1}分钟"
else:
wait_hint = f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
return True, None
return _check_and_increment_rate_bucket(
buckets=_ip_request_rate,
key=key,
now_ts=now_ts,
max_requests=max_requests,
window_seconds=window_seconds,
)
def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
@@ -417,8 +460,7 @@ def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
data = _ip_request_rate.get(key) or {}
action = key.split(":", 1)[0]
_, window_seconds = _get_action_rate_limit(action)
window_start = float(data.get("window_start", 0) or 0)
if now_ts - window_start >= window_seconds:
if _is_bucket_expired(data, now_ts, window_seconds):
_ip_request_rate.pop(key, None)
removed += 1
return removed
@@ -487,6 +529,30 @@ def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_s
return data
def _is_bucket_expired(
data: Optional[Dict[str, Any]],
now_ts: float,
window_seconds: int,
*,
time_field: str = "window_start",
) -> bool:
start_ts = float((data or {}).get(time_field, 0) or 0)
return (now_ts - start_ts) >= max(1, int(window_seconds))
def _cleanup_map_entries(
store: Dict[Any, Dict[str, Any]],
should_remove: Callable[[Dict[str, Any]], bool],
) -> int:
removed = 0
for key, value in list(store.items()):
item = value if isinstance(value, dict) else {}
if should_remove(item):
store.pop(key, None)
removed += 1
return removed
def record_login_username_attempt(ip_address: str, username: str) -> bool:
now_ts = time.time()
threshold, window_seconds, cooldown_seconds = _get_login_scan_config()
@@ -527,26 +593,32 @@ def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optio
user_key = _normalize_login_key("user", "", username)
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]:
if not key or max_requests <= 0:
return True, None
data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_login_rate_limits[key] = data
return True, None
with _login_rate_limits_lock:
allowed, msg = _check(ip_key, ip_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=ip_key,
now_ts=now_ts,
max_requests=ip_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
allowed, msg = _check(ip_user_key, ip_user_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=ip_user_key,
now_ts=now_ts,
max_requests=ip_user_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
allowed, msg = _check(user_key, user_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=user_key,
now_ts=now_ts,
max_requests=user_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
@@ -622,15 +694,18 @@ def check_login_captcha_required(ip_address: str, username: Optional[str] = None
ip_key = _normalize_login_key("ip", ip_address)
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
def _is_over_threshold(data: Optional[Dict[str, Any]]) -> bool:
if not data:
return False
if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
return False
return int(data.get("count", 0) or 0) >= max_failures
with _login_failures_lock:
ip_data = _login_failures.get(ip_key)
if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_data.get("count", 0) or 0) >= max_failures:
return True
ip_user_data = _login_failures.get(ip_user_key)
if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_user_data.get("count", 0) or 0) >= max_failures:
return True
if _is_over_threshold(_login_failures.get(ip_key)):
return True
if _is_over_threshold(_login_failures.get(ip_user_key)):
return True
if is_login_scan_locked(ip_address):
return True
@@ -685,6 +760,56 @@ def should_send_login_alert(user_id: int, ip_address: str) -> bool:
return False
def cleanup_expired_login_security_state(now_ts: Optional[float] = None) -> Dict[str, int]:
now_ts = float(now_ts if now_ts is not None else time.time())
_, captcha_window = _get_login_captcha_config()
_, _, _, rate_window = _get_login_rate_limit_config()
_, lock_window, _ = _get_login_lock_config()
_, scan_window, _ = _get_login_scan_config()
alert_expire_seconds = max(3600, int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS) * 3)
with _login_failures_lock:
failures_removed = _cleanup_map_entries(
_login_failures,
lambda data: (now_ts - float(data.get("first_failed", 0) or 0)) > max(captcha_window, lock_window),
)
with _login_rate_limits_lock:
rate_removed = _cleanup_map_entries(
_login_rate_limits,
lambda data: _is_bucket_expired(data, now_ts, rate_window),
)
with _login_scan_lock:
scan_removed = _cleanup_map_entries(
_login_scan_state,
lambda data: (
(now_ts - float(data.get("first_seen", 0) or 0)) > scan_window
and now_ts >= float(data.get("scan_until", 0) or 0)
),
)
with _login_ip_user_lock:
ip_user_locks_removed = _cleanup_map_entries(
_login_ip_user_locks,
lambda data: now_ts >= float(data.get("lock_until", 0) or 0),
)
with _login_alert_lock:
alerts_removed = _cleanup_map_entries(
_login_alert_state,
lambda data: (now_ts - float(data.get("last_sent", 0) or 0)) > alert_expire_seconds,
)
return {
"failures": failures_removed,
"rate_limits": rate_removed,
"scan_states": scan_removed,
"ip_user_locks": ip_user_locks_removed,
"alerts": alerts_removed,
}
# ==================== 邮箱维度限流 ====================
_email_rate_limit: Dict[str, Dict[str, Any]] = {}
@@ -701,14 +826,13 @@ def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str]
key = f"{action}:{email_key}"
with _email_rate_limit_lock:
data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_email_rate_limit[key] = data
return True, None
return _check_and_increment_rate_bucket(
buckets=_email_rate_limit,
key=key,
now_ts=now_ts,
max_requests=max_requests,
window_seconds=window_seconds,
)
# ==================== Batch screenshots批次任务截图收集 ====================