安全修复:加固CSRF与凭证保护并修复越权风险

This commit is contained in:
2026-02-16 01:19:43 +08:00
parent 14b506e8a1
commit 1389ec7434
22 changed files with 375 additions and 83 deletions

View File

@@ -9,6 +9,7 @@ import os
import re
import time
import hashlib
import hmac
import secrets
import ipaddress
import socket
@@ -78,7 +79,13 @@ def sanitize_filename(filename):
class IPRateLimiter:
"""IP访问频率限制器"""
def __init__(self, max_attempts=10, window_seconds=3600, lock_duration=3600):
def __init__(
self,
max_attempts=10,
window_seconds=3600,
lock_duration=3600,
max_tracked_ips=20000,
):
"""
初始化限流器
@@ -90,6 +97,7 @@ class IPRateLimiter:
self.max_attempts = max_attempts
self.window_seconds = window_seconds
self.lock_duration = lock_duration
self.max_tracked_ips = max(1000, int(max_tracked_ips or 0))
# IP访问记录: {ip: [(timestamp, success), ...]}
self._attempts = defaultdict(list)
@@ -97,6 +105,47 @@ class IPRateLimiter:
self._locked = {}
self._lock = threading.Lock()
def _prune_if_oversized(self, now_ts: float) -> None:
"""限制内部映射大小避免在高频随机IP攻击下持续膨胀。"""
tracked = len(self._attempts) + len(self._locked)
if tracked <= self.max_tracked_ips:
return
cutoff_time = now_ts - self.window_seconds
for ip in list(self._attempts.keys()):
self._attempts[ip] = [
(ts, succ) for ts, succ in self._attempts[ip]
if ts > cutoff_time
]
if not self._attempts[ip]:
del self._attempts[ip]
for ip in list(self._locked.keys()):
if now_ts >= self._locked[ip]:
del self._locked[ip]
tracked = len(self._attempts) + len(self._locked)
if tracked <= self.max_tracked_ips:
return
# 优先按“最近访问时间最早”淘汰 attempts 中的 IP 记录。
overflow = tracked - self.max_tracked_ips
oldest = []
for ip, attempt_items in self._attempts.items():
if attempt_items:
oldest.append((attempt_items[-1][0], ip))
else:
oldest.append((0.0, ip))
oldest.sort(key=lambda item: item[0])
removed = 0
for _, ip in oldest:
self._attempts.pop(ip, None)
self._locked.pop(ip, None)
removed += 1
if removed >= overflow:
break
def is_locked(self, ip_address):
"""
检查IP是否被锁定
@@ -129,6 +178,7 @@ class IPRateLimiter:
"""
with self._lock:
now = time.time()
self._prune_if_oversized(now)
# 清理过期记录
cutoff_time = now - self.window_seconds
@@ -357,7 +407,19 @@ def generate_csrf_token():
def validate_csrf_token(token):
"""验证CSRF令牌"""
return token == session.get('csrf_token')
expected = session.get("csrf_token")
if (token is None) or (expected is None):
return False
provided_text = str(token or "")
expected_text = str(expected or "")
if (not provided_text) or (not expected_text):
return False
return hmac.compare_digest(
provided_text.encode("utf-8"),
expected_text.encode("utf-8"),
)
# ==================== 内容安全 ====================