Files
zsglpt/security/middleware.py
Yu Yon 53c78e8e3c feat: 添加安全模块 + Dockerfile添加curl支持健康检查
主要更新:
- 新增 security/ 安全模块 (风险评估、威胁检测、蜜罐等)
- Dockerfile 添加 curl 以支持 Docker 健康检查
- 前端页面更新 (管理后台、用户端)
- 数据库迁移和 schema 更新
- 新增 kdocs 上传服务
- 添加安全相关测试用例

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-08 17:48:33 +08:00

308 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import logging
from typing import Optional
from flask import g, jsonify, request
from flask_login import current_user
from app_logger import get_logger
from app_security import get_rate_limit_ip
from .blacklist import BlacklistManager
from .honeypot import HoneypotResponder
from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from .risk_scorer import RiskScorer
from .threat_detector import ThreatDetector, ThreatResult
# 全局实例(保持单例,避免重复初始化开销)
detector = ThreatDetector()
blacklist = BlacklistManager()
scorer = RiskScorer(blacklist_manager=blacklist)
handler: Optional[ResponseHandler] = None
honeypot: Optional[HoneypotResponder] = None
def _get_handler() -> ResponseHandler:
global handler
if handler is None:
handler = ResponseHandler()
return handler
def _get_honeypot() -> HoneypotResponder:
global honeypot
if honeypot is None:
honeypot = HoneypotResponder()
return honeypot
def _get_security_log_level(app) -> int:
level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper()
return int(getattr(logging, level_name, logging.INFO))
def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None:
"""按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。"""
try:
logger = get_logger("app")
min_level = _get_security_log_level(app)
if int(level) >= int(min_level):
logger.log(int(level), message, *args, exc_info=exc_info)
except Exception:
# 安全模块日志故障不得影响正常请求
return
def _is_static_request(app) -> bool:
"""对静态文件请求跳过安全检查以提升性能。"""
try:
path = str(getattr(request, "path", "") or "")
except Exception:
path = ""
if path.startswith("/static/"):
return True
try:
static_url_path = getattr(app, "static_url_path", None) or "/static"
if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"):
return True
except Exception:
pass
try:
endpoint = getattr(request, "endpoint", None)
if endpoint in {"static", "serve_static"}:
return True
except Exception:
pass
return False
def _safe_get_user_id() -> Optional[int]:
try:
if hasattr(current_user, "is_authenticated") and current_user.is_authenticated:
return getattr(current_user, "id", None)
except Exception:
return None
return None
def _scan_request_threats(req) -> list[ThreatResult]:
"""仅扫描 GET query 与 POST JSON body降低开销与误报"""
threats: list[ThreatResult] = []
try:
# 1) Query 参数(所有方法均可能携带 query string
try:
args = getattr(req, "args", None)
if args:
# MultiDict -> dict(list) 以保留多值
args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args)
threats.extend(detector.scan_input(args_dict, "args"))
except Exception:
pass
# 2) JSON body主要针对 POST其他方法保持兼容
try:
method = str(getattr(req, "method", "") or "").upper()
except Exception:
method = ""
if method in {"POST", "PUT", "PATCH", "DELETE"}:
try:
data = req.get_json(silent=True) if hasattr(req, "get_json") else None
except Exception:
data = None
if data is not None:
threats.extend(detector.scan_input(data, "json"))
except Exception:
# 扫描失败不应阻断业务
return []
threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True)
return threats
def init_security_middleware(app):
"""初始化安全中间件到 Flask 应用。"""
try:
scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True))
except Exception:
pass
@app.before_request
def security_check():
if not bool(app.config.get("SECURITY_ENABLED", True)):
return None
if _is_static_request(app):
return None
try:
ip = get_rate_limit_ip()
except Exception:
ip = getattr(request, "remote_addr", "") or ""
user_id = _safe_get_user_id()
# 默认值,确保后续逻辑可用
g.risk_score = 0
g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.honeypot_mode = False
g.honeypot_endpoint = None
g.honeypot_generated = False
try:
# 1) 检查黑名单(静默拒绝,返回通用错误)
try:
if blacklist.is_ip_banned(ip):
_log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True)
try:
if user_id is not None and blacklist.is_user_banned(user_id):
_log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True)
# 2) 扫描威胁GET query / POST JSON
threats = _scan_request_threats(request)
if threats:
max_threat = threats[0]
_log(
app,
logging.WARNING,
"威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s",
ip,
user_id,
getattr(max_threat, "threat_type", "unknown"),
getattr(max_threat, "score", 0),
getattr(max_threat, "field_name", ""),
getattr(max_threat, "rule", ""),
)
# 记录威胁事件(异常不应阻断业务)
try:
payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or ""
scorer.record_threat(
ip=ip,
user_id=user_id,
threat_type=getattr(max_threat, "threat_type", "unknown"),
score=int(getattr(max_threat, "score", 0) or 0),
request_path=getattr(request, "path", None),
payload=str(payload)[:500] if payload else None,
)
except Exception:
_log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
# 高危威胁启用蜜罐模式
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if int(getattr(max_threat, "score", 0) or 0) >= 80:
g.honeypot_mode = True
g.honeypot_endpoint = getattr(request, "endpoint", None)
except Exception:
pass
# 3) 综合风险分与响应策略
try:
risk_score = scorer.get_combined_score(ip, user_id)
except Exception:
_log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
risk_score = 0
try:
strategy = _get_handler().get_strategy(risk_score)
except Exception:
_log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True)
strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.risk_score = int(risk_score or 0)
g.response_strategy = strategy
# 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略)
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if getattr(strategy, "action", None) == ResponseAction.HONEYPOT:
g.honeypot_mode = True
except Exception:
pass
# 4) 应用延迟
try:
if float(getattr(strategy, "delay_seconds", 0) or 0) > 0:
_get_handler().apply_delay(strategy)
except Exception:
_log(app, logging.ERROR, "延迟应用失败", exc_info=True)
# 优先短路:避免业务 side effects例如发送邮件/修改状态)
if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
g.honeypot_generated = True
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True)
return None
except Exception:
# 全局兜底:安全模块任何异常都不能阻断正常请求
_log(app, logging.ERROR, "安全中间件发生异常", exc_info=True)
return None
return None # 继续正常处理
@app.after_request
def security_response(response):
"""请求后处理 - 兜底应用蜜罐响应。"""
if not bool(app.config.get("SECURITY_ENABLED", True)):
return response
if not bool(app.config.get("HONEYPOT_ENABLED", True)):
return response
try:
if _is_static_request(app):
return response
except Exception:
pass
# 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动
try:
if getattr(g, "honeypot_generated", False):
return response
except Exception:
pass
try:
if getattr(g, "honeypot_mode", False):
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True)
return response
return response