#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import logging from typing import Optional from flask import g, jsonify, request from flask_login import current_user from app_logger import get_logger from app_security import get_rate_limit_ip from .blacklist import BlacklistManager from .honeypot import HoneypotResponder from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy from .risk_scorer import RiskScorer from .threat_detector import ThreatDetector, ThreatResult # 全局实例(保持单例,避免重复初始化开销) detector = ThreatDetector() blacklist = BlacklistManager() scorer = RiskScorer(blacklist_manager=blacklist) handler: Optional[ResponseHandler] = None honeypot: Optional[HoneypotResponder] = None def _get_handler() -> ResponseHandler: global handler if handler is None: handler = ResponseHandler() return handler def _get_honeypot() -> HoneypotResponder: global honeypot if honeypot is None: honeypot = HoneypotResponder() return honeypot def _get_security_log_level(app) -> int: level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper() return int(getattr(logging, level_name, logging.INFO)) def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None: """按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。""" try: logger = get_logger("app") min_level = _get_security_log_level(app) if int(level) >= int(min_level): logger.log(int(level), message, *args, exc_info=exc_info) except Exception: # 安全模块日志故障不得影响正常请求 return def _is_static_request(app) -> bool: """对静态文件请求跳过安全检查以提升性能。""" try: path = str(getattr(request, "path", "") or "") except Exception: path = "" if path.startswith("/static/"): return True try: static_url_path = getattr(app, "static_url_path", None) or "/static" if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"): return True except Exception: pass try: endpoint = getattr(request, "endpoint", None) if endpoint in {"static", "serve_static"}: return True except Exception: pass return False def _safe_get_user_id() -> Optional[int]: try: if hasattr(current_user, "is_authenticated") and current_user.is_authenticated: return getattr(current_user, "id", None) except Exception: return None return None def _scan_request_threats(req) -> list[ThreatResult]: """仅扫描 GET query 与 POST JSON body(降低开销与误报)。""" threats: list[ThreatResult] = [] try: # 1) Query 参数(所有方法均可能携带 query string) try: args = getattr(req, "args", None) if args: # MultiDict -> dict(list) 以保留多值 args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args) threats.extend(detector.scan_input(args_dict, "args")) except Exception: pass # 2) JSON body(主要针对 POST;其他方法保持兼容) try: method = str(getattr(req, "method", "") or "").upper() except Exception: method = "" if method in {"POST", "PUT", "PATCH", "DELETE"}: try: data = req.get_json(silent=True) if hasattr(req, "get_json") else None except Exception: data = None if data is not None: threats.extend(detector.scan_input(data, "json")) except Exception: # 扫描失败不应阻断业务 return [] threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True) return threats def init_security_middleware(app): """初始化安全中间件到 Flask 应用。""" try: scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True)) except Exception: pass @app.before_request def security_check(): if not bool(app.config.get("SECURITY_ENABLED", True)): return None if _is_static_request(app): return None try: ip = get_rate_limit_ip() except Exception: ip = getattr(request, "remote_addr", "") or "" user_id = _safe_get_user_id() # 默认值,确保后续逻辑可用 g.risk_score = 0 g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW) g.honeypot_mode = False g.honeypot_endpoint = None g.honeypot_generated = False try: # 1) 检查黑名单(静默拒绝,返回通用错误) try: if blacklist.is_ip_banned(ip): _log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256]) return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503 except Exception: _log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True) try: if user_id is not None and blacklist.is_user_banned(user_id): _log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256]) return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503 except Exception: _log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True) # 2) 扫描威胁(GET query / POST JSON) threats = _scan_request_threats(request) if threats: max_threat = threats[0] _log( app, logging.WARNING, "威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s", ip, user_id, getattr(max_threat, "threat_type", "unknown"), getattr(max_threat, "score", 0), getattr(max_threat, "field_name", ""), getattr(max_threat, "rule", ""), ) # 记录威胁事件(异常不应阻断业务) try: payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or "" scorer.record_threat( ip=ip, user_id=user_id, threat_type=getattr(max_threat, "threat_type", "unknown"), score=int(getattr(max_threat, "score", 0) or 0), request_path=getattr(request, "path", None), payload=str(payload)[:500] if payload else None, ) except Exception: _log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True) # 高危威胁启用蜜罐模式 if bool(app.config.get("HONEYPOT_ENABLED", True)): try: if int(getattr(max_threat, "score", 0) or 0) >= 80: g.honeypot_mode = True g.honeypot_endpoint = getattr(request, "endpoint", None) except Exception: pass # 3) 综合风险分与响应策略 try: risk_score = scorer.get_combined_score(ip, user_id) except Exception: _log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True) risk_score = 0 try: strategy = _get_handler().get_strategy(risk_score) except Exception: _log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True) strategy = ResponseStrategy(action=ResponseAction.ALLOW) g.risk_score = int(risk_score or 0) g.response_strategy = strategy # 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略) if bool(app.config.get("HONEYPOT_ENABLED", True)): try: if getattr(strategy, "action", None) == ResponseAction.HONEYPOT: g.honeypot_mode = True except Exception: pass # 4) 应用延迟 try: if float(getattr(strategy, "delay_seconds", 0) or 0) > 0: _get_handler().apply_delay(strategy) except Exception: _log(app, logging.ERROR, "延迟应用失败", exc_info=True) # 优先短路:避免业务 side effects(例如发送邮件/修改状态) if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)): try: fake_payload = None try: fake_payload = request.get_json(silent=True) except Exception: fake_payload = None fake_response = _get_honeypot().generate_fake_response( getattr(g, "honeypot_endpoint", "default"), fake_payload if isinstance(fake_payload, dict) else None, ) g.honeypot_generated = True return jsonify(fake_response), 200 except Exception: _log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True) return None except Exception: # 全局兜底:安全模块任何异常都不能阻断正常请求 _log(app, logging.ERROR, "安全中间件发生异常", exc_info=True) return None return None # 继续正常处理 @app.after_request def security_response(response): """请求后处理 - 兜底应用蜜罐响应。""" if not bool(app.config.get("SECURITY_ENABLED", True)): return response if not bool(app.config.get("HONEYPOT_ENABLED", True)): return response try: if _is_static_request(app): return response except Exception: pass # 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动 try: if getattr(g, "honeypot_generated", False): return response except Exception: pass try: if getattr(g, "honeypot_mode", False): fake_payload = None try: fake_payload = request.get_json(silent=True) except Exception: fake_payload = None fake_response = _get_honeypot().generate_fake_response( getattr(g, "honeypot_endpoint", "default"), fake_payload if isinstance(fake_payload, dict) else None, ) return jsonify(fake_response), 200 except Exception: _log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True) return response return response