feat: 实现完整安全防护系统

Phase 1 - 威胁检测引擎:
- security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测
- security/constants.py: 威胁检测规则和评分常量
- 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist

Phase 2 - 风险评分与黑名单:
- security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减
- security/blacklist.py: 黑名单管理,自动封禁规则

Phase 3 - 响应策略:
- security/honeypot.py: 蜜罐响应生成器
- security/response_handler.py: 渐进式响应策略

Phase 4 - 集成:
- security/middleware.py: Flask安全中间件
- routes/admin_api/security.py: 管理后台安全仪表板API
- 36个测试用例全部通过

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-27 01:28:38 +08:00
parent e3b0c35da6
commit 46253337eb
24 changed files with 3219 additions and 4 deletions

22
security/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from security.blacklist import BlacklistManager
from security.honeypot import HoneypotResponder
from security.middleware import init_security_middleware
from security.response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from security.risk_scorer import RiskScorer
from security.threat_detector import ThreatDetector, ThreatResult
__all__ = [
"BlacklistManager",
"HoneypotResponder",
"init_security_middleware",
"ResponseAction",
"ResponseHandler",
"ResponseStrategy",
"RiskScorer",
"ThreatDetector",
"ThreatResult",
]

255
security/blacklist.py Normal file
View File

@@ -0,0 +1,255 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import threading
from datetime import timedelta
from typing import List, Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str
class BlacklistManager:
"""黑名单管理器"""
def __init__(self) -> None:
self._schema_ready = False
self._schema_lock = threading.Lock()
def is_ip_banned(self, ip: str) -> bool:
"""检查IP是否被封禁"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM ip_blacklist
WHERE ip = ?
AND is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
LIMIT 1
""",
(ip_text, now_str),
)
return cursor.fetchone() is not None
def is_user_banned(self, user_id: int) -> bool:
"""检查用户是否被封禁"""
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM user_blacklist
WHERE user_id = ?
AND is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
LIMIT 1
""",
(user_id_int, now_str),
)
return cursor.fetchone() is not None
def ban_ip(self, ip: str, reason: str, duration_hours: int = 24, permanent: bool = False):
"""封禁IP"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
reason_text = str(reason or "").strip()[:512]
now_str = get_cst_now_str()
expires_at: Optional[str]
if permanent:
expires_at = None
else:
hours = max(1, int(duration_hours))
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
reason = excluded.reason,
is_active = 1,
added_at = excluded.added_at,
expires_at = excluded.expires_at
""",
(ip_text, reason_text, now_str, expires_at),
)
conn.commit()
return True
def ban_user(self, user_id: int, reason: str):
"""封禁用户"""
return self._ban_user_internal(user_id, reason=reason, duration_hours=24, permanent=False)
def unban_ip(self, ip: str):
"""解除IP封禁"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE ip_blacklist SET is_active = 0 WHERE ip = ?", (ip_text,))
conn.commit()
return cursor.rowcount > 0
def unban_user(self, user_id: int):
"""解除用户封禁"""
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE user_blacklist SET is_active = 0 WHERE user_id = ?", (user_id_int,))
conn.commit()
return cursor.rowcount > 0
def get_banned_ips(self) -> List[dict]:
"""获取所有被封禁的IP"""
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT ip, reason, is_active, added_at, expires_at
FROM ip_blacklist
WHERE is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
ORDER BY added_at DESC
""",
(now_str,),
)
return [dict(row) for row in cursor.fetchall()]
def get_banned_users(self) -> List[dict]:
"""获取所有被封禁的用户"""
self._ensure_schema()
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT user_id, reason, is_active, added_at, expires_at
FROM user_blacklist
WHERE is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
ORDER BY added_at DESC
""",
(now_str,),
)
return [dict(row) for row in cursor.fetchall()]
def cleanup_expired(self):
"""清理过期的封禁记录"""
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE ip_blacklist
SET is_active = 0
WHERE is_active = 1
AND expires_at IS NOT NULL
AND expires_at <= ?
""",
(now_str,),
)
conn.commit()
self._ensure_schema()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE user_blacklist
SET is_active = 0
WHERE is_active = 1
AND expires_at IS NOT NULL
AND expires_at <= ?
""",
(now_str,),
)
conn.commit()
# ==================== Internal ====================
def _ensure_schema(self) -> None:
if self._schema_ready:
return
with self._schema_lock:
if self._schema_ready:
return
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
conn.commit()
self._schema_ready = True
def _ban_user_internal(
self,
user_id: int,
*,
reason: str,
duration_hours: int = 24,
permanent: bool = False,
) -> bool:
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
reason_text = str(reason or "").strip()[:512]
now_str = get_cst_now_str()
expires_at: Optional[str]
if permanent:
expires_at = None
else:
hours = max(1, int(duration_hours))
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO user_blacklist (user_id, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
reason = excluded.reason,
is_active = 1,
added_at = excluded.added_at,
expires_at = excluded.expires_at
""",
(user_id_int, reason_text, now_str, expires_at),
)
conn.commit()
return True

97
security/constants.py Normal file
View File

@@ -0,0 +1,97 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
# ==================== Threat Types ====================
THREAT_TYPE_JNDI_INJECTION = "jndi_injection"
THREAT_TYPE_NESTED_EXPRESSION = "nested_expression"
THREAT_TYPE_SQL_INJECTION = "sql_injection"
THREAT_TYPE_XSS = "xss"
THREAT_TYPE_PATH_TRAVERSAL = "path_traversal"
THREAT_TYPE_COMMAND_INJECTION = "command_injection"
# ==================== Scores ====================
SCORE_JNDI_DIRECT = 100
SCORE_JNDI_OBFUSCATED = 100
SCORE_NESTED_EXPRESSION = 80
SCORE_SQL_INJECTION = 90
SCORE_XSS = 70
SCORE_PATH_TRAVERSAL = 60
SCORE_COMMAND_INJECTION = 85
# ==================== JNDI (Log4j) ====================
#
# - Direct: ${jndi:ldap://...} / ${jndi:rmi://...} => 100
# - Obfuscated: ${${xxx:-j}${xxx:-n}...:ldap://...} => detect
# - Nested expression: ${${...}} => 80
JNDI_DIRECT_PATTERN = r"\$\{\s*jndi\s*:\s*(?:ldap|rmi)\s*://"
# Common Log4j "default value" obfuscation variants:
# ${${::-j}${::-n}${::-d}${::-i}:ldap://...}
# ${${foo:-j}${bar:-n}${baz:-d}${qux:-i}:rmi://...}
JNDI_OBFUSCATED_PATTERN = (
r"\$\{\s*"
r"(?:\$\{[^{}]{0,50}:-j\}|\$\{::-[jJ]\})\s*"
r"(?:\$\{[^{}]{0,50}:-n\}|\$\{::-[nN]\})\s*"
r"(?:\$\{[^{}]{0,50}:-d\}|\$\{::-[dD]\})\s*"
r"(?:\$\{[^{}]{0,50}:-i\}|\$\{::-[iI]\})\s*"
r":\s*(?:ldap|rmi)\s*://"
)
NESTED_EXPRESSION_PATTERN = r"\$\{\s*\$\{"
# ==================== SQL Injection ====================
SQLI_UNION_SELECT_PATTERN = r"\bunion\b\s+(?:all\s+)?\bselect\b"
SQLI_OR_1_EQ_1_PATTERN = r"\bor\b\s+1\s*=\s*1\b"
# ==================== XSS ====================
XSS_SCRIPT_TAG_PATTERN = r"<\s*script\b"
XSS_JS_PROTOCOL_PATTERN = r"javascript\s*:"
XSS_INLINE_EVENT_HANDLER_PATTERN = r"\bon\w+\s*="
# ==================== Path Traversal ====================
PATH_TRAVERSAL_PATTERN = r"(?:\.\./|\.\.\\)+"
# ==================== Command Injection ====================
CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN = (
r"(?:;|&&|\|\||\|)\s*"
r"(?:bash|sh|zsh|cmd|powershell|pwsh|curl|wget|nc|netcat|python|perl|ruby|php|node|cat|ls|id|whoami|uname|rm)\b"
)
CMD_INJECTION_SUBSHELL_PATTERN = r"(?:`[^`]{1,200}`|\$\([^)]{1,200}\))"
# ==================== Compiled Regex ====================
_FLAGS = re.IGNORECASE | re.MULTILINE
JNDI_DIRECT_RE = re.compile(JNDI_DIRECT_PATTERN, _FLAGS)
JNDI_OBFUSCATED_RE = re.compile(JNDI_OBFUSCATED_PATTERN, _FLAGS)
NESTED_EXPRESSION_RE = re.compile(NESTED_EXPRESSION_PATTERN, _FLAGS)
SQLI_UNION_SELECT_RE = re.compile(SQLI_UNION_SELECT_PATTERN, _FLAGS)
SQLI_OR_1_EQ_1_RE = re.compile(SQLI_OR_1_EQ_1_PATTERN, _FLAGS)
XSS_SCRIPT_TAG_RE = re.compile(XSS_SCRIPT_TAG_PATTERN, _FLAGS)
XSS_JS_PROTOCOL_RE = re.compile(XSS_JS_PROTOCOL_PATTERN, _FLAGS)
XSS_INLINE_EVENT_HANDLER_RE = re.compile(XSS_INLINE_EVENT_HANDLER_PATTERN, _FLAGS)
PATH_TRAVERSAL_RE = re.compile(PATH_TRAVERSAL_PATTERN, _FLAGS)
CMD_INJECTION_OPERATOR_WITH_CMD_RE = re.compile(CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN, _FLAGS)
CMD_INJECTION_SUBSHELL_RE = re.compile(CMD_INJECTION_SUBSHELL_PATTERN, _FLAGS)

126
security/honeypot.py Normal file
View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import random
import uuid
from typing import Any, Optional
from app_logger import get_logger
class HoneypotResponder:
"""蜜罐响应生成器 - 返回假成功响应,欺骗攻击者"""
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
self._rng = rng or random.SystemRandom()
self._logger = get_logger("app")
def generate_fake_response(self, endpoint: str, original_data: dict = None) -> dict:
"""
根据端点生成假的成功响应
策略:
- 邮件发送类: {"success": True, "message": "邮件已发送"}
- 注册类: {"success": True, "user_id": fake_uuid}
- 登录类: {"success": True} 但不设置session
- 通用: {"success": True, "message": "操作成功"}
"""
endpoint_text = str(endpoint or "").strip()
endpoint_lc = endpoint_text.lower()
category = self._classify_endpoint(endpoint_lc)
response: dict[str, Any] = {"success": True}
if category == "email":
response["message"] = "邮件已发送"
elif category == "register":
response["user_id"] = str(uuid.uuid4())
elif category == "login":
# 登录类:保持正常成功响应,但不进行任何 session / token 设置(调用方负责不写 session
pass
else:
response["message"] = "操作成功"
response = self._merge_safe_fields(response, original_data)
self._logger.warning(
"蜜罐响应已生成: endpoint=%s, category=%s, keys=%s",
endpoint_text[:256],
category,
sorted(response.keys()),
)
return response
def should_use_honeypot(self, risk_score: int) -> bool:
"""风险分>=80使用蜜罐响应"""
score = self._normalize_risk_score(risk_score)
use = score >= 80
self._logger.debug("蜜罐判定: risk_score=%s => %s", score, use)
return use
def delay_response(self, risk_score: int) -> float:
"""
根据风险分计算延迟时间
0-20: 0秒
21-50: 随机0.5-1秒
51-80: 随机1-3秒
81-100: 随机3-8秒蜜罐模式额外延迟消耗攻击者时间
"""
score = self._normalize_risk_score(risk_score)
delay = 0.0
if score <= 20:
delay = 0.0
elif score <= 50:
delay = float(self._rng.uniform(0.5, 1.0))
elif score <= 80:
delay = float(self._rng.uniform(1.0, 3.0))
else:
delay = float(self._rng.uniform(3.0, 8.0))
self._logger.debug("蜜罐延迟计算: risk_score=%s => delay_seconds=%.3f", score, delay)
return delay
# ==================== Internal ====================
def _normalize_risk_score(self, risk_score: Any) -> int:
try:
score = int(risk_score)
except Exception:
score = 0
return max(0, min(100, score))
def _classify_endpoint(self, endpoint_lc: str) -> str:
if not endpoint_lc:
return "generic"
# 先匹配更具体的:注册 / 登录
if any(k in endpoint_lc for k in ["/register", "register", "signup", "sign-up"]):
return "register"
if any(k in endpoint_lc for k in ["/login", "login", "signin", "sign-in"]):
return "login"
# 邮件相关:发送验证码 / 重置密码 / 重发验证等
if any(k in endpoint_lc for k in ["email", "mail", "forgot-password", "reset-password", "resend-verify"]):
return "email"
return "generic"
def _merge_safe_fields(self, base: dict, original_data: Optional[dict]) -> dict:
if not isinstance(original_data, dict) or not original_data:
return base
# 避免把攻击者输入或真实业务结果回显得太明显;仅合并少量“形状字段”
safe_bool_keys = {"need_verify", "need_captcha"}
merged = dict(base)
for key in safe_bool_keys:
if key in original_data and key not in merged:
try:
merged[key] = bool(original_data.get(key))
except Exception:
continue
return merged

307
security/middleware.py Normal file
View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import logging
from typing import Optional
from flask import g, jsonify, request
from flask_login import current_user
from app_logger import get_logger
from app_security import get_rate_limit_ip
from .blacklist import BlacklistManager
from .honeypot import HoneypotResponder
from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from .risk_scorer import RiskScorer
from .threat_detector import ThreatDetector, ThreatResult
# 全局实例(保持单例,避免重复初始化开销)
detector = ThreatDetector()
blacklist = BlacklistManager()
scorer = RiskScorer(blacklist_manager=blacklist)
handler: Optional[ResponseHandler] = None
honeypot: Optional[HoneypotResponder] = None
def _get_handler() -> ResponseHandler:
global handler
if handler is None:
handler = ResponseHandler()
return handler
def _get_honeypot() -> HoneypotResponder:
global honeypot
if honeypot is None:
honeypot = HoneypotResponder()
return honeypot
def _get_security_log_level(app) -> int:
level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper()
return int(getattr(logging, level_name, logging.INFO))
def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None:
"""按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。"""
try:
logger = get_logger("app")
min_level = _get_security_log_level(app)
if int(level) >= int(min_level):
logger.log(int(level), message, *args, exc_info=exc_info)
except Exception:
# 安全模块日志故障不得影响正常请求
return
def _is_static_request(app) -> bool:
"""对静态文件请求跳过安全检查以提升性能。"""
try:
path = str(getattr(request, "path", "") or "")
except Exception:
path = ""
if path.startswith("/static/"):
return True
try:
static_url_path = getattr(app, "static_url_path", None) or "/static"
if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"):
return True
except Exception:
pass
try:
endpoint = getattr(request, "endpoint", None)
if endpoint in {"static", "serve_static"}:
return True
except Exception:
pass
return False
def _safe_get_user_id() -> Optional[int]:
try:
if hasattr(current_user, "is_authenticated") and current_user.is_authenticated:
return getattr(current_user, "id", None)
except Exception:
return None
return None
def _scan_request_threats(req) -> list[ThreatResult]:
"""仅扫描 GET query 与 POST JSON body降低开销与误报"""
threats: list[ThreatResult] = []
try:
# 1) Query 参数(所有方法均可能携带 query string
try:
args = getattr(req, "args", None)
if args:
# MultiDict -> dict(list) 以保留多值
args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args)
threats.extend(detector.scan_input(args_dict, "args"))
except Exception:
pass
# 2) JSON body主要针对 POST其他方法保持兼容
try:
method = str(getattr(req, "method", "") or "").upper()
except Exception:
method = ""
if method in {"POST", "PUT", "PATCH", "DELETE"}:
try:
data = req.get_json(silent=True) if hasattr(req, "get_json") else None
except Exception:
data = None
if data is not None:
threats.extend(detector.scan_input(data, "json"))
except Exception:
# 扫描失败不应阻断业务
return []
threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True)
return threats
def init_security_middleware(app):
"""初始化安全中间件到 Flask 应用。"""
try:
scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True))
except Exception:
pass
@app.before_request
def security_check():
if not bool(app.config.get("SECURITY_ENABLED", True)):
return None
if _is_static_request(app):
return None
try:
ip = get_rate_limit_ip()
except Exception:
ip = getattr(request, "remote_addr", "") or ""
user_id = _safe_get_user_id()
# 默认值,确保后续逻辑可用
g.risk_score = 0
g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.honeypot_mode = False
g.honeypot_endpoint = None
g.honeypot_generated = False
try:
# 1) 检查黑名单(静默拒绝,返回通用错误)
try:
if blacklist.is_ip_banned(ip):
_log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True)
try:
if user_id is not None and blacklist.is_user_banned(user_id):
_log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True)
# 2) 扫描威胁GET query / POST JSON
threats = _scan_request_threats(request)
if threats:
max_threat = threats[0]
_log(
app,
logging.WARNING,
"威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s",
ip,
user_id,
getattr(max_threat, "threat_type", "unknown"),
getattr(max_threat, "score", 0),
getattr(max_threat, "field_name", ""),
getattr(max_threat, "rule", ""),
)
# 记录威胁事件(异常不应阻断业务)
try:
payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or ""
scorer.record_threat(
ip=ip,
user_id=user_id,
threat_type=getattr(max_threat, "threat_type", "unknown"),
score=int(getattr(max_threat, "score", 0) or 0),
request_path=getattr(request, "path", None),
payload=str(payload)[:500] if payload else None,
)
except Exception:
_log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
# 高危威胁启用蜜罐模式
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if int(getattr(max_threat, "score", 0) or 0) >= 80:
g.honeypot_mode = True
g.honeypot_endpoint = getattr(request, "endpoint", None)
except Exception:
pass
# 3) 综合风险分与响应策略
try:
risk_score = scorer.get_combined_score(ip, user_id)
except Exception:
_log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
risk_score = 0
try:
strategy = _get_handler().get_strategy(risk_score)
except Exception:
_log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True)
strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.risk_score = int(risk_score or 0)
g.response_strategy = strategy
# 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略)
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if getattr(strategy, "action", None) == ResponseAction.HONEYPOT:
g.honeypot_mode = True
except Exception:
pass
# 4) 应用延迟
try:
if float(getattr(strategy, "delay_seconds", 0) or 0) > 0:
_get_handler().apply_delay(strategy)
except Exception:
_log(app, logging.ERROR, "延迟应用失败", exc_info=True)
# 优先短路:避免业务 side effects例如发送邮件/修改状态)
if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
g.honeypot_generated = True
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True)
return None
except Exception:
# 全局兜底:安全模块任何异常都不能阻断正常请求
_log(app, logging.ERROR, "安全中间件发生异常", exc_info=True)
return None
return None # 继续正常处理
@app.after_request
def security_response(response):
"""请求后处理 - 兜底应用蜜罐响应。"""
if not bool(app.config.get("SECURITY_ENABLED", True)):
return response
if not bool(app.config.get("HONEYPOT_ENABLED", True)):
return response
try:
if _is_static_request(app):
return response
except Exception:
pass
# 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动
try:
if getattr(g, "honeypot_generated", False):
return response
except Exception:
pass
try:
if getattr(g, "honeypot_mode", False):
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True)
return response
return response

View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import random
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
from app_logger import get_logger
class ResponseAction(Enum):
ALLOW = "allow" # 正常放行
ENHANCE_CAPTCHA = "enhance_captcha" # 增强验证码
DELAY = "delay" # 静默延迟
HONEYPOT = "honeypot" # 蜜罐响应
BLOCK = "block" # 直接拒绝
@dataclass
class ResponseStrategy:
action: ResponseAction
delay_seconds: float = 0
captcha_level: int = 1 # 1=普通4位, 2=6位, 3=滑块
message: str | None = None
class ResponseHandler:
"""响应策略处理器"""
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
self._rng = rng or random.SystemRandom()
self._logger = get_logger("app")
def get_strategy(self, risk_score: int, is_banned: bool = False) -> ResponseStrategy:
"""
根据风险分获取响应策略
0-20分: ALLOW, 无延迟, 普通验证码
21-40分: ALLOW, 无延迟, 6位验证码
41-60分: DELAY, 1-2秒延迟
61-80分: DELAY, 2-5秒延迟
81-100分: HONEYPOT, 3-8秒延迟
已封禁: BLOCK
"""
score = self._normalize_risk_score(risk_score)
if is_banned:
strategy = ResponseStrategy(action=ResponseAction.BLOCK, message="访问被拒绝")
self._logger.warning("响应策略: BLOCK (banned=%s, risk_score=%s)", is_banned, score)
return strategy
if score <= 20:
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=1)
elif score <= 40:
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=2)
elif score <= 60:
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(1.0, 2.0)))
elif score <= 80:
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(2.0, 5.0)))
else:
strategy = ResponseStrategy(action=ResponseAction.HONEYPOT, delay_seconds=float(self._rng.uniform(3.0, 8.0)))
strategy.captcha_level = self._normalize_captcha_level(strategy.captcha_level)
self._logger.info(
"响应策略: action=%s risk_score=%s delay=%.3f captcha_level=%s",
strategy.action.value,
score,
float(strategy.delay_seconds or 0),
int(strategy.captcha_level),
)
return strategy
def apply_delay(self, strategy: ResponseStrategy):
"""应用延迟使用time.sleep"""
if strategy is None:
return
delay = 0.0
try:
delay = float(getattr(strategy, "delay_seconds", 0) or 0)
except Exception:
delay = 0.0
if delay <= 0:
return
self._logger.debug("应用延迟: action=%s delay=%.3f", getattr(strategy.action, "value", strategy.action), delay)
time.sleep(delay)
def get_captcha_requirement(self, strategy: ResponseStrategy) -> dict:
"""返回验证码要求 {"required": True, "level": 2}"""
level = 1
try:
level = int(getattr(strategy, "captcha_level", 1) or 1)
except Exception:
level = 1
level = self._normalize_captcha_level(level)
required = True
try:
required = getattr(strategy, "action", None) != ResponseAction.BLOCK
except Exception:
required = True
payload = {"required": bool(required), "level": level}
self._logger.debug("验证码要求: %s", payload)
return payload
# ==================== Internal ====================
def _normalize_risk_score(self, risk_score: Any) -> int:
try:
score = int(risk_score)
except Exception:
score = 0
return max(0, min(100, score))
def _normalize_captcha_level(self, level: Any) -> int:
try:
i = int(level)
except Exception:
i = 1
if i <= 1:
return 1
if i == 2:
return 2
return 3

362
security/risk_scorer.py Normal file
View File

@@ -0,0 +1,362 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str, parse_cst_datetime
from . import constants as C
from .blacklist import BlacklistManager
@dataclass(frozen=True)
class _ScoreUpdateResult:
ip_score: int
user_score: int
@dataclass(frozen=True)
class _BanAction:
reason: str
duration_hours: Optional[int] = None
permanent: bool = False
class RiskScorer:
"""风险评分引擎 - 计算IP和用户的风险分数"""
def __init__(
self,
*,
auto_ban_enabled: bool = True,
auto_ban_duration_hours: int = 24,
high_risk_threshold: int = 80,
high_risk_window_hours: int = 1,
high_risk_permanent_ban_count: int = 3,
blacklist_manager: Optional[BlacklistManager] = None,
) -> None:
self.auto_ban_enabled = bool(auto_ban_enabled)
self.auto_ban_duration_hours = max(1, int(auto_ban_duration_hours))
self.high_risk_threshold = max(0, int(high_risk_threshold))
self.high_risk_window_hours = max(1, int(high_risk_window_hours))
self.high_risk_permanent_ban_count = max(1, int(high_risk_permanent_ban_count))
self.blacklist = blacklist_manager or BlacklistManager()
def get_ip_score(self, ip_address: str) -> int:
"""获取IP风险分(0-100),从数据库读取"""
ip_text = str(ip_address or "").strip()[:64]
if not ip_text:
return 0
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip_text,))
row = cursor.fetchone()
if not row:
return 0
try:
return max(0, min(100, int(row["risk_score"])))
except Exception:
return 0
def get_user_score(self, user_id: int) -> int:
"""获取用户风险分(0-100)"""
if user_id is None:
return 0
user_id_int = int(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (user_id_int,))
row = cursor.fetchone()
if not row:
return 0
try:
return max(0, min(100, int(row["risk_score"])))
except Exception:
return 0
def get_combined_score(self, ip: str, user_id: int = None) -> int:
"""综合风险分 = max(IP分, 用户分) + 行为加成"""
base = max(self.get_ip_score(ip), self.get_user_score(user_id) if user_id is not None else 0)
bonus = self._get_behavior_bonus(ip, user_id)
return max(0, min(100, int(base + bonus)))
def record_threat(
self,
ip: str,
user_id: int,
threat_type: str,
score: int,
request_path: str = None,
payload: str = None,
):
"""记录威胁事件到数据库并更新IP/用户风险分"""
ip_text = str(ip or "").strip()[:64]
user_id_int = int(user_id) if user_id is not None else None
threat_type_text = str(threat_type or "").strip()[:64] or "unknown"
score_int = max(0, int(score))
path_text = str(request_path or "").strip()[:512] if request_path else None
payload_text = str(payload or "").strip() if payload else None
if payload_text and len(payload_text) > 2048:
payload_text = payload_text[:2048]
now_str = get_cst_now_str()
ip_ban_action: Optional[_BanAction] = None
user_ban_action: Optional[_BanAction] = None
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (
threat_type, score, ip, user_id, request_path, value_preview, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
threat_type_text,
score_int,
ip_text or None,
user_id_int,
path_text,
payload_text,
now_str,
),
)
update = self._update_scores(cursor, ip_text, user_id_int, score_int, now_str)
if self.auto_ban_enabled:
ip_ban_action, user_ban_action = self._get_auto_ban_actions(
cursor,
ip_text,
user_id_int,
threat_type_text,
score_int,
update,
)
conn.commit()
if not self.auto_ban_enabled:
return
if ip_ban_action and ip_text:
self.blacklist.ban_ip(
ip_text,
reason=ip_ban_action.reason,
duration_hours=ip_ban_action.duration_hours or self.auto_ban_duration_hours,
permanent=ip_ban_action.permanent,
)
if user_ban_action and user_id_int is not None:
self.blacklist._ban_user_internal(
user_id_int,
reason=user_ban_action.reason,
duration_hours=user_ban_action.duration_hours or self.auto_ban_duration_hours,
permanent=user_ban_action.permanent,
)
def decay_scores(self):
"""风险分衰减 - 定期调用,降低历史风险分"""
now = get_cst_now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT ip, risk_score, updated_at, created_at FROM ip_risk_scores")
for row in cursor.fetchall():
ip = row["ip"]
current_score = int(row["risk_score"] or 0)
updated_at = row["updated_at"] or row["created_at"]
hours = self._hours_since(updated_at, now)
if hours <= 0:
continue
new_score = self._apply_hourly_decay(current_score, hours)
if new_score == current_score:
continue
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = ?, updated_at = ? WHERE ip = ?",
(new_score, now_str, ip),
)
cursor.execute("SELECT user_id, risk_score, updated_at, created_at FROM user_risk_scores")
for row in cursor.fetchall():
user_id = int(row["user_id"])
current_score = int(row["risk_score"] or 0)
updated_at = row["updated_at"] or row["created_at"]
hours = self._hours_since(updated_at, now)
if hours <= 0:
continue
new_score = self._apply_hourly_decay(current_score, hours)
if new_score == current_score:
continue
cursor.execute(
"UPDATE user_risk_scores SET risk_score = ?, updated_at = ? WHERE user_id = ?",
(new_score, now_str, user_id),
)
conn.commit()
def _update_ip_score(self, ip: str, score_delta: int):
"""更新IP风险分"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return
delta = int(score_delta)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
self._update_scores(cursor, ip_text, None, delta, now_str)
conn.commit()
def _update_user_score(self, user_id: int, score_delta: int):
"""更新用户风险分"""
if user_id is None:
return
user_id_int = int(user_id)
delta = int(score_delta)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
self._update_scores(cursor, "", user_id_int, delta, now_str)
conn.commit()
def _update_scores(
self,
cursor,
ip: str,
user_id: Optional[int],
score_delta: int,
now_str: str,
) -> _ScoreUpdateResult:
ip_score = 0
user_score = 0
if ip:
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip,))
row = cursor.fetchone()
current = int(row["risk_score"]) if row else 0
ip_score = max(0, min(100, current + int(score_delta)))
if row:
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE ip = ?",
(ip_score, now_str, now_str, ip),
)
else:
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(ip, ip_score, now_str, now_str, now_str),
)
if user_id is not None:
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (int(user_id),))
row = cursor.fetchone()
current = int(row["risk_score"]) if row else 0
user_score = max(0, min(100, current + int(score_delta)))
if row:
cursor.execute(
"UPDATE user_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE user_id = ?",
(user_score, now_str, now_str, int(user_id)),
)
else:
cursor.execute(
"""
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(int(user_id), user_score, now_str, now_str, now_str),
)
return _ScoreUpdateResult(ip_score=ip_score, user_score=user_score)
def _get_auto_ban_actions(
self,
cursor,
ip: str,
user_id: Optional[int],
threat_type: str,
score: int,
update: _ScoreUpdateResult,
) -> tuple[Optional["_BanAction"], Optional["_BanAction"]]:
ip_action: Optional[_BanAction] = None
user_action: Optional[_BanAction] = None
if threat_type == C.THREAT_TYPE_JNDI_INJECTION:
if ip:
ip_action = _BanAction(reason="JNDI injection detected", permanent=True)
if user_id is not None:
user_action = _BanAction(reason="JNDI injection detected", permanent=True)
return ip_action, user_action
if ip and update.ip_score >= 100:
ip_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
if user_id is not None and update.user_score >= 100:
user_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
if score < self.high_risk_threshold:
return ip_action, user_action
window_start = (get_cst_now() - timedelta(hours=self.high_risk_window_hours)).strftime("%Y-%m-%d %H:%M:%S")
if ip:
cursor.execute(
"""
SELECT COUNT(*) AS cnt
FROM threat_events
WHERE ip = ? AND score >= ? AND created_at >= ?
""",
(ip, int(self.high_risk_threshold), window_start),
)
row = cursor.fetchone()
cnt = int(row["cnt"]) if row else 0
if cnt >= self.high_risk_permanent_ban_count:
ip_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
if user_id is not None:
cursor.execute(
"""
SELECT COUNT(*) AS cnt
FROM threat_events
WHERE user_id = ? AND score >= ? AND created_at >= ?
""",
(int(user_id), int(self.high_risk_threshold), window_start),
)
row = cursor.fetchone()
cnt = int(row["cnt"]) if row else 0
if cnt >= self.high_risk_permanent_ban_count:
user_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
return ip_action, user_action
def _get_behavior_bonus(self, ip: str, user_id: Optional[int]) -> int:
return 0
def _hours_since(self, dt_str: Optional[str], now) -> int:
if not dt_str:
return 0
try:
dt = parse_cst_datetime(str(dt_str))
except Exception:
return 0
seconds = (now - dt).total_seconds()
if seconds <= 0:
return 0
return int(seconds // 3600)
def _apply_hourly_decay(self, score: int, hours: int) -> int:
score_int = max(0, int(score))
if score_int <= 0 or hours <= 0:
return score_int
decayed = int(math.floor(score_int * (0.9**int(hours))))
return max(0, min(100, decayed))

316
security/threat_detector.py Normal file
View File

@@ -0,0 +1,316 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import unquote_plus
from . import constants as C
@dataclass
class ThreatResult:
threat_type: str
score: int
field_name: str
rule: str = ""
matched: str = ""
value_preview: str = ""
def to_dict(self) -> dict:
return {
"threat_type": self.threat_type,
"score": int(self.score),
"field_name": self.field_name,
"rule": self.rule,
"matched": self.matched,
"value_preview": self.value_preview,
}
class ThreatDetector:
def __init__(
self,
*,
max_value_length: int = 4096,
max_decode_rounds: int = 2,
) -> None:
self.max_value_length = max(64, int(max_value_length))
self.max_decode_rounds = max(0, int(max_decode_rounds))
def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]:
"""扫描单个输入值(支持 dict/list 等嵌套结构)。"""
results: List[ThreatResult] = []
for sub_field, leaf in self._flatten_value(value, field_name):
text = self._stringify(leaf)
if not text:
continue
if len(text) > self.max_value_length:
text = text[: self.max_value_length]
results.extend(self._scan_text(text, sub_field))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
def scan_request(self, request: Any) -> List[ThreatResult]:
"""扫描整个请求对象(兼容 Flask Request / dict 风格对象)。"""
results: List[ThreatResult] = []
for field_name, value in self._extract_request_fields(request):
results.extend(self.scan_input(value, field_name))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
# ==================== Internal scanning ====================
def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]:
hits: List[ThreatResult] = []
for check in [
self._check_jndi_injection,
self._check_sql_injection,
self._check_xss,
self._check_path_traversal,
self._check_command_injection,
]:
result = check(text)
if result:
threat_type, score, rule, matched = result
hits.append(
ThreatResult(
threat_type=threat_type,
score=int(score),
field_name=field_name,
rule=rule,
matched=matched,
value_preview=self._preview(text),
)
)
return hits
def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
# 1) Direct match
m = C.JNDI_DIRECT_RE.search(text)
if m:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0))
# 2) URL-decoded
decoded = self._multi_unquote(text)
if decoded != text:
m2 = C.JNDI_DIRECT_RE.search(decoded)
if m2:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0))
# 3) Obfuscation patterns (raw/decoded)
for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]:
m3 = C.JNDI_OBFUSCATED_RE.search(candidate)
if m3:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0))
# 4) Try limited de-obfuscation to reveal ${jndi:...}
deobf = self._deobfuscate_log4j(decoded)
if deobf and deobf != decoded:
m4 = C.JNDI_DIRECT_RE.search(deobf)
if m4:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0))
# 5) Nested expression heuristic
for candidate in [text, decoded]:
m5 = C.NESTED_EXPRESSION_RE.search(candidate)
if m5:
return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0))
return None
def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.SQLI_UNION_SELECT_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0))
m = C.SQLI_OR_1_EQ_1_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0))
return None
def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.XSS_SCRIPT_TAG_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0))
m = C.XSS_JS_PROTOCOL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0))
m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0))
return None
def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.PATH_TRAVERSAL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0))
return None
def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0))
m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0))
return None
# ==================== Helpers ====================
def _preview(self, text: str, limit: int = 160) -> str:
s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
if len(s) <= limit:
return s
return s[: limit - 3] + "..."
def _stringify(self, value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
try:
return value.decode("utf-8", errors="ignore")
except Exception:
return ""
try:
return str(value)
except Exception:
return ""
def _multi_unquote(self, text: str) -> str:
s = text
for _ in range(self.max_decode_rounds):
try:
nxt = unquote_plus(s)
except Exception:
break
if nxt == s:
break
s = nxt
return s
def _deobfuscate_log4j(self, text: str) -> str:
# Replace ${...:-x} with x (including ${::-x}).
# This is intentionally conservative to reduce false positives.
import re
s = text
pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}")
for _ in range(3):
nxt = pattern.sub(lambda m: m.group(1), s)
if nxt == s:
break
s = nxt
return s
def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]:
if isinstance(value, dict):
for k, v in value.items():
key = self._stringify(k) or "key"
sub_name = f"{field_name}.{key}" if field_name else key
yield from self._flatten_value(v, sub_name)
return
if isinstance(value, (list, tuple, set)):
for i, v in enumerate(value):
sub_name = f"{field_name}[{i}]"
yield from self._flatten_value(v, sub_name)
return
yield (field_name, value)
def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]:
# dict-like input (useful for unit tests / non-Flask callers)
if isinstance(request, dict):
out: List[Tuple[str, Any]] = []
for k, v in request.items():
out.append((self._stringify(k) or "request", v))
return out
out: List[Tuple[str, Any]] = []
# path / method
for attr_name in ["method", "path", "full_path", "url", "remote_addr"]:
try:
v = getattr(request, attr_name, None)
except Exception:
v = None
if v:
out.append((attr_name, v))
# args / form (Flask MultiDict)
out.extend(self._extract_multidict(getattr(request, "args", None), "args"))
out.extend(self._extract_multidict(getattr(request, "form", None), "form"))
# headers
try:
headers = getattr(request, "headers", None)
if headers is not None:
try:
items = headers.items()
except Exception:
items = []
for k, v in items:
out.append((f"headers.{self._stringify(k)}", v))
except Exception:
pass
# cookies
try:
cookies = getattr(request, "cookies", None)
if isinstance(cookies, dict):
for k, v in cookies.items():
out.append((f"cookies.{self._stringify(k)}", v))
except Exception:
pass
# json body
data = None
try:
get_json = getattr(request, "get_json", None)
if callable(get_json):
data = get_json(silent=True)
except Exception:
data = None
if data is not None:
for name, v in self._flatten_value(data, "json"):
out.append((name, v))
return out
# raw body (as a fallback)
try:
get_data = getattr(request, "get_data", None)
if callable(get_data):
raw = get_data(cache=True, as_text=True)
if raw:
out.append(("body", raw))
except Exception:
pass
return out
def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]:
out: List[Tuple[str, Any]] = []
if md is None:
return out
try:
items = md.items(multi=True)
except Exception:
try:
items = md.items()
except Exception:
return out
for k, v in items:
out.append((f"{prefix}.{self._stringify(k)}", v))
return out