Files
zsglpt/tests/test_admin_security_api.py
yuyx 46253337eb feat: 实现完整安全防护系统
Phase 1 - 威胁检测引擎:
- security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测
- security/constants.py: 威胁检测规则和评分常量
- 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist

Phase 2 - 风险评分与黑名单:
- security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减
- security/blacklist.py: 黑名单管理,自动封禁规则

Phase 3 - 响应策略:
- security/honeypot.py: 蜜罐响应生成器
- security/response_handler.py: 渐进式响应策略

Phase 4 - 集成:
- security/middleware.py: Flask安全中间件
- routes/admin_api/security.py: 管理后台安全仪表板API
- 36个测试用例全部通过

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-27 01:28:38 +08:00

250 lines
7.8 KiB
Python

from __future__ import annotations
from datetime import timedelta
import pytest
from flask import Flask
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "admin_security_api_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app() -> Flask:
from routes.admin_api.security import security_bp
app = Flask(__name__)
app.config.update(SECRET_KEY="test-secret", TESTING=True)
app.register_blueprint(security_bp)
return app
def _login_admin(client) -> None:
with client.session_transaction() as sess:
sess["admin_id"] = 1
sess["admin_username"] = "admin"
def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(threat_type, int(score), ip, user_id, "/api/test", payload, created_at),
)
conn.commit()
def test_dashboard_requires_admin(_test_db):
app = _make_app()
client = app.test_client()
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 403
assert resp.get_json() == {"error": "需要管理员权限"}
def test_dashboard_counts_and_payload_truncation(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
within_24h = now.strftime("%Y-%m-%d %H:%M:%S")
within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S")
long_payload = "x" * 300
_insert_threat_event(
threat_type="sql_injection",
score=90,
ip="1.2.3.4",
user_id=10,
created_at=within_24h,
payload=long_payload,
)
_insert_threat_event(
threat_type="xss",
score=70,
ip="2.3.4.5",
user_id=11,
created_at=within_24h_2,
payload="short",
)
_insert_threat_event(
threat_type="path_traversal",
score=60,
ip="9.9.9.9",
user_id=None,
created_at=older,
payload="old",
)
manager = BlacklistManager()
manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False)
manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False)
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 200
data = resp.get_json()
assert data["threat_events_24h"] == 2
assert data["banned_ip_count"] == 1
assert data["banned_user_count"] == 1
recent = data["recent_threat_events"]
assert isinstance(recent, list)
assert len(recent) == 3
payload_preview = recent[0]["value_preview"]
assert isinstance(payload_preview, str)
assert len(payload_preview) <= 200
assert payload_preview.endswith("...")
def test_threats_pagination_and_filters(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S")
t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S")
t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S")
_insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a")
_insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b")
_insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c")
resp = client.get("/api/admin/security/threats?page=1&per_page=2")
assert resp.status_code == 200
data = resp.get_json()
assert data["total"] == 3
assert len(data["items"]) == 2
resp2 = client.get("/api/admin/security/threats?page=2&per_page=2")
assert resp2.status_code == 200
data2 = resp2.get_json()
assert data2["total"] == 3
assert len(data2["items"]) == 1
resp3 = client.get("/api/admin/security/threats?event_type=sql_injection")
assert resp3.status_code == 200
data3 = resp3.get_json()
assert data3["total"] == 1
assert data3["items"][0]["threat_type"] == "sql_injection"
resp4 = client.get("/api/admin/security/threats?severity=high")
assert resp4.status_code == 200
data4 = resp4.get_json()
assert data4["total"] == 2
assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"}
def test_ban_and_unban_ip(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1})
assert resp.status_code == 200
assert resp.get_json()["success"] is True
list_resp = client.get("/api/admin/security/banned-ips")
assert list_resp.status_code == 200
payload = list_resp.get_json()
assert payload["count"] == 1
assert payload["items"][0]["ip"] == "7.7.7.7"
resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"})
assert resp2.status_code == 200
assert resp2.get_json()["success"] is True
list_resp2 = client.get("/api/admin/security/banned-ips")
assert list_resp2.status_code == 200
assert list_resp2.get_json()["count"] == 0
def test_risk_endpoints_and_cleanup(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
scorer = RiskScorer(auto_ban_enabled=False)
scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="<script>")
ip_resp = client.get("/api/admin/security/ip-risk/4.4.4.4")
assert ip_resp.status_code == 200
ip_data = ip_resp.get_json()
assert ip_data["risk_score"] == 20
assert len(ip_data["threat_history"]) >= 1
user_resp = client.get("/api/admin/security/user-risk/44")
assert user_resp.status_code == 200
user_data = user_resp.get_json()
assert user_data["risk_score"] == 20
assert len(user_data["threat_history"]) >= 1
# Prepare decaying scores and expired ban
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
("5.5.5.5", old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
""",
("6.6.6.6", "expired", old_ts, old_ts),
)
conn.commit()
manager = BlacklistManager()
assert manager.is_ip_banned("6.6.6.6") is False # expired already
cleanup_resp = client.post("/api/admin/security/cleanup", json={})
assert cleanup_resp.status_code == 200
assert cleanup_resp.get_json()["success"] is True
# Score decayed by cleanup
assert RiskScorer().get_ip_score("5.5.5.5") == 81