feat: 实现完整安全防护系统
Phase 1 - 威胁检测引擎: - security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测 - security/constants.py: 威胁检测规则和评分常量 - 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist Phase 2 - 风险评分与黑名单: - security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减 - security/blacklist.py: 黑名单管理,自动封禁规则 Phase 3 - 响应策略: - security/honeypot.py: 蜜罐响应生成器 - security/response_handler.py: 渐进式响应策略 Phase 4 - 集成: - security/middleware.py: Flask安全中间件 - routes/admin_api/security.py: 管理后台安全仪表板API - 36个测试用例全部通过 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
179
tests/test_risk_scorer.py
Normal file
179
tests/test_risk_scorer.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
import db_pool
|
||||
from db.schema import ensure_schema
|
||||
from db.utils import get_cst_now
|
||||
from security import constants as C
|
||||
from security.blacklist import BlacklistManager
|
||||
from security.risk_scorer import RiskScorer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _test_db(tmp_path):
|
||||
db_file = tmp_path / "risk_scorer_test.db"
|
||||
|
||||
old_pool = getattr(db_pool, "_pool", None)
|
||||
try:
|
||||
if old_pool is not None:
|
||||
try:
|
||||
old_pool.close_all()
|
||||
except Exception:
|
||||
pass
|
||||
db_pool._pool = None
|
||||
db_pool.init_pool(str(db_file), pool_size=1)
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
ensure_schema(conn)
|
||||
|
||||
yield db_file
|
||||
finally:
|
||||
try:
|
||||
if getattr(db_pool, "_pool", None) is not None:
|
||||
db_pool._pool.close_all()
|
||||
except Exception:
|
||||
pass
|
||||
db_pool._pool = old_pool
|
||||
|
||||
|
||||
def test_record_threat_updates_scores_and_combined(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager)
|
||||
|
||||
ip = "1.2.3.4"
|
||||
user_id = 123
|
||||
|
||||
assert scorer.get_ip_score(ip) == 0
|
||||
assert scorer.get_user_score(user_id) == 0
|
||||
assert scorer.get_combined_score(ip, user_id) == 0
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=30, request_path="/login", payload="x")
|
||||
|
||||
assert scorer.get_ip_score(ip) == 30
|
||||
assert scorer.get_user_score(user_id) == 30
|
||||
assert scorer.get_combined_score(ip, user_id) == 30
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=80, request_path="/login", payload="y")
|
||||
|
||||
assert scorer.get_ip_score(ip) == 100
|
||||
assert scorer.get_user_score(user_id) == 100
|
||||
assert scorer.get_combined_score(ip, user_id) == 100
|
||||
|
||||
|
||||
def test_auto_ban_on_score_100(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager)
|
||||
|
||||
ip = "5.6.7.8"
|
||||
user_id = 456
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=100, request_path="/api", payload="boom")
|
||||
|
||||
assert manager.is_ip_banned(ip) is True
|
||||
assert manager.is_user_banned(user_id) is True
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is not None
|
||||
|
||||
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is not None
|
||||
|
||||
|
||||
def test_jndi_injection_permanent_ban(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager)
|
||||
|
||||
ip = "9.9.9.9"
|
||||
user_id = 999
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type=C.THREAT_TYPE_JNDI_INJECTION, score=100, request_path="/", payload="${jndi:ldap://x}")
|
||||
|
||||
assert manager.is_ip_banned(ip) is True
|
||||
assert manager.is_user_banned(user_id) is True
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is None
|
||||
|
||||
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is None
|
||||
|
||||
|
||||
def test_high_risk_three_times_permanent_ban(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager, high_risk_threshold=80, high_risk_permanent_ban_count=3)
|
||||
|
||||
ip = "10.0.0.1"
|
||||
user_id = 1
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="a")
|
||||
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="b")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is not None # score hits 100 => temporary ban first
|
||||
|
||||
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="c")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is None # 3 high-risk threats => permanent
|
||||
|
||||
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["expires_at"] is None
|
||||
|
||||
|
||||
def test_decay_scores_hourly_10_percent(_test_db):
|
||||
manager = BlacklistManager()
|
||||
scorer = RiskScorer(blacklist_manager=manager)
|
||||
|
||||
ip = "3.3.3.3"
|
||||
user_id = 11
|
||||
|
||||
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
with db_pool.get_db() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
|
||||
VALUES (?, 100, ?, ?, ?)
|
||||
""",
|
||||
(ip, old_ts, old_ts, old_ts),
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
|
||||
VALUES (?, 100, ?, ?, ?)
|
||||
""",
|
||||
(user_id, old_ts, old_ts, old_ts),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
scorer.decay_scores()
|
||||
|
||||
assert scorer.get_ip_score(ip) == 81
|
||||
assert scorer.get_user_score(user_id) == 81
|
||||
|
||||
Reference in New Issue
Block a user