from __future__ import annotations from datetime import timedelta import pytest from flask import Flask import db_pool from db.schema import ensure_schema from db.utils import get_cst_now from security.blacklist import BlacklistManager from security.risk_scorer import RiskScorer @pytest.fixture() def _test_db(tmp_path): db_file = tmp_path / "admin_security_api_test.db" old_pool = getattr(db_pool, "_pool", None) try: if old_pool is not None: try: old_pool.close_all() except Exception: pass db_pool._pool = None db_pool.init_pool(str(db_file), pool_size=1) with db_pool.get_db() as conn: ensure_schema(conn) yield db_file finally: try: if getattr(db_pool, "_pool", None) is not None: db_pool._pool.close_all() except Exception: pass db_pool._pool = old_pool def _make_app() -> Flask: from routes.admin_api.security import security_bp app = Flask(__name__) app.config.update(SECRET_KEY="test-secret", TESTING=True) app.register_blueprint(security_bp) return app def _login_admin(client) -> None: with client.session_transaction() as sess: sess["admin_id"] = 1 sess["admin_username"] = "admin" def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str): with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( """ INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at) VALUES (?, ?, ?, ?, ?, ?, ?) """, (threat_type, int(score), ip, user_id, "/api/test", payload, created_at), ) conn.commit() def test_dashboard_requires_admin(_test_db): app = _make_app() client = app.test_client() resp = client.get("/api/admin/security/dashboard") assert resp.status_code == 403 assert resp.get_json() == {"error": "需要管理员权限"} def test_dashboard_counts_and_payload_truncation(_test_db): app = _make_app() client = app.test_client() _login_admin(client) now = get_cst_now() within_24h = now.strftime("%Y-%m-%d %H:%M:%S") within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S") older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S") long_payload = "x" * 300 _insert_threat_event( threat_type="sql_injection", score=90, ip="1.2.3.4", user_id=10, created_at=within_24h, payload=long_payload, ) _insert_threat_event( threat_type="xss", score=70, ip="2.3.4.5", user_id=11, created_at=within_24h_2, payload="short", ) _insert_threat_event( threat_type="path_traversal", score=60, ip="9.9.9.9", user_id=None, created_at=older, payload="old", ) manager = BlacklistManager() manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False) manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False) resp = client.get("/api/admin/security/dashboard") assert resp.status_code == 200 data = resp.get_json() assert data["threat_events_24h"] == 2 assert data["banned_ip_count"] == 1 assert data["banned_user_count"] == 1 recent = data["recent_threat_events"] assert isinstance(recent, list) assert len(recent) == 3 payload_preview = recent[0]["value_preview"] assert isinstance(payload_preview, str) assert len(payload_preview) <= 200 assert payload_preview.endswith("...") def test_threats_pagination_and_filters(_test_db): app = _make_app() client = app.test_client() _login_admin(client) now = get_cst_now() t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S") t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S") t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S") _insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a") _insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b") _insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c") resp = client.get("/api/admin/security/threats?page=1&per_page=2") assert resp.status_code == 200 data = resp.get_json() assert data["total"] == 3 assert len(data["items"]) == 2 resp2 = client.get("/api/admin/security/threats?page=2&per_page=2") assert resp2.status_code == 200 data2 = resp2.get_json() assert data2["total"] == 3 assert len(data2["items"]) == 1 resp3 = client.get("/api/admin/security/threats?event_type=sql_injection") assert resp3.status_code == 200 data3 = resp3.get_json() assert data3["total"] == 1 assert data3["items"][0]["threat_type"] == "sql_injection" resp4 = client.get("/api/admin/security/threats?severity=high") assert resp4.status_code == 200 data4 = resp4.get_json() assert data4["total"] == 2 assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"} def test_ban_and_unban_ip(_test_db): app = _make_app() client = app.test_client() _login_admin(client) resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1}) assert resp.status_code == 200 assert resp.get_json()["success"] is True list_resp = client.get("/api/admin/security/banned-ips") assert list_resp.status_code == 200 payload = list_resp.get_json() assert payload["count"] == 1 assert payload["items"][0]["ip"] == "7.7.7.7" resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"}) assert resp2.status_code == 200 assert resp2.get_json()["success"] is True list_resp2 = client.get("/api/admin/security/banned-ips") assert list_resp2.status_code == 200 assert list_resp2.get_json()["count"] == 0 def test_risk_endpoints_and_cleanup(_test_db): app = _make_app() client = app.test_client() _login_admin(client) scorer = RiskScorer(auto_ban_enabled=False) scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="