安全修复:加固CSRF与凭证保护并修复越权风险
This commit is contained in:
@@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
import ipaddress
|
||||
import socket
|
||||
@@ -78,7 +79,13 @@ def sanitize_filename(filename):
|
||||
class IPRateLimiter:
|
||||
"""IP访问频率限制器"""
|
||||
|
||||
def __init__(self, max_attempts=10, window_seconds=3600, lock_duration=3600):
|
||||
def __init__(
|
||||
self,
|
||||
max_attempts=10,
|
||||
window_seconds=3600,
|
||||
lock_duration=3600,
|
||||
max_tracked_ips=20000,
|
||||
):
|
||||
"""
|
||||
初始化限流器
|
||||
|
||||
@@ -90,6 +97,7 @@ class IPRateLimiter:
|
||||
self.max_attempts = max_attempts
|
||||
self.window_seconds = window_seconds
|
||||
self.lock_duration = lock_duration
|
||||
self.max_tracked_ips = max(1000, int(max_tracked_ips or 0))
|
||||
|
||||
# IP访问记录: {ip: [(timestamp, success), ...]}
|
||||
self._attempts = defaultdict(list)
|
||||
@@ -97,6 +105,47 @@ class IPRateLimiter:
|
||||
self._locked = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _prune_if_oversized(self, now_ts: float) -> None:
|
||||
"""限制内部映射大小,避免在高频随机IP攻击下持续膨胀。"""
|
||||
tracked = len(self._attempts) + len(self._locked)
|
||||
if tracked <= self.max_tracked_ips:
|
||||
return
|
||||
|
||||
cutoff_time = now_ts - self.window_seconds
|
||||
for ip in list(self._attempts.keys()):
|
||||
self._attempts[ip] = [
|
||||
(ts, succ) for ts, succ in self._attempts[ip]
|
||||
if ts > cutoff_time
|
||||
]
|
||||
if not self._attempts[ip]:
|
||||
del self._attempts[ip]
|
||||
|
||||
for ip in list(self._locked.keys()):
|
||||
if now_ts >= self._locked[ip]:
|
||||
del self._locked[ip]
|
||||
|
||||
tracked = len(self._attempts) + len(self._locked)
|
||||
if tracked <= self.max_tracked_ips:
|
||||
return
|
||||
|
||||
# 优先按“最近访问时间最早”淘汰 attempts 中的 IP 记录。
|
||||
overflow = tracked - self.max_tracked_ips
|
||||
oldest = []
|
||||
for ip, attempt_items in self._attempts.items():
|
||||
if attempt_items:
|
||||
oldest.append((attempt_items[-1][0], ip))
|
||||
else:
|
||||
oldest.append((0.0, ip))
|
||||
oldest.sort(key=lambda item: item[0])
|
||||
|
||||
removed = 0
|
||||
for _, ip in oldest:
|
||||
self._attempts.pop(ip, None)
|
||||
self._locked.pop(ip, None)
|
||||
removed += 1
|
||||
if removed >= overflow:
|
||||
break
|
||||
|
||||
def is_locked(self, ip_address):
|
||||
"""
|
||||
检查IP是否被锁定
|
||||
@@ -129,6 +178,7 @@ class IPRateLimiter:
|
||||
"""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
self._prune_if_oversized(now)
|
||||
|
||||
# 清理过期记录
|
||||
cutoff_time = now - self.window_seconds
|
||||
@@ -357,7 +407,19 @@ def generate_csrf_token():
|
||||
|
||||
def validate_csrf_token(token):
|
||||
"""验证CSRF令牌"""
|
||||
return token == session.get('csrf_token')
|
||||
expected = session.get("csrf_token")
|
||||
if (token is None) or (expected is None):
|
||||
return False
|
||||
|
||||
provided_text = str(token or "")
|
||||
expected_text = str(expected or "")
|
||||
if (not provided_text) or (not expected_text):
|
||||
return False
|
||||
|
||||
return hmac.compare_digest(
|
||||
provided_text.encode("utf-8"),
|
||||
expected_text.encode("utf-8"),
|
||||
)
|
||||
|
||||
|
||||
# ==================== 内容安全 ====================
|
||||
|
||||
Reference in New Issue
Block a user