#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import math from dataclasses import dataclass from datetime import timedelta from typing import Optional import db_pool from db.utils import get_cst_now, get_cst_now_str, parse_cst_datetime from . import constants as C from .blacklist import BlacklistManager @dataclass(frozen=True) class _ScoreUpdateResult: ip_score: int user_score: int @dataclass(frozen=True) class _BanAction: reason: str duration_hours: Optional[int] = None permanent: bool = False class RiskScorer: """风险评分引擎 - 计算IP和用户的风险分数""" def __init__( self, *, auto_ban_enabled: bool = True, auto_ban_duration_hours: int = 24, high_risk_threshold: int = 80, high_risk_window_hours: int = 1, high_risk_permanent_ban_count: int = 3, blacklist_manager: Optional[BlacklistManager] = None, ) -> None: self.auto_ban_enabled = bool(auto_ban_enabled) self.auto_ban_duration_hours = max(1, int(auto_ban_duration_hours)) self.high_risk_threshold = max(0, int(high_risk_threshold)) self.high_risk_window_hours = max(1, int(high_risk_window_hours)) self.high_risk_permanent_ban_count = max(1, int(high_risk_permanent_ban_count)) self.blacklist = blacklist_manager or BlacklistManager() def get_ip_score(self, ip_address: str) -> int: """获取IP风险分(0-100),从数据库读取""" ip_text = str(ip_address or "").strip()[:64] if not ip_text: return 0 with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip_text,)) row = cursor.fetchone() if not row: return 0 try: return max(0, min(100, int(row["risk_score"]))) except Exception: return 0 def get_user_score(self, user_id: int) -> int: """获取用户风险分(0-100)""" if user_id is None: return 0 user_id_int = int(user_id) with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (user_id_int,)) row = cursor.fetchone() if not row: return 0 try: return max(0, min(100, int(row["risk_score"]))) except Exception: return 0 def get_combined_score(self, ip: str, user_id: int = None) -> int: """综合风险分 = max(IP分, 用户分) + 行为加成""" base = max(self.get_ip_score(ip), self.get_user_score(user_id) if user_id is not None else 0) bonus = self._get_behavior_bonus(ip, user_id) return max(0, min(100, int(base + bonus))) def record_threat( self, ip: str, user_id: int, threat_type: str, score: int, request_path: str = None, payload: str = None, ): """记录威胁事件到数据库,并更新IP/用户风险分""" ip_text = str(ip or "").strip()[:64] user_id_int = int(user_id) if user_id is not None else None threat_type_text = str(threat_type or "").strip()[:64] or "unknown" score_int = max(0, int(score)) path_text = str(request_path or "").strip()[:512] if request_path else None payload_text = str(payload or "").strip() if payload else None if payload_text and len(payload_text) > 2048: payload_text = payload_text[:2048] now_str = get_cst_now_str() ip_ban_action: Optional[_BanAction] = None user_ban_action: Optional[_BanAction] = None with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( """ INSERT INTO threat_events ( threat_type, score, ip, user_id, request_path, value_preview, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?) """, ( threat_type_text, score_int, ip_text or None, user_id_int, path_text, payload_text, now_str, ), ) update = self._update_scores(cursor, ip_text, user_id_int, score_int, now_str) if self.auto_ban_enabled: ip_ban_action, user_ban_action = self._get_auto_ban_actions( cursor, ip_text, user_id_int, threat_type_text, score_int, update, ) conn.commit() if not self.auto_ban_enabled: return if ip_ban_action and ip_text: self.blacklist.ban_ip( ip_text, reason=ip_ban_action.reason, duration_hours=ip_ban_action.duration_hours or self.auto_ban_duration_hours, permanent=ip_ban_action.permanent, ) if user_ban_action and user_id_int is not None: self.blacklist._ban_user_internal( user_id_int, reason=user_ban_action.reason, duration_hours=user_ban_action.duration_hours or self.auto_ban_duration_hours, permanent=user_ban_action.permanent, ) def decay_scores(self): """风险分衰减 - 定期调用,降低历史风险分""" now = get_cst_now() now_str = now.strftime("%Y-%m-%d %H:%M:%S") with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT ip, risk_score, updated_at, created_at FROM ip_risk_scores") for row in cursor.fetchall(): ip = row["ip"] current_score = int(row["risk_score"] or 0) updated_at = row["updated_at"] or row["created_at"] hours = self._hours_since(updated_at, now) if hours <= 0: continue new_score = self._apply_hourly_decay(current_score, hours) if new_score == current_score: continue cursor.execute( "UPDATE ip_risk_scores SET risk_score = ?, updated_at = ? WHERE ip = ?", (new_score, now_str, ip), ) cursor.execute("SELECT user_id, risk_score, updated_at, created_at FROM user_risk_scores") for row in cursor.fetchall(): user_id = int(row["user_id"]) current_score = int(row["risk_score"] or 0) updated_at = row["updated_at"] or row["created_at"] hours = self._hours_since(updated_at, now) if hours <= 0: continue new_score = self._apply_hourly_decay(current_score, hours) if new_score == current_score: continue cursor.execute( "UPDATE user_risk_scores SET risk_score = ?, updated_at = ? WHERE user_id = ?", (new_score, now_str, user_id), ) conn.commit() def _update_ip_score(self, ip: str, score_delta: int): """更新IP风险分""" ip_text = str(ip or "").strip()[:64] if not ip_text: return delta = int(score_delta) now_str = get_cst_now_str() with db_pool.get_db() as conn: cursor = conn.cursor() self._update_scores(cursor, ip_text, None, delta, now_str) conn.commit() def _update_user_score(self, user_id: int, score_delta: int): """更新用户风险分""" if user_id is None: return user_id_int = int(user_id) delta = int(score_delta) now_str = get_cst_now_str() with db_pool.get_db() as conn: cursor = conn.cursor() self._update_scores(cursor, "", user_id_int, delta, now_str) conn.commit() def _update_scores( self, cursor, ip: str, user_id: Optional[int], score_delta: int, now_str: str, ) -> _ScoreUpdateResult: ip_score = 0 user_score = 0 if ip: cursor.execute("SELECT risk_score FROM ip_risk_scores WHERE ip = ?", (ip,)) row = cursor.fetchone() current = int(row["risk_score"]) if row else 0 ip_score = max(0, min(100, current + int(score_delta))) if row: cursor.execute( "UPDATE ip_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE ip = ?", (ip_score, now_str, now_str, ip), ) else: cursor.execute( """ INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at) VALUES (?, ?, ?, ?, ?) """, (ip, ip_score, now_str, now_str, now_str), ) if user_id is not None: cursor.execute("SELECT risk_score FROM user_risk_scores WHERE user_id = ?", (int(user_id),)) row = cursor.fetchone() current = int(row["risk_score"]) if row else 0 user_score = max(0, min(100, current + int(score_delta))) if row: cursor.execute( "UPDATE user_risk_scores SET risk_score = ?, last_seen = ?, updated_at = ? WHERE user_id = ?", (user_score, now_str, now_str, int(user_id)), ) else: cursor.execute( """ INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at) VALUES (?, ?, ?, ?, ?) """, (int(user_id), user_score, now_str, now_str, now_str), ) return _ScoreUpdateResult(ip_score=ip_score, user_score=user_score) def _get_auto_ban_actions( self, cursor, ip: str, user_id: Optional[int], threat_type: str, score: int, update: _ScoreUpdateResult, ) -> tuple[Optional["_BanAction"], Optional["_BanAction"]]: ip_action: Optional[_BanAction] = None user_action: Optional[_BanAction] = None if threat_type == C.THREAT_TYPE_JNDI_INJECTION: if ip: ip_action = _BanAction(reason="JNDI injection detected", permanent=True) if user_id is not None: user_action = _BanAction(reason="JNDI injection detected", permanent=True) return ip_action, user_action if ip and update.ip_score >= 100: ip_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours) if user_id is not None and update.user_score >= 100: user_action = _BanAction(reason="Risk score reached 100", duration_hours=self.auto_ban_duration_hours) if score < self.high_risk_threshold: return ip_action, user_action window_start = (get_cst_now() - timedelta(hours=self.high_risk_window_hours)).strftime("%Y-%m-%d %H:%M:%S") if ip: cursor.execute( """ SELECT COUNT(*) AS cnt FROM threat_events WHERE ip = ? AND score >= ? AND created_at >= ? """, (ip, int(self.high_risk_threshold), window_start), ) row = cursor.fetchone() cnt = int(row["cnt"]) if row else 0 if cnt >= self.high_risk_permanent_ban_count: ip_action = _BanAction(reason="High-risk threats threshold reached", permanent=True) if user_id is not None: cursor.execute( """ SELECT COUNT(*) AS cnt FROM threat_events WHERE user_id = ? AND score >= ? AND created_at >= ? """, (int(user_id), int(self.high_risk_threshold), window_start), ) row = cursor.fetchone() cnt = int(row["cnt"]) if row else 0 if cnt >= self.high_risk_permanent_ban_count: user_action = _BanAction(reason="High-risk threats threshold reached", permanent=True) return ip_action, user_action def _get_behavior_bonus(self, ip: str, user_id: Optional[int]) -> int: return 0 def _hours_since(self, dt_str: Optional[str], now) -> int: if not dt_str: return 0 try: dt = parse_cst_datetime(str(dt_str)) except Exception: return 0 seconds = (now - dt).total_seconds() if seconds <= 0: return 0 return int(seconds // 3600) def _apply_hourly_decay(self, score: int, hours: int) -> int: score_int = max(0, int(score)) if score_int <= 0 or hours <= 0: return score_int decayed = int(math.floor(score_int * (0.9**int(hours)))) return max(0, min(100, decayed))