from __future__ import annotations from datetime import timedelta import pytest import db_pool from db.schema import ensure_schema from db.utils import get_cst_now from security import constants as C from security.blacklist import BlacklistManager from security.risk_scorer import RiskScorer @pytest.fixture() def _test_db(tmp_path): db_file = tmp_path / "risk_scorer_test.db" old_pool = getattr(db_pool, "_pool", None) try: if old_pool is not None: try: old_pool.close_all() except Exception: pass db_pool._pool = None db_pool.init_pool(str(db_file), pool_size=1) with db_pool.get_db() as conn: ensure_schema(conn) yield db_file finally: try: if getattr(db_pool, "_pool", None) is not None: db_pool._pool.close_all() except Exception: pass db_pool._pool = old_pool def test_record_threat_updates_scores_and_combined(_test_db): manager = BlacklistManager() scorer = RiskScorer(blacklist_manager=manager) ip = "1.2.3.4" user_id = 123 assert scorer.get_ip_score(ip) == 0 assert scorer.get_user_score(user_id) == 0 assert scorer.get_combined_score(ip, user_id) == 0 scorer.record_threat(ip, user_id, threat_type="sql_injection", score=30, request_path="/login", payload="x") assert scorer.get_ip_score(ip) == 30 assert scorer.get_user_score(user_id) == 30 assert scorer.get_combined_score(ip, user_id) == 30 scorer.record_threat(ip, user_id, threat_type="sql_injection", score=80, request_path="/login", payload="y") assert scorer.get_ip_score(ip) == 100 assert scorer.get_user_score(user_id) == 100 assert scorer.get_combined_score(ip, user_id) == 100 def test_auto_ban_on_score_100(_test_db): manager = BlacklistManager() scorer = RiskScorer(blacklist_manager=manager) ip = "5.6.7.8" user_id = 456 scorer.record_threat(ip, user_id, threat_type="sql_injection", score=100, request_path="/api", payload="boom") assert manager.is_ip_banned(ip) is True assert manager.is_user_banned(user_id) is True with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,)) row = cursor.fetchone() assert row is not None assert row["expires_at"] is not None cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,)) row = cursor.fetchone() assert row is not None assert row["expires_at"] is not None def test_jndi_injection_permanent_ban(_test_db): manager = BlacklistManager() scorer = RiskScorer(blacklist_manager=manager) ip = "9.9.9.9" user_id = 999 scorer.record_threat(ip, user_id, threat_type=C.THREAT_TYPE_JNDI_INJECTION, score=100, request_path="/", payload="${jndi:ldap://x}") assert manager.is_ip_banned(ip) is True assert manager.is_user_banned(user_id) is True with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,)) row = cursor.fetchone() assert row is not None assert row["expires_at"] is None cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,)) row = cursor.fetchone() assert row is not None assert row["expires_at"] is None def test_high_risk_three_times_permanent_ban(_test_db): manager = BlacklistManager() scorer = RiskScorer(blacklist_manager=manager, high_risk_threshold=80, high_risk_permanent_ban_count=3) ip = "10.0.0.1" user_id = 1 scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="a") scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="b") with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,)) row = cursor.fetchone() assert row is not None assert row["expires_at"] is not None # score hits 100 => temporary ban first scorer.record_threat(ip, user_id, threat_type="nested_expression", score=80, request_path="/", payload="c") with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT expires_at FROM ip_blacklist WHERE ip = ?", (ip,)) row = cursor.fetchone() assert row is not None assert row["expires_at"] is None # 3 high-risk threats => permanent cursor.execute("SELECT expires_at FROM user_blacklist WHERE user_id = ?", (user_id,)) row = cursor.fetchone() assert row is not None assert row["expires_at"] is None def test_decay_scores_hourly_10_percent(_test_db): manager = BlacklistManager() scorer = RiskScorer(blacklist_manager=manager) ip = "3.3.3.3" user_id = 11 old_ts = (get_cst_now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S") with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( """ INSERT INTO ip_risk_scores (ip, risk_score, last_seen, created_at, updated_at) VALUES (?, 100, ?, ?, ?) """, (ip, old_ts, old_ts, old_ts), ) cursor.execute( """ INSERT INTO user_risk_scores (user_id, risk_score, last_seen, created_at, updated_at) VALUES (?, 100, ?, ?, ?) """, (user_id, old_ts, old_ts, old_ts), ) conn.commit() scorer.decay_scores() assert scorer.get_ip_score(ip) == 81 assert scorer.get_user_score(user_id) == 81