Files
zsglpt/db_pool.py

280 lines
7.6 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库连接池模块
使用queue实现固定大小的连接池,防止连接泄漏
"""
import sqlite3
import threading
from queue import Empty, Full, Queue
from app_logger import get_logger
logger = get_logger("database")
class ConnectionPool:
"""SQLite连接池"""
def __init__(self, database, pool_size=5, timeout=30):
"""
初始化连接池
Args:
database: 数据库文件路径
pool_size: 连接池大小(默认5)
timeout: 获取连接超时时间(秒)
"""
self.database = database
self.pool_size = pool_size
self.timeout = timeout
self._pool = Queue(maxsize=pool_size)
self._lock = threading.Lock()
self._created_connections = 0
# 预创建连接
self._initialize_pool()
def _initialize_pool(self):
"""预创建连接池中的连接"""
for _ in range(self.pool_size):
conn = self._create_connection()
self._pool.put(conn)
self._created_connections += 1
def _create_connection(self):
"""创建新的数据库连接"""
conn = sqlite3.connect(self.database, check_same_thread=False)
conn.row_factory = sqlite3.Row
# 启用外键约束,确保 ON DELETE CASCADE 等约束生效
conn.execute("PRAGMA foreign_keys=ON")
# 设置WAL模式提高并发性能
conn.execute("PRAGMA journal_mode=WAL")
# 在WAL模式下使用NORMAL同步兼顾性能与可靠性
conn.execute("PRAGMA synchronous=NORMAL")
# 设置合理的超时时间
conn.execute("PRAGMA busy_timeout=5000")
return conn
def _close_connection(self, conn) -> None:
if conn is None:
return
try:
conn.close()
except Exception as e:
logger.warning(f"关闭连接失败: {e}")
def _is_connection_healthy(self, conn) -> bool:
if conn is None:
return False
try:
conn.rollback()
conn.execute("SELECT 1")
return True
except sqlite3.Error as e:
logger.warning(f"连接健康检查失败(数据库错误): {e}")
except Exception as e:
logger.warning(f"连接健康检查失败(未知错误): {e}")
return False
def _replenish_pool_if_needed(self) -> None:
with self._lock:
if self._pool.qsize() >= self.pool_size:
return
new_conn = None
try:
new_conn = self._create_connection()
self._pool.put(new_conn, block=False)
self._created_connections += 1
except Full:
if new_conn:
self._close_connection(new_conn)
except Exception as e:
if new_conn:
self._close_connection(new_conn)
logger.warning(f"重建连接失败: {e}")
def get_connection(self):
"""
从连接池获取连接
Returns:
PooledConnection: 连接包装对象
"""
try:
conn = self._pool.get(timeout=self.timeout)
return PooledConnection(conn, self)
except Empty:
raise RuntimeError(f"无法在{self.timeout}秒内获取数据库连接")
def return_connection(self, conn):
"""
归还连接到连接池 [安全修复: 改进竞态条件处理]
Args:
conn: 要归还的连接
"""
if conn is None:
return
if self._is_connection_healthy(conn):
try:
self._pool.put(conn, block=False)
return
except Full:
logger.warning("连接池已满,关闭多余连接")
self._close_connection(conn)
return
self._close_connection(conn)
self._replenish_pool_if_needed()
def close_all(self):
"""关闭所有连接"""
while not self._pool.empty():
try:
conn = self._pool.get(block=False)
conn.close()
except Exception as e:
logger.warning(f"关闭连接失败: {e}")
def get_stats(self):
"""获取连接池统计信息"""
return {
"pool_size": self.pool_size,
"available": self._pool.qsize(),
"in_use": self.pool_size - self._pool.qsize(),
"total_created": self._created_connections,
}
class PooledConnection:
"""连接池连接包装器,支持with语句自动归还"""
def __init__(self, conn, pool):
"""
初始化
Args:
conn: 实际的数据库连接
pool: 连接池对象
"""
self._conn = conn
self._pool = pool
self._cursor = None
def __enter__(self):
"""支持with语句"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""with语句结束时自动归还连接 [已修复Bug#3]"""
try:
if exc_type is not None:
# 发生异常,回滚事务
self._conn.rollback()
logger.warning(f"数据库事务已回滚: {exc_type.__name__}")
# 注意: 不自动commit要求用户显式调用conn.commit()
if self._cursor:
self._cursor.close()
self._cursor = None
except Exception as e:
logger.warning(f"关闭游标失败: {e}")
finally:
# 归还连接
self._pool.return_connection(self._conn)
return False # 不抑制异常
def cursor(self):
"""获取游标"""
if self._cursor is None:
self._cursor = self._conn.cursor()
return self._cursor
def commit(self):
"""提交事务"""
self._conn.commit()
def rollback(self):
"""回滚事务"""
self._conn.rollback()
def execute(self, sql, parameters=None):
"""执行SQL"""
cursor = self.cursor()
if parameters:
return cursor.execute(sql, parameters)
return cursor.execute(sql)
def fetchone(self):
"""获取一行"""
if self._cursor:
return self._cursor.fetchone()
return None
def fetchall(self):
"""获取所有行"""
if self._cursor:
return self._cursor.fetchall()
return []
@property
def lastrowid(self):
"""最后插入的行ID"""
if self._cursor:
return self._cursor.lastrowid
return None
@property
def rowcount(self):
"""影响的行数"""
if self._cursor:
return self._cursor.rowcount
return 0
# 全局连接池实例
_pool = None
_pool_lock = threading.Lock()
def init_pool(database, pool_size=5):
"""
初始化全局连接池
Args:
database: 数据库文件路径
pool_size: 连接池大小
"""
global _pool
with _pool_lock:
if _pool is None:
_pool = ConnectionPool(database, pool_size)
logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
def get_db():
"""
获取数据库连接(替代原有的get_db函数)
Returns:
PooledConnection: 连接对象
"""
global _pool
if _pool is None:
raise RuntimeError("连接池未初始化,请先调用init_pool()")
return _pool.get_connection()
def get_pool_stats():
"""获取连接池统计信息"""
global _pool
if _pool:
return _pool.get_stats()
return None