#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据库连接池模块 使用queue实现固定大小的连接池,防止连接泄漏 """ import sqlite3 import threading import time from queue import Empty, Full, Queue from app_config import get_config from app_logger import get_logger logger = get_logger("database") config = get_config() DB_CONNECT_TIMEOUT_SECONDS = max(1, int(getattr(config, "DB_CONNECT_TIMEOUT_SECONDS", 10))) DB_BUSY_TIMEOUT_MS = max(1000, int(getattr(config, "DB_BUSY_TIMEOUT_MS", 10000))) DB_CACHE_SIZE_KB = max(1024, int(getattr(config, "DB_CACHE_SIZE_KB", 8192))) DB_WAL_AUTOCHECKPOINT_PAGES = max(100, int(getattr(config, "DB_WAL_AUTOCHECKPOINT_PAGES", 1000))) DB_MMAP_SIZE_MB = max(0, int(getattr(config, "DB_MMAP_SIZE_MB", 256))) DB_LOCK_RETRY_COUNT = max(0, int(getattr(config, "DB_LOCK_RETRY_COUNT", 3))) DB_LOCK_RETRY_BASE_MS = max(10, int(getattr(config, "DB_LOCK_RETRY_BASE_MS", 50))) DB_SLOW_QUERY_MS = max(0, int(getattr(config, "DB_SLOW_QUERY_MS", 120))) DB_SLOW_QUERY_SQL_MAX_LEN = max(80, int(getattr(config, "DB_SLOW_QUERY_SQL_MAX_LEN", 240))) def _is_lock_conflict_error(error: sqlite3.OperationalError) -> bool: message = str(error or "").lower() return ("locked" in message) or ("busy" in message) def _compact_sql(sql: str) -> str: statement = " ".join(str(sql or "").split()) if len(statement) <= DB_SLOW_QUERY_SQL_MAX_LEN: return statement return statement[: DB_SLOW_QUERY_SQL_MAX_LEN - 3] + "..." def _describe_params(parameters) -> str: if parameters is None: return "none" if isinstance(parameters, dict): return f"dict[{len(parameters)}]" if isinstance(parameters, (list, tuple)): return f"{type(parameters).__name__}[{len(parameters)}]" return type(parameters).__name__ class TracedCursor: """带慢查询检测的游标包装器""" def __init__(self, cursor, on_query_executed): self._cursor = cursor self._on_query_executed = on_query_executed def _trace(self, sql, parameters, execute_fn): start = time.perf_counter() try: execute_fn() finally: elapsed_ms = (time.perf_counter() - start) * 1000.0 try: self._on_query_executed(sql, parameters, elapsed_ms) except Exception: pass def execute(self, sql, parameters=None): if parameters is None: self._trace(sql, None, lambda: self._cursor.execute(sql)) else: self._trace(sql, parameters, lambda: self._cursor.execute(sql, parameters)) return self def executemany(self, sql, seq_of_parameters): self._trace(sql, seq_of_parameters, lambda: self._cursor.executemany(sql, seq_of_parameters)) return self def executescript(self, sql_script): self._trace(sql_script, None, lambda: self._cursor.executescript(sql_script)) return self def fetchone(self): return self._cursor.fetchone() def fetchall(self): return self._cursor.fetchall() def fetchmany(self, size=None): if size is None: return self._cursor.fetchmany() return self._cursor.fetchmany(size) def close(self): return self._cursor.close() @property def rowcount(self): return self._cursor.rowcount @property def lastrowid(self): return self._cursor.lastrowid def __iter__(self): return iter(self._cursor) def __getattr__(self, item): return getattr(self._cursor, item) class ConnectionPool: """SQLite连接池""" def __init__(self, database, pool_size=5, timeout=30): """ 初始化连接池 Args: database: 数据库文件路径 pool_size: 连接池大小(默认5) timeout: 获取连接超时时间(秒) """ self.database = database self.pool_size = pool_size self.timeout = timeout self._pool = Queue(maxsize=pool_size) self._lock = threading.Lock() self._created_connections = 0 # 预创建连接 self._initialize_pool() def _initialize_pool(self): """预创建连接池中的连接""" for _ in range(self.pool_size): conn = self._create_connection() self._pool.put(conn) self._created_connections += 1 def _create_connection(self): """创建新的数据库连接""" conn = sqlite3.connect( self.database, check_same_thread=False, timeout=DB_CONNECT_TIMEOUT_SECONDS, ) conn.row_factory = sqlite3.Row pragma_statements = [ "PRAGMA foreign_keys=ON", "PRAGMA journal_mode=WAL", "PRAGMA synchronous=NORMAL", f"PRAGMA busy_timeout={DB_BUSY_TIMEOUT_MS}", "PRAGMA temp_store=MEMORY", f"PRAGMA cache_size={-DB_CACHE_SIZE_KB}", f"PRAGMA wal_autocheckpoint={DB_WAL_AUTOCHECKPOINT_PAGES}", ] if DB_MMAP_SIZE_MB > 0: pragma_statements.append(f"PRAGMA mmap_size={DB_MMAP_SIZE_MB * 1024 * 1024}") for statement in pragma_statements: try: conn.execute(statement) except sqlite3.DatabaseError as e: logger.warning(f"设置数据库参数失败 ({statement}): {e}") return conn def _close_connection(self, conn) -> None: if conn is None: return try: conn.close() except Exception as e: logger.warning(f"关闭连接失败: {e}") def _is_connection_healthy(self, conn) -> bool: if conn is None: return False try: conn.rollback() conn.execute("SELECT 1") return True except sqlite3.Error as e: logger.warning(f"连接健康检查失败(数据库错误): {e}") except Exception as e: logger.warning(f"连接健康检查失败(未知错误): {e}") return False def _replenish_pool_if_needed(self) -> None: with self._lock: if self._pool.qsize() >= self.pool_size: return new_conn = None try: new_conn = self._create_connection() self._pool.put(new_conn, block=False) self._created_connections += 1 except Full: if new_conn: self._close_connection(new_conn) except Exception as e: if new_conn: self._close_connection(new_conn) logger.warning(f"重建连接失败: {e}") def get_connection(self): """ 从连接池获取连接 Returns: PooledConnection: 连接包装对象 """ try: conn = self._pool.get(timeout=self.timeout) return PooledConnection(conn, self) except Empty: raise RuntimeError(f"无法在{self.timeout}秒内获取数据库连接") def return_connection(self, conn): """ 归还连接到连接池 [安全修复: 改进竞态条件处理] Args: conn: 要归还的连接 """ if conn is None: return if self._is_connection_healthy(conn): try: self._pool.put(conn, block=False) return except Full: logger.warning("连接池已满,关闭多余连接") self._close_connection(conn) return self._close_connection(conn) self._replenish_pool_if_needed() def close_all(self): """关闭所有连接""" while not self._pool.empty(): try: conn = self._pool.get(block=False) conn.close() except Exception as e: logger.warning(f"关闭连接失败: {e}") def get_stats(self): """获取连接池统计信息""" return { "pool_size": self.pool_size, "available": self._pool.qsize(), "in_use": self.pool_size - self._pool.qsize(), "total_created": self._created_connections, } class PooledConnection: """连接池连接包装器,支持with语句自动归还""" def __init__(self, conn, pool): """ 初始化 Args: conn: 实际的数据库连接 pool: 连接池对象 """ self._conn = conn self._pool = pool self._cursor = None def __enter__(self): """支持with语句""" return self def __exit__(self, exc_type, exc_val, exc_tb): """with语句结束时自动归还连接 [已修复Bug#3]""" try: if exc_type is not None: self._conn.rollback() logger.warning(f"数据库事务已回滚: {exc_type.__name__}") if self._cursor is not None: self._cursor.close() self._cursor = None except Exception as e: logger.warning(f"关闭游标失败: {e}") finally: self._pool.return_connection(self._conn) return False def _on_query_executed(self, sql: str, parameters, elapsed_ms: float) -> None: if DB_SLOW_QUERY_MS <= 0: return if elapsed_ms < DB_SLOW_QUERY_MS: return logger.warning( f"[慢SQL] {elapsed_ms:.1f}ms sql=\"{_compact_sql(sql)}\" params={_describe_params(parameters)}" ) def cursor(self): """获取游标""" if self._cursor is None: raw_cursor = self._conn.cursor() self._cursor = TracedCursor(raw_cursor, self._on_query_executed) return self._cursor def commit(self): """提交事务""" for attempt in range(DB_LOCK_RETRY_COUNT + 1): try: self._conn.commit() return except sqlite3.OperationalError as e: if (not _is_lock_conflict_error(e)) or attempt >= DB_LOCK_RETRY_COUNT: raise sleep_seconds = (DB_LOCK_RETRY_BASE_MS * (2**attempt)) / 1000.0 logger.warning( f"数据库提交遇到锁冲突,{sleep_seconds:.3f}s 后重试 " f"({attempt + 1}/{DB_LOCK_RETRY_COUNT})" ) time.sleep(sleep_seconds) def rollback(self): """回滚事务""" self._conn.rollback() def execute(self, sql, parameters=None): """执行SQL""" cursor = self.cursor() if parameters is None: return cursor.execute(sql) return cursor.execute(sql, parameters) def fetchone(self): """获取一行""" if self._cursor: return self._cursor.fetchone() return None def fetchall(self): """获取所有行""" if self._cursor: return self._cursor.fetchall() return [] @property def lastrowid(self): """最后插入的行ID""" if self._cursor: return self._cursor.lastrowid return None @property def rowcount(self): """影响的行数""" if self._cursor: return self._cursor.rowcount return 0 # 全局连接池实例 _pool = None _pool_lock = threading.Lock() def init_pool(database, pool_size=5): """ 初始化全局连接池 Args: database: 数据库文件路径 pool_size: 连接池大小 """ global _pool with _pool_lock: if _pool is None: _pool = ConnectionPool(database, pool_size) logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})") def get_db(): """ 获取数据库连接(替代原有的get_db函数) Returns: PooledConnection: 连接对象 """ global _pool if _pool is None: raise RuntimeError("连接池未初始化,请先调用init_pool()") return _pool.get_connection() def get_pool_stats(): """获取连接池统计信息""" global _pool if _pool: return _pool.get_stats() return None