Files
zsglpt/db_pool.py
yuyx 70cd95c366 修复多项安全漏洞
安全修复清单:
1. 验证码改为图片方式返回,防止明文泄露
2. CORS配置从环境变量读取,不再使用通配符"*"
3. VIP API添加@admin_required装饰器,统一认证
4. 用户登录统一错误消息,防止用户枚举
5. IP限流不再信任X-Forwarded-For头,防止伪造绕过
6. 密码强度要求提升(8位+字母+数字)
7. 日志不���记录完整session/cookie内容,防止敏感信息泄露
8. XSS防护:日志输出和Bug反馈内容转义HTML
9. SQL注入防护:LIKE查询参数转义
10. 路径遍历防护:截图目录白名单验证
11. 验证码重放防护:验证前删除验证码
12. 数据库连接池健康检查
13. 正则DoS防护:限制数字匹配长度
14. Account类密码私有化,__repr__不暴露密码

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-11 17:53:48 +08:00

253 lines
7.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
数据库连接池模块
使用queue实现固定大小的连接池,防止连接泄漏
"""
import sqlite3
import threading
from queue import Queue, Empty
import time
class ConnectionPool:
"""SQLite连接池"""
def __init__(self, database, pool_size=5, timeout=30):
"""
初始化连接池
Args:
database: 数据库文件路径
pool_size: 连接池大小(默认5)
timeout: 获取连接超时时间(秒)
"""
self.database = database
self.pool_size = pool_size
self.timeout = timeout
self._pool = Queue(maxsize=pool_size)
self._lock = threading.Lock()
self._created_connections = 0
# 预创建连接
self._initialize_pool()
def _initialize_pool(self):
"""预创建连接池中的连接"""
for _ in range(self.pool_size):
conn = self._create_connection()
self._pool.put(conn)
self._created_connections += 1
def _create_connection(self):
"""创建新的数据库连接"""
conn = sqlite3.connect(self.database, check_same_thread=False)
conn.row_factory = sqlite3.Row
# 设置WAL模式提高并发性能
conn.execute('PRAGMA journal_mode=WAL')
# 设置合理的超时时间
conn.execute('PRAGMA busy_timeout=5000')
return conn
def get_connection(self):
"""
从连接池获取连接
Returns:
PooledConnection: 连接包装对象
"""
try:
conn = self._pool.get(timeout=self.timeout)
return PooledConnection(conn, self)
except Empty:
raise RuntimeError(f"无法在{self.timeout}秒内获取数据库连接")
def return_connection(self, conn):
"""
归还连接到连接池 [已修复Bug#7, Bug#11]
Args:
conn: 要归还的连接
"""
import sqlite3
from queue import Full
try:
# 回滚任何未提交的事务
conn.rollback()
# 安全修复:验证连接是否健康,防止损坏的连接污染连接池
conn.execute("SELECT 1")
self._pool.put(conn, block=False)
except sqlite3.Error as e:
# 数据库相关错误,连接可能损坏
print(f"归还连接失败(数据库错误): {e}")
try:
conn.close()
except Exception as close_error:
print(f"关闭损坏的连接失败: {close_error}")
# 创建新连接补充
with self._lock:
try:
new_conn = self._create_connection()
self._pool.put(new_conn, block=False)
except Exception as create_error:
print(f"重建连接失败: {create_error}")
except Full:
# 队列已满(不应该发生)
print(f"警告: 连接池已满,关闭多余连接")
try:
conn.close()
except Exception as close_error:
print(f"关闭多余连接失败: {close_error}")
except Exception as e:
print(f"归还连接失败(未知错误): {e}")
try:
conn.close()
except Exception as close_error:
print(f"关闭异常连接失败: {close_error}")
def close_all(self):
"""关闭所有连接"""
while not self._pool.empty():
try:
conn = self._pool.get(block=False)
conn.close()
except Exception as e:
print(f"关闭连接失败: {e}")
def get_stats(self):
"""获取连接池统计信息"""
return {
'pool_size': self.pool_size,
'available': self._pool.qsize(),
'in_use': self.pool_size - self._pool.qsize(),
'total_created': self._created_connections
}
class PooledConnection:
"""连接池连接包装器,支持with语句自动归还"""
def __init__(self, conn, pool):
"""
初始化
Args:
conn: 实际的数据库连接
pool: 连接池对象
"""
self._conn = conn
self._pool = pool
self._cursor = None
def __enter__(self):
"""支持with语句"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""with语句结束时自动归还连接 [已修复Bug#3]"""
try:
if exc_type is not None:
# 发生异常,回滚事务
self._conn.rollback()
print(f"数据库事务已回滚: {exc_type.__name__}")
# 注意: 不自动commit要求用户显式调用conn.commit()
if self._cursor:
self._cursor.close()
except Exception as e:
print(f"关闭游标失败: {e}")
finally:
# 归还连接
self._pool.return_connection(self._conn)
return False # 不抑制异常
def cursor(self):
"""获取游标"""
self._cursor = self._conn.cursor()
return self._cursor
def commit(self):
"""提交事务"""
self._conn.commit()
def rollback(self):
"""回滚事务"""
self._conn.rollback()
def execute(self, sql, parameters=None):
"""执行SQL"""
cursor = self.cursor()
if parameters:
return cursor.execute(sql, parameters)
return cursor.execute(sql)
def fetchone(self):
"""获取一行"""
if self._cursor:
return self._cursor.fetchone()
return None
def fetchall(self):
"""获取所有行"""
if self._cursor:
return self._cursor.fetchall()
return []
@property
def lastrowid(self):
"""最后插入的行ID"""
if self._cursor:
return self._cursor.lastrowid
return None
@property
def rowcount(self):
"""影响的行数"""
if self._cursor:
return self._cursor.rowcount
return 0
# 全局连接池实例
_pool = None
_pool_lock = threading.Lock()
def init_pool(database, pool_size=5):
"""
初始化全局连接池
Args:
database: 数据库文件路径
pool_size: 连接池大小
"""
global _pool
with _pool_lock:
if _pool is None:
_pool = ConnectionPool(database, pool_size)
print(f"✓ 数据库连接池已初始化 (大小: {pool_size})")
def get_db():
"""
获取数据库连接(替代原有的get_db函数)
Returns:
PooledConnection: 连接对象
"""
global _pool
if _pool is None:
raise RuntimeError("连接池未初始化,请先调用init_pool()")
return _pool.get_connection()
def get_pool_stats():
"""获取连接池统计信息"""
global _pool
if _pool:
return _pool.get_stats()
return None