Phase 1 - 威胁检测引擎: - security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测 - security/constants.py: 威胁检测规则和评分常量 - 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist Phase 2 - 风险评分与黑名单: - security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减 - security/blacklist.py: 黑名单管理,自动封禁规则 Phase 3 - 响应策略: - security/honeypot.py: 蜜罐响应生成器 - security/response_handler.py: 渐进式响应策略 Phase 4 - 集成: - security/middleware.py: Flask安全中间件 - routes/admin_api/security.py: 管理后台安全仪表板API - 36个测试用例全部通过 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
156 lines
4.5 KiB
Python
156 lines
4.5 KiB
Python
from __future__ import annotations
|
|
|
|
import pytest
|
|
from flask import Flask, g, jsonify
|
|
from flask_login import LoginManager
|
|
|
|
import db_pool
|
|
from db.schema import ensure_schema
|
|
from security import init_security_middleware
|
|
|
|
|
|
@pytest.fixture()
|
|
def _test_db(tmp_path):
|
|
db_file = tmp_path / "security_middleware_test.db"
|
|
|
|
old_pool = getattr(db_pool, "_pool", None)
|
|
try:
|
|
if old_pool is not None:
|
|
try:
|
|
old_pool.close_all()
|
|
except Exception:
|
|
pass
|
|
db_pool._pool = None
|
|
db_pool.init_pool(str(db_file), pool_size=1)
|
|
|
|
with db_pool.get_db() as conn:
|
|
ensure_schema(conn)
|
|
|
|
yield db_file
|
|
finally:
|
|
try:
|
|
if getattr(db_pool, "_pool", None) is not None:
|
|
db_pool._pool.close_all()
|
|
except Exception:
|
|
pass
|
|
db_pool._pool = old_pool
|
|
|
|
|
|
def _make_app(monkeypatch, _test_db, *, security_enabled: bool = True, honeypot_enabled: bool = True) -> Flask:
|
|
import security.middleware as sm
|
|
import security.response_handler as rh
|
|
|
|
# 避免测试因风控延迟而变慢
|
|
monkeypatch.setattr(rh.time, "sleep", lambda _seconds: None)
|
|
|
|
# 每个测试用例保持 handler/honeypot 的懒加载状态
|
|
sm.handler = None
|
|
sm.honeypot = None
|
|
|
|
app = Flask(__name__)
|
|
app.config.update(
|
|
SECRET_KEY="test-secret",
|
|
TESTING=True,
|
|
SECURITY_ENABLED=bool(security_enabled),
|
|
HONEYPOT_ENABLED=bool(honeypot_enabled),
|
|
SECURITY_LOG_LEVEL="CRITICAL", # 降低测试日志噪音
|
|
)
|
|
|
|
login_manager = LoginManager()
|
|
login_manager.init_app(app)
|
|
|
|
@login_manager.user_loader
|
|
def _load_user(_user_id: str):
|
|
return None
|
|
|
|
init_security_middleware(app)
|
|
return app
|
|
|
|
|
|
def _client_get(app: Flask, path: str, *, ip: str = "1.2.3.4"):
|
|
return app.test_client().get(path, environ_overrides={"REMOTE_ADDR": ip})
|
|
|
|
|
|
def test_middleware_blocks_banned_ip(_test_db, monkeypatch):
|
|
app = _make_app(monkeypatch, _test_db)
|
|
|
|
@app.get("/api/ping")
|
|
def _ping():
|
|
return jsonify({"ok": True})
|
|
|
|
import security.middleware as sm
|
|
|
|
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
|
|
|
|
resp = _client_get(app, "/api/ping", ip="1.2.3.4")
|
|
assert resp.status_code == 503
|
|
assert resp.get_json() == {"error": "服务暂时繁忙,请稍后重试"}
|
|
|
|
|
|
def test_middleware_skips_static_requests(_test_db, monkeypatch):
|
|
app = _make_app(monkeypatch, _test_db)
|
|
|
|
@app.get("/static/test")
|
|
def _static_test():
|
|
return "ok"
|
|
|
|
import security.middleware as sm
|
|
|
|
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
|
|
|
|
resp = _client_get(app, "/static/test", ip="1.2.3.4")
|
|
assert resp.status_code == 200
|
|
assert resp.get_data(as_text=True) == "ok"
|
|
|
|
|
|
def test_middleware_honeypot_short_circuits_side_effects(_test_db, monkeypatch):
|
|
app = _make_app(monkeypatch, _test_db, honeypot_enabled=True)
|
|
|
|
called = {"count": 0}
|
|
|
|
@app.get("/api/side-effect")
|
|
def _side_effect():
|
|
called["count"] += 1
|
|
return jsonify({"real": True})
|
|
|
|
resp = _client_get(app, "/api/side-effect?q=${${a}}", ip="9.9.9.9")
|
|
assert resp.status_code == 200
|
|
payload = resp.get_json()
|
|
assert isinstance(payload, dict)
|
|
assert payload.get("success") is True
|
|
assert called["count"] == 0
|
|
|
|
|
|
def test_middleware_fails_open_on_internal_errors(_test_db, monkeypatch):
|
|
app = _make_app(monkeypatch, _test_db)
|
|
|
|
@app.get("/api/ok")
|
|
def _ok():
|
|
return jsonify({"ok": True, "risk_score": getattr(g, "risk_score", None)})
|
|
|
|
import security.middleware as sm
|
|
|
|
def boom(*_args, **_kwargs):
|
|
raise RuntimeError("boom")
|
|
|
|
monkeypatch.setattr(sm.blacklist, "is_ip_banned", boom)
|
|
monkeypatch.setattr(sm.detector, "scan_input", boom)
|
|
|
|
resp = _client_get(app, "/api/ok", ip="2.2.2.2")
|
|
assert resp.status_code == 200
|
|
assert resp.get_json()["ok"] is True
|
|
|
|
|
|
def test_middleware_sets_request_context_fields(_test_db, monkeypatch):
|
|
app = _make_app(monkeypatch, _test_db)
|
|
|
|
@app.get("/api/context")
|
|
def _context():
|
|
strategy = getattr(g, "response_strategy", None)
|
|
action = getattr(getattr(strategy, "action", None), "value", None)
|
|
return jsonify({"risk_score": getattr(g, "risk_score", None), "action": action})
|
|
|
|
resp = _client_get(app, "/api/context", ip="8.8.8.8")
|
|
assert resp.status_code == 200
|
|
assert resp.get_json() == {"risk_score": 0, "action": "allow"}
|