from __future__ import annotations import pytest from flask import Flask, g, jsonify from flask_login import LoginManager import db_pool from db.schema import ensure_schema from security import init_security_middleware @pytest.fixture() def _test_db(tmp_path): db_file = tmp_path / "security_middleware_test.db" old_pool = getattr(db_pool, "_pool", None) try: if old_pool is not None: try: old_pool.close_all() except Exception: pass db_pool._pool = None db_pool.init_pool(str(db_file), pool_size=1) with db_pool.get_db() as conn: ensure_schema(conn) yield db_file finally: try: if getattr(db_pool, "_pool", None) is not None: db_pool._pool.close_all() except Exception: pass db_pool._pool = old_pool def _make_app(monkeypatch, _test_db, *, security_enabled: bool = True, honeypot_enabled: bool = True) -> Flask: import security.middleware as sm import security.response_handler as rh # 避免测试因风控延迟而变慢 monkeypatch.setattr(rh.time, "sleep", lambda _seconds: None) # 每个测试用例保持 handler/honeypot 的懒加载状态 sm.handler = None sm.honeypot = None app = Flask(__name__) app.config.update( SECRET_KEY="test-secret", TESTING=True, SECURITY_ENABLED=bool(security_enabled), HONEYPOT_ENABLED=bool(honeypot_enabled), SECURITY_LOG_LEVEL="CRITICAL", # 降低测试日志噪音 ) login_manager = LoginManager() login_manager.init_app(app) @login_manager.user_loader def _load_user(_user_id: str): return None init_security_middleware(app) return app def _client_get(app: Flask, path: str, *, ip: str = "1.2.3.4"): return app.test_client().get(path, environ_overrides={"REMOTE_ADDR": ip}) def test_middleware_blocks_banned_ip(_test_db, monkeypatch): app = _make_app(monkeypatch, _test_db) @app.get("/api/ping") def _ping(): return jsonify({"ok": True}) import security.middleware as sm sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False) resp = _client_get(app, "/api/ping", ip="1.2.3.4") assert resp.status_code == 503 assert resp.get_json() == {"error": "服务暂时繁忙,请稍后重试"} def test_middleware_skips_static_requests(_test_db, monkeypatch): app = _make_app(monkeypatch, _test_db) @app.get("/static/test") def _static_test(): return "ok" import security.middleware as sm sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False) resp = _client_get(app, "/static/test", ip="1.2.3.4") assert resp.status_code == 200 assert resp.get_data(as_text=True) == "ok" def test_middleware_honeypot_short_circuits_side_effects(_test_db, monkeypatch): app = _make_app(monkeypatch, _test_db, honeypot_enabled=True) called = {"count": 0} @app.get("/api/side-effect") def _side_effect(): called["count"] += 1 return jsonify({"real": True}) resp = _client_get(app, "/api/side-effect?q=${${a}}", ip="9.9.9.9") assert resp.status_code == 200 payload = resp.get_json() assert isinstance(payload, dict) assert payload.get("success") is True assert called["count"] == 0 def test_middleware_fails_open_on_internal_errors(_test_db, monkeypatch): app = _make_app(monkeypatch, _test_db) @app.get("/api/ok") def _ok(): return jsonify({"ok": True, "risk_score": getattr(g, "risk_score", None)}) import security.middleware as sm def boom(*_args, **_kwargs): raise RuntimeError("boom") monkeypatch.setattr(sm.blacklist, "is_ip_banned", boom) monkeypatch.setattr(sm.detector, "scan_input", boom) resp = _client_get(app, "/api/ok", ip="2.2.2.2") assert resp.status_code == 200 assert resp.get_json()["ok"] is True def test_middleware_sets_request_context_fields(_test_db, monkeypatch): app = _make_app(monkeypatch, _test_db) @app.get("/api/context") def _context(): strategy = getattr(g, "response_strategy", None) action = getattr(getattr(strategy, "action", None), "value", None) return jsonify({"risk_score": getattr(g, "risk_score", None), "action": action}) resp = _client_get(app, "/api/context", ip="8.8.8.8") assert resp.status_code == 200 assert resp.get_json() == {"risk_score": 0, "action": "allow"}