feat: 添加安全模块 + Dockerfile添加curl支持健康检查
主要更新: - 新增 security/ 安全模块 (风险评估、威胁检测、蜜罐等) - Dockerfile 添加 curl 以支持 Docker 健康检查 - 前端页面更新 (管理后台、用户端) - 数据库迁移和 schema 更新 - 新增 kdocs 上传服务 - 添加安全相关测试用例 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
22
security/__init__.py
Normal file
22
security/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from security.blacklist import BlacklistManager
|
||||
from security.honeypot import HoneypotResponder
|
||||
from security.middleware import init_security_middleware
|
||||
from security.response_handler import ResponseAction, ResponseHandler, ResponseStrategy
|
||||
from security.risk_scorer import RiskScorer
|
||||
from security.threat_detector import ThreatDetector, ThreatResult
|
||||
|
||||
__all__ = [
|
||||
"BlacklistManager",
|
||||
"HoneypotResponder",
|
||||
"init_security_middleware",
|
||||
"ResponseAction",
|
||||
"ResponseHandler",
|
||||
"ResponseStrategy",
|
||||
"RiskScorer",
|
||||
"ThreatDetector",
|
||||
"ThreatResult",
|
||||
]
|
||||
255
security/blacklist.py
Normal file
255
security/blacklist.py
Normal file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
import db_pool
|
||||
from db.utils import get_cst_now, get_cst_now_str
|
||||
|
||||
|
||||
class BlacklistManager:
|
||||
"""黑名单管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._schema_ready = False
|
||||
self._schema_lock = threading.Lock()
|
||||
|
||||
def is_ip_banned(self, ip: str) -> bool:
|
||||
"""检查IP是否被封禁"""
|
||||
ip_text = str(ip or "").strip()[:64]
|
||||
if not ip_text:
|
||||
return False
|
||||
now_str = get_cst_now_str()
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT 1
|
||||
FROM ip_blacklist
|
||||
WHERE ip = ?
|
||||
AND is_active = 1
|
||||
AND (expires_at IS NULL OR expires_at > ?)
|
||||
LIMIT 1
|
||||
""",
|
||||
(ip_text, now_str),
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def is_user_banned(self, user_id: int) -> bool:
|
||||
"""检查用户是否被封禁"""
|
||||
if user_id is None:
|
||||
return False
|
||||
self._ensure_schema()
|
||||
user_id_int = int(user_id)
|
||||
now_str = get_cst_now_str()
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT 1
|
||||
FROM user_blacklist
|
||||
WHERE user_id = ?
|
||||
AND is_active = 1
|
||||
AND (expires_at IS NULL OR expires_at > ?)
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id_int, now_str),
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def ban_ip(self, ip: str, reason: str, duration_hours: int = 24, permanent: bool = False):
|
||||
"""封禁IP"""
|
||||
ip_text = str(ip or "").strip()[:64]
|
||||
if not ip_text:
|
||||
return False
|
||||
reason_text = str(reason or "").strip()[:512]
|
||||
now_str = get_cst_now_str()
|
||||
|
||||
expires_at: Optional[str]
|
||||
if permanent:
|
||||
expires_at = None
|
||||
else:
|
||||
hours = max(1, int(duration_hours))
|
||||
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
|
||||
VALUES (?, ?, 1, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
reason = excluded.reason,
|
||||
is_active = 1,
|
||||
added_at = excluded.added_at,
|
||||
expires_at = excluded.expires_at
|
||||
""",
|
||||
(ip_text, reason_text, now_str, expires_at),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
def ban_user(self, user_id: int, reason: str):
|
||||
"""封禁用户"""
|
||||
return self._ban_user_internal(user_id, reason=reason, duration_hours=24, permanent=False)
|
||||
|
||||
def unban_ip(self, ip: str):
|
||||
"""解除IP封禁"""
|
||||
ip_text = str(ip or "").strip()[:64]
|
||||
if not ip_text:
|
||||
return False
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("UPDATE ip_blacklist SET is_active = 0 WHERE ip = ?", (ip_text,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def unban_user(self, user_id: int):
|
||||
"""解除用户封禁"""
|
||||
if user_id is None:
|
||||
return False
|
||||
self._ensure_schema()
|
||||
user_id_int = int(user_id)
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("UPDATE user_blacklist SET is_active = 0 WHERE user_id = ?", (user_id_int,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_banned_ips(self) -> List[dict]:
|
||||
"""获取所有被封禁的IP"""
|
||||
now_str = get_cst_now_str()
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT ip, reason, is_active, added_at, expires_at
|
||||
FROM ip_blacklist
|
||||
WHERE is_active = 1
|
||||
AND (expires_at IS NULL OR expires_at > ?)
|
||||
ORDER BY added_at DESC
|
||||
""",
|
||||
(now_str,),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_banned_users(self) -> List[dict]:
|
||||
"""获取所有被封禁的用户"""
|
||||
self._ensure_schema()
|
||||
now_str = get_cst_now_str()
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT user_id, reason, is_active, added_at, expires_at
|
||||
FROM user_blacklist
|
||||
WHERE is_active = 1
|
||||
AND (expires_at IS NULL OR expires_at > ?)
|
||||
ORDER BY added_at DESC
|
||||
""",
|
||||
(now_str,),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""清理过期的封禁记录"""
|
||||
now_str = get_cst_now_str()
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE ip_blacklist
|
||||
SET is_active = 0
|
||||
WHERE is_active = 1
|
||||
AND expires_at IS NOT NULL
|
||||
AND expires_at <= ?
|
||||
""",
|
||||
(now_str,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
self._ensure_schema()
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE user_blacklist
|
||||
SET is_active = 0
|
||||
WHERE is_active = 1
|
||||
AND expires_at IS NOT NULL
|
||||
AND expires_at <= ?
|
||||
""",
|
||||
(now_str,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# ==================== Internal ====================
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
if self._schema_ready:
|
||||
return
|
||||
with self._schema_lock:
|
||||
if self._schema_ready:
|
||||
return
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS user_blacklist (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
reason TEXT,
|
||||
is_active INTEGER DEFAULT 1,
|
||||
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
|
||||
conn.commit()
|
||||
self._schema_ready = True
|
||||
|
||||
def _ban_user_internal(
|
||||
self,
|
||||
user_id: int,
|
||||
*,
|
||||
reason: str,
|
||||
duration_hours: int = 24,
|
||||
permanent: bool = False,
|
||||
) -> bool:
|
||||
if user_id is None:
|
||||
return False
|
||||
self._ensure_schema()
|
||||
user_id_int = int(user_id)
|
||||
reason_text = str(reason or "").strip()[:512]
|
||||
now_str = get_cst_now_str()
|
||||
|
||||
expires_at: Optional[str]
|
||||
if permanent:
|
||||
expires_at = None
|
||||
else:
|
||||
hours = max(1, int(duration_hours))
|
||||
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO user_blacklist (user_id, reason, is_active, added_at, expires_at)
|
||||
VALUES (?, ?, 1, ?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET
|
||||
reason = excluded.reason,
|
||||
is_active = 1,
|
||||
added_at = excluded.added_at,
|
||||
expires_at = excluded.expires_at
|
||||
""",
|
||||
(user_id_int, reason_text, now_str, expires_at),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
146
security/constants.py
Normal file
146
security/constants.py
Normal file
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
# ==================== Threat Types ====================
|
||||
|
||||
THREAT_TYPE_JNDI_INJECTION = "jndi_injection"
|
||||
THREAT_TYPE_NESTED_EXPRESSION = "nested_expression"
|
||||
THREAT_TYPE_SQL_INJECTION = "sql_injection"
|
||||
THREAT_TYPE_XSS = "xss"
|
||||
THREAT_TYPE_PATH_TRAVERSAL = "path_traversal"
|
||||
THREAT_TYPE_COMMAND_INJECTION = "command_injection"
|
||||
THREAT_TYPE_SSRF = "ssrf"
|
||||
THREAT_TYPE_XXE = "xxe"
|
||||
THREAT_TYPE_TEMPLATE_INJECTION = "template_injection"
|
||||
THREAT_TYPE_SENSITIVE_PATH_PROBE = "sensitive_path_probe"
|
||||
|
||||
|
||||
# ==================== Scores ====================
|
||||
|
||||
SCORE_JNDI_DIRECT = 100
|
||||
SCORE_JNDI_OBFUSCATED = 100
|
||||
SCORE_NESTED_EXPRESSION = 80
|
||||
SCORE_SQL_INJECTION = 90
|
||||
SCORE_XSS = 70
|
||||
SCORE_PATH_TRAVERSAL = 60
|
||||
SCORE_COMMAND_INJECTION = 85
|
||||
SCORE_SSRF = 75
|
||||
SCORE_XXE = 85
|
||||
SCORE_TEMPLATE_INJECTION = 70
|
||||
SCORE_SENSITIVE_PATH_PROBE = 40
|
||||
|
||||
|
||||
# ==================== JNDI (Log4j) ====================
|
||||
#
|
||||
# - Direct: ${jndi:ldap://...} / ${jndi:rmi://...} => 100
|
||||
# - Obfuscated: ${${xxx:-j}${xxx:-n}...:ldap://...} => detect
|
||||
# - Nested expression: ${${...}} => 80
|
||||
|
||||
JNDI_DIRECT_PATTERN = r"\$\{\s*jndi\s*:\s*(?:ldap|rmi)\s*://"
|
||||
|
||||
# Common Log4j "default value" obfuscation variants:
|
||||
# ${${::-j}${::-n}${::-d}${::-i}:ldap://...}
|
||||
# ${${foo:-j}${bar:-n}${baz:-d}${qux:-i}:rmi://...}
|
||||
JNDI_OBFUSCATED_PATTERN = (
|
||||
r"\$\{\s*"
|
||||
r"(?:\$\{[^{}]{0,50}:-j\}|\$\{::-[jJ]\})\s*"
|
||||
r"(?:\$\{[^{}]{0,50}:-n\}|\$\{::-[nN]\})\s*"
|
||||
r"(?:\$\{[^{}]{0,50}:-d\}|\$\{::-[dD]\})\s*"
|
||||
r"(?:\$\{[^{}]{0,50}:-i\}|\$\{::-[iI]\})\s*"
|
||||
r":\s*(?:ldap|rmi)\s*://"
|
||||
)
|
||||
|
||||
NESTED_EXPRESSION_PATTERN = r"\$\{\s*\$\{"
|
||||
|
||||
|
||||
# ==================== SQL Injection ====================
|
||||
|
||||
SQLI_UNION_SELECT_PATTERN = r"\bunion\b\s+(?:all\s+)?\bselect\b"
|
||||
SQLI_OR_1_EQ_1_PATTERN = r"\bor\b\s+1\s*=\s*1\b"
|
||||
|
||||
|
||||
# ==================== XSS ====================
|
||||
|
||||
XSS_SCRIPT_TAG_PATTERN = r"<\s*script\b"
|
||||
XSS_JS_PROTOCOL_PATTERN = r"javascript\s*:"
|
||||
XSS_INLINE_EVENT_HANDLER_PATTERN = r"\bon\w+\s*="
|
||||
|
||||
|
||||
# ==================== Path Traversal ====================
|
||||
|
||||
PATH_TRAVERSAL_PATTERN = r"(?:\.\./|\.\.\\)+"
|
||||
|
||||
|
||||
# ==================== Command Injection ====================
|
||||
|
||||
CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN = (
|
||||
r"(?:;|&&|\|\||\|)\s*"
|
||||
r"(?:bash|sh|zsh|cmd|powershell|pwsh|curl|wget|nc|netcat|python|perl|ruby|php|node|cat|ls|id|whoami|uname|rm)\b"
|
||||
)
|
||||
CMD_INJECTION_SUBSHELL_PATTERN = r"(?:`[^`]{1,200}`|\$\([^)]{1,200}\))"
|
||||
|
||||
|
||||
# ==================== SSRF ====================
|
||||
|
||||
SSRF_LOCALHOST_URL_PATTERN = r"\bhttps?\s*:\s*//\s*(?:127\.0\.0\.1\b|localhost\b|0\.0\.0\.0\b)"
|
||||
SSRF_INTERNAL_IP_URL_PATTERN = r"\bhttps?\s*:\s*//\s*(?:10\.|192\.168\.|172\.(?:1[6-9]|2[0-9]|3[0-1])\.)"
|
||||
SSRF_DANGEROUS_PROTOCOL_PATTERN = r"\b(?:file|gopher|dict)\s*:\s*//"
|
||||
|
||||
|
||||
# ==================== XXE ====================
|
||||
|
||||
XXE_DOCTYPE_PATTERN = r"<!\s*doctype\b|\bdoctype\b"
|
||||
XXE_ENTITY_PATTERN = r"<!\s*entity\b|\bentity\b"
|
||||
XXE_SYSTEM_PUBLIC_PATTERN = r"\b(?:system|public)\b"
|
||||
|
||||
|
||||
# ==================== Template Injection ====================
|
||||
|
||||
TEMPLATE_JINJA_EXPR_PATTERN = r"\{\{\s*[^}]{0,200}\s*\}\}"
|
||||
TEMPLATE_JINJA_STMT_PATTERN = r"\{%\s*[^%]{0,200}\s*%\}"
|
||||
TEMPLATE_VELOCITY_DIRECTIVE_PATTERN = r"#\s*(?:set|if)\b"
|
||||
|
||||
|
||||
# ==================== Sensitive Path Probing ====================
|
||||
|
||||
SENSITIVE_PATH_DOTFILES_PATTERN = r"/\.(?:git|svn|env)(?:/|\b|$)"
|
||||
SENSITIVE_PATH_PROBE_PATTERN = r"/(?:actuator|phpinfo|wp-admin)(?:/|\b|$)"
|
||||
|
||||
|
||||
# ==================== Compiled Regex ====================
|
||||
|
||||
_FLAGS = re.IGNORECASE | re.MULTILINE
|
||||
|
||||
JNDI_DIRECT_RE = re.compile(JNDI_DIRECT_PATTERN, _FLAGS)
|
||||
JNDI_OBFUSCATED_RE = re.compile(JNDI_OBFUSCATED_PATTERN, _FLAGS)
|
||||
NESTED_EXPRESSION_RE = re.compile(NESTED_EXPRESSION_PATTERN, _FLAGS)
|
||||
|
||||
SQLI_UNION_SELECT_RE = re.compile(SQLI_UNION_SELECT_PATTERN, _FLAGS)
|
||||
SQLI_OR_1_EQ_1_RE = re.compile(SQLI_OR_1_EQ_1_PATTERN, _FLAGS)
|
||||
|
||||
XSS_SCRIPT_TAG_RE = re.compile(XSS_SCRIPT_TAG_PATTERN, _FLAGS)
|
||||
XSS_JS_PROTOCOL_RE = re.compile(XSS_JS_PROTOCOL_PATTERN, _FLAGS)
|
||||
XSS_INLINE_EVENT_HANDLER_RE = re.compile(XSS_INLINE_EVENT_HANDLER_PATTERN, _FLAGS)
|
||||
|
||||
PATH_TRAVERSAL_RE = re.compile(PATH_TRAVERSAL_PATTERN, _FLAGS)
|
||||
|
||||
CMD_INJECTION_OPERATOR_WITH_CMD_RE = re.compile(CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN, _FLAGS)
|
||||
CMD_INJECTION_SUBSHELL_RE = re.compile(CMD_INJECTION_SUBSHELL_PATTERN, _FLAGS)
|
||||
|
||||
SSRF_LOCALHOST_URL_RE = re.compile(SSRF_LOCALHOST_URL_PATTERN, _FLAGS)
|
||||
SSRF_INTERNAL_IP_URL_RE = re.compile(SSRF_INTERNAL_IP_URL_PATTERN, _FLAGS)
|
||||
SSRF_DANGEROUS_PROTOCOL_RE = re.compile(SSRF_DANGEROUS_PROTOCOL_PATTERN, _FLAGS)
|
||||
|
||||
XXE_DOCTYPE_RE = re.compile(XXE_DOCTYPE_PATTERN, _FLAGS)
|
||||
XXE_ENTITY_RE = re.compile(XXE_ENTITY_PATTERN, _FLAGS)
|
||||
XXE_SYSTEM_PUBLIC_RE = re.compile(XXE_SYSTEM_PUBLIC_PATTERN, _FLAGS)
|
||||
|
||||
TEMPLATE_JINJA_EXPR_RE = re.compile(TEMPLATE_JINJA_EXPR_PATTERN, _FLAGS)
|
||||
TEMPLATE_JINJA_STMT_RE = re.compile(TEMPLATE_JINJA_STMT_PATTERN, _FLAGS)
|
||||
TEMPLATE_VELOCITY_DIRECTIVE_RE = re.compile(TEMPLATE_VELOCITY_DIRECTIVE_PATTERN, _FLAGS)
|
||||
|
||||
SENSITIVE_PATH_DOTFILES_RE = re.compile(SENSITIVE_PATH_DOTFILES_PATTERN, _FLAGS)
|
||||
SENSITIVE_PATH_PROBE_RE = re.compile(SENSITIVE_PATH_PROBE_PATTERN, _FLAGS)
|
||||
126
security/honeypot.py
Normal file
126
security/honeypot.py
Normal file
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from app_logger import get_logger
|
||||
|
||||
|
||||
class HoneypotResponder:
|
||||
"""蜜罐响应生成器 - 返回假成功响应,欺骗攻击者"""
|
||||
|
||||
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
|
||||
self._rng = rng or random.SystemRandom()
|
||||
self._logger = get_logger("app")
|
||||
|
||||
def generate_fake_response(self, endpoint: str, original_data: dict = None) -> dict:
|
||||
"""
|
||||
根据端点生成假的成功响应
|
||||
|
||||
策略:
|
||||
- 邮件发送类: {"success": True, "message": "邮件已发送"}
|
||||
- 注册类: {"success": True, "user_id": fake_uuid}
|
||||
- 登录类: {"success": True} 但不设置session
|
||||
- 通用: {"success": True, "message": "操作成功"}
|
||||
"""
|
||||
endpoint_text = str(endpoint or "").strip()
|
||||
endpoint_lc = endpoint_text.lower()
|
||||
|
||||
category = self._classify_endpoint(endpoint_lc)
|
||||
response: dict[str, Any] = {"success": True}
|
||||
|
||||
if category == "email":
|
||||
response["message"] = "邮件已发送"
|
||||
elif category == "register":
|
||||
response["user_id"] = str(uuid.uuid4())
|
||||
elif category == "login":
|
||||
# 登录类:保持正常成功响应,但不进行任何 session / token 设置(调用方负责不写 session)
|
||||
pass
|
||||
else:
|
||||
response["message"] = "操作成功"
|
||||
|
||||
response = self._merge_safe_fields(response, original_data)
|
||||
|
||||
self._logger.warning(
|
||||
"蜜罐响应已生成: endpoint=%s, category=%s, keys=%s",
|
||||
endpoint_text[:256],
|
||||
category,
|
||||
sorted(response.keys()),
|
||||
)
|
||||
return response
|
||||
|
||||
def should_use_honeypot(self, risk_score: int) -> bool:
|
||||
"""风险分>=80使用蜜罐响应"""
|
||||
score = self._normalize_risk_score(risk_score)
|
||||
use = score >= 80
|
||||
self._logger.debug("蜜罐判定: risk_score=%s => %s", score, use)
|
||||
return use
|
||||
|
||||
def delay_response(self, risk_score: int) -> float:
|
||||
"""
|
||||
根据风险分计算延迟时间
|
||||
0-20: 0秒
|
||||
21-50: 随机0.5-1秒
|
||||
51-80: 随机1-3秒
|
||||
81-100: 随机3-8秒(蜜罐模式额外延迟消耗攻击者时间)
|
||||
"""
|
||||
score = self._normalize_risk_score(risk_score)
|
||||
|
||||
delay = 0.0
|
||||
if score <= 20:
|
||||
delay = 0.0
|
||||
elif score <= 50:
|
||||
delay = float(self._rng.uniform(0.5, 1.0))
|
||||
elif score <= 80:
|
||||
delay = float(self._rng.uniform(1.0, 3.0))
|
||||
else:
|
||||
delay = float(self._rng.uniform(3.0, 8.0))
|
||||
|
||||
self._logger.debug("蜜罐延迟计算: risk_score=%s => delay_seconds=%.3f", score, delay)
|
||||
return delay
|
||||
|
||||
# ==================== Internal ====================
|
||||
|
||||
def _normalize_risk_score(self, risk_score: Any) -> int:
|
||||
try:
|
||||
score = int(risk_score)
|
||||
except Exception:
|
||||
score = 0
|
||||
return max(0, min(100, score))
|
||||
|
||||
def _classify_endpoint(self, endpoint_lc: str) -> str:
|
||||
if not endpoint_lc:
|
||||
return "generic"
|
||||
|
||||
# 先匹配更具体的:注册 / 登录
|
||||
if any(k in endpoint_lc for k in ["/register", "register", "signup", "sign-up"]):
|
||||
return "register"
|
||||
if any(k in endpoint_lc for k in ["/login", "login", "signin", "sign-in"]):
|
||||
return "login"
|
||||
|
||||
# 邮件相关:发送验证码 / 重置密码 / 重发验证等
|
||||
if any(k in endpoint_lc for k in ["email", "mail", "forgot-password", "reset-password", "resend-verify"]):
|
||||
return "email"
|
||||
|
||||
return "generic"
|
||||
|
||||
def _merge_safe_fields(self, base: dict, original_data: Optional[dict]) -> dict:
|
||||
if not isinstance(original_data, dict) or not original_data:
|
||||
return base
|
||||
|
||||
# 避免把攻击者输入或真实业务结果回显得太明显;仅合并少量“形状字段”
|
||||
safe_bool_keys = {"need_verify", "need_captcha"}
|
||||
|
||||
merged = dict(base)
|
||||
for key in safe_bool_keys:
|
||||
if key in original_data and key not in merged:
|
||||
try:
|
||||
merged[key] = bool(original_data.get(key))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return merged
|
||||
|
||||
307
security/middleware.py
Normal file
307
security/middleware.py
Normal file
@@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask import g, jsonify, request
|
||||
from flask_login import current_user
|
||||
|
||||
from app_logger import get_logger
|
||||
from app_security import get_rate_limit_ip
|
||||
|
||||
from .blacklist import BlacklistManager
|
||||
from .honeypot import HoneypotResponder
|
||||
from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy
|
||||
from .risk_scorer import RiskScorer
|
||||
from .threat_detector import ThreatDetector, ThreatResult
|
||||
|
||||
# 全局实例(保持单例,避免重复初始化开销)
|
||||
detector = ThreatDetector()
|
||||
blacklist = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=blacklist)
|
||||
handler: Optional[ResponseHandler] = None
|
||||
honeypot: Optional[HoneypotResponder] = None
|
||||
|
||||
|
||||
def _get_handler() -> ResponseHandler:
|
||||
global handler
|
||||
if handler is None:
|
||||
handler = ResponseHandler()
|
||||
return handler
|
||||
|
||||
|
||||
def _get_honeypot() -> HoneypotResponder:
|
||||
global honeypot
|
||||
if honeypot is None:
|
||||
honeypot = HoneypotResponder()
|
||||
return honeypot
|
||||
|
||||
|
||||
def _get_security_log_level(app) -> int:
|
||||
level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper()
|
||||
return int(getattr(logging, level_name, logging.INFO))
|
||||
|
||||
|
||||
def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None:
|
||||
"""按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。"""
|
||||
try:
|
||||
logger = get_logger("app")
|
||||
min_level = _get_security_log_level(app)
|
||||
if int(level) >= int(min_level):
|
||||
logger.log(int(level), message, *args, exc_info=exc_info)
|
||||
except Exception:
|
||||
# 安全模块日志故障不得影响正常请求
|
||||
return
|
||||
|
||||
|
||||
def _is_static_request(app) -> bool:
|
||||
"""对静态文件请求跳过安全检查以提升性能。"""
|
||||
try:
|
||||
path = str(getattr(request, "path", "") or "")
|
||||
except Exception:
|
||||
path = ""
|
||||
|
||||
if path.startswith("/static/"):
|
||||
return True
|
||||
|
||||
try:
|
||||
static_url_path = getattr(app, "static_url_path", None) or "/static"
|
||||
if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
endpoint = getattr(request, "endpoint", None)
|
||||
if endpoint in {"static", "serve_static"}:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _safe_get_user_id() -> Optional[int]:
|
||||
try:
|
||||
if hasattr(current_user, "is_authenticated") and current_user.is_authenticated:
|
||||
return getattr(current_user, "id", None)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _scan_request_threats(req) -> list[ThreatResult]:
|
||||
"""仅扫描 GET query 与 POST JSON body(降低开销与误报)。"""
|
||||
threats: list[ThreatResult] = []
|
||||
|
||||
try:
|
||||
# 1) Query 参数(所有方法均可能携带 query string)
|
||||
try:
|
||||
args = getattr(req, "args", None)
|
||||
if args:
|
||||
# MultiDict -> dict(list) 以保留多值
|
||||
args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args)
|
||||
threats.extend(detector.scan_input(args_dict, "args"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2) JSON body(主要针对 POST;其他方法保持兼容)
|
||||
try:
|
||||
method = str(getattr(req, "method", "") or "").upper()
|
||||
except Exception:
|
||||
method = ""
|
||||
|
||||
if method in {"POST", "PUT", "PATCH", "DELETE"}:
|
||||
try:
|
||||
data = req.get_json(silent=True) if hasattr(req, "get_json") else None
|
||||
except Exception:
|
||||
data = None
|
||||
if data is not None:
|
||||
threats.extend(detector.scan_input(data, "json"))
|
||||
except Exception:
|
||||
# 扫描失败不应阻断业务
|
||||
return []
|
||||
|
||||
threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True)
|
||||
return threats
|
||||
|
||||
|
||||
def init_security_middleware(app):
|
||||
"""初始化安全中间件到 Flask 应用。"""
|
||||
try:
|
||||
scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@app.before_request
|
||||
def security_check():
|
||||
if not bool(app.config.get("SECURITY_ENABLED", True)):
|
||||
return None
|
||||
if _is_static_request(app):
|
||||
return None
|
||||
|
||||
try:
|
||||
ip = get_rate_limit_ip()
|
||||
except Exception:
|
||||
ip = getattr(request, "remote_addr", "") or ""
|
||||
|
||||
user_id = _safe_get_user_id()
|
||||
|
||||
# 默认值,确保后续逻辑可用
|
||||
g.risk_score = 0
|
||||
g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW)
|
||||
g.honeypot_mode = False
|
||||
g.honeypot_endpoint = None
|
||||
g.honeypot_generated = False
|
||||
|
||||
try:
|
||||
# 1) 检查黑名单(静默拒绝,返回通用错误)
|
||||
try:
|
||||
if blacklist.is_ip_banned(ip):
|
||||
_log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256])
|
||||
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
|
||||
except Exception:
|
||||
_log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True)
|
||||
|
||||
try:
|
||||
if user_id is not None and blacklist.is_user_banned(user_id):
|
||||
_log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256])
|
||||
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
|
||||
except Exception:
|
||||
_log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True)
|
||||
|
||||
# 2) 扫描威胁(GET query / POST JSON)
|
||||
threats = _scan_request_threats(request)
|
||||
|
||||
if threats:
|
||||
max_threat = threats[0]
|
||||
_log(
|
||||
app,
|
||||
logging.WARNING,
|
||||
"威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s",
|
||||
ip,
|
||||
user_id,
|
||||
getattr(max_threat, "threat_type", "unknown"),
|
||||
getattr(max_threat, "score", 0),
|
||||
getattr(max_threat, "field_name", ""),
|
||||
getattr(max_threat, "rule", ""),
|
||||
)
|
||||
|
||||
# 记录威胁事件(异常不应阻断业务)
|
||||
try:
|
||||
payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or ""
|
||||
scorer.record_threat(
|
||||
ip=ip,
|
||||
user_id=user_id,
|
||||
threat_type=getattr(max_threat, "threat_type", "unknown"),
|
||||
score=int(getattr(max_threat, "score", 0) or 0),
|
||||
request_path=getattr(request, "path", None),
|
||||
payload=str(payload)[:500] if payload else None,
|
||||
)
|
||||
except Exception:
|
||||
_log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
|
||||
|
||||
# 高危威胁启用蜜罐模式
|
||||
if bool(app.config.get("HONEYPOT_ENABLED", True)):
|
||||
try:
|
||||
if int(getattr(max_threat, "score", 0) or 0) >= 80:
|
||||
g.honeypot_mode = True
|
||||
g.honeypot_endpoint = getattr(request, "endpoint", None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3) 综合风险分与响应策略
|
||||
try:
|
||||
risk_score = scorer.get_combined_score(ip, user_id)
|
||||
except Exception:
|
||||
_log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
|
||||
risk_score = 0
|
||||
|
||||
try:
|
||||
strategy = _get_handler().get_strategy(risk_score)
|
||||
except Exception:
|
||||
_log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True)
|
||||
strategy = ResponseStrategy(action=ResponseAction.ALLOW)
|
||||
|
||||
g.risk_score = int(risk_score or 0)
|
||||
g.response_strategy = strategy
|
||||
|
||||
# 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略)
|
||||
if bool(app.config.get("HONEYPOT_ENABLED", True)):
|
||||
try:
|
||||
if getattr(strategy, "action", None) == ResponseAction.HONEYPOT:
|
||||
g.honeypot_mode = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 4) 应用延迟
|
||||
try:
|
||||
if float(getattr(strategy, "delay_seconds", 0) or 0) > 0:
|
||||
_get_handler().apply_delay(strategy)
|
||||
except Exception:
|
||||
_log(app, logging.ERROR, "延迟应用失败", exc_info=True)
|
||||
|
||||
# 优先短路:避免业务 side effects(例如发送邮件/修改状态)
|
||||
if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)):
|
||||
try:
|
||||
fake_payload = None
|
||||
try:
|
||||
fake_payload = request.get_json(silent=True)
|
||||
except Exception:
|
||||
fake_payload = None
|
||||
fake_response = _get_honeypot().generate_fake_response(
|
||||
getattr(g, "honeypot_endpoint", "default"),
|
||||
fake_payload if isinstance(fake_payload, dict) else None,
|
||||
)
|
||||
g.honeypot_generated = True
|
||||
return jsonify(fake_response), 200
|
||||
except Exception:
|
||||
_log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True)
|
||||
return None
|
||||
except Exception:
|
||||
# 全局兜底:安全模块任何异常都不能阻断正常请求
|
||||
_log(app, logging.ERROR, "安全中间件发生异常", exc_info=True)
|
||||
return None
|
||||
|
||||
return None # 继续正常处理
|
||||
|
||||
@app.after_request
|
||||
def security_response(response):
|
||||
"""请求后处理 - 兜底应用蜜罐响应。"""
|
||||
if not bool(app.config.get("SECURITY_ENABLED", True)):
|
||||
return response
|
||||
if not bool(app.config.get("HONEYPOT_ENABLED", True)):
|
||||
return response
|
||||
|
||||
try:
|
||||
if _is_static_request(app):
|
||||
return response
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动
|
||||
try:
|
||||
if getattr(g, "honeypot_generated", False):
|
||||
return response
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if getattr(g, "honeypot_mode", False):
|
||||
fake_payload = None
|
||||
try:
|
||||
fake_payload = request.get_json(silent=True)
|
||||
except Exception:
|
||||
fake_payload = None
|
||||
fake_response = _get_honeypot().generate_fake_response(
|
||||
getattr(g, "honeypot_endpoint", "default"),
|
||||
fake_payload if isinstance(fake_payload, dict) else None,
|
||||
)
|
||||
return jsonify(fake_response), 200
|
||||
except Exception:
|
||||
_log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True)
|
||||
return response
|
||||
|
||||
return response
|
||||
131
security/response_handler.py
Normal file
131
security/response_handler.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from app_logger import get_logger
|
||||
|
||||
|
||||
class ResponseAction(Enum):
|
||||
ALLOW = "allow" # 正常放行
|
||||
ENHANCE_CAPTCHA = "enhance_captcha" # 增强验证码
|
||||
DELAY = "delay" # 静默延迟
|
||||
HONEYPOT = "honeypot" # 蜜罐响应
|
||||
BLOCK = "block" # 直接拒绝
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseStrategy:
|
||||
action: ResponseAction
|
||||
delay_seconds: float = 0
|
||||
captcha_level: int = 1 # 1=普通4位, 2=6位, 3=滑块
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class ResponseHandler:
|
||||
"""响应策略处理器"""
|
||||
|
||||
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
|
||||
self._rng = rng or random.SystemRandom()
|
||||
self._logger = get_logger("app")
|
||||
|
||||
def get_strategy(self, risk_score: int, is_banned: bool = False) -> ResponseStrategy:
|
||||
"""
|
||||
根据风险分获取响应策略
|
||||
|
||||
0-20分: ALLOW, 无延迟, 普通验证码
|
||||
21-40分: ALLOW, 无延迟, 6位验证码
|
||||
41-60分: DELAY, 1-2秒延迟
|
||||
61-80分: DELAY, 2-5秒延迟
|
||||
81-100分: HONEYPOT, 3-8秒延迟
|
||||
已封禁: BLOCK
|
||||
"""
|
||||
score = self._normalize_risk_score(risk_score)
|
||||
|
||||
if is_banned:
|
||||
strategy = ResponseStrategy(action=ResponseAction.BLOCK, message="访问被拒绝")
|
||||
self._logger.warning("响应策略: BLOCK (banned=%s, risk_score=%s)", is_banned, score)
|
||||
return strategy
|
||||
|
||||
if score <= 20:
|
||||
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=1)
|
||||
elif score <= 40:
|
||||
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=2)
|
||||
elif score <= 60:
|
||||
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(1.0, 2.0)))
|
||||
elif score <= 80:
|
||||
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(2.0, 5.0)))
|
||||
else:
|
||||
strategy = ResponseStrategy(action=ResponseAction.HONEYPOT, delay_seconds=float(self._rng.uniform(3.0, 8.0)))
|
||||
|
||||
strategy.captcha_level = self._normalize_captcha_level(strategy.captcha_level)
|
||||
|
||||
self._logger.info(
|
||||
"响应策略: action=%s risk_score=%s delay=%.3f captcha_level=%s",
|
||||
strategy.action.value,
|
||||
score,
|
||||
float(strategy.delay_seconds or 0),
|
||||
int(strategy.captcha_level),
|
||||
)
|
||||
return strategy
|
||||
|
||||
def apply_delay(self, strategy: ResponseStrategy):
|
||||
"""应用延迟(使用time.sleep)"""
|
||||
if strategy is None:
|
||||
return
|
||||
delay = 0.0
|
||||
try:
|
||||
delay = float(getattr(strategy, "delay_seconds", 0) or 0)
|
||||
except Exception:
|
||||
delay = 0.0
|
||||
|
||||
if delay <= 0:
|
||||
return
|
||||
|
||||
self._logger.debug("应用延迟: action=%s delay=%.3f", getattr(strategy.action, "value", strategy.action), delay)
|
||||
time.sleep(delay)
|
||||
|
||||
def get_captcha_requirement(self, strategy: ResponseStrategy) -> dict:
|
||||
"""返回验证码要求 {"required": True, "level": 2}"""
|
||||
level = 1
|
||||
try:
|
||||
level = int(getattr(strategy, "captcha_level", 1) or 1)
|
||||
except Exception:
|
||||
level = 1
|
||||
level = self._normalize_captcha_level(level)
|
||||
|
||||
required = True
|
||||
try:
|
||||
required = getattr(strategy, "action", None) != ResponseAction.BLOCK
|
||||
except Exception:
|
||||
required = True
|
||||
|
||||
payload = {"required": bool(required), "level": level}
|
||||
self._logger.debug("验证码要求: %s", payload)
|
||||
return payload
|
||||
|
||||
# ==================== Internal ====================
|
||||
|
||||
def _normalize_risk_score(self, risk_score: Any) -> int:
|
||||
try:
|
||||
score = int(risk_score)
|
||||
except Exception:
|
||||
score = 0
|
||||
return max(0, min(100, score))
|
||||
|
||||
def _normalize_captcha_level(self, level: Any) -> int:
|
||||
try:
|
||||
i = int(level)
|
||||
except Exception:
|
||||
i = 1
|
||||
if i <= 1:
|
||||
return 1
|
||||
if i == 2:
|
||||
return 2
|
||||
return 3
|
||||
|
||||
389
security/risk_scorer.py
Normal file
389
security/risk_scorer.py
Normal file
@@ -0,0 +1,389 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import db_pool
|
||||
from db.utils import get_cst_now, get_cst_now_str, parse_cst_datetime
|
||||
|
||||
from . import constants as C
|
||||
from .blacklist import BlacklistManager
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ScoreUpdateResult:
|
||||
ip_score: int
|
||||
user_score: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _BanAction:
|
||||
reason: str
|
||||
duration_hours: Optional[int] = None
|
||||
permanent: bool = False
|
||||
|
||||
|
||||
class RiskScorer:
|
||||
"""风险评分引擎 - 计算IP和用户的风险分数"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
auto_ban_enabled: bool = True,
|
||||
auto_ban_duration_hours: int = 24,
|
||||
high_risk_threshold: int = 80,
|
||||
high_risk_window_hours: int = 1,
|
||||
high_risk_permanent_ban_count: int = 3,
|
||||
blacklist_manager: Optional[BlacklistManager] = None,
|
||||
) -> None:
|
||||
self.auto_ban_enabled = bool(auto_ban_enabled)
|
||||
self.auto_ban_duration_hours = max(1, int(auto_ban_duration_hours))
|
||||
self.high_risk_threshold = max(0, int(high_risk_threshold))
|
||||
self.high_risk_window_hours = max(1, int(high_risk_window_hours))
|
||||
self.high_risk_permanent_ban_count = max(1, int(high_risk_permanent_ban_count))
|
||||
self.blacklist = blacklist_manager or BlacklistManager()
|
||||
|
||||
def get_ip_score(self, ip_address: str) -> int:
|
||||
"""获取IP风险分(0-100),从数据库读取"""
|
||||
ip_text = str(ip_address or "").strip()[:64]
|
||||
if not ip_text:
|
||||
return 0
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip_text,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
try:
|
||||
return max(0, min(100, int(row["risk_score"])))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def get_user_score(self, user_id: int) -> int:
|
||||
"""获取用户风险分(0-100)"""
|
||||
if user_id is None:
|
||||
return 0
|
||||
user_id_int = int(user_id)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (user_id_int,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
try:
|
||||
return max(0, min(100, int(row["risk_score"])))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def get_combined_score(self, ip: str, user_id: int = None) -> int:
|
||||
"""综合风险分 = max(IP分, 用户分) + 行为加成"""
|
||||
base = max(self.get_ip_score(ip), self.get_user_score(user_id) if user_id is not None else 0)
|
||||
bonus = self._get_behavior_bonus(ip, user_id)
|
||||
return max(0, min(100, int(base + bonus)))
|
||||
|
||||
def record_threat(
|
||||
self,
|
||||
ip: str,
|
||||
user_id: int,
|
||||
threat_type: str,
|
||||
score: int,
|
||||
request_path: str = None,
|
||||
payload: str = None,
|
||||
):
|
||||
"""记录威胁事件到数据库,并更新IP/用户风险分"""
|
||||
ip_text = str(ip or "").strip()[:64]
|
||||
user_id_int = int(user_id) if user_id is not None else None
|
||||
threat_type_text = str(threat_type or "").strip()[:64] or "unknown"
|
||||
score_int = max(0, int(score))
|
||||
path_text = str(request_path or "").strip()[:512] if request_path else None
|
||||
payload_text = str(payload or "").strip() if payload else None
|
||||
if payload_text and len(payload_text) > 2048:
|
||||
payload_text = payload_text[:2048]
|
||||
|
||||
now_str = get_cst_now_str()
|
||||
|
||||
ip_ban_action: Optional[_BanAction] = None
|
||||
user_ban_action: Optional[_BanAction] = None
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO threat_events (
|
||||
threat_type, score, ip, user_id, request_path, value_preview, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
threat_type_text,
|
||||
score_int,
|
||||
ip_text or None,
|
||||
user_id_int,
|
||||
path_text,
|
||||
payload_text,
|
||||
now_str,
|
||||
),
|
||||
)
|
||||
|
||||
update = self._update_scores(cursor, ip_text, user_id_int, score_int, now_str)
|
||||
|
||||
if self.auto_ban_enabled:
|
||||
ip_ban_action, user_ban_action = self._get_auto_ban_actions(
|
||||
cursor,
|
||||
ip_text,
|
||||
user_id_int,
|
||||
threat_type_text,
|
||||
score_int,
|
||||
update,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
if not self.auto_ban_enabled:
|
||||
return
|
||||
|
||||
if ip_ban_action and ip_text:
|
||||
self.blacklist.ban_ip(
|
||||
ip_text,
|
||||
reason=ip_ban_action.reason,
|
||||
duration_hours=ip_ban_action.duration_hours or self.auto_ban_duration_hours,
|
||||
permanent=ip_ban_action.permanent,
|
||||
)
|
||||
if user_ban_action and user_id_int is not None:
|
||||
self.blacklist._ban_user_internal(
|
||||
user_id_int,
|
||||
reason=user_ban_action.reason,
|
||||
duration_hours=user_ban_action.duration_hours or self.auto_ban_duration_hours,
|
||||
permanent=user_ban_action.permanent,
|
||||
)
|
||||
|
||||
def decay_scores(self):
|
||||
"""风险分衰减 - 定期调用,降低历史风险分"""
|
||||
now = get_cst_now()
|
||||
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT ip, risk_score, updated_at, created_at FROM ip_risk_scores")
|
||||
for row in cursor.fetchall():
|
||||
ip = row["ip"]
|
||||
current_score = int(row["risk_score"] or 0)
|
||||
updated_at = row["updated_at"] or row["created_at"]
|
||||
hours = self._hours_since(updated_at, now)
|
||||
if hours <= 0:
|
||||
continue
|
||||
new_score = self._apply_hourly_decay(current_score, hours)
|
||||
if new_score == current_score:
|
||||
continue
|
||||
cursor.execute(
|
||||
"UPDATE ip_risk_scores SET risk_score = ?, updated_at = ? WHERE ip = ?",
|
||||
(new_score, now_str, ip),
|
||||
)
|
||||
|
||||
cursor.execute("SELECT user_id, risk_score, updated_at, created_at FROM user_risk_scores")
|
||||
for row in cursor.fetchall():
|
||||
user_id = int(row["user_id"])
|
||||
current_score = int(row["risk_score"] or 0)
|
||||
updated_at = row["updated_at"] or row["created_at"]
|
||||
hours = self._hours_since(updated_at, now)
|
||||
if hours <= 0:
|
||||
continue
|
||||
new_score = self._apply_hourly_decay(current_score, hours)
|
||||
if new_score == current_score:
|
||||
continue
|
||||
cursor.execute(
|
||||
"UPDATE user_risk_scores SET risk_score = ?, updated_at = ? WHERE user_id = ?",
|
||||
(new_score, now_str, user_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _update_ip_score(self, ip: str, score_delta: int):
|
||||
"""更新IP风险分"""
|
||||
ip_text = str(ip or "").strip()[:64]
|
||||
if not ip_text:
|
||||
return
|
||||
delta = int(score_delta)
|
||||
now_str = get_cst_now_str()
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
self._update_scores(cursor, ip_text, None, delta, now_str)
|
||||
conn.commit()
|
||||
|
||||
def _update_user_score(self, user_id: int, score_delta: int):
|
||||
"""更新用户风险分"""
|
||||
if user_id is None:
|
||||
return
|
||||
user_id_int = int(user_id)
|
||||
delta = int(score_delta)
|
||||
now_str = get_cst_now_str()
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
self._update_scores(cursor, "", user_id_int, delta, now_str)
|
||||
conn.commit()
|
||||
|
||||
def reset_ip_score(self, ip: str) -> bool:
|
||||
"""清零指定IP的风险分"""
|
||||
ip_text = str(ip or "").strip()[:64]
|
||||
if not ip_text:
|
||||
return False
|
||||
|
||||
now_str = get_cst_now_str()
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT ip FROM ip_risk_scores WHERE ip = ?", (ip_text,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
cursor.execute(
|
||||
"UPDATE ip_risk_scores SET risk_score = 0, last_seen = ?, updated_at = ? WHERE ip = ?",
|
||||
(now_str, now_str, ip_text),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
|
||||
VALUES (?, 0, ?, ?, ?)
|
||||
""",
|
||||
(ip_text, now_str, now_str, now_str),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
def _update_scores(
|
||||
self,
|
||||
cursor,
|
||||
ip: str,
|
||||
user_id: Optional[int],
|
||||
score_delta: int,
|
||||
now_str: str,
|
||||
) -> _ScoreUpdateResult:
|
||||
ip_score = 0
|
||||
user_score = 0
|
||||
|
||||
if ip:
|
||||
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
current = int(row["risk_score"]) if row else 0
|
||||
ip_score = max(0, min(100, current + int(score_delta)))
|
||||
if row:
|
||||
cursor.execute(
|
||||
"UPDATE ip_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE ip = ?",
|
||||
(ip_score, now_str, now_str, ip),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(ip, ip_score, now_str, now_str, now_str),
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (int(user_id),))
|
||||
row = cursor.fetchone()
|
||||
current = int(row["risk_score"]) if row else 0
|
||||
user_score = max(0, min(100, current + int(score_delta)))
|
||||
if row:
|
||||
cursor.execute(
|
||||
"UPDATE user_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE user_id = ?",
|
||||
(user_score, now_str, now_str, int(user_id)),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(int(user_id), user_score, now_str, now_str, now_str),
|
||||
)
|
||||
|
||||
return _ScoreUpdateResult(ip_score=ip_score, user_score=user_score)
|
||||
|
||||
def _get_auto_ban_actions(
|
||||
self,
|
||||
cursor,
|
||||
ip: str,
|
||||
user_id: Optional[int],
|
||||
threat_type: str,
|
||||
score: int,
|
||||
update: _ScoreUpdateResult,
|
||||
) -> tuple[Optional["_BanAction"], Optional["_BanAction"]]:
|
||||
ip_action: Optional[_BanAction] = None
|
||||
user_action: Optional[_BanAction] = None
|
||||
|
||||
if threat_type == C.THREAT_TYPE_JNDI_INJECTION:
|
||||
if ip:
|
||||
ip_action = _BanAction(reason="JNDI injection detected", permanent=True)
|
||||
if user_id is not None:
|
||||
user_action = _BanAction(reason="JNDI injection detected", permanent=True)
|
||||
return ip_action, user_action
|
||||
|
||||
if ip and update.ip_score >= 100:
|
||||
ip_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
|
||||
if user_id is not None and update.user_score >= 100:
|
||||
user_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
|
||||
|
||||
if score < self.high_risk_threshold:
|
||||
return ip_action, user_action
|
||||
|
||||
window_start = (get_cst_now() - timedelta(hours=self.high_risk_window_hours)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if ip:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM threat_events
|
||||
WHERE ip = ? AND score >= ? AND created_at >= ?
|
||||
""",
|
||||
(ip, int(self.high_risk_threshold), window_start),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
cnt = int(row["cnt"]) if row else 0
|
||||
if cnt >= self.high_risk_permanent_ban_count:
|
||||
ip_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
|
||||
|
||||
if user_id is not None:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM threat_events
|
||||
WHERE user_id = ? AND score >= ? AND created_at >= ?
|
||||
""",
|
||||
(int(user_id), int(self.high_risk_threshold), window_start),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
cnt = int(row["cnt"]) if row else 0
|
||||
if cnt >= self.high_risk_permanent_ban_count:
|
||||
user_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
|
||||
|
||||
return ip_action, user_action
|
||||
|
||||
def _get_behavior_bonus(self, ip: str, user_id: Optional[int]) -> int:
|
||||
return 0
|
||||
|
||||
def _hours_since(self, dt_str: Optional[str], now) -> int:
|
||||
if not dt_str:
|
||||
return 0
|
||||
try:
|
||||
dt = parse_cst_datetime(str(dt_str))
|
||||
except Exception:
|
||||
return 0
|
||||
seconds = (now - dt).total_seconds()
|
||||
if seconds <= 0:
|
||||
return 0
|
||||
return int(seconds // 3600)
|
||||
|
||||
def _apply_hourly_decay(self, score: int, hours: int) -> int:
|
||||
score_int = max(0, int(score))
|
||||
if score_int <= 0 or hours <= 0:
|
||||
return score_int
|
||||
decayed = int(math.floor(score_int * (0.9**int(hours))))
|
||||
return max(0, min(100, decayed))
|
||||
410
security/threat_detector.py
Normal file
410
security/threat_detector.py
Normal file
@@ -0,0 +1,410 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
from urllib.parse import unquote_plus
|
||||
|
||||
from . import constants as C
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreatResult:
|
||||
threat_type: str
|
||||
score: int
|
||||
field_name: str
|
||||
rule: str = ""
|
||||
matched: str = ""
|
||||
value_preview: str = ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"threat_type": self.threat_type,
|
||||
"score": int(self.score),
|
||||
"field_name": self.field_name,
|
||||
"rule": self.rule,
|
||||
"matched": self.matched,
|
||||
"value_preview": self.value_preview,
|
||||
}
|
||||
|
||||
|
||||
class ThreatDetector:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_value_length: int = 4096,
|
||||
max_decode_rounds: int = 2,
|
||||
) -> None:
|
||||
self.max_value_length = max(64, int(max_value_length))
|
||||
self.max_decode_rounds = max(0, int(max_decode_rounds))
|
||||
|
||||
def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]:
|
||||
"""扫描单个输入值(支持 dict/list 等嵌套结构)。"""
|
||||
results: List[ThreatResult] = []
|
||||
for sub_field, leaf in self._flatten_value(value, field_name):
|
||||
text = self._stringify(leaf)
|
||||
if not text:
|
||||
continue
|
||||
if len(text) > self.max_value_length:
|
||||
text = text[: self.max_value_length]
|
||||
results.extend(self._scan_text(text, sub_field))
|
||||
results.sort(key=lambda r: int(r.score), reverse=True)
|
||||
return results
|
||||
|
||||
def scan_request(self, request: Any) -> List[ThreatResult]:
|
||||
"""扫描整个请求对象(兼容 Flask Request / dict 风格对象)。"""
|
||||
results: List[ThreatResult] = []
|
||||
for field_name, value in self._extract_request_fields(request):
|
||||
results.extend(self.scan_input(value, field_name))
|
||||
results.sort(key=lambda r: int(r.score), reverse=True)
|
||||
return results
|
||||
|
||||
# ==================== Internal scanning ====================
|
||||
|
||||
def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]:
|
||||
hits: List[ThreatResult] = []
|
||||
|
||||
for check in [
|
||||
self._check_jndi_injection,
|
||||
self._check_sql_injection,
|
||||
self._check_xss,
|
||||
self._check_path_traversal,
|
||||
self._check_command_injection,
|
||||
self._check_ssrf,
|
||||
self._check_xxe,
|
||||
self._check_template_injection,
|
||||
self._check_sensitive_path_probe,
|
||||
]:
|
||||
result = check(text)
|
||||
if result:
|
||||
threat_type, score, rule, matched = result
|
||||
hits.append(
|
||||
ThreatResult(
|
||||
threat_type=threat_type,
|
||||
score=int(score),
|
||||
field_name=field_name,
|
||||
rule=rule,
|
||||
matched=matched,
|
||||
value_preview=self._preview(text),
|
||||
)
|
||||
)
|
||||
|
||||
return hits
|
||||
|
||||
def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
# 1) Direct match
|
||||
m = C.JNDI_DIRECT_RE.search(text)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0))
|
||||
|
||||
# 2) URL-decoded
|
||||
decoded = self._multi_unquote(text)
|
||||
if decoded != text:
|
||||
m2 = C.JNDI_DIRECT_RE.search(decoded)
|
||||
if m2:
|
||||
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0))
|
||||
|
||||
# 3) Obfuscation patterns (raw/decoded)
|
||||
for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]:
|
||||
m3 = C.JNDI_OBFUSCATED_RE.search(candidate)
|
||||
if m3:
|
||||
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0))
|
||||
|
||||
# 4) Try limited de-obfuscation to reveal ${jndi:...}
|
||||
deobf = self._deobfuscate_log4j(decoded)
|
||||
if deobf and deobf != decoded:
|
||||
m4 = C.JNDI_DIRECT_RE.search(deobf)
|
||||
if m4:
|
||||
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0))
|
||||
|
||||
# 5) Nested expression heuristic
|
||||
for candidate in [text, decoded]:
|
||||
m5 = C.NESTED_EXPRESSION_RE.search(candidate)
|
||||
if m5:
|
||||
return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0))
|
||||
|
||||
return None
|
||||
|
||||
def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
candidates = [text, self._multi_unquote(text)]
|
||||
for candidate in candidates:
|
||||
m = C.SQLI_UNION_SELECT_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0))
|
||||
m = C.SQLI_OR_1_EQ_1_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0))
|
||||
return None
|
||||
|
||||
def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
candidates = [text, self._multi_unquote(text)]
|
||||
for candidate in candidates:
|
||||
m = C.XSS_SCRIPT_TAG_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0))
|
||||
m = C.XSS_JS_PROTOCOL_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0))
|
||||
m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0))
|
||||
return None
|
||||
|
||||
def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
decoded = self._multi_unquote(text)
|
||||
candidates = [text, decoded]
|
||||
for candidate in candidates:
|
||||
m = C.PATH_TRAVERSAL_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0))
|
||||
return None
|
||||
|
||||
def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
decoded = self._multi_unquote(text)
|
||||
candidates = [text, decoded]
|
||||
for candidate in candidates:
|
||||
m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0))
|
||||
m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0))
|
||||
return None
|
||||
|
||||
def _check_ssrf(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
decoded = self._multi_unquote(text)
|
||||
candidates: List[Tuple[str, str]] = [(text, "")]
|
||||
if decoded != text:
|
||||
candidates.append((decoded, "_URL_DECODED"))
|
||||
|
||||
for candidate, suffix in candidates:
|
||||
m = C.SSRF_LOCALHOST_URL_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_LOCALHOST{suffix}", m.group(0))
|
||||
m = C.SSRF_INTERNAL_IP_URL_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_INTERNAL_IP{suffix}", m.group(0))
|
||||
m = C.SSRF_DANGEROUS_PROTOCOL_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_SSRF, C.SCORE_SSRF, f"SSRF_DANGEROUS_PROTOCOL{suffix}", m.group(0))
|
||||
|
||||
return None
|
||||
|
||||
def _check_xxe(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
decoded = self._multi_unquote(text)
|
||||
candidates: List[Tuple[str, str]] = [(text, "")]
|
||||
if decoded != text:
|
||||
candidates.append((decoded, "_URL_DECODED"))
|
||||
|
||||
for candidate, suffix in candidates:
|
||||
m_doctype = C.XXE_DOCTYPE_RE.search(candidate)
|
||||
if not m_doctype:
|
||||
continue
|
||||
m_entity = C.XXE_ENTITY_RE.search(candidate)
|
||||
if not m_entity:
|
||||
continue
|
||||
m_sys_pub = C.XXE_SYSTEM_PUBLIC_RE.search(candidate)
|
||||
if not m_sys_pub:
|
||||
continue
|
||||
matched = f"{m_doctype.group(0)} {m_entity.group(0)} {m_sys_pub.group(0)}"
|
||||
return (C.THREAT_TYPE_XXE, C.SCORE_XXE, f"XXE_KEYWORD_COMBO{suffix}", matched)
|
||||
|
||||
return None
|
||||
|
||||
def _check_template_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
decoded = self._multi_unquote(text)
|
||||
candidates: List[Tuple[str, str]] = [(text, "")]
|
||||
if decoded != text:
|
||||
candidates.append((decoded, "_URL_DECODED"))
|
||||
|
||||
for candidate, suffix in candidates:
|
||||
m = C.TEMPLATE_JINJA_EXPR_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_TEMPLATE_INJECTION, C.SCORE_TEMPLATE_INJECTION, f"TEMPLATE_JINJA_EXPR{suffix}", m.group(0))
|
||||
m = C.TEMPLATE_JINJA_STMT_RE.search(candidate)
|
||||
if m:
|
||||
return (C.THREAT_TYPE_TEMPLATE_INJECTION, C.SCORE_TEMPLATE_INJECTION, f"TEMPLATE_JINJA_STMT{suffix}", m.group(0))
|
||||
m = C.TEMPLATE_VELOCITY_DIRECTIVE_RE.search(candidate)
|
||||
if m:
|
||||
return (
|
||||
C.THREAT_TYPE_TEMPLATE_INJECTION,
|
||||
C.SCORE_TEMPLATE_INJECTION,
|
||||
f"TEMPLATE_VELOCITY_DIRECTIVE{suffix}",
|
||||
m.group(0),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _check_sensitive_path_probe(self, text: str) -> Optional[Tuple[str, int, str, str]]:
|
||||
decoded = self._multi_unquote(text)
|
||||
candidates: List[Tuple[str, str]] = [(text, "")]
|
||||
if decoded != text:
|
||||
candidates.append((decoded, "_URL_DECODED"))
|
||||
|
||||
for candidate, suffix in candidates:
|
||||
m = C.SENSITIVE_PATH_DOTFILES_RE.search(candidate)
|
||||
if m:
|
||||
return (
|
||||
C.THREAT_TYPE_SENSITIVE_PATH_PROBE,
|
||||
C.SCORE_SENSITIVE_PATH_PROBE,
|
||||
f"SENSITIVE_PATH_DOTFILES{suffix}",
|
||||
m.group(0),
|
||||
)
|
||||
m = C.SENSITIVE_PATH_PROBE_RE.search(candidate)
|
||||
if m:
|
||||
return (
|
||||
C.THREAT_TYPE_SENSITIVE_PATH_PROBE,
|
||||
C.SCORE_SENSITIVE_PATH_PROBE,
|
||||
f"SENSITIVE_PATH_PROBE{suffix}",
|
||||
m.group(0),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# ==================== Helpers ====================
|
||||
|
||||
def _preview(self, text: str, limit: int = 160) -> str:
|
||||
s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
|
||||
if len(s) <= limit:
|
||||
return s
|
||||
return s[: limit - 3] + "..."
|
||||
|
||||
def _stringify(self, value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
return value.decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
return ""
|
||||
try:
|
||||
return str(value)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _multi_unquote(self, text: str) -> str:
|
||||
s = text
|
||||
for _ in range(self.max_decode_rounds):
|
||||
try:
|
||||
nxt = unquote_plus(s)
|
||||
except Exception:
|
||||
break
|
||||
if nxt == s:
|
||||
break
|
||||
s = nxt
|
||||
return s
|
||||
|
||||
def _deobfuscate_log4j(self, text: str) -> str:
|
||||
# Replace ${...:-x} with x (including ${::-x}).
|
||||
# This is intentionally conservative to reduce false positives.
|
||||
import re
|
||||
|
||||
s = text
|
||||
pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}")
|
||||
for _ in range(3):
|
||||
nxt = pattern.sub(lambda m: m.group(1), s)
|
||||
if nxt == s:
|
||||
break
|
||||
s = nxt
|
||||
return s
|
||||
|
||||
def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]:
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
key = self._stringify(k) or "key"
|
||||
sub_name = f"{field_name}.{key}" if field_name else key
|
||||
yield from self._flatten_value(v, sub_name)
|
||||
return
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
for i, v in enumerate(value):
|
||||
sub_name = f"{field_name}[{i}]"
|
||||
yield from self._flatten_value(v, sub_name)
|
||||
return
|
||||
yield (field_name, value)
|
||||
|
||||
def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]:
|
||||
# dict-like input (useful for unit tests / non-Flask callers)
|
||||
if isinstance(request, dict):
|
||||
out: List[Tuple[str, Any]] = []
|
||||
for k, v in request.items():
|
||||
out.append((self._stringify(k) or "request", v))
|
||||
return out
|
||||
|
||||
out: List[Tuple[str, Any]] = []
|
||||
|
||||
# path / method
|
||||
for attr_name in ["method", "path", "full_path", "url", "remote_addr"]:
|
||||
try:
|
||||
v = getattr(request, attr_name, None)
|
||||
except Exception:
|
||||
v = None
|
||||
if v:
|
||||
out.append((attr_name, v))
|
||||
|
||||
# args / form (Flask MultiDict)
|
||||
out.extend(self._extract_multidict(getattr(request, "args", None), "args"))
|
||||
out.extend(self._extract_multidict(getattr(request, "form", None), "form"))
|
||||
|
||||
# headers
|
||||
try:
|
||||
headers = getattr(request, "headers", None)
|
||||
if headers is not None:
|
||||
try:
|
||||
items = headers.items()
|
||||
except Exception:
|
||||
items = []
|
||||
for k, v in items:
|
||||
out.append((f"headers.{self._stringify(k)}", v))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# cookies
|
||||
try:
|
||||
cookies = getattr(request, "cookies", None)
|
||||
if isinstance(cookies, dict):
|
||||
for k, v in cookies.items():
|
||||
out.append((f"cookies.{self._stringify(k)}", v))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# json body
|
||||
data = None
|
||||
try:
|
||||
get_json = getattr(request, "get_json", None)
|
||||
if callable(get_json):
|
||||
data = get_json(silent=True)
|
||||
except Exception:
|
||||
data = None
|
||||
|
||||
if data is not None:
|
||||
for name, v in self._flatten_value(data, "json"):
|
||||
out.append((name, v))
|
||||
return out
|
||||
|
||||
# raw body (as a fallback)
|
||||
try:
|
||||
get_data = getattr(request, "get_data", None)
|
||||
if callable(get_data):
|
||||
raw = get_data(cache=True, as_text=True)
|
||||
if raw:
|
||||
out.append(("body", raw))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return out
|
||||
|
||||
def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]:
|
||||
out: List[Tuple[str, Any]] = []
|
||||
if md is None:
|
||||
return out
|
||||
try:
|
||||
items = md.items(multi=True)
|
||||
except Exception:
|
||||
try:
|
||||
items = md.items()
|
||||
except Exception:
|
||||
return out
|
||||
for k, v in items:
|
||||
out.append((f"{prefix}.{self._stringify(k)}", v))
|
||||
return out
|
||||
Reference in New Issue
Block a user