#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations from dataclasses import dataclass from typing import Any, Iterable, List, Optional, Tuple from urllib.parse import unquote_plus from . import constants as C @dataclass class ThreatResult: threat_type: str score: int field_name: str rule: str = "" matched: str = "" value_preview: str = "" def to_dict(self) -> dict: return { "threat_type": self.threat_type, "score": int(self.score), "field_name": self.field_name, "rule": self.rule, "matched": self.matched, "value_preview": self.value_preview, } class ThreatDetector: def __init__( self, *, max_value_length: int = 4096, max_decode_rounds: int = 2, ) -> None: self.max_value_length = max(64, int(max_value_length)) self.max_decode_rounds = max(0, int(max_decode_rounds)) def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]: """扫描单个输入值(支持 dict/list 等嵌套结构)。""" results: List[ThreatResult] = [] for sub_field, leaf in self._flatten_value(value, field_name): text = self._stringify(leaf) if not text: continue if len(text) > self.max_value_length: text = text[: self.max_value_length] results.extend(self._scan_text(text, sub_field)) results.sort(key=lambda r: int(r.score), reverse=True) return results def scan_request(self, request: Any) -> List[ThreatResult]: """扫描整个请求对象(兼容 Flask Request / dict 风格对象)。""" results: List[ThreatResult] = [] for field_name, value in self._extract_request_fields(request): results.extend(self.scan_input(value, field_name)) results.sort(key=lambda r: int(r.score), reverse=True) return results # ==================== Internal scanning ==================== def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]: hits: List[ThreatResult] = [] for check in [ self._check_jndi_injection, self._check_sql_injection, self._check_xss, self._check_path_traversal, self._check_command_injection, ]: result = check(text) if result: threat_type, score, rule, matched = result hits.append( ThreatResult( threat_type=threat_type, score=int(score), field_name=field_name, rule=rule, matched=matched, value_preview=self._preview(text), ) ) return hits def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]: # 1) Direct match m = C.JNDI_DIRECT_RE.search(text) if m: return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0)) # 2) URL-decoded decoded = self._multi_unquote(text) if decoded != text: m2 = C.JNDI_DIRECT_RE.search(decoded) if m2: return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0)) # 3) Obfuscation patterns (raw/decoded) for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]: m3 = C.JNDI_OBFUSCATED_RE.search(candidate) if m3: return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0)) # 4) Try limited de-obfuscation to reveal ${jndi:...} deobf = self._deobfuscate_log4j(decoded) if deobf and deobf != decoded: m4 = C.JNDI_DIRECT_RE.search(deobf) if m4: return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0)) # 5) Nested expression heuristic for candidate in [text, decoded]: m5 = C.NESTED_EXPRESSION_RE.search(candidate) if m5: return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0)) return None def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]: candidates = [text, self._multi_unquote(text)] for candidate in candidates: m = C.SQLI_UNION_SELECT_RE.search(candidate) if m: return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0)) m = C.SQLI_OR_1_EQ_1_RE.search(candidate) if m: return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0)) return None def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]: candidates = [text, self._multi_unquote(text)] for candidate in candidates: m = C.XSS_SCRIPT_TAG_RE.search(candidate) if m: return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0)) m = C.XSS_JS_PROTOCOL_RE.search(candidate) if m: return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0)) m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate) if m: return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0)) return None def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]: decoded = self._multi_unquote(text) candidates = [text, decoded] for candidate in candidates: m = C.PATH_TRAVERSAL_RE.search(candidate) if m: return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0)) return None def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]: decoded = self._multi_unquote(text) candidates = [text, decoded] for candidate in candidates: m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate) if m: return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0)) m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate) if m: return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0)) return None # ==================== Helpers ==================== def _preview(self, text: str, limit: int = 160) -> str: s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t") if len(s) <= limit: return s return s[: limit - 3] + "..." def _stringify(self, value: Any) -> str: if value is None: return "" if isinstance(value, bytes): try: return value.decode("utf-8", errors="ignore") except Exception: return "" try: return str(value) except Exception: return "" def _multi_unquote(self, text: str) -> str: s = text for _ in range(self.max_decode_rounds): try: nxt = unquote_plus(s) except Exception: break if nxt == s: break s = nxt return s def _deobfuscate_log4j(self, text: str) -> str: # Replace ${...:-x} with x (including ${::-x}). # This is intentionally conservative to reduce false positives. import re s = text pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}") for _ in range(3): nxt = pattern.sub(lambda m: m.group(1), s) if nxt == s: break s = nxt return s def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]: if isinstance(value, dict): for k, v in value.items(): key = self._stringify(k) or "key" sub_name = f"{field_name}.{key}" if field_name else key yield from self._flatten_value(v, sub_name) return if isinstance(value, (list, tuple, set)): for i, v in enumerate(value): sub_name = f"{field_name}[{i}]" yield from self._flatten_value(v, sub_name) return yield (field_name, value) def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]: # dict-like input (useful for unit tests / non-Flask callers) if isinstance(request, dict): out: List[Tuple[str, Any]] = [] for k, v in request.items(): out.append((self._stringify(k) or "request", v)) return out out: List[Tuple[str, Any]] = [] # path / method for attr_name in ["method", "path", "full_path", "url", "remote_addr"]: try: v = getattr(request, attr_name, None) except Exception: v = None if v: out.append((attr_name, v)) # args / form (Flask MultiDict) out.extend(self._extract_multidict(getattr(request, "args", None), "args")) out.extend(self._extract_multidict(getattr(request, "form", None), "form")) # headers try: headers = getattr(request, "headers", None) if headers is not None: try: items = headers.items() except Exception: items = [] for k, v in items: out.append((f"headers.{self._stringify(k)}", v)) except Exception: pass # cookies try: cookies = getattr(request, "cookies", None) if isinstance(cookies, dict): for k, v in cookies.items(): out.append((f"cookies.{self._stringify(k)}", v)) except Exception: pass # json body data = None try: get_json = getattr(request, "get_json", None) if callable(get_json): data = get_json(silent=True) except Exception: data = None if data is not None: for name, v in self._flatten_value(data, "json"): out.append((name, v)) return out # raw body (as a fallback) try: get_data = getattr(request, "get_data", None) if callable(get_data): raw = get_data(cache=True, as_text=True) if raw: out.append(("body", raw)) except Exception: pass return out def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]: out: List[Tuple[str, Any]] = [] if md is None: return out try: items = md.items(multi=True) except Exception: try: items = md.items() except Exception: return out for k, v in items: out.append((f"{prefix}.{self._stringify(k)}", v)) return out