Files
zsglpt/security/threat_detector.py
yuyx 46253337eb feat: 实现完整安全防护系统
Phase 1 - 威胁检测引擎:
- security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测
- security/constants.py: 威胁检测规则和评分常量
- 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist

Phase 2 - 风险评分与黑名单:
- security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减
- security/blacklist.py: 黑名单管理,自动封禁规则

Phase 3 - 响应策略:
- security/honeypot.py: 蜜罐响应生成器
- security/response_handler.py: 渐进式响应策略

Phase 4 - 集成:
- security/middleware.py: Flask安全中间件
- routes/admin_api/security.py: 管理后台安全仪表板API
- 36个测试用例全部通过

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-27 01:28:38 +08:00

317 lines
11 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import unquote_plus
from . import constants as C
@dataclass
class ThreatResult:
threat_type: str
score: int
field_name: str
rule: str = ""
matched: str = ""
value_preview: str = ""
def to_dict(self) -> dict:
return {
"threat_type": self.threat_type,
"score": int(self.score),
"field_name": self.field_name,
"rule": self.rule,
"matched": self.matched,
"value_preview": self.value_preview,
}
class ThreatDetector:
def __init__(
self,
*,
max_value_length: int = 4096,
max_decode_rounds: int = 2,
) -> None:
self.max_value_length = max(64, int(max_value_length))
self.max_decode_rounds = max(0, int(max_decode_rounds))
def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]:
"""扫描单个输入值(支持 dict/list 等嵌套结构)。"""
results: List[ThreatResult] = []
for sub_field, leaf in self._flatten_value(value, field_name):
text = self._stringify(leaf)
if not text:
continue
if len(text) > self.max_value_length:
text = text[: self.max_value_length]
results.extend(self._scan_text(text, sub_field))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
def scan_request(self, request: Any) -> List[ThreatResult]:
"""扫描整个请求对象(兼容 Flask Request / dict 风格对象)。"""
results: List[ThreatResult] = []
for field_name, value in self._extract_request_fields(request):
results.extend(self.scan_input(value, field_name))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
# ==================== Internal scanning ====================
def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]:
hits: List[ThreatResult] = []
for check in [
self._check_jndi_injection,
self._check_sql_injection,
self._check_xss,
self._check_path_traversal,
self._check_command_injection,
]:
result = check(text)
if result:
threat_type, score, rule, matched = result
hits.append(
ThreatResult(
threat_type=threat_type,
score=int(score),
field_name=field_name,
rule=rule,
matched=matched,
value_preview=self._preview(text),
)
)
return hits
def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
# 1) Direct match
m = C.JNDI_DIRECT_RE.search(text)
if m:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0))
# 2) URL-decoded
decoded = self._multi_unquote(text)
if decoded != text:
m2 = C.JNDI_DIRECT_RE.search(decoded)
if m2:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0))
# 3) Obfuscation patterns (raw/decoded)
for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]:
m3 = C.JNDI_OBFUSCATED_RE.search(candidate)
if m3:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0))
# 4) Try limited de-obfuscation to reveal ${jndi:...}
deobf = self._deobfuscate_log4j(decoded)
if deobf and deobf != decoded:
m4 = C.JNDI_DIRECT_RE.search(deobf)
if m4:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0))
# 5) Nested expression heuristic
for candidate in [text, decoded]:
m5 = C.NESTED_EXPRESSION_RE.search(candidate)
if m5:
return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0))
return None
def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.SQLI_UNION_SELECT_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0))
m = C.SQLI_OR_1_EQ_1_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0))
return None
def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.XSS_SCRIPT_TAG_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0))
m = C.XSS_JS_PROTOCOL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0))
m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0))
return None
def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.PATH_TRAVERSAL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0))
return None
def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0))
m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0))
return None
# ==================== Helpers ====================
def _preview(self, text: str, limit: int = 160) -> str:
s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
if len(s) <= limit:
return s
return s[: limit - 3] + "..."
def _stringify(self, value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
try:
return value.decode("utf-8", errors="ignore")
except Exception:
return ""
try:
return str(value)
except Exception:
return ""
def _multi_unquote(self, text: str) -> str:
s = text
for _ in range(self.max_decode_rounds):
try:
nxt = unquote_plus(s)
except Exception:
break
if nxt == s:
break
s = nxt
return s
def _deobfuscate_log4j(self, text: str) -> str:
# Replace ${...:-x} with x (including ${::-x}).
# This is intentionally conservative to reduce false positives.
import re
s = text
pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}")
for _ in range(3):
nxt = pattern.sub(lambda m: m.group(1), s)
if nxt == s:
break
s = nxt
return s
def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]:
if isinstance(value, dict):
for k, v in value.items():
key = self._stringify(k) or "key"
sub_name = f"{field_name}.{key}" if field_name else key
yield from self._flatten_value(v, sub_name)
return
if isinstance(value, (list, tuple, set)):
for i, v in enumerate(value):
sub_name = f"{field_name}[{i}]"
yield from self._flatten_value(v, sub_name)
return
yield (field_name, value)
def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]:
# dict-like input (useful for unit tests / non-Flask callers)
if isinstance(request, dict):
out: List[Tuple[str, Any]] = []
for k, v in request.items():
out.append((self._stringify(k) or "request", v))
return out
out: List[Tuple[str, Any]] = []
# path / method
for attr_name in ["method", "path", "full_path", "url", "remote_addr"]:
try:
v = getattr(request, attr_name, None)
except Exception:
v = None
if v:
out.append((attr_name, v))
# args / form (Flask MultiDict)
out.extend(self._extract_multidict(getattr(request, "args", None), "args"))
out.extend(self._extract_multidict(getattr(request, "form", None), "form"))
# headers
try:
headers = getattr(request, "headers", None)
if headers is not None:
try:
items = headers.items()
except Exception:
items = []
for k, v in items:
out.append((f"headers.{self._stringify(k)}", v))
except Exception:
pass
# cookies
try:
cookies = getattr(request, "cookies", None)
if isinstance(cookies, dict):
for k, v in cookies.items():
out.append((f"cookies.{self._stringify(k)}", v))
except Exception:
pass
# json body
data = None
try:
get_json = getattr(request, "get_json", None)
if callable(get_json):
data = get_json(silent=True)
except Exception:
data = None
if data is not None:
for name, v in self._flatten_value(data, "json"):
out.append((name, v))
return out
# raw body (as a fallback)
try:
get_data = getattr(request, "get_data", None)
if callable(get_data):
raw = get_data(cache=True, as_text=True)
if raw:
out.append(("body", raw))
except Exception:
pass
return out
def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]:
out: List[Tuple[str, Any]] = []
if md is None:
return out
try:
items = md.items(multi=True)
except Exception:
try:
items = md.items()
except Exception:
return out
for k, v in items:
out.append((f"{prefix}.{self._stringify(k)}", v))
return out