Files
zsglpt/db/security.py

288 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
from datetime import timedelta
from typing import Any, Dict, Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str
_THREAT_EVENT_SELECT_COLUMNS = """
id,
threat_type,
score,
rule,
field_name,
matched,
value_preview,
ip,
user_id,
request_method,
request_path,
user_agent,
created_at
"""
def _normalize_page(page: int) -> int:
try:
page_i = int(page)
except Exception:
page_i = 1
return max(1, page_i)
def _normalize_per_page(per_page: int, default: int = 20) -> int:
try:
value = int(per_page)
except Exception:
value = default
return max(1, min(200, value))
def _normalize_limit(limit: int, default: int = 50) -> int:
try:
value = int(limit)
except Exception:
value = default
return max(1, min(200, value))
def _row_value(row, key: str, index: int = 0, default=None):
if row is None:
return default
try:
return row[key]
except Exception:
try:
return row[index]
except Exception:
return default
def _fetch_threat_events_history(where_clause: str, params: tuple[Any, ...], limit_i: int) -> list[dict]:
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT
{_THREAT_EVENT_SELECT_COLUMNS}
FROM threat_events
WHERE {where_clause}
ORDER BY created_at DESC, id DESC
LIMIT ?
""",
tuple(params) + (limit_i,),
)
return [dict(r) for r in cursor.fetchall()]
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
"""记录登录环境信息,返回是否新设备/新IP。"""
user_id = int(user_id)
ip_text = str(ip_address or "").strip()[:64]
ua_text = str(user_agent or "").strip()[:512]
now_str = get_cst_now_str()
new_device = False
new_ip = False
with db_pool.get_db() as conn:
cursor = conn.cursor()
if ua_text:
cursor.execute(
"SELECT id FROM login_fingerprints WHERE user_id = ? AND user_agent = ?",
(user_id, ua_text),
)
row = cursor.fetchone()
if row:
cursor.execute(
"""
UPDATE login_fingerprints
SET last_seen = ?, last_ip = ?
WHERE id = ?
""",
(now_str, ip_text, _row_value(row, "id", 0)),
)
else:
cursor.execute(
"""
INSERT INTO login_fingerprints (user_id, user_agent, first_seen, last_seen, last_ip)
VALUES (?, ?, ?, ?, ?)
""",
(user_id, ua_text, now_str, now_str, ip_text),
)
new_device = True
if ip_text:
cursor.execute(
"SELECT id FROM login_ips WHERE user_id = ? AND ip = ?",
(user_id, ip_text),
)
row = cursor.fetchone()
if row:
cursor.execute(
"""
UPDATE login_ips
SET last_seen = ?
WHERE id = ?
""",
(now_str, _row_value(row, "id", 0)),
)
else:
cursor.execute(
"""
INSERT INTO login_ips (user_id, ip, first_seen, last_seen)
VALUES (?, ?, ?, ?)
""",
(user_id, ip_text, now_str, now_str),
)
new_ip = True
conn.commit()
return {"new_device": new_device, "new_ip": new_ip}
def get_threat_events_count(hours: int = 24) -> int:
"""获取指定时间内的威胁事件数。"""
try:
hours_int = max(0, int(hours))
except Exception:
hours_int = 24
if hours_int <= 0:
return 0
start_time = (get_cst_now() - timedelta(hours=hours_int)).strftime("%Y-%m-%d %H:%M:%S")
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS cnt FROM threat_events WHERE created_at >= ?", (start_time,))
row = cursor.fetchone()
try:
return int(row["cnt"] if row else 0)
except Exception:
return 0
def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, list[Any]]:
clauses: list[str] = []
params: list[Any] = []
if not isinstance(filters, dict):
return "", []
event_type = filters.get("event_type") or filters.get("threat_type")
if event_type:
raw = str(event_type).strip()
types = [t.strip()[:64] for t in raw.split(",") if t.strip()]
if len(types) == 1:
clauses.append("threat_type = ?")
params.append(types[0])
elif types:
placeholders = ", ".join(["?"] * len(types))
clauses.append(f"threat_type IN ({placeholders})")
params.extend(types)
severity = filters.get("severity")
if severity is not None and str(severity).strip():
sev = str(severity).strip().lower()
if "-" in sev:
parts = [p.strip() for p in sev.split("-", 1)]
try:
min_score = int(parts[0])
max_score = int(parts[1])
clauses.append("score >= ? AND score <= ?")
params.extend([min_score, max_score])
except Exception:
pass
elif sev.isdigit():
clauses.append("score >= ?")
params.append(int(sev))
elif sev in {"high", "critical"}:
clauses.append("score >= ?")
params.append(80)
elif sev in {"medium", "med"}:
clauses.append("score >= ? AND score < ?")
params.extend([50, 80])
elif sev in {"low", "info"}:
clauses.append("score < ?")
params.append(50)
ip = filters.get("ip")
if ip is not None and str(ip).strip():
ip_text = str(ip).strip()[:64]
clauses.append("ip = ?")
params.append(ip_text)
user_id = filters.get("user_id")
if user_id is not None and str(user_id).strip():
try:
user_id_int = int(user_id)
except Exception:
user_id_int = None
if user_id_int is not None:
clauses.append("user_id = ?")
params.append(user_id_int)
if not clauses:
return "", []
return " WHERE " + " AND ".join(clauses), params
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
"""分页获取威胁事件。"""
page_i = _normalize_page(page)
per_page_i = _normalize_per_page(per_page, default=20)
where_sql, params = _build_threat_events_where_clause(filters)
offset = (page_i - 1) * per_page_i
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) AS cnt FROM threat_events{where_sql}", tuple(params))
row = cursor.fetchone()
total = int(row["cnt"]) if row else 0
cursor.execute(
f"""
SELECT
{_THREAT_EVENT_SELECT_COLUMNS}
FROM threat_events
{where_sql}
ORDER BY created_at DESC, id DESC
LIMIT ? OFFSET ?
""",
tuple(params + [per_page_i, offset]),
)
items = [dict(r) for r in cursor.fetchall()]
return {"page": page_i, "per_page": per_page_i, "total": total, "items": items, "filters": filters or {}}
def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
"""获取IP的威胁历史最近limit条"""
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return []
limit_i = _normalize_limit(limit, default=50)
return _fetch_threat_events_history("ip = ?", (ip_text,), limit_i)
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
"""获取用户的威胁历史最近limit条"""
if user_id is None:
return []
try:
user_id_int = int(user_id)
except Exception:
return []
limit_i = _normalize_limit(limit, default=50)
return _fetch_threat_events_history("user_id = ?", (user_id_int,), limit_i)