Files
zsglpt/db_pool.py

459 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库连接池模块
使用queue实现固定大小的连接池,防止连接泄漏
"""
import sqlite3
import threading
import time
from queue import Empty, Full, Queue
from app_config import get_config
from app_logger import get_logger
logger = get_logger("database")
config = get_config()
DB_CONNECT_TIMEOUT_SECONDS = max(1, int(getattr(config, "DB_CONNECT_TIMEOUT_SECONDS", 10)))
DB_BUSY_TIMEOUT_MS = max(1000, int(getattr(config, "DB_BUSY_TIMEOUT_MS", 10000)))
DB_CACHE_SIZE_KB = max(1024, int(getattr(config, "DB_CACHE_SIZE_KB", 8192)))
DB_WAL_AUTOCHECKPOINT_PAGES = max(100, int(getattr(config, "DB_WAL_AUTOCHECKPOINT_PAGES", 1000)))
DB_MMAP_SIZE_MB = max(0, int(getattr(config, "DB_MMAP_SIZE_MB", 256)))
DB_LOCK_RETRY_COUNT = max(0, int(getattr(config, "DB_LOCK_RETRY_COUNT", 3)))
DB_LOCK_RETRY_BASE_MS = max(10, int(getattr(config, "DB_LOCK_RETRY_BASE_MS", 50)))
DB_SLOW_QUERY_MS = max(0, int(getattr(config, "DB_SLOW_QUERY_MS", 120)))
DB_SLOW_QUERY_SQL_MAX_LEN = max(80, int(getattr(config, "DB_SLOW_QUERY_SQL_MAX_LEN", 240)))
_slow_query_runtime_lock = threading.Lock()
_slow_query_runtime_threshold_ms = DB_SLOW_QUERY_MS
_slow_query_runtime_sql_max_len = DB_SLOW_QUERY_SQL_MAX_LEN
def _get_slow_query_runtime_values() -> tuple[int, int]:
with _slow_query_runtime_lock:
return int(_slow_query_runtime_threshold_ms), int(_slow_query_runtime_sql_max_len)
def get_slow_query_runtime() -> dict:
threshold_ms, sql_max_len = _get_slow_query_runtime_values()
return {"threshold_ms": threshold_ms, "sql_max_len": sql_max_len}
def configure_slow_query_runtime(*, threshold_ms=None, sql_max_len=None) -> dict:
global _slow_query_runtime_threshold_ms, _slow_query_runtime_sql_max_len
with _slow_query_runtime_lock:
if threshold_ms is not None:
_slow_query_runtime_threshold_ms = max(0, int(threshold_ms))
if sql_max_len is not None:
_slow_query_runtime_sql_max_len = max(80, int(sql_max_len))
runtime_threshold_ms = int(_slow_query_runtime_threshold_ms)
runtime_sql_max_len = int(_slow_query_runtime_sql_max_len)
try:
from services.slow_sql_metrics import configure_slow_sql_runtime
configure_slow_sql_runtime(
threshold_ms=runtime_threshold_ms,
sql_max_len=runtime_sql_max_len,
)
except Exception:
pass
return {"threshold_ms": runtime_threshold_ms, "sql_max_len": runtime_sql_max_len}
def _is_lock_conflict_error(error: sqlite3.OperationalError) -> bool:
message = str(error or "").lower()
return ("locked" in message) or ("busy" in message)
def _compact_sql(sql: str) -> str:
_, sql_max_len = _get_slow_query_runtime_values()
statement = " ".join(str(sql or "").split())
if len(statement) <= sql_max_len:
return statement
return statement[: sql_max_len - 3] + "..."
def _describe_params(parameters) -> str:
if parameters is None:
return "none"
if isinstance(parameters, dict):
return f"dict[{len(parameters)}]"
if isinstance(parameters, (list, tuple)):
return f"{type(parameters).__name__}[{len(parameters)}]"
return type(parameters).__name__
class TracedCursor:
"""带慢查询检测的游标包装器"""
def __init__(self, cursor, on_query_executed):
self._cursor = cursor
self._on_query_executed = on_query_executed
def _trace(self, sql, parameters, execute_fn):
start = time.perf_counter()
try:
execute_fn()
finally:
elapsed_ms = (time.perf_counter() - start) * 1000.0
try:
self._on_query_executed(sql, parameters, elapsed_ms)
except Exception:
pass
def execute(self, sql, parameters=None):
if parameters is None:
self._trace(sql, None, lambda: self._cursor.execute(sql))
else:
self._trace(sql, parameters, lambda: self._cursor.execute(sql, parameters))
return self
def executemany(self, sql, seq_of_parameters):
self._trace(sql, seq_of_parameters, lambda: self._cursor.executemany(sql, seq_of_parameters))
return self
def executescript(self, sql_script):
self._trace(sql_script, None, lambda: self._cursor.executescript(sql_script))
return self
def fetchone(self):
return self._cursor.fetchone()
def fetchall(self):
return self._cursor.fetchall()
def fetchmany(self, size=None):
if size is None:
return self._cursor.fetchmany()
return self._cursor.fetchmany(size)
def close(self):
return self._cursor.close()
@property
def rowcount(self):
return self._cursor.rowcount
@property
def lastrowid(self):
return self._cursor.lastrowid
def __iter__(self):
return iter(self._cursor)
def __getattr__(self, item):
return getattr(self._cursor, item)
class ConnectionPool:
"""SQLite连接池"""
def __init__(self, database, pool_size=5, timeout=30):
"""
初始化连接池
Args:
database: 数据库文件路径
pool_size: 连接池大小(默认5)
timeout: 获取连接超时时间(秒)
"""
self.database = database
self.pool_size = pool_size
self.timeout = timeout
self._pool = Queue(maxsize=pool_size)
self._lock = threading.Lock()
self._created_connections = 0
# 预创建连接
self._initialize_pool()
def _initialize_pool(self):
"""预创建连接池中的连接"""
for _ in range(self.pool_size):
conn = self._create_connection()
self._pool.put(conn)
self._created_connections += 1
def _create_connection(self):
"""创建新的数据库连接"""
conn = sqlite3.connect(
self.database,
check_same_thread=False,
timeout=DB_CONNECT_TIMEOUT_SECONDS,
)
conn.row_factory = sqlite3.Row
pragma_statements = [
"PRAGMA foreign_keys=ON",
"PRAGMA journal_mode=WAL",
"PRAGMA synchronous=NORMAL",
f"PRAGMA busy_timeout={DB_BUSY_TIMEOUT_MS}",
"PRAGMA temp_store=MEMORY",
f"PRAGMA cache_size={-DB_CACHE_SIZE_KB}",
f"PRAGMA wal_autocheckpoint={DB_WAL_AUTOCHECKPOINT_PAGES}",
]
if DB_MMAP_SIZE_MB > 0:
pragma_statements.append(f"PRAGMA mmap_size={DB_MMAP_SIZE_MB * 1024 * 1024}")
for statement in pragma_statements:
try:
conn.execute(statement)
except sqlite3.DatabaseError as e:
logger.warning(f"设置数据库参数失败 ({statement}): {e}")
return conn
def _close_connection(self, conn) -> None:
if conn is None:
return
try:
conn.close()
except Exception as e:
logger.warning(f"关闭连接失败: {e}")
def _is_connection_healthy(self, conn) -> bool:
if conn is None:
return False
try:
conn.rollback()
conn.execute("SELECT 1")
return True
except sqlite3.Error as e:
logger.warning(f"连接健康检查失败(数据库错误): {e}")
except Exception as e:
logger.warning(f"连接健康检查失败(未知错误): {e}")
return False
def _replenish_pool_if_needed(self) -> None:
with self._lock:
if self._pool.qsize() >= self.pool_size:
return
new_conn = None
try:
new_conn = self._create_connection()
self._pool.put(new_conn, block=False)
self._created_connections += 1
except Full:
if new_conn:
self._close_connection(new_conn)
except Exception as e:
if new_conn:
self._close_connection(new_conn)
logger.warning(f"重建连接失败: {e}")
def get_connection(self):
"""
从连接池获取连接
Returns:
PooledConnection: 连接包装对象
"""
try:
conn = self._pool.get(timeout=self.timeout)
return PooledConnection(conn, self)
except Empty:
raise RuntimeError(f"无法在{self.timeout}秒内获取数据库连接")
def return_connection(self, conn):
"""
归还连接到连接池 [安全修复: 改进竞态条件处理]
Args:
conn: 要归还的连接
"""
if conn is None:
return
if self._is_connection_healthy(conn):
try:
self._pool.put(conn, block=False)
return
except Full:
logger.warning("连接池已满,关闭多余连接")
self._close_connection(conn)
return
self._close_connection(conn)
self._replenish_pool_if_needed()
def close_all(self):
"""关闭所有连接"""
while not self._pool.empty():
try:
conn = self._pool.get(block=False)
conn.close()
except Exception as e:
logger.warning(f"关闭连接失败: {e}")
def get_stats(self):
"""获取连接池统计信息"""
return {
"pool_size": self.pool_size,
"available": self._pool.qsize(),
"in_use": self.pool_size - self._pool.qsize(),
"total_created": self._created_connections,
}
class PooledConnection:
"""连接池连接包装器,支持with语句自动归还"""
def __init__(self, conn, pool):
"""
初始化
Args:
conn: 实际的数据库连接
pool: 连接池对象
"""
self._conn = conn
self._pool = pool
self._cursor = None
def __enter__(self):
"""支持with语句"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""with语句结束时自动归还连接 [已修复Bug#3]"""
try:
if exc_type is not None:
self._conn.rollback()
logger.warning(f"数据库事务已回滚: {exc_type.__name__}")
if self._cursor is not None:
self._cursor.close()
self._cursor = None
except Exception as e:
logger.warning(f"关闭游标失败: {e}")
finally:
self._pool.return_connection(self._conn)
return False
def _on_query_executed(self, sql: str, parameters, elapsed_ms: float) -> None:
slow_query_ms, _ = _get_slow_query_runtime_values()
if slow_query_ms <= 0:
return
if elapsed_ms < slow_query_ms:
return
params_info = _describe_params(parameters)
try:
from services.slow_sql_metrics import record_slow_sql
record_slow_sql(sql=sql, duration_ms=elapsed_ms, params_info=params_info)
except Exception:
pass
logger.warning(f"[慢SQL] {elapsed_ms:.1f}ms sql=\"{_compact_sql(sql)}\" params={params_info}")
def cursor(self):
"""获取游标"""
if self._cursor is None:
raw_cursor = self._conn.cursor()
self._cursor = TracedCursor(raw_cursor, self._on_query_executed)
return self._cursor
def commit(self):
"""提交事务"""
for attempt in range(DB_LOCK_RETRY_COUNT + 1):
try:
self._conn.commit()
return
except sqlite3.OperationalError as e:
if (not _is_lock_conflict_error(e)) or attempt >= DB_LOCK_RETRY_COUNT:
raise
sleep_seconds = (DB_LOCK_RETRY_BASE_MS * (2**attempt)) / 1000.0
logger.warning(
f"数据库提交遇到锁冲突,{sleep_seconds:.3f}s 后重试 "
f"({attempt + 1}/{DB_LOCK_RETRY_COUNT})"
)
time.sleep(sleep_seconds)
def rollback(self):
"""回滚事务"""
self._conn.rollback()
def execute(self, sql, parameters=None):
"""执行SQL"""
cursor = self.cursor()
if parameters is None:
return cursor.execute(sql)
return cursor.execute(sql, parameters)
def fetchone(self):
"""获取一行"""
if self._cursor:
return self._cursor.fetchone()
return None
def fetchall(self):
"""获取所有行"""
if self._cursor:
return self._cursor.fetchall()
return []
@property
def lastrowid(self):
"""最后插入的行ID"""
if self._cursor:
return self._cursor.lastrowid
return None
@property
def rowcount(self):
"""影响的行数"""
if self._cursor:
return self._cursor.rowcount
return 0
# 全局连接池实例
_pool = None
_pool_lock = threading.Lock()
def init_pool(database, pool_size=5):
"""
初始化全局连接池
Args:
database: 数据库文件路径
pool_size: 连接池大小
"""
global _pool
with _pool_lock:
if _pool is None:
_pool = ConnectionPool(database, pool_size)
logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
def get_db():
"""
获取数据库连接(替代原有的get_db函数)
Returns:
PooledConnection: 连接对象
"""
global _pool
if _pool is None:
raise RuntimeError("连接池未初始化,请先调用init_pool()")
return _pool.get_connection()
def get_pool_stats():
"""获取连接池统计信息"""
global _pool
if _pool:
return _pool.get_stats()
return None