322 lines
9.4 KiB
Python
Executable File
322 lines
9.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
数据库连接池模块
|
||
使用queue实现固定大小的连接池,防止连接泄漏
|
||
"""
|
||
|
||
import sqlite3
|
||
import threading
|
||
import time
|
||
from queue import Empty, Full, Queue
|
||
|
||
from app_config import get_config
|
||
from app_logger import get_logger
|
||
|
||
|
||
logger = get_logger("database")
|
||
config = get_config()
|
||
|
||
DB_CONNECT_TIMEOUT_SECONDS = max(1, int(getattr(config, "DB_CONNECT_TIMEOUT_SECONDS", 10)))
|
||
DB_BUSY_TIMEOUT_MS = max(1000, int(getattr(config, "DB_BUSY_TIMEOUT_MS", 10000)))
|
||
DB_CACHE_SIZE_KB = max(1024, int(getattr(config, "DB_CACHE_SIZE_KB", 8192)))
|
||
DB_WAL_AUTOCHECKPOINT_PAGES = max(100, int(getattr(config, "DB_WAL_AUTOCHECKPOINT_PAGES", 1000)))
|
||
DB_MMAP_SIZE_MB = max(0, int(getattr(config, "DB_MMAP_SIZE_MB", 256)))
|
||
DB_LOCK_RETRY_COUNT = max(0, int(getattr(config, "DB_LOCK_RETRY_COUNT", 3)))
|
||
DB_LOCK_RETRY_BASE_MS = max(10, int(getattr(config, "DB_LOCK_RETRY_BASE_MS", 50)))
|
||
|
||
|
||
def _is_lock_conflict_error(error: sqlite3.OperationalError) -> bool:
|
||
message = str(error or "").lower()
|
||
return ("locked" in message) or ("busy" in message)
|
||
|
||
|
||
class ConnectionPool:
|
||
"""SQLite连接池"""
|
||
|
||
def __init__(self, database, pool_size=5, timeout=30):
|
||
"""
|
||
初始化连接池
|
||
|
||
Args:
|
||
database: 数据库文件路径
|
||
pool_size: 连接池大小(默认5)
|
||
timeout: 获取连接超时时间(秒)
|
||
"""
|
||
self.database = database
|
||
self.pool_size = pool_size
|
||
self.timeout = timeout
|
||
self._pool = Queue(maxsize=pool_size)
|
||
self._lock = threading.Lock()
|
||
self._created_connections = 0
|
||
|
||
# 预创建连接
|
||
self._initialize_pool()
|
||
|
||
def _initialize_pool(self):
|
||
"""预创建连接池中的连接"""
|
||
for _ in range(self.pool_size):
|
||
conn = self._create_connection()
|
||
self._pool.put(conn)
|
||
self._created_connections += 1
|
||
|
||
def _create_connection(self):
|
||
"""创建新的数据库连接"""
|
||
conn = sqlite3.connect(
|
||
self.database,
|
||
check_same_thread=False,
|
||
timeout=DB_CONNECT_TIMEOUT_SECONDS,
|
||
)
|
||
conn.row_factory = sqlite3.Row
|
||
pragma_statements = [
|
||
"PRAGMA foreign_keys=ON",
|
||
"PRAGMA journal_mode=WAL",
|
||
"PRAGMA synchronous=NORMAL",
|
||
f"PRAGMA busy_timeout={DB_BUSY_TIMEOUT_MS}",
|
||
"PRAGMA temp_store=MEMORY",
|
||
f"PRAGMA cache_size={-DB_CACHE_SIZE_KB}",
|
||
f"PRAGMA wal_autocheckpoint={DB_WAL_AUTOCHECKPOINT_PAGES}",
|
||
]
|
||
if DB_MMAP_SIZE_MB > 0:
|
||
pragma_statements.append(f"PRAGMA mmap_size={DB_MMAP_SIZE_MB * 1024 * 1024}")
|
||
|
||
for statement in pragma_statements:
|
||
try:
|
||
conn.execute(statement)
|
||
except sqlite3.DatabaseError as e:
|
||
logger.warning(f"设置数据库参数失败 ({statement}): {e}")
|
||
return conn
|
||
|
||
def _close_connection(self, conn) -> None:
|
||
if conn is None:
|
||
return
|
||
try:
|
||
conn.close()
|
||
except Exception as e:
|
||
logger.warning(f"关闭连接失败: {e}")
|
||
|
||
def _is_connection_healthy(self, conn) -> bool:
|
||
if conn is None:
|
||
return False
|
||
try:
|
||
conn.rollback()
|
||
conn.execute("SELECT 1")
|
||
return True
|
||
except sqlite3.Error as e:
|
||
logger.warning(f"连接健康检查失败(数据库错误): {e}")
|
||
except Exception as e:
|
||
logger.warning(f"连接健康检查失败(未知错误): {e}")
|
||
return False
|
||
|
||
def _replenish_pool_if_needed(self) -> None:
|
||
with self._lock:
|
||
if self._pool.qsize() >= self.pool_size:
|
||
return
|
||
|
||
new_conn = None
|
||
try:
|
||
new_conn = self._create_connection()
|
||
self._pool.put(new_conn, block=False)
|
||
self._created_connections += 1
|
||
except Full:
|
||
if new_conn:
|
||
self._close_connection(new_conn)
|
||
except Exception as e:
|
||
if new_conn:
|
||
self._close_connection(new_conn)
|
||
logger.warning(f"重建连接失败: {e}")
|
||
|
||
def get_connection(self):
|
||
"""
|
||
从连接池获取连接
|
||
|
||
Returns:
|
||
PooledConnection: 连接包装对象
|
||
"""
|
||
try:
|
||
conn = self._pool.get(timeout=self.timeout)
|
||
return PooledConnection(conn, self)
|
||
except Empty:
|
||
raise RuntimeError(f"无法在{self.timeout}秒内获取数据库连接")
|
||
|
||
def return_connection(self, conn):
|
||
"""
|
||
归还连接到连接池 [安全修复: 改进竞态条件处理]
|
||
|
||
Args:
|
||
conn: 要归还的连接
|
||
"""
|
||
if conn is None:
|
||
return
|
||
|
||
if self._is_connection_healthy(conn):
|
||
try:
|
||
self._pool.put(conn, block=False)
|
||
return
|
||
except Full:
|
||
logger.warning("连接池已满,关闭多余连接")
|
||
self._close_connection(conn)
|
||
return
|
||
|
||
self._close_connection(conn)
|
||
self._replenish_pool_if_needed()
|
||
|
||
def close_all(self):
|
||
"""关闭所有连接"""
|
||
while not self._pool.empty():
|
||
try:
|
||
conn = self._pool.get(block=False)
|
||
conn.close()
|
||
except Exception as e:
|
||
logger.warning(f"关闭连接失败: {e}")
|
||
|
||
def get_stats(self):
|
||
"""获取连接池统计信息"""
|
||
return {
|
||
"pool_size": self.pool_size,
|
||
"available": self._pool.qsize(),
|
||
"in_use": self.pool_size - self._pool.qsize(),
|
||
"total_created": self._created_connections,
|
||
}
|
||
|
||
|
||
class PooledConnection:
|
||
"""连接池连接包装器,支持with语句自动归还"""
|
||
|
||
def __init__(self, conn, pool):
|
||
"""
|
||
初始化
|
||
|
||
Args:
|
||
conn: 实际的数据库连接
|
||
pool: 连接池对象
|
||
"""
|
||
self._conn = conn
|
||
self._pool = pool
|
||
self._cursor = None
|
||
|
||
def __enter__(self):
|
||
"""支持with语句"""
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""with语句结束时自动归还连接 [已修复Bug#3]"""
|
||
try:
|
||
if exc_type is not None:
|
||
# 发生异常,回滚事务
|
||
self._conn.rollback()
|
||
logger.warning(f"数据库事务已回滚: {exc_type.__name__}")
|
||
# 注意: 不自动commit,要求用户显式调用conn.commit()
|
||
|
||
if self._cursor:
|
||
self._cursor.close()
|
||
self._cursor = None
|
||
except Exception as e:
|
||
logger.warning(f"关闭游标失败: {e}")
|
||
finally:
|
||
# 归还连接
|
||
self._pool.return_connection(self._conn)
|
||
|
||
return False # 不抑制异常
|
||
|
||
def cursor(self):
|
||
"""获取游标"""
|
||
if self._cursor is None:
|
||
self._cursor = self._conn.cursor()
|
||
return self._cursor
|
||
|
||
def commit(self):
|
||
"""提交事务"""
|
||
for attempt in range(DB_LOCK_RETRY_COUNT + 1):
|
||
try:
|
||
self._conn.commit()
|
||
return
|
||
except sqlite3.OperationalError as e:
|
||
if (not _is_lock_conflict_error(e)) or attempt >= DB_LOCK_RETRY_COUNT:
|
||
raise
|
||
|
||
sleep_seconds = (DB_LOCK_RETRY_BASE_MS * (2**attempt)) / 1000.0
|
||
logger.warning(
|
||
f"数据库提交遇到锁冲突,{sleep_seconds:.3f}s 后重试 "
|
||
f"({attempt + 1}/{DB_LOCK_RETRY_COUNT})"
|
||
)
|
||
time.sleep(sleep_seconds)
|
||
|
||
def rollback(self):
|
||
"""回滚事务"""
|
||
self._conn.rollback()
|
||
|
||
def execute(self, sql, parameters=None):
|
||
"""执行SQL"""
|
||
cursor = self.cursor()
|
||
if parameters:
|
||
return cursor.execute(sql, parameters)
|
||
return cursor.execute(sql)
|
||
|
||
def fetchone(self):
|
||
"""获取一行"""
|
||
if self._cursor:
|
||
return self._cursor.fetchone()
|
||
return None
|
||
|
||
def fetchall(self):
|
||
"""获取所有行"""
|
||
if self._cursor:
|
||
return self._cursor.fetchall()
|
||
return []
|
||
|
||
@property
|
||
def lastrowid(self):
|
||
"""最后插入的行ID"""
|
||
if self._cursor:
|
||
return self._cursor.lastrowid
|
||
return None
|
||
|
||
@property
|
||
def rowcount(self):
|
||
"""影响的行数"""
|
||
if self._cursor:
|
||
return self._cursor.rowcount
|
||
return 0
|
||
|
||
|
||
# 全局连接池实例
|
||
_pool = None
|
||
_pool_lock = threading.Lock()
|
||
|
||
|
||
def init_pool(database, pool_size=5):
|
||
"""
|
||
初始化全局连接池
|
||
|
||
Args:
|
||
database: 数据库文件路径
|
||
pool_size: 连接池大小
|
||
"""
|
||
global _pool
|
||
with _pool_lock:
|
||
if _pool is None:
|
||
_pool = ConnectionPool(database, pool_size)
|
||
logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
|
||
|
||
|
||
def get_db():
|
||
"""
|
||
获取数据库连接(替代原有的get_db函数)
|
||
|
||
Returns:
|
||
PooledConnection: 连接对象
|
||
"""
|
||
global _pool
|
||
if _pool is None:
|
||
raise RuntimeError("连接池未初始化,请先调用init_pool()")
|
||
return _pool.get_connection()
|
||
|
||
|
||
def get_pool_stats():
|
||
"""获取连接池统计信息"""
|
||
global _pool
|
||
if _pool:
|
||
return _pool.get_stats()
|
||
return None
|