修复高优先级安全和稳定性问题

1. app.py:
   - 添加user_semaphores_lock线程锁,修复get_user_semaphore竞态条件
   - 将9处裸except:改为except Exception:或具体异常类型

2. database.py:
   - 将3处裸except:改为except Exception:

3. email_service.py:
   - 添加_smtp_config_lock线程锁
   - 修复daily_sent计数竞态条件:获取配置时预增,失败时回退
   - _get_available_smtp_config和_get_next_available_smtp_config使用锁保护

🤖 Generated with Claude Code
This commit is contained in:
Yu Yon
2025-12-12 15:04:58 +08:00
parent 8b1014b922
commit b15e6f2af0
3 changed files with 129 additions and 94 deletions

28
app.py
View File

@@ -162,6 +162,7 @@ IP_LOCK_DURATION = config.IP_LOCK_DURATION
max_concurrent_per_account = config.MAX_CONCURRENT_PER_ACCOUNT max_concurrent_per_account = config.MAX_CONCURRENT_PER_ACCOUNT
max_concurrent_global = config.MAX_CONCURRENT_GLOBAL max_concurrent_global = config.MAX_CONCURRENT_GLOBAL
user_semaphores = {} # {user_id: Semaphore} user_semaphores = {} # {user_id: Semaphore}
user_semaphores_lock = threading.Lock() # 保护user_semaphores的线程锁
global_semaphore = threading.Semaphore(max_concurrent_global) global_semaphore = threading.Semaphore(max_concurrent_global)
# 截图专用信号量:限制同时进行的截图任务数量(避免资源竞争) # 截图专用信号量:限制同时进行的截图任务数量(避免资源竞争)
@@ -1084,7 +1085,7 @@ def generate_captcha():
try: try:
font = ImageFont.truetype(font_path, 42) font = ImageFont.truetype(font_path, 42)
break break
except: except Exception:
continue continue
if font is None: if font is None:
font = ImageFont.load_default() font = ImageFont.load_default()
@@ -1937,10 +1938,11 @@ def stop_account(account_id):
def get_user_semaphore(user_id): def get_user_semaphore(user_id):
"""获取或创建用户的信号量""" """获取或创建用户的信号量(线程安全)"""
if user_id not in user_semaphores: with user_semaphores_lock:
user_semaphores[user_id] = threading.Semaphore(max_concurrent_per_account) if user_id not in user_semaphores:
return user_semaphores[user_id] user_semaphores[user_id] = threading.Semaphore(max_concurrent_per_account)
return user_semaphores[user_id]
def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="manual", retry_count=0): def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="manual", retry_count=0):
@@ -4232,7 +4234,7 @@ def get_user_schedules_api():
for s in schedules: for s in schedules:
try: try:
s['account_ids'] = json.loads(s.get('account_ids', '[]') or '[]') s['account_ids'] = json.loads(s.get('account_ids', '[]') or '[]')
except: except (json.JSONDecodeError, TypeError):
s['account_ids'] = [] s['account_ids'] = []
return jsonify(schedules) return jsonify(schedules)
@@ -4282,7 +4284,7 @@ def get_schedule_detail_api(schedule_id):
import json import json
try: try:
schedule['account_ids'] = json.loads(schedule.get('account_ids', '[]') or '[]') schedule['account_ids'] = json.loads(schedule.get('account_ids', '[]') or '[]')
except: except (json.JSONDecodeError, TypeError):
schedule['account_ids'] = [] schedule['account_ids'] = []
return jsonify(schedule) return jsonify(schedule)
@@ -4362,7 +4364,7 @@ def run_schedule_now_api(schedule_id):
try: try:
account_ids = json.loads(schedule.get('account_ids', '[]') or '[]') account_ids = json.loads(schedule.get('account_ids', '[]') or '[]')
except: except (json.JSONDecodeError, TypeError):
account_ids = [] account_ids = []
if not account_ids: if not account_ids:
@@ -4544,7 +4546,7 @@ def cleanup_on_exit():
for user_id in user_accounts: for user_id in user_accounts:
if account_id in user_accounts[user_id]: if account_id in user_accounts[user_id]:
user_accounts[user_id][account_id].should_stop = True user_accounts[user_id][account_id].should_stop = True
except: except Exception:
pass pass
# 2. 等待所有线程完成最多等待5秒 # 2. 等待所有线程完成最多等待5秒
@@ -4553,28 +4555,28 @@ def cleanup_on_exit():
try: try:
if thread and thread.is_alive(): if thread and thread.is_alive():
thread.join(timeout=2) thread.join(timeout=2)
except: except Exception:
pass pass
# 3. 关闭浏览器工作线程池 # 3. 关闭浏览器工作线程池
print("- 关闭浏览器线程池...") print("- 关闭浏览器线程池...")
try: try:
shutdown_browser_worker_pool() shutdown_browser_worker_pool()
except: except Exception:
pass pass
# 3.5 关闭邮件队列 # 3.5 关闭邮件队列
print("- 关闭邮件队列...") print("- 关闭邮件队列...")
try: try:
email_service.shutdown_email_queue() email_service.shutdown_email_queue()
except: except Exception:
pass pass
# 4. 关闭数据库连接池 # 4. 关闭数据库连接池
print("- 关闭数据库连接池...") print("- 关闭数据库连接池...")
try: try:
db_pool._pool.close_all() if db_pool._pool else None db_pool._pool.close_all() if db_pool._pool else None
except: except Exception:
pass pass
print("✓ 资源清理完成") print("✓ 资源清理完成")

View File

@@ -837,7 +837,7 @@ def update_user_email(user_id, email, verified=False):
# 先检查email_verified字段是否存在不存在则添加 # 先检查email_verified字段是否存在不存在则添加
try: try:
cursor.execute('SELECT email_verified FROM users LIMIT 1') cursor.execute('SELECT email_verified FROM users LIMIT 1')
except: except Exception:
cursor.execute('ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0') cursor.execute('ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0')
conn.commit() conn.commit()
@@ -857,7 +857,7 @@ def update_user_email_notify(user_id, enabled):
# 先检查字段是否存在 # 先检查字段是否存在
try: try:
cursor.execute('SELECT email_notify_enabled FROM users LIMIT 1') cursor.execute('SELECT email_notify_enabled FROM users LIMIT 1')
except: except Exception:
cursor.execute('ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1') cursor.execute('ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1')
conn.commit() conn.commit()
@@ -881,7 +881,7 @@ def get_user_email_notify(user_id):
if row is None: if row is None:
return True return True
return bool(row[0]) if row[0] is not None else True return bool(row[0]) if row[0] is not None else True
except: except Exception:
return True # 字段不存在时默认开启 return True # 字段不存在时默认开启

View File

@@ -78,6 +78,9 @@ QUEUE_MAX_SIZE = int(os.environ.get('EMAIL_QUEUE_MAX_SIZE', '100'))
# 为安全起见设置为10MB超过则分批发送 # 为安全起见设置为10MB超过则分批发送
MAX_ATTACHMENT_SIZE = int(os.environ.get('EMAIL_MAX_ATTACHMENT_SIZE', str(10 * 1024 * 1024))) # 10MB MAX_ATTACHMENT_SIZE = int(os.environ.get('EMAIL_MAX_ATTACHMENT_SIZE', str(10 * 1024 * 1024))) # 10MB
# SMTP配置获取锁防止并发获取时竞态条件导致超过每日限额
_smtp_config_lock = threading.Lock()
# ============ 数据库操作 ============ # ============ 数据库操作 ============
@@ -500,80 +503,97 @@ def set_primary_smtp_config(config_id: int) -> bool:
def _get_available_smtp_config(failover: bool = True) -> Optional[Dict[str, Any]]: def _get_available_smtp_config(failover: bool = True) -> Optional[Dict[str, Any]]:
""" """
获取可用的SMTP配置 获取可用的SMTP配置(线程安全)
优先级: 主配置 > 按priority排序的启用配置 优先级: 主配置 > 按priority排序的启用配置
使用锁保护防止并发获取时超过每日限额
""" """
today = datetime.now().strftime('%Y-%m-%d') today = datetime.now().strftime('%Y-%m-%d')
with db_pool.get_db() as conn: with _smtp_config_lock: # 使用锁保护整个获取过程
cursor = conn.cursor() with db_pool.get_db() as conn:
cursor = conn.cursor()
# 先重置过期的每日计数 # 先重置过期的每日计数
cursor.execute(""" cursor.execute("""
UPDATE smtp_configs UPDATE smtp_configs
SET daily_sent = 0, daily_reset_date = ? SET daily_sent = 0, daily_reset_date = ?
WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = '' WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = ''
""", (today, today)) """, (today, today))
conn.commit() conn.commit()
# 获取所有启用的配置,按优先级排序 # 获取所有启用的配置,按优先级排序
cursor.execute(""" cursor.execute("""
SELECT id, name, host, port, username, password, use_ssl, use_tls, SELECT id, name, host, port, username, password, use_ssl, use_tls,
sender_name, sender_email, daily_limit, daily_sent, is_primary sender_name, sender_email, daily_limit, daily_sent, is_primary
FROM smtp_configs FROM smtp_configs
WHERE enabled = 1 WHERE enabled = 1
ORDER BY is_primary DESC, priority ASC, id ASC ORDER BY is_primary DESC, priority ASC, id ASC
""") """)
configs = cursor.fetchall() configs = cursor.fetchall()
for row in configs: for row in configs:
config_id, name, host, port, username, password, use_ssl, use_tls, \ config_id, name, host, port, username, password, use_ssl, use_tls, \
sender_name, sender_email, daily_limit, daily_sent, is_primary = row sender_name, sender_email, daily_limit, daily_sent, is_primary = row
# 检查每日限额 # 检查每日限额
if daily_limit > 0 and daily_sent >= daily_limit: if daily_limit > 0 and daily_sent >= daily_limit:
continue # 超过限额,跳过此配置 continue # 超过限额,跳过此配置
# 解密密码 # 预增计数(在返回配置前先占用配额,防止并发超限)
decrypted_password = decrypt_password(password) if password else '' # 如果发送失败_update_smtp_stats会在失败时回退
cursor.execute("""
UPDATE smtp_configs
SET daily_sent = daily_sent + 1
WHERE id = ?
""", (config_id,))
conn.commit()
return { # 解密密码
'id': config_id, decrypted_password = decrypt_password(password) if password else ''
'name': name,
'host': host,
'port': port,
'username': username,
'password': decrypted_password,
'use_ssl': bool(use_ssl),
'use_tls': bool(use_tls),
'sender_name': sender_name,
'sender_email': sender_email,
'is_primary': bool(is_primary)
}
return None return {
'id': config_id,
'name': name,
'host': host,
'port': port,
'username': username,
'password': decrypted_password,
'use_ssl': bool(use_ssl),
'use_tls': bool(use_tls),
'sender_name': sender_name,
'sender_email': sender_email,
'is_primary': bool(is_primary)
}
return None
def _update_smtp_stats(config_id: int, success: bool, error: str = ''): def _update_smtp_stats(config_id: int, success: bool, error: str = ''):
"""更新SMTP配置的统计信息""" """更新SMTP配置的统计信息
注意daily_sent已在_get_available_smtp_config中预增
成功时只更新success_count<EFBC8C><E5A4B1>时需要回退daily_sent
"""
with db_pool.get_db() as conn: with db_pool.get_db() as conn:
cursor = conn.cursor() cursor = conn.cursor()
if success: if success:
# 成功只更新成功计数daily_sent已在获取配置时预增
cursor.execute(""" cursor.execute("""
UPDATE smtp_configs UPDATE smtp_configs
SET daily_sent = daily_sent + 1, SET success_count = success_count + 1,
success_count = success_count + 1,
last_success_at = CURRENT_TIMESTAMP, last_success_at = CURRENT_TIMESTAMP,
last_error = '', last_error = '',
updated_at = CURRENT_TIMESTAMP updated_at = CURRENT_TIMESTAMP
WHERE id = ? WHERE id = ?
""", (config_id,)) """, (config_id,))
else: else:
# 失败回退daily_sent并更新失败计数
cursor.execute(""" cursor.execute("""
UPDATE smtp_configs UPDATE smtp_configs
SET fail_count = fail_count + 1, SET daily_sent = MAX(0, daily_sent - 1),
fail_count = fail_count + 1,
last_error = ?, last_error = ?,
updated_at = CURRENT_TIMESTAMP updated_at = CURRENT_TIMESTAMP
WHERE id = ? WHERE id = ?
@@ -765,46 +785,59 @@ def send_email(
def _get_next_available_smtp_config(exclude_ids: List[int]) -> Optional[Dict[str, Any]]: def _get_next_available_smtp_config(exclude_ids: List[int]) -> Optional[Dict[str, Any]]:
"""获取下一个可用的SMTP配置排除已尝试的""" """获取下一个可用的SMTP配置排除已尝试的,线程安全"""
today = datetime.now().strftime('%Y-%m-%d') today = datetime.now().strftime('%Y-%m-%d')
with db_pool.get_db() as conn: with _smtp_config_lock: # 使用锁保护
cursor = conn.cursor() with db_pool.get_db() as conn:
cursor = conn.cursor()
placeholders = ','.join(['?' for _ in exclude_ids]) placeholders = ','.join(['?' for _ in exclude_ids])
cursor.execute(f""" cursor.execute(f"""
SELECT id, name, host, port, username, password, use_ssl, use_tls, SELECT id, name, host, port, username, password, use_ssl, use_tls,
sender_name, sender_email, daily_limit, daily_sent, is_primary sender_name, sender_email, daily_limit, daily_sent, is_primary
FROM smtp_configs FROM smtp_configs
WHERE enabled = 1 AND id NOT IN ({placeholders}) WHERE enabled = 1 AND id NOT IN ({placeholders})
ORDER BY is_primary DESC, priority ASC, id ASC ORDER BY is_primary DESC, priority ASC, id ASC
LIMIT 1 LIMIT 1
""", exclude_ids) """, exclude_ids)
row = cursor.fetchone() row = cursor.fetchone()
if not row: if not row:
return None return None
config_id, name, host, port, username, password, use_ssl, use_tls, \ config_id, name, host, port, username, password, use_ssl, use_tls, \
sender_name, sender_email, daily_limit, daily_sent, is_primary = row sender_name, sender_email, daily_limit, daily_sent, is_primary = row
# 检查每日限额 # 检查每日限额
if daily_limit > 0 and daily_sent >= daily_limit: if daily_limit > 0 and daily_sent >= daily_limit:
return _get_next_available_smtp_config(exclude_ids + [config_id]) # 递归调用在锁外进行,避免死锁
pass
else:
# 预增计数
cursor.execute("""
UPDATE smtp_configs
SET daily_sent = daily_sent + 1
WHERE id = ?
""", (config_id,))
conn.commit()
return { return {
'id': config_id, 'id': config_id,
'name': name, 'name': name,
'host': host, 'host': host,
'port': port, 'port': port,
'username': username, 'username': username,
'password': decrypt_password(password) if password else '', 'password': decrypt_password(password) if password else '',
'use_ssl': bool(use_ssl), 'use_ssl': bool(use_ssl),
'use_tls': bool(use_tls), 'use_tls': bool(use_tls),
'sender_name': sender_name, 'sender_name': sender_name,
'sender_email': sender_email, 'sender_email': sender_email,
'is_primary': bool(is_primary) 'is_primary': bool(is_primary)
} }
# 递归调用在锁外进行
return _get_next_available_smtp_config(exclude_ids + [config_id])
def test_smtp_config(config_id: int, test_email: str) -> Dict[str, Any]: def test_smtp_config(config_id: int, test_email: str) -> Dict[str, Any]: