主要更新: - 新增 security/ 安全模块 (风险评估、威胁检测、蜜罐等) - Dockerfile 添加 curl 以支持 Docker 健康检查 - 前端页面更新 (管理后台、用户端) - 数据库迁移和 schema 更新 - 新增 kdocs 上传服务 - 添加安全相关测试用例 Co-Authored-By: Claude <noreply@anthropic.com>
250 lines
7.8 KiB
Python
250 lines
7.8 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import timedelta
|
|
|
|
import pytest
|
|
from flask import Flask
|
|
|
|
import db_pool
|
|
from db.schema import ensure_schema
|
|
from db.utils import get_cst_now
|
|
from security.blacklist import BlacklistManager
|
|
from security.risk_scorer import RiskScorer
|
|
|
|
|
|
@pytest.fixture()
|
|
def _test_db(tmp_path):
|
|
db_file = tmp_path / "admin_security_api_test.db"
|
|
|
|
old_pool = getattr(db_pool, "_pool", None)
|
|
try:
|
|
if old_pool is not None:
|
|
try:
|
|
old_pool.close_all()
|
|
except Exception:
|
|
pass
|
|
db_pool._pool = None
|
|
db_pool.init_pool(str(db_file), pool_size=1)
|
|
|
|
with db_pool.get_db() as conn:
|
|
ensure_schema(conn)
|
|
|
|
yield db_file
|
|
finally:
|
|
try:
|
|
if getattr(db_pool, "_pool", None) is not None:
|
|
db_pool._pool.close_all()
|
|
except Exception:
|
|
pass
|
|
db_pool._pool = old_pool
|
|
|
|
|
|
def _make_app() -> Flask:
|
|
from routes.admin_api.security import security_bp
|
|
|
|
app = Flask(__name__)
|
|
app.config.update(SECRET_KEY="test-secret", TESTING=True)
|
|
app.register_blueprint(security_bp)
|
|
return app
|
|
|
|
|
|
def _login_admin(client) -> None:
|
|
with client.session_transaction() as sess:
|
|
sess["admin_id"] = 1
|
|
sess["admin_username"] = "admin"
|
|
|
|
|
|
def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str):
|
|
with db_pool.get_db() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(threat_type, int(score), ip, user_id, "/api/test", payload, created_at),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
def test_dashboard_requires_admin(_test_db):
|
|
app = _make_app()
|
|
client = app.test_client()
|
|
|
|
resp = client.get("/api/admin/security/dashboard")
|
|
assert resp.status_code == 403
|
|
assert resp.get_json() == {"error": "需要管理员权限"}
|
|
|
|
|
|
def test_dashboard_counts_and_payload_truncation(_test_db):
|
|
app = _make_app()
|
|
client = app.test_client()
|
|
_login_admin(client)
|
|
|
|
now = get_cst_now()
|
|
within_24h = now.strftime("%Y-%m-%d %H:%M:%S")
|
|
within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
|
|
older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
long_payload = "x" * 300
|
|
_insert_threat_event(
|
|
threat_type="sql_injection",
|
|
score=90,
|
|
ip="1.2.3.4",
|
|
user_id=10,
|
|
created_at=within_24h,
|
|
payload=long_payload,
|
|
)
|
|
_insert_threat_event(
|
|
threat_type="xss",
|
|
score=70,
|
|
ip="2.3.4.5",
|
|
user_id=11,
|
|
created_at=within_24h_2,
|
|
payload="short",
|
|
)
|
|
_insert_threat_event(
|
|
threat_type="path_traversal",
|
|
score=60,
|
|
ip="9.9.9.9",
|
|
user_id=None,
|
|
created_at=older,
|
|
payload="old",
|
|
)
|
|
|
|
manager = BlacklistManager()
|
|
manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False)
|
|
manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False)
|
|
|
|
resp = client.get("/api/admin/security/dashboard")
|
|
assert resp.status_code == 200
|
|
data = resp.get_json()
|
|
|
|
assert data["threat_events_24h"] == 2
|
|
assert data["banned_ip_count"] == 1
|
|
assert data["banned_user_count"] == 1
|
|
|
|
recent = data["recent_threat_events"]
|
|
assert isinstance(recent, list)
|
|
assert len(recent) == 3
|
|
|
|
payload_preview = recent[0]["value_preview"]
|
|
assert isinstance(payload_preview, str)
|
|
assert len(payload_preview) <= 200
|
|
assert payload_preview.endswith("...")
|
|
|
|
|
|
def test_threats_pagination_and_filters(_test_db):
|
|
app = _make_app()
|
|
client = app.test_client()
|
|
_login_admin(client)
|
|
|
|
now = get_cst_now()
|
|
t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S")
|
|
t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S")
|
|
t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
_insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a")
|
|
_insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b")
|
|
_insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c")
|
|
|
|
resp = client.get("/api/admin/security/threats?page=1&per_page=2")
|
|
assert resp.status_code == 200
|
|
data = resp.get_json()
|
|
assert data["total"] == 3
|
|
assert len(data["items"]) == 2
|
|
|
|
resp2 = client.get("/api/admin/security/threats?page=2&per_page=2")
|
|
assert resp2.status_code == 200
|
|
data2 = resp2.get_json()
|
|
assert data2["total"] == 3
|
|
assert len(data2["items"]) == 1
|
|
|
|
resp3 = client.get("/api/admin/security/threats?event_type=sql_injection")
|
|
assert resp3.status_code == 200
|
|
data3 = resp3.get_json()
|
|
assert data3["total"] == 1
|
|
assert data3["items"][0]["threat_type"] == "sql_injection"
|
|
|
|
resp4 = client.get("/api/admin/security/threats?severity=high")
|
|
assert resp4.status_code == 200
|
|
data4 = resp4.get_json()
|
|
assert data4["total"] == 2
|
|
assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"}
|
|
|
|
|
|
def test_ban_and_unban_ip(_test_db):
|
|
app = _make_app()
|
|
client = app.test_client()
|
|
_login_admin(client)
|
|
|
|
resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1})
|
|
assert resp.status_code == 200
|
|
assert resp.get_json()["success"] is True
|
|
|
|
list_resp = client.get("/api/admin/security/banned-ips")
|
|
assert list_resp.status_code == 200
|
|
payload = list_resp.get_json()
|
|
assert payload["count"] == 1
|
|
assert payload["items"][0]["ip"] == "7.7.7.7"
|
|
|
|
resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"})
|
|
assert resp2.status_code == 200
|
|
assert resp2.get_json()["success"] is True
|
|
|
|
list_resp2 = client.get("/api/admin/security/banned-ips")
|
|
assert list_resp2.status_code == 200
|
|
assert list_resp2.get_json()["count"] == 0
|
|
|
|
|
|
def test_risk_endpoints_and_cleanup(_test_db):
|
|
app = _make_app()
|
|
client = app.test_client()
|
|
_login_admin(client)
|
|
|
|
scorer = RiskScorer(auto_ban_enabled=False)
|
|
scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="<script>")
|
|
|
|
ip_resp = client.get("/api/admin/security/ip-risk/4.4.4.4")
|
|
assert ip_resp.status_code == 200
|
|
ip_data = ip_resp.get_json()
|
|
assert ip_data["risk_score"] == 20
|
|
assert len(ip_data["threat_history"]) >= 1
|
|
|
|
user_resp = client.get("/api/admin/security/user-risk/44")
|
|
assert user_resp.status_code == 200
|
|
user_data = user_resp.get_json()
|
|
assert user_data["risk_score"] == 20
|
|
assert len(user_data["threat_history"]) >= 1
|
|
|
|
# Prepare decaying scores and expired ban
|
|
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
|
|
with db_pool.get_db() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
|
|
VALUES (?, 100, ?, ?, ?)
|
|
""",
|
|
("5.5.5.5", old_ts, old_ts, old_ts),
|
|
)
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
|
|
VALUES (?, ?, 1, ?, ?)
|
|
""",
|
|
("6.6.6.6", "expired", old_ts, old_ts),
|
|
)
|
|
conn.commit()
|
|
|
|
manager = BlacklistManager()
|
|
assert manager.is_ip_banned("6.6.6.6") is False # expired already
|
|
|
|
cleanup_resp = client.post("/api/admin/security/cleanup", json={})
|
|
assert cleanup_resp.status_code == 200
|
|
assert cleanup_resp.get_json()["success"] is True
|
|
|
|
# Score decayed by cleanup
|
|
assert RiskScorer().get_ip_score("5.5.5.5") == 81
|
|
|