Files
zsglpt/db_pool.py
yuyx acb22cf96b 修复12项安全漏洞和代码质量问题
安全修复:
- 使用secrets替代random生成验证码,提升安全性
- 添加内存清理调度器,防止内存泄漏
- PIL缺失时返回503而非降级服务
- 改进会话安全配置,支持环境自动检测
- 密钥文件路径支持环境变量配置

Bug修复:
- 改进异常处理,不再吞掉SystemExit/KeyboardInterrupt
- 清理死代码(if False占位符)
- 改进浏览器资源释放逻辑,使用try-finally确保关闭
- 重构数据库连接池归还逻辑,修复竞态条件
- 添加安全的JSON解析方法,处理损坏数据
- 日志级别默认值改为INFO
- 提取魔法数字为可配置常量

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-11 20:00:19 +08:00

267 lines
7.5 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库连接池模块
使用queue实现固定大小的连接池,防止连接泄漏
"""
import sqlite3
import threading
from queue import Queue, Empty
import time
class ConnectionPool:
"""SQLite连接池"""
def __init__(self, database, pool_size=5, timeout=30):
"""
初始化连接池
Args:
database: 数据库文件路径
pool_size: 连接池大小(默认5)
timeout: 获取连接超时时间(秒)
"""
self.database = database
self.pool_size = pool_size
self.timeout = timeout
self._pool = Queue(maxsize=pool_size)
self._lock = threading.Lock()
self._created_connections = 0
# 预创建连接
self._initialize_pool()
def _initialize_pool(self):
"""预创建连接池中的连接"""
for _ in range(self.pool_size):
conn = self._create_connection()
self._pool.put(conn)
self._created_connections += 1
def _create_connection(self):
"""创建新的数据库连接"""
conn = sqlite3.connect(self.database, check_same_thread=False)
conn.row_factory = sqlite3.Row
# 设置WAL模式提高并发性能
conn.execute('PRAGMA journal_mode=WAL')
# 设置合理的超时时间
conn.execute('PRAGMA busy_timeout=5000')
return conn
def get_connection(self):
"""
从连接池获取连接
Returns:
PooledConnection: 连接包装对象
"""
try:
conn = self._pool.get(timeout=self.timeout)
return PooledConnection(conn, self)
except Empty:
raise RuntimeError(f"无法在{self.timeout}秒内获取数据库连接")
def return_connection(self, conn):
"""
归还连接到连接池 [安全修复: 改进竞态条件处理]
Args:
conn: 要归还的连接
"""
import sqlite3
from queue import Full
if conn is None:
return
connection_healthy = False
try:
# 回滚任何未提交的事务
conn.rollback()
# 安全修复:验证连接是否健康,防止损坏的连接污染连接池
conn.execute("SELECT 1")
connection_healthy = True
except sqlite3.Error as e:
# 数据库相关错误,连接可能损坏
print(f"连接健康检查失败(数据库错误): {e}")
except Exception as e:
print(f"连接健康检查失败(未知错误): {e}")
if connection_healthy:
try:
self._pool.put(conn, block=False)
return # 成功归还
except Full:
# 队列已满(不应该发生,但处理它)
print(f"警告: 连接池已满,关闭多余连接")
connection_healthy = False # 标记为需要关闭
# 连接不健康或队列已满,关闭它
try:
conn.close()
except Exception as close_error:
print(f"关闭连接失败: {close_error}")
# 如果连接不健康,尝试创建新连接补充池
if not connection_healthy:
with self._lock:
# 双重检查:确保池确实需要补充
if self._pool.qsize() < self.pool_size:
try:
new_conn = self._create_connection()
self._pool.put(new_conn, block=False)
except Full:
# 在获取锁期间池被填满了,关闭新建的连接
try:
new_conn.close()
except Exception:
pass
except Exception as create_error:
print(f"重建连接失败: {create_error}")
def close_all(self):
"""关闭所有连接"""
while not self._pool.empty():
try:
conn = self._pool.get(block=False)
conn.close()
except Exception as e:
print(f"关闭连接失败: {e}")
def get_stats(self):
"""获取连接池统计信息"""
return {
'pool_size': self.pool_size,
'available': self._pool.qsize(),
'in_use': self.pool_size - self._pool.qsize(),
'total_created': self._created_connections
}
class PooledConnection:
"""连接池连接包装器,支持with语句自动归还"""
def __init__(self, conn, pool):
"""
初始化
Args:
conn: 实际的数据库连接
pool: 连接池对象
"""
self._conn = conn
self._pool = pool
self._cursor = None
def __enter__(self):
"""支持with语句"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""with语句结束时自动归还连接 [已修复Bug#3]"""
try:
if exc_type is not None:
# 发生异常,回滚事务
self._conn.rollback()
print(f"数据库事务已回滚: {exc_type.__name__}")
# 注意: 不自动commit要求用户显式调用conn.commit()
if self._cursor:
self._cursor.close()
except Exception as e:
print(f"关闭游标失败: {e}")
finally:
# 归还连接
self._pool.return_connection(self._conn)
return False # 不抑制异常
def cursor(self):
"""获取游标"""
self._cursor = self._conn.cursor()
return self._cursor
def commit(self):
"""提交事务"""
self._conn.commit()
def rollback(self):
"""回滚事务"""
self._conn.rollback()
def execute(self, sql, parameters=None):
"""执行SQL"""
cursor = self.cursor()
if parameters:
return cursor.execute(sql, parameters)
return cursor.execute(sql)
def fetchone(self):
"""获取一行"""
if self._cursor:
return self._cursor.fetchone()
return None
def fetchall(self):
"""获取所有行"""
if self._cursor:
return self._cursor.fetchall()
return []
@property
def lastrowid(self):
"""最后插入的行ID"""
if self._cursor:
return self._cursor.lastrowid
return None
@property
def rowcount(self):
"""影响的行数"""
if self._cursor:
return self._cursor.rowcount
return 0
# 全局连接池实例
_pool = None
_pool_lock = threading.Lock()
def init_pool(database, pool_size=5):
"""
初始化全局连接池
Args:
database: 数据库文件路径
pool_size: 连接池大小
"""
global _pool
with _pool_lock:
if _pool is None:
_pool = ConnectionPool(database, pool_size)
print(f"✓ 数据库连接池已初始化 (大小: {pool_size})")
def get_db():
"""
获取数据库连接(替代原有的get_db函数)
Returns:
PooledConnection: 连接对象
"""
global _pool
if _pool is None:
raise RuntimeError("连接池未初始化,请先调用init_pool()")
return _pool.get_connection()
def get_pool_stats():
"""获取连接池统计信息"""
global _pool
if _pool:
return _pool.get_stats()
return None