diff --git a/app.py b/app.py
index 94dad79..4fcfee3 100644
--- a/app.py
+++ b/app.py
@@ -173,10 +173,28 @@ def serve_static(filename):
if not is_safe_path("static", filename):
return jsonify({"error": "非法路径"}), 403
- response = send_from_directory("static", filename)
- response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
- response.headers["Pragma"] = "no-cache"
- response.headers["Expires"] = "0"
+ cache_ttl = 3600
+ lowered = filename.lower()
+ if "/assets/" in lowered or lowered.endswith((".js", ".css", ".woff", ".woff2", ".ttf", ".svg")):
+ cache_ttl = 604800 # 7天
+
+ if request.args.get("v"):
+ cache_ttl = max(cache_ttl, 604800)
+
+ response = send_from_directory("static", filename, max_age=cache_ttl, conditional=True)
+
+ # 协商缓存:确保存在 ETag,并基于 If-None-Match/If-Modified-Since 返回 304
+ try:
+ response.add_etag(overwrite=False)
+ except Exception:
+ pass
+ try:
+ response.make_conditional(request)
+ except Exception:
+ pass
+
+ response.headers.setdefault("Vary", "Accept-Encoding")
+ response.headers["Cache-Control"] = f"public, max-age={cache_ttl}"
return response
@@ -232,6 +250,93 @@ def _signal_handler(sig, frame):
sys.exit(0)
+def _cleanup_stale_task_state() -> None:
+ logger.info("清理遗留任务状态...")
+ try:
+ from services.state import safe_get_active_task_ids, safe_remove_task, safe_remove_task_status
+
+ for _, accounts in safe_iter_user_accounts_items():
+ for acc in accounts.values():
+ if not getattr(acc, "is_running", False):
+ continue
+ acc.is_running = False
+ acc.should_stop = False
+ acc.status = "未开始"
+
+ for account_id in list(safe_get_active_task_ids()):
+ safe_remove_task(account_id)
+ safe_remove_task_status(account_id)
+
+ logger.info("[OK] 遗留任务状态已清理")
+ except Exception as e:
+ logger.warning(f"清理遗留任务状态失败: {e}")
+
+
+def _init_optional_email_service() -> None:
+ try:
+ email_service.init_email_service()
+ logger.info("[OK] 邮件服务已初始化")
+ except Exception as e:
+ logger.warning(f"警告: 邮件服务初始化失败: {e}")
+
+
+def _load_and_apply_scheduler_limits() -> None:
+ try:
+ system_config = database.get_system_config() or {}
+ max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
+ max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
+ get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
+ logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
+ except Exception as e:
+ logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
+
+
+def _start_background_workers() -> None:
+ logger.info("启动定时任务调度器...")
+ threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
+ logger.info("[OK] 定时任务调度器已启动")
+
+ logger.info("[OK] 状态推送线程已启动(默认2秒/次)")
+ threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
+
+
+def _init_screenshot_worker_pool() -> None:
+ try:
+ pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
+ except Exception:
+ pool_size = 3
+
+ try:
+ logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...")
+ init_browser_worker_pool(pool_size=pool_size)
+ logger.info("[OK] 截图线程池初始化完成")
+ except Exception as e:
+ logger.warning(f"警告: 截图线程池初始化失败: {e}")
+
+
+def _warmup_api_connection() -> None:
+ logger.info("预热 API 连接...")
+ try:
+ from api_browser import warmup_api_connection
+
+ threading.Thread(
+ target=warmup_api_connection,
+ kwargs={"log_callback": lambda msg: logger.info(msg)},
+ daemon=True,
+ name="api-warmup",
+ ).start()
+ except Exception as e:
+ logger.warning(f"API 预热失败: {e}")
+
+
+def _log_startup_urls() -> None:
+ logger.info("服务器启动中...")
+ logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
+ logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
+ logger.info("默认管理员: admin (首次运行随机密码见日志)")
+ logger.info("=" * 60)
+
+
if __name__ == "__main__":
atexit.register(cleanup_on_exit)
signal.signal(signal.SIGINT, _signal_handler)
@@ -245,81 +350,17 @@ if __name__ == "__main__":
init_checkpoint_manager()
logger.info("[OK] 任务断点管理器已初始化")
- # 【新增】容器重启时清理遗留的任务状态
- logger.info("清理遗留任务状态...")
- try:
- from services.state import safe_remove_task, safe_get_active_task_ids, safe_remove_task_status
- # 重置所有账号的运行状态
- for _, accounts in safe_iter_user_accounts_items():
- for acc in accounts.values():
- if getattr(acc, "is_running", False):
- acc.is_running = False
- acc.should_stop = False
- acc.status = "未开始"
- # 清理活跃任务句柄
- for account_id in list(safe_get_active_task_ids()):
- safe_remove_task(account_id)
- safe_remove_task_status(account_id)
- logger.info("[OK] 遗留任务状态已清理")
- except Exception as e:
- logger.warning(f"清理遗留任务状态失败: {e}")
-
- try:
- email_service.init_email_service()
- logger.info("[OK] 邮件服务已初始化")
- except Exception as e:
- logger.warning(f"警告: 邮件服务初始化失败: {e}")
+ _cleanup_stale_task_state()
+ _init_optional_email_service()
start_cleanup_scheduler()
start_kdocs_monitor()
- try:
- system_config = database.get_system_config() or {}
- max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
- max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
- get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
- logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
- except Exception as e:
- logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
-
- logger.info("启动定时任务调度器...")
- threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
- logger.info("[OK] 定时任务调度器已启动")
-
- logger.info("[OK] 状态推送线程已启动(默认2秒/次)")
- threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
-
- logger.info("服务器启动中...")
- logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
- logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
- logger.info("默认管理员: admin (首次运行随机密码见日志)")
- logger.info("=" * 60)
-
- try:
- pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
- except Exception:
- pool_size = 3
- try:
- logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...")
- init_browser_worker_pool(pool_size=pool_size)
- logger.info("[OK] 截图线程池初始化完成")
- except Exception as e:
- logger.warning(f"警告: 截图线程池初始化失败: {e}")
-
- # 预热 API 连接(后台进行,不阻塞启动)
- logger.info("预热 API 连接...")
- try:
- from api_browser import warmup_api_connection
- import threading
-
- threading.Thread(
- target=warmup_api_connection,
- kwargs={"log_callback": lambda msg: logger.info(msg)},
- daemon=True,
- name="api-warmup",
- ).start()
- except Exception as e:
- logger.warning(f"API 预热失败: {e}")
+ _load_and_apply_scheduler_limits()
+ _start_background_workers()
+ _log_startup_urls()
+ _init_screenshot_worker_pool()
+ _warmup_api_connection()
socketio.run(
app,
diff --git a/database.py b/database.py
index d6a132a..56121e1 100644
--- a/database.py
+++ b/database.py
@@ -120,7 +120,7 @@ config = get_config()
DB_FILE = config.DB_FILE
# 数据库版本 (用于迁移管理)
-DB_VERSION = 17
+DB_VERSION = 18
# ==================== 系统配置缓存(P1 / O-03) ====================
@@ -142,6 +142,37 @@ def invalidate_system_config_cache() -> None:
_system_config_cache_loaded_at = 0.0
+def _normalize_system_config_value(value) -> dict:
+ try:
+ return dict(value or {})
+ except Exception:
+ return {}
+
+
+def _is_system_config_cache_valid(now_ts: float) -> bool:
+ if _system_config_cache_value is None:
+ return False
+ if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0:
+ return True
+ return (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS
+
+
+def _read_system_config_cache(now_ts: float, *, ignore_ttl: bool = False) -> Optional[dict]:
+ with _system_config_cache_lock:
+ if _system_config_cache_value is None:
+ return None
+ if (not ignore_ttl) and (not _is_system_config_cache_valid(now_ts)):
+ return None
+ return dict(_system_config_cache_value)
+
+
+def _write_system_config_cache(value: dict, now_ts: float) -> None:
+ global _system_config_cache_value, _system_config_cache_loaded_at
+ with _system_config_cache_lock:
+ _system_config_cache_value = dict(value)
+ _system_config_cache_loaded_at = now_ts
+
+
def init_database():
"""初始化数据库表结构 + 迁移(入口统一)。"""
db_pool.init_pool(DB_FILE, pool_size=config.DB_POOL_SIZE)
@@ -165,19 +196,21 @@ def migrate_database():
def get_system_config():
"""获取系统配置(带进程内缓存)。"""
- global _system_config_cache_value, _system_config_cache_loaded_at
-
now_ts = time.time()
- with _system_config_cache_lock:
- if _system_config_cache_value is not None:
- if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0 or (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS:
- return dict(_system_config_cache_value)
- value = _get_system_config_raw()
+ cached_value = _read_system_config_cache(now_ts)
+ if cached_value is not None:
+ return cached_value
- with _system_config_cache_lock:
- _system_config_cache_value = dict(value)
- _system_config_cache_loaded_at = now_ts
+ try:
+ value = _normalize_system_config_value(_get_system_config_raw())
+ except Exception:
+ fallback_value = _read_system_config_cache(now_ts, ignore_ttl=True)
+ if fallback_value is not None:
+ return fallback_value
+ raise
+
+ _write_system_config_cache(value, now_ts)
return dict(value)
diff --git a/db/accounts.py b/db/accounts.py
index 85e4132..fb98789 100644
--- a/db/accounts.py
+++ b/db/accounts.py
@@ -6,19 +6,51 @@ import db_pool
from crypto_utils import decrypt_password, encrypt_password
from db.utils import get_cst_now_str
+_ACCOUNT_STATUS_QUERY_SQL = """
+ SELECT status, login_fail_count, last_login_error
+ FROM accounts
+ WHERE id = ?
+"""
+
+
+def _decode_account_password(account_dict: dict) -> dict:
+ account_dict["password"] = decrypt_password(account_dict.get("password", ""))
+ return account_dict
+
+
+def _normalize_account_ids(account_ids) -> list[str]:
+ normalized = []
+ seen = set()
+ for account_id in account_ids or []:
+ if not account_id:
+ continue
+ account_key = str(account_id)
+ if account_key in seen:
+ continue
+ seen.add(account_key)
+ normalized.append(account_key)
+ return normalized
+
def create_account(user_id, account_id, username, password, remember=True, remark=""):
"""创建账号(密码加密存储)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cst_time = get_cst_now_str()
encrypted_password = encrypt_password(password)
cursor.execute(
"""
INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
- (account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time),
+ (
+ account_id,
+ user_id,
+ username,
+ encrypted_password,
+ 1 if remember else 0,
+ remark,
+ get_cst_now_str(),
+ ),
)
conn.commit()
return cursor.lastrowid
@@ -29,12 +61,7 @@ def get_user_accounts(user_id):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,))
- accounts = []
- for row in cursor.fetchall():
- account = dict(row)
- account["password"] = decrypt_password(account.get("password", ""))
- accounts.append(account)
- return accounts
+ return [_decode_account_password(dict(row)) for row in cursor.fetchall()]
def get_account(account_id):
@@ -43,11 +70,9 @@ def get_account(account_id):
cursor = conn.cursor()
cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,))
row = cursor.fetchone()
- if row:
- account = dict(row)
- account["password"] = decrypt_password(account.get("password", ""))
- return account
- return None
+ if not row:
+ return None
+ return _decode_account_password(dict(row))
def update_account_remark(account_id, remark):
@@ -78,33 +103,21 @@ def increment_account_login_fail(account_id, error_message):
if not row:
return False
- fail_count = (row["login_fail_count"] or 0) + 1
-
- if fail_count >= 3:
- cursor.execute(
- """
- UPDATE accounts
- SET login_fail_count = ?,
- last_login_error = ?,
- status = 'suspended'
- WHERE id = ?
- """,
- (fail_count, error_message, account_id),
- )
- conn.commit()
- return True
+ fail_count = int(row["login_fail_count"] or 0) + 1
+ is_suspended = fail_count >= 3
cursor.execute(
"""
UPDATE accounts
SET login_fail_count = ?,
- last_login_error = ?
+ last_login_error = ?,
+ status = CASE WHEN ? = 1 THEN 'suspended' ELSE status END
WHERE id = ?
""",
- (fail_count, error_message, account_id),
+ (fail_count, error_message, 1 if is_suspended else 0, account_id),
)
conn.commit()
- return False
+ return is_suspended
def reset_account_login_status(account_id):
@@ -129,29 +142,22 @@ def get_account_status(account_id):
"""获取账号状态信息"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cursor.execute(
- """
- SELECT status, login_fail_count, last_login_error
- FROM accounts
- WHERE id = ?
- """,
- (account_id,),
- )
+ cursor.execute(_ACCOUNT_STATUS_QUERY_SQL, (account_id,))
return cursor.fetchone()
def get_account_status_batch(account_ids):
"""批量获取账号状态信息"""
- account_ids = [str(account_id) for account_id in (account_ids or []) if account_id]
- if not account_ids:
+ normalized_ids = _normalize_account_ids(account_ids)
+ if not normalized_ids:
return {}
results = {}
chunk_size = 900 # 避免触发 SQLite 绑定参数上限
with db_pool.get_db() as conn:
cursor = conn.cursor()
- for idx in range(0, len(account_ids), chunk_size):
- chunk = account_ids[idx : idx + chunk_size]
+ for idx in range(0, len(normalized_ids), chunk_size):
+ chunk = normalized_ids[idx : idx + chunk_size]
placeholders = ",".join("?" for _ in chunk)
cursor.execute(
f"""
diff --git a/db/admin.py b/db/admin.py
index b087ad3..f3be7ca 100644
--- a/db/admin.py
+++ b/db/admin.py
@@ -3,9 +3,6 @@
from __future__ import annotations
import sqlite3
-from datetime import datetime, timedelta
-
-import pytz
import db_pool
from db.utils import get_cst_now_str
@@ -16,6 +13,99 @@ from password_utils import (
verify_password_sha256,
)
+_DEFAULT_SYSTEM_CONFIG = {
+ "max_concurrent_global": 2,
+ "max_concurrent_per_account": 1,
+ "max_screenshot_concurrent": 3,
+ "schedule_enabled": 0,
+ "schedule_time": "02:00",
+ "schedule_browse_type": "应读",
+ "schedule_weekdays": "1,2,3,4,5,6,7",
+ "proxy_enabled": 0,
+ "proxy_api_url": "",
+ "proxy_expire_minutes": 3,
+ "enable_screenshot": 1,
+ "auto_approve_enabled": 0,
+ "auto_approve_hourly_limit": 10,
+ "auto_approve_vip_days": 7,
+ "kdocs_enabled": 0,
+ "kdocs_doc_url": "",
+ "kdocs_default_unit": "",
+ "kdocs_sheet_name": "",
+ "kdocs_sheet_index": 0,
+ "kdocs_unit_column": "A",
+ "kdocs_image_column": "D",
+ "kdocs_admin_notify_enabled": 0,
+ "kdocs_admin_notify_email": "",
+ "kdocs_row_start": 0,
+ "kdocs_row_end": 0,
+}
+
+_SYSTEM_CONFIG_UPDATERS = (
+ ("max_concurrent_global", "max_concurrent"),
+ ("schedule_enabled", "schedule_enabled"),
+ ("schedule_time", "schedule_time"),
+ ("schedule_browse_type", "schedule_browse_type"),
+ ("schedule_weekdays", "schedule_weekdays"),
+ ("max_concurrent_per_account", "max_concurrent_per_account"),
+ ("max_screenshot_concurrent", "max_screenshot_concurrent"),
+ ("enable_screenshot", "enable_screenshot"),
+ ("proxy_enabled", "proxy_enabled"),
+ ("proxy_api_url", "proxy_api_url"),
+ ("proxy_expire_minutes", "proxy_expire_minutes"),
+ ("auto_approve_enabled", "auto_approve_enabled"),
+ ("auto_approve_hourly_limit", "auto_approve_hourly_limit"),
+ ("auto_approve_vip_days", "auto_approve_vip_days"),
+ ("kdocs_enabled", "kdocs_enabled"),
+ ("kdocs_doc_url", "kdocs_doc_url"),
+ ("kdocs_default_unit", "kdocs_default_unit"),
+ ("kdocs_sheet_name", "kdocs_sheet_name"),
+ ("kdocs_sheet_index", "kdocs_sheet_index"),
+ ("kdocs_unit_column", "kdocs_unit_column"),
+ ("kdocs_image_column", "kdocs_image_column"),
+ ("kdocs_admin_notify_enabled", "kdocs_admin_notify_enabled"),
+ ("kdocs_admin_notify_email", "kdocs_admin_notify_email"),
+ ("kdocs_row_start", "kdocs_row_start"),
+ ("kdocs_row_end", "kdocs_row_end"),
+)
+
+
+def _count_scalar(cursor, sql: str, params=()) -> int:
+ cursor.execute(sql, params)
+ row = cursor.fetchone()
+ if not row:
+ return 0
+ try:
+ if "count" in row.keys():
+ return int(row["count"] or 0)
+ except Exception:
+ pass
+ try:
+ return int(row[0] or 0)
+ except Exception:
+ return 0
+
+
+def _table_exists(cursor, table_name: str) -> bool:
+ cursor.execute(
+ """
+ SELECT name FROM sqlite_master
+ WHERE type='table' AND name=?
+ """,
+ (table_name,),
+ )
+ return bool(cursor.fetchone())
+
+
+def _normalize_days(days, default: int = 30) -> int:
+ try:
+ value = int(days)
+ except Exception:
+ value = default
+ if value < 0:
+ return 0
+ return value
+
def ensure_default_admin() -> bool:
"""确保存在默认管理员账号(行为保持不变)。"""
@@ -24,10 +114,9 @@ def ensure_default_admin() -> bool:
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cursor.execute("SELECT COUNT(*) as count FROM admins")
- result = cursor.fetchone()
+ count = _count_scalar(cursor, "SELECT COUNT(*) as count FROM admins")
- if result["count"] == 0:
+ if count == 0:
alphabet = string.ascii_letters + string.digits
random_password = "".join(secrets.choice(alphabet) for _ in range(12))
@@ -101,41 +190,33 @@ def get_system_stats() -> dict:
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cursor.execute("SELECT COUNT(*) as count FROM users")
- total_users = cursor.fetchone()["count"]
-
- cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
- approved_users = cursor.fetchone()["count"]
-
- cursor.execute(
+ total_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users")
+ approved_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users WHERE status = 'approved'")
+ new_users_today = _count_scalar(
+ cursor,
"""
SELECT COUNT(*) as count
FROM users
WHERE date(created_at) = date('now', 'localtime')
- """
+ """,
)
- new_users_today = cursor.fetchone()["count"]
-
- cursor.execute(
+ new_users_7d = _count_scalar(
+ cursor,
"""
SELECT COUNT(*) as count
FROM users
WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days')
- """
+ """,
)
- new_users_7d = cursor.fetchone()["count"]
-
- cursor.execute("SELECT COUNT(*) as count FROM accounts")
- total_accounts = cursor.fetchone()["count"]
-
- cursor.execute(
+ total_accounts = _count_scalar(cursor, "SELECT COUNT(*) as count FROM accounts")
+ vip_users = _count_scalar(
+ cursor,
"""
SELECT COUNT(*) as count FROM users
WHERE vip_expire_time IS NOT NULL
AND datetime(vip_expire_time) > datetime('now', 'localtime')
- """
+ """,
)
- vip_users = cursor.fetchone()["count"]
return {
"total_users": total_users,
@@ -153,37 +234,9 @@ def get_system_config_raw() -> dict:
cursor = conn.cursor()
cursor.execute("SELECT * FROM system_config WHERE id = 1")
row = cursor.fetchone()
-
if row:
return dict(row)
-
- return {
- "max_concurrent_global": 2,
- "max_concurrent_per_account": 1,
- "max_screenshot_concurrent": 3,
- "schedule_enabled": 0,
- "schedule_time": "02:00",
- "schedule_browse_type": "应读",
- "schedule_weekdays": "1,2,3,4,5,6,7",
- "proxy_enabled": 0,
- "proxy_api_url": "",
- "proxy_expire_minutes": 3,
- "enable_screenshot": 1,
- "auto_approve_enabled": 0,
- "auto_approve_hourly_limit": 10,
- "auto_approve_vip_days": 7,
- "kdocs_enabled": 0,
- "kdocs_doc_url": "",
- "kdocs_default_unit": "",
- "kdocs_sheet_name": "",
- "kdocs_sheet_index": 0,
- "kdocs_unit_column": "A",
- "kdocs_image_column": "D",
- "kdocs_admin_notify_enabled": 0,
- "kdocs_admin_notify_email": "",
- "kdocs_row_start": 0,
- "kdocs_row_end": 0,
- }
+ return dict(_DEFAULT_SYSTEM_CONFIG)
def update_system_config(
@@ -215,127 +268,51 @@ def update_system_config(
kdocs_row_end=None,
) -> bool:
"""更新系统配置(仅更新DB,不做缓存处理)。"""
- allowed_fields = {
- "max_concurrent_global",
- "schedule_enabled",
- "schedule_time",
- "schedule_browse_type",
- "schedule_weekdays",
- "max_concurrent_per_account",
- "max_screenshot_concurrent",
- "enable_screenshot",
- "proxy_enabled",
- "proxy_api_url",
- "proxy_expire_minutes",
- "auto_approve_enabled",
- "auto_approve_hourly_limit",
- "auto_approve_vip_days",
- "kdocs_enabled",
- "kdocs_doc_url",
- "kdocs_default_unit",
- "kdocs_sheet_name",
- "kdocs_sheet_index",
- "kdocs_unit_column",
- "kdocs_image_column",
- "kdocs_admin_notify_enabled",
- "kdocs_admin_notify_email",
- "kdocs_row_start",
- "kdocs_row_end",
- "updated_at",
+ arg_values = {
+ "max_concurrent": max_concurrent,
+ "schedule_enabled": schedule_enabled,
+ "schedule_time": schedule_time,
+ "schedule_browse_type": schedule_browse_type,
+ "schedule_weekdays": schedule_weekdays,
+ "max_concurrent_per_account": max_concurrent_per_account,
+ "max_screenshot_concurrent": max_screenshot_concurrent,
+ "enable_screenshot": enable_screenshot,
+ "proxy_enabled": proxy_enabled,
+ "proxy_api_url": proxy_api_url,
+ "proxy_expire_minutes": proxy_expire_minutes,
+ "auto_approve_enabled": auto_approve_enabled,
+ "auto_approve_hourly_limit": auto_approve_hourly_limit,
+ "auto_approve_vip_days": auto_approve_vip_days,
+ "kdocs_enabled": kdocs_enabled,
+ "kdocs_doc_url": kdocs_doc_url,
+ "kdocs_default_unit": kdocs_default_unit,
+ "kdocs_sheet_name": kdocs_sheet_name,
+ "kdocs_sheet_index": kdocs_sheet_index,
+ "kdocs_unit_column": kdocs_unit_column,
+ "kdocs_image_column": kdocs_image_column,
+ "kdocs_admin_notify_enabled": kdocs_admin_notify_enabled,
+ "kdocs_admin_notify_email": kdocs_admin_notify_email,
+ "kdocs_row_start": kdocs_row_start,
+ "kdocs_row_end": kdocs_row_end,
}
+ updates = []
+ params = []
+ for db_field, arg_name in _SYSTEM_CONFIG_UPDATERS:
+ value = arg_values.get(arg_name)
+ if value is None:
+ continue
+ updates.append(f"{db_field} = ?")
+ params.append(value)
+
+ if not updates:
+ return False
+
+ updates.append("updated_at = ?")
+ params.append(get_cst_now_str())
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
- updates = []
- params = []
-
- if max_concurrent is not None:
- updates.append("max_concurrent_global = ?")
- params.append(max_concurrent)
- if schedule_enabled is not None:
- updates.append("schedule_enabled = ?")
- params.append(schedule_enabled)
- if schedule_time is not None:
- updates.append("schedule_time = ?")
- params.append(schedule_time)
- if schedule_browse_type is not None:
- updates.append("schedule_browse_type = ?")
- params.append(schedule_browse_type)
- if max_concurrent_per_account is not None:
- updates.append("max_concurrent_per_account = ?")
- params.append(max_concurrent_per_account)
- if max_screenshot_concurrent is not None:
- updates.append("max_screenshot_concurrent = ?")
- params.append(max_screenshot_concurrent)
- if enable_screenshot is not None:
- updates.append("enable_screenshot = ?")
- params.append(enable_screenshot)
- if schedule_weekdays is not None:
- updates.append("schedule_weekdays = ?")
- params.append(schedule_weekdays)
- if proxy_enabled is not None:
- updates.append("proxy_enabled = ?")
- params.append(proxy_enabled)
- if proxy_api_url is not None:
- updates.append("proxy_api_url = ?")
- params.append(proxy_api_url)
- if proxy_expire_minutes is not None:
- updates.append("proxy_expire_minutes = ?")
- params.append(proxy_expire_minutes)
- if auto_approve_enabled is not None:
- updates.append("auto_approve_enabled = ?")
- params.append(auto_approve_enabled)
- if auto_approve_hourly_limit is not None:
- updates.append("auto_approve_hourly_limit = ?")
- params.append(auto_approve_hourly_limit)
- if auto_approve_vip_days is not None:
- updates.append("auto_approve_vip_days = ?")
- params.append(auto_approve_vip_days)
- if kdocs_enabled is not None:
- updates.append("kdocs_enabled = ?")
- params.append(kdocs_enabled)
- if kdocs_doc_url is not None:
- updates.append("kdocs_doc_url = ?")
- params.append(kdocs_doc_url)
- if kdocs_default_unit is not None:
- updates.append("kdocs_default_unit = ?")
- params.append(kdocs_default_unit)
- if kdocs_sheet_name is not None:
- updates.append("kdocs_sheet_name = ?")
- params.append(kdocs_sheet_name)
- if kdocs_sheet_index is not None:
- updates.append("kdocs_sheet_index = ?")
- params.append(kdocs_sheet_index)
- if kdocs_unit_column is not None:
- updates.append("kdocs_unit_column = ?")
- params.append(kdocs_unit_column)
- if kdocs_image_column is not None:
- updates.append("kdocs_image_column = ?")
- params.append(kdocs_image_column)
- if kdocs_admin_notify_enabled is not None:
- updates.append("kdocs_admin_notify_enabled = ?")
- params.append(kdocs_admin_notify_enabled)
- if kdocs_admin_notify_email is not None:
- updates.append("kdocs_admin_notify_email = ?")
- params.append(kdocs_admin_notify_email)
- if kdocs_row_start is not None:
- updates.append("kdocs_row_start = ?")
- params.append(kdocs_row_start)
- if kdocs_row_end is not None:
- updates.append("kdocs_row_end = ?")
- params.append(kdocs_row_end)
-
- if not updates:
- return False
-
- updates.append("updated_at = ?")
- params.append(get_cst_now_str())
-
- for update_clause in updates:
- field_name = update_clause.split("=")[0].strip()
- if field_name not in allowed_fields:
- raise ValueError(f"非法字段名: {field_name}")
-
sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1"
cursor.execute(sql, params)
conn.commit()
@@ -346,13 +323,13 @@ def get_hourly_registration_count() -> int:
"""获取最近一小时内的注册用户数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cursor.execute(
+ return _count_scalar(
+ cursor,
"""
- SELECT COUNT(*) FROM users
+ SELECT COUNT(*) as count FROM users
WHERE created_at >= datetime('now', 'localtime', '-1 hour')
- """
+ """,
)
- return cursor.fetchone()[0]
# ==================== 密码重置(管理员) ====================
@@ -374,17 +351,12 @@ def admin_reset_user_password(user_id: int, new_password: str) -> bool:
def clean_old_operation_logs(days: int = 30) -> int:
"""清理指定天数前的操作日志(如果存在operation_logs表)"""
+ safe_days = _normalize_days(days, default=30)
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cursor.execute(
- """
- SELECT name FROM sqlite_master
- WHERE type='table' AND name='operation_logs'
- """
- )
-
- if not cursor.fetchone():
+ if not _table_exists(cursor, "operation_logs"):
return 0
try:
@@ -393,11 +365,11 @@ def clean_old_operation_logs(days: int = 30) -> int:
DELETE FROM operation_logs
WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
""",
- (days,),
+ (safe_days,),
)
deleted_count = cursor.rowcount
conn.commit()
- print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)")
+ print(f"已清理 {deleted_count} 条旧操作日志 (>{safe_days}天)")
return deleted_count
except Exception as e:
print(f"清理旧操作日志失败: {e}")
diff --git a/db/announcements.py b/db/announcements.py
index c680816..937d84a 100644
--- a/db/announcements.py
+++ b/db/announcements.py
@@ -6,12 +6,38 @@ import db_pool
from db.utils import get_cst_now_str
+def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
+ try:
+ parsed = int(value)
+ except Exception:
+ parsed = default
+ parsed = max(minimum, parsed)
+ parsed = min(maximum, parsed)
+ return parsed
+
+
+def _normalize_offset(value, default: int = 0) -> int:
+ try:
+ parsed = int(value)
+ except Exception:
+ parsed = default
+ return max(0, parsed)
+
+
+def _normalize_announcement_payload(title, content, image_url):
+ normalized_title = str(title or "").strip()
+ normalized_content = str(content or "").strip()
+ normalized_image = str(image_url or "").strip() or None
+ return normalized_title, normalized_content, normalized_image
+
+
+def _deactivate_all_active_announcements(cursor, cst_time: str) -> None:
+ cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
+
+
def create_announcement(title, content, image_url=None, is_active=True):
"""创建公告(默认启用;启用时会自动停用其他公告)"""
- title = (title or "").strip()
- content = (content or "").strip()
- image_url = (image_url or "").strip()
- image_url = image_url or None
+ title, content, image_url = _normalize_announcement_payload(title, content, image_url)
if not title or not content:
return None
@@ -20,7 +46,7 @@ def create_announcement(title, content, image_url=None, is_active=True):
cst_time = get_cst_now_str()
if is_active:
- cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
+ _deactivate_all_active_announcements(cursor, cst_time)
cursor.execute(
"""
@@ -44,6 +70,9 @@ def get_announcement_by_id(announcement_id):
def get_announcements(limit=50, offset=0):
"""获取公告列表(管理员用)"""
+ safe_limit = _normalize_limit(limit, 50)
+ safe_offset = _normalize_offset(offset, 0)
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -52,7 +81,7 @@ def get_announcements(limit=50, offset=0):
ORDER BY created_at DESC, id DESC
LIMIT ? OFFSET ?
""",
- (limit, offset),
+ (safe_limit, safe_offset),
)
return [dict(row) for row in cursor.fetchall()]
@@ -64,7 +93,7 @@ def set_announcement_active(announcement_id, is_active):
cst_time = get_cst_now_str()
if is_active:
- cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,))
+ _deactivate_all_active_announcements(cursor, cst_time)
cursor.execute(
"""
UPDATE announcements
@@ -121,13 +150,12 @@ def dismiss_announcement_for_user(user_id, announcement_id):
"""用户永久关闭某条公告(幂等)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cst_time = get_cst_now_str()
cursor.execute(
"""
INSERT OR IGNORE INTO announcement_dismissals (user_id, announcement_id, dismissed_at)
VALUES (?, ?, ?)
""",
- (user_id, announcement_id, cst_time),
+ (user_id, announcement_id, get_cst_now_str()),
)
conn.commit()
return cursor.rowcount >= 0
diff --git a/db/email.py b/db/email.py
index 2716e50..b0efc4f 100644
--- a/db/email.py
+++ b/db/email.py
@@ -5,6 +5,27 @@ from __future__ import annotations
import db_pool
+def _to_bool_with_default(value, default: bool = True) -> bool:
+ if value is None:
+ return default
+ try:
+ return bool(int(value))
+ except Exception:
+ try:
+ return bool(value)
+ except Exception:
+ return default
+
+
+def _normalize_notify_enabled(enabled) -> int:
+ if isinstance(enabled, bool):
+ return 1 if enabled else 0
+ try:
+ return 1 if int(enabled) else 0
+ except Exception:
+ return 1
+
+
def get_user_by_email(email):
"""根据邮箱获取用户"""
with db_pool.get_db() as conn:
@@ -25,7 +46,7 @@ def update_user_email(user_id, email, verified=False):
SET email = ?, email_verified = ?
WHERE id = ?
""",
- (email, int(verified), user_id),
+ (email, 1 if verified else 0, user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -42,7 +63,7 @@ def update_user_email_notify(user_id, enabled):
SET email_notify_enabled = ?
WHERE id = ?
""",
- (int(enabled), user_id),
+ (_normalize_notify_enabled(enabled), user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -57,6 +78,6 @@ def get_user_email_notify(user_id):
row = cursor.fetchone()
if row is None:
return True
- return bool(row[0]) if row[0] is not None else True
+ return _to_bool_with_default(row[0], default=True)
except Exception:
return True
diff --git a/db/feedbacks.py b/db/feedbacks.py
index 828f96a..2245ff7 100644
--- a/db/feedbacks.py
+++ b/db/feedbacks.py
@@ -2,32 +2,73 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
-from datetime import datetime
-
-import pytz
-
import db_pool
-from db.utils import escape_html
+from db.utils import escape_html, get_cst_now_str
+
+
+def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int:
+ try:
+ parsed = int(value)
+ except Exception:
+ parsed = default
+ parsed = max(minimum, parsed)
+ parsed = min(maximum, parsed)
+ return parsed
+
+
+def _normalize_offset(value, default: int = 0) -> int:
+ try:
+ parsed = int(value)
+ except Exception:
+ parsed = default
+ return max(0, parsed)
+
+
+def _safe_text(value) -> str:
+ if value is None:
+ return ""
+ text = str(value)
+ return escape_html(text) if text else ""
+
+
+def _build_feedback_filter_sql(status_filter=None) -> tuple[str, list]:
+ where_clauses = ["1=1"]
+ params = []
+
+ if status_filter:
+ where_clauses.append("status = ?")
+ params.append(status_filter)
+
+ return " AND ".join(where_clauses), params
+
+
+def _normalize_feedback_stats_row(row) -> dict:
+ row_dict = dict(row) if row else {}
+ return {
+ "total": int(row_dict.get("total") or 0),
+ "pending": int(row_dict.get("pending") or 0),
+ "replied": int(row_dict.get("replied") or 0),
+ "closed": int(row_dict.get("closed") or 0),
+ }
def create_bug_feedback(user_id, username, title, description, contact=""):
"""创建Bug反馈(带XSS防护)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cst_tz = pytz.timezone("Asia/Shanghai")
- cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
-
- safe_title = escape_html(title) if title else ""
- safe_description = escape_html(description) if description else ""
- safe_contact = escape_html(contact) if contact else ""
- safe_username = escape_html(username) if username else ""
-
cursor.execute(
"""
INSERT INTO bug_feedbacks (user_id, username, title, description, contact, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
- (user_id, safe_username, safe_title, safe_description, safe_contact, cst_time),
+ (
+ user_id,
+ _safe_text(username),
+ _safe_text(title),
+ _safe_text(description),
+ _safe_text(contact),
+ get_cst_now_str(),
+ ),
)
conn.commit()
@@ -36,25 +77,25 @@ def create_bug_feedback(user_id, username, title, description, contact=""):
def get_bug_feedbacks(limit=100, offset=0, status_filter=None):
"""获取Bug反馈列表(管理员用)"""
+ safe_limit = _normalize_limit(limit, 100, minimum=1, maximum=1000)
+ safe_offset = _normalize_offset(offset, 0)
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
-
- sql = "SELECT * FROM bug_feedbacks WHERE 1=1"
- params = []
-
- if status_filter:
- sql += " AND status = ?"
- params.append(status_filter)
-
- sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
- params.extend([limit, offset])
-
- cursor.execute(sql, params)
+ where_sql, params = _build_feedback_filter_sql(status_filter=status_filter)
+ sql = f"""
+ SELECT * FROM bug_feedbacks
+ WHERE {where_sql}
+ ORDER BY created_at DESC
+ LIMIT ? OFFSET ?
+ """
+ cursor.execute(sql, params + [safe_limit, safe_offset])
return [dict(row) for row in cursor.fetchall()]
def get_user_feedbacks(user_id, limit=50):
"""获取用户自己的反馈列表"""
+ safe_limit = _normalize_limit(limit, 50, minimum=1, maximum=1000)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -64,7 +105,7 @@ def get_user_feedbacks(user_id, limit=50):
ORDER BY created_at DESC
LIMIT ?
""",
- (user_id, limit),
+ (user_id, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -82,18 +123,13 @@ def reply_feedback(feedback_id, admin_reply):
"""管理员回复反馈(带XSS防护)"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cst_tz = pytz.timezone("Asia/Shanghai")
- cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
-
- safe_reply = escape_html(admin_reply) if admin_reply else ""
-
cursor.execute(
"""
UPDATE bug_feedbacks
SET admin_reply = ?, status = 'replied', replied_at = ?
WHERE id = ?
""",
- (safe_reply, cst_time, feedback_id),
+ (_safe_text(admin_reply), get_cst_now_str(), feedback_id),
)
conn.commit()
@@ -139,6 +175,4 @@ def get_feedback_stats():
FROM bug_feedbacks
"""
)
- row = cursor.fetchone()
- return dict(row) if row else {"total": 0, "pending": 0, "replied": 0, "closed": 0}
-
+ return _normalize_feedback_stats_row(cursor.fetchone())
diff --git a/db/migrations.py b/db/migrations.py
index 7d9347a..c1e88db 100644
--- a/db/migrations.py
+++ b/db/migrations.py
@@ -28,105 +28,136 @@ def set_current_version(conn, version: int) -> None:
conn.commit()
+def _table_exists(cursor, table_name: str) -> bool:
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (str(table_name),))
+ return cursor.fetchone() is not None
+
+
+def _get_table_columns(cursor, table_name: str) -> set[str]:
+ cursor.execute(f"PRAGMA table_info({table_name})")
+ return {col[1] for col in cursor.fetchall()}
+
+
+def _add_column_if_missing(cursor, table_name: str, columns: set[str], column_name: str, column_ddl: str, *, ok_message: str) -> bool:
+ if column_name in columns:
+ return False
+ cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_ddl}")
+ columns.add(column_name)
+ print(ok_message)
+ return True
+
+
+def _read_row_value(row, key: str, index: int):
+ if isinstance(row, sqlite3.Row):
+ return row[key]
+ return row[index]
+
+
+def _get_migration_steps():
+ return [
+ (1, _migrate_to_v1),
+ (2, _migrate_to_v2),
+ (3, _migrate_to_v3),
+ (4, _migrate_to_v4),
+ (5, _migrate_to_v5),
+ (6, _migrate_to_v6),
+ (7, _migrate_to_v7),
+ (8, _migrate_to_v8),
+ (9, _migrate_to_v9),
+ (10, _migrate_to_v10),
+ (11, _migrate_to_v11),
+ (12, _migrate_to_v12),
+ (13, _migrate_to_v13),
+ (14, _migrate_to_v14),
+ (15, _migrate_to_v15),
+ (16, _migrate_to_v16),
+ (17, _migrate_to_v17),
+ (18, _migrate_to_v18),
+ ]
+
+
def migrate_database(conn, target_version: int) -> None:
"""数据库迁移:按版本增量升级(向前兼容)。"""
cursor = conn.cursor()
cursor.execute("INSERT OR IGNORE INTO db_version (id, version, updated_at) VALUES (1, 0, ?)", (get_cst_now_str(),))
conn.commit()
+ target_version = int(target_version)
current_version = get_current_version(conn)
- if current_version < 1:
- _migrate_to_v1(conn)
- current_version = 1
- if current_version < 2:
- _migrate_to_v2(conn)
- current_version = 2
- if current_version < 3:
- _migrate_to_v3(conn)
- current_version = 3
- if current_version < 4:
- _migrate_to_v4(conn)
- current_version = 4
- if current_version < 5:
- _migrate_to_v5(conn)
- current_version = 5
- if current_version < 6:
- _migrate_to_v6(conn)
- current_version = 6
- if current_version < 7:
- _migrate_to_v7(conn)
- current_version = 7
- if current_version < 8:
- _migrate_to_v8(conn)
- current_version = 8
- if current_version < 9:
- _migrate_to_v9(conn)
- current_version = 9
- if current_version < 10:
- _migrate_to_v10(conn)
- current_version = 10
- if current_version < 11:
- _migrate_to_v11(conn)
- current_version = 11
- if current_version < 12:
- _migrate_to_v12(conn)
- current_version = 12
- if current_version < 13:
- _migrate_to_v13(conn)
- current_version = 13
- if current_version < 14:
- _migrate_to_v14(conn)
- current_version = 14
- if current_version < 15:
- _migrate_to_v15(conn)
- current_version = 15
- if current_version < 16:
- _migrate_to_v16(conn)
- current_version = 16
- if current_version < 17:
- _migrate_to_v17(conn)
- current_version = 17
- if current_version < 18:
- _migrate_to_v18(conn)
- current_version = 18
+ for version, migrate_fn in _get_migration_steps():
+ if version > target_version or current_version >= version:
+ continue
+ migrate_fn(conn)
+ current_version = version
- if current_version != int(target_version):
- set_current_version(conn, int(target_version))
+ if current_version != target_version:
+ set_current_version(conn, target_version)
def _migrate_to_v1(conn):
"""迁移到版本1 - 添加缺失字段"""
cursor = conn.cursor()
- cursor.execute("PRAGMA table_info(system_config)")
- columns = [col[1] for col in cursor.fetchall()]
+ system_columns = _get_table_columns(cursor, "system_config")
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ system_columns,
+ "schedule_weekdays",
+ 'TEXT DEFAULT "1,2,3,4,5,6,7"',
+ ok_message=" [OK] 添加 schedule_weekdays 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ system_columns,
+ "max_screenshot_concurrent",
+ "INTEGER DEFAULT 3",
+ ok_message=" [OK] 添加 max_screenshot_concurrent 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ system_columns,
+ "max_concurrent_per_account",
+ "INTEGER DEFAULT 1",
+ ok_message=" [OK] 添加 max_concurrent_per_account 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ system_columns,
+ "auto_approve_enabled",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 auto_approve_enabled 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ system_columns,
+ "auto_approve_hourly_limit",
+ "INTEGER DEFAULT 10",
+ ok_message=" [OK] 添加 auto_approve_hourly_limit 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ system_columns,
+ "auto_approve_vip_days",
+ "INTEGER DEFAULT 7",
+ ok_message=" [OK] 添加 auto_approve_vip_days 字段",
+ )
- if "schedule_weekdays" not in columns:
- cursor.execute('ALTER TABLE system_config ADD COLUMN schedule_weekdays TEXT DEFAULT "1,2,3,4,5,6,7"')
- print(" [OK] 添加 schedule_weekdays 字段")
-
- if "max_screenshot_concurrent" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN max_screenshot_concurrent INTEGER DEFAULT 3")
- print(" [OK] 添加 max_screenshot_concurrent 字段")
- if "max_concurrent_per_account" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN max_concurrent_per_account INTEGER DEFAULT 1")
- print(" [OK] 添加 max_concurrent_per_account 字段")
- if "auto_approve_enabled" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_enabled INTEGER DEFAULT 0")
- print(" [OK] 添加 auto_approve_enabled 字段")
- if "auto_approve_hourly_limit" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_hourly_limit INTEGER DEFAULT 10")
- print(" [OK] 添加 auto_approve_hourly_limit 字段")
- if "auto_approve_vip_days" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_vip_days INTEGER DEFAULT 7")
- print(" [OK] 添加 auto_approve_vip_days 字段")
-
- cursor.execute("PRAGMA table_info(task_logs)")
- columns = [col[1] for col in cursor.fetchall()]
- if "duration" not in columns:
- cursor.execute("ALTER TABLE task_logs ADD COLUMN duration INTEGER")
- print(" [OK] 添加 duration 字段到 task_logs")
+ task_log_columns = _get_table_columns(cursor, "task_logs")
+ _add_column_if_missing(
+ cursor,
+ "task_logs",
+ task_log_columns,
+ "duration",
+ "INTEGER",
+ ok_message=" [OK] 添加 duration 字段到 task_logs",
+ )
conn.commit()
@@ -135,24 +166,39 @@ def _migrate_to_v2(conn):
"""迁移到版本2 - 添加代理配置字段"""
cursor = conn.cursor()
- cursor.execute("PRAGMA table_info(system_config)")
- columns = [col[1] for col in cursor.fetchall()]
-
- if "proxy_enabled" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_enabled INTEGER DEFAULT 0")
- print(" [OK] 添加 proxy_enabled 字段")
-
- if "proxy_api_url" not in columns:
- cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_api_url TEXT DEFAULT ""')
- print(" [OK] 添加 proxy_api_url 字段")
-
- if "proxy_expire_minutes" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_expire_minutes INTEGER DEFAULT 3")
- print(" [OK] 添加 proxy_expire_minutes 字段")
-
- if "enable_screenshot" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN enable_screenshot INTEGER DEFAULT 1")
- print(" [OK] 添加 enable_screenshot 字段")
+ columns = _get_table_columns(cursor, "system_config")
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ columns,
+ "proxy_enabled",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 proxy_enabled 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ columns,
+ "proxy_api_url",
+ 'TEXT DEFAULT ""',
+ ok_message=" [OK] 添加 proxy_api_url 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ columns,
+ "proxy_expire_minutes",
+ "INTEGER DEFAULT 3",
+ ok_message=" [OK] 添加 proxy_expire_minutes 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ columns,
+ "enable_screenshot",
+ "INTEGER DEFAULT 1",
+ ok_message=" [OK] 添加 enable_screenshot 字段",
+ )
conn.commit()
@@ -161,20 +207,31 @@ def _migrate_to_v3(conn):
"""迁移到版本3 - 添加账号状态和登录失败计数字段"""
cursor = conn.cursor()
- cursor.execute("PRAGMA table_info(accounts)")
- columns = [col[1] for col in cursor.fetchall()]
-
- if "status" not in columns:
- cursor.execute('ALTER TABLE accounts ADD COLUMN status TEXT DEFAULT "active"')
- print(" [OK] 添加 accounts.status 字段 (账号状态)")
-
- if "login_fail_count" not in columns:
- cursor.execute("ALTER TABLE accounts ADD COLUMN login_fail_count INTEGER DEFAULT 0")
- print(" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)")
-
- if "last_login_error" not in columns:
- cursor.execute("ALTER TABLE accounts ADD COLUMN last_login_error TEXT")
- print(" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)")
+ columns = _get_table_columns(cursor, "accounts")
+ _add_column_if_missing(
+ cursor,
+ "accounts",
+ columns,
+ "status",
+ 'TEXT DEFAULT "active"',
+ ok_message=" [OK] 添加 accounts.status 字段 (账号状态)",
+ )
+ _add_column_if_missing(
+ cursor,
+ "accounts",
+ columns,
+ "login_fail_count",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)",
+ )
+ _add_column_if_missing(
+ cursor,
+ "accounts",
+ columns,
+ "last_login_error",
+ "TEXT",
+ ok_message=" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)",
+ )
conn.commit()
@@ -183,12 +240,15 @@ def _migrate_to_v4(conn):
"""迁移到版本4 - 添加任务来源字段"""
cursor = conn.cursor()
- cursor.execute("PRAGMA table_info(task_logs)")
- columns = [col[1] for col in cursor.fetchall()]
-
- if "source" not in columns:
- cursor.execute('ALTER TABLE task_logs ADD COLUMN source TEXT DEFAULT "manual"')
- print(" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)")
+ columns = _get_table_columns(cursor, "task_logs")
+ _add_column_if_missing(
+ cursor,
+ "task_logs",
+ columns,
+ "source",
+ 'TEXT DEFAULT "manual"',
+ ok_message=" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)",
+ )
conn.commit()
@@ -300,20 +360,17 @@ def _migrate_to_v6(conn):
def _migrate_to_v7(conn):
"""迁移到版本7 - 统一存储北京时间(将历史UTC时间字段整体+8小时)"""
cursor = conn.cursor()
-
- def table_exists(table_name: str) -> bool:
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
- return cursor.fetchone() is not None
-
- def column_exists(table_name: str, column_name: str) -> bool:
- cursor.execute(f"PRAGMA table_info({table_name})")
- return any(row[1] == column_name for row in cursor.fetchall())
+ columns_cache: dict[str, set[str]] = {}
def shift_utc_to_cst(table_name: str, column_name: str) -> None:
- if not table_exists(table_name):
+ if not _table_exists(cursor, table_name):
return
- if not column_exists(table_name, column_name):
+
+ if table_name not in columns_cache:
+ columns_cache[table_name] = _get_table_columns(cursor, table_name)
+ if column_name not in columns_cache[table_name]:
return
+
cursor.execute(
f"""
UPDATE {table_name}
@@ -329,10 +386,6 @@ def _migrate_to_v7(conn):
("accounts", "created_at"),
("password_reset_requests", "created_at"),
("password_reset_requests", "processed_at"),
- ]:
- shift_utc_to_cst(table, col)
-
- for table, col in [
("smtp_configs", "created_at"),
("smtp_configs", "updated_at"),
("smtp_configs", "last_success_at"),
@@ -340,10 +393,6 @@ def _migrate_to_v7(conn):
("email_tokens", "created_at"),
("email_logs", "created_at"),
("email_stats", "last_updated"),
- ]:
- shift_utc_to_cst(table, col)
-
- for table, col in [
("task_checkpoints", "created_at"),
("task_checkpoints", "updated_at"),
("task_checkpoints", "completed_at"),
@@ -359,15 +408,23 @@ def _migrate_to_v8(conn):
cursor = conn.cursor()
# 1) 增量字段:random_delay(旧库可能不存在)
- cursor.execute("PRAGMA table_info(user_schedules)")
- columns = [col[1] for col in cursor.fetchall()]
- if "random_delay" not in columns:
- cursor.execute("ALTER TABLE user_schedules ADD COLUMN random_delay INTEGER DEFAULT 0")
- print(" [OK] 添加 user_schedules.random_delay 字段")
-
- if "next_run_at" not in columns:
- cursor.execute("ALTER TABLE user_schedules ADD COLUMN next_run_at TIMESTAMP")
- print(" [OK] 添加 user_schedules.next_run_at 字段")
+ columns = _get_table_columns(cursor, "user_schedules")
+ _add_column_if_missing(
+ cursor,
+ "user_schedules",
+ columns,
+ "random_delay",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 user_schedules.random_delay 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "user_schedules",
+ columns,
+ "next_run_at",
+ "TIMESTAMP",
+ ok_message=" [OK] 添加 user_schedules.next_run_at 字段",
+ )
cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_schedules_next_run ON user_schedules(next_run_at)")
conn.commit()
@@ -392,12 +449,12 @@ def _migrate_to_v8(conn):
fixed = 0
for row in rows:
try:
- schedule_id = row["id"] if isinstance(row, sqlite3.Row) else row[0]
- schedule_time = row["schedule_time"] if isinstance(row, sqlite3.Row) else row[1]
- weekdays = row["weekdays"] if isinstance(row, sqlite3.Row) else row[2]
- random_delay = row["random_delay"] if isinstance(row, sqlite3.Row) else row[3]
- last_run_at = row["last_run_at"] if isinstance(row, sqlite3.Row) else row[4]
- next_run_at = row["next_run_at"] if isinstance(row, sqlite3.Row) else row[5]
+ schedule_id = _read_row_value(row, "id", 0)
+ schedule_time = _read_row_value(row, "schedule_time", 1)
+ weekdays = _read_row_value(row, "weekdays", 2)
+ random_delay = _read_row_value(row, "random_delay", 3)
+ last_run_at = _read_row_value(row, "last_run_at", 4)
+ next_run_at = _read_row_value(row, "next_run_at", 5)
except Exception:
continue
@@ -430,27 +487,46 @@ def _migrate_to_v9(conn):
"""迁移到版本9 - 邮件设置字段迁移(清理 email_service scattered ALTER TABLE)"""
cursor = conn.cursor()
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
- if not cursor.fetchone():
+ if not _table_exists(cursor, "email_settings"):
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return
- cursor.execute("PRAGMA table_info(email_settings)")
- columns = [col[1] for col in cursor.fetchall()]
+ columns = _get_table_columns(cursor, "email_settings")
changed = False
- if "register_verify_enabled" not in columns:
- cursor.execute("ALTER TABLE email_settings ADD COLUMN register_verify_enabled INTEGER DEFAULT 0")
- print(" [OK] 添加 email_settings.register_verify_enabled 字段")
- changed = True
- if "base_url" not in columns:
- cursor.execute("ALTER TABLE email_settings ADD COLUMN base_url TEXT DEFAULT ''")
- print(" [OK] 添加 email_settings.base_url 字段")
- changed = True
- if "task_notify_enabled" not in columns:
- cursor.execute("ALTER TABLE email_settings ADD COLUMN task_notify_enabled INTEGER DEFAULT 0")
- print(" [OK] 添加 email_settings.task_notify_enabled 字段")
- changed = True
+ changed = (
+ _add_column_if_missing(
+ cursor,
+ "email_settings",
+ columns,
+ "register_verify_enabled",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 email_settings.register_verify_enabled 字段",
+ )
+ or changed
+ )
+ changed = (
+ _add_column_if_missing(
+ cursor,
+ "email_settings",
+ columns,
+ "base_url",
+ "TEXT DEFAULT ''",
+ ok_message=" [OK] 添加 email_settings.base_url 字段",
+ )
+ or changed
+ )
+ changed = (
+ _add_column_if_missing(
+ cursor,
+ "email_settings",
+ columns,
+ "task_notify_enabled",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 email_settings.task_notify_enabled 字段",
+ )
+ or changed
+ )
if changed:
conn.commit()
@@ -459,18 +535,31 @@ def _migrate_to_v9(conn):
def _migrate_to_v10(conn):
"""迁移到版本10 - users 邮箱字段迁移(避免运行时 ALTER TABLE)"""
cursor = conn.cursor()
- cursor.execute("PRAGMA table_info(users)")
- columns = [col[1] for col in cursor.fetchall()]
+ columns = _get_table_columns(cursor, "users")
changed = False
- if "email_verified" not in columns:
- cursor.execute("ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0")
- print(" [OK] 添加 users.email_verified 字段")
- changed = True
- if "email_notify_enabled" not in columns:
- cursor.execute("ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1")
- print(" [OK] 添加 users.email_notify_enabled 字段")
- changed = True
+ changed = (
+ _add_column_if_missing(
+ cursor,
+ "users",
+ columns,
+ "email_verified",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 users.email_verified 字段",
+ )
+ or changed
+ )
+ changed = (
+ _add_column_if_missing(
+ cursor,
+ "users",
+ columns,
+ "email_notify_enabled",
+ "INTEGER DEFAULT 1",
+ ok_message=" [OK] 添加 users.email_notify_enabled 字段",
+ )
+ or changed
+ )
if changed:
conn.commit()
@@ -657,19 +746,24 @@ def _migrate_to_v15(conn):
"""迁移到版本15 - 邮件设置:新设备登录提醒全局开关"""
cursor = conn.cursor()
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'")
- if not cursor.fetchone():
+ if not _table_exists(cursor, "email_settings"):
# 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移
return
- cursor.execute("PRAGMA table_info(email_settings)")
- columns = [col[1] for col in cursor.fetchall()]
+ columns = _get_table_columns(cursor, "email_settings")
changed = False
- if "login_alert_enabled" not in columns:
- cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1")
- print(" [OK] 添加 email_settings.login_alert_enabled 字段")
- changed = True
+ changed = (
+ _add_column_if_missing(
+ cursor,
+ "email_settings",
+ columns,
+ "login_alert_enabled",
+ "INTEGER DEFAULT 1",
+ ok_message=" [OK] 添加 email_settings.login_alert_enabled 字段",
+ )
+ or changed
+ )
try:
cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL")
@@ -686,22 +780,24 @@ def _migrate_to_v15(conn):
def _migrate_to_v16(conn):
"""迁移到版本16 - 公告支持图片字段"""
cursor = conn.cursor()
- cursor.execute("PRAGMA table_info(announcements)")
- columns = [col[1] for col in cursor.fetchall()]
+ columns = _get_table_columns(cursor, "announcements")
- if "image_url" not in columns:
- cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT")
+ if _add_column_if_missing(
+ cursor,
+ "announcements",
+ columns,
+ "image_url",
+ "TEXT",
+ ok_message=" [OK] 添加 announcements.image_url 字段",
+ ):
conn.commit()
- print(" [OK] 添加 announcements.image_url 字段")
def _migrate_to_v17(conn):
"""迁移到版本17 - 金山文档上传配置与用户开关"""
cursor = conn.cursor()
- cursor.execute("PRAGMA table_info(system_config)")
- columns = [col[1] for col in cursor.fetchall()]
-
+ system_columns = _get_table_columns(cursor, "system_config")
system_fields = [
("kdocs_enabled", "INTEGER DEFAULT 0"),
("kdocs_doc_url", "TEXT DEFAULT ''"),
@@ -714,21 +810,29 @@ def _migrate_to_v17(conn):
("kdocs_admin_notify_email", "TEXT DEFAULT ''"),
]
for field, ddl in system_fields:
- if field not in columns:
- cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}")
- print(f" [OK] 添加 system_config.{field} 字段")
-
- cursor.execute("PRAGMA table_info(users)")
- columns = [col[1] for col in cursor.fetchall()]
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ system_columns,
+ field,
+ ddl,
+ ok_message=f" [OK] 添加 system_config.{field} 字段",
+ )
+ user_columns = _get_table_columns(cursor, "users")
user_fields = [
("kdocs_unit", "TEXT DEFAULT ''"),
("kdocs_auto_upload", "INTEGER DEFAULT 0"),
]
for field, ddl in user_fields:
- if field not in columns:
- cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}")
- print(f" [OK] 添加 users.{field} 字段")
+ _add_column_if_missing(
+ cursor,
+ "users",
+ user_columns,
+ field,
+ ddl,
+ ok_message=f" [OK] 添加 users.{field} 字段",
+ )
conn.commit()
@@ -737,15 +841,22 @@ def _migrate_to_v18(conn):
"""迁移到版本18 - 金山文档上传:有效行范围配置"""
cursor = conn.cursor()
- cursor.execute("PRAGMA table_info(system_config)")
- columns = [col[1] for col in cursor.fetchall()]
-
- if "kdocs_row_start" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0")
- print(" [OK] 添加 system_config.kdocs_row_start 字段")
-
- if "kdocs_row_end" not in columns:
- cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0")
- print(" [OK] 添加 system_config.kdocs_row_end 字段")
+ columns = _get_table_columns(cursor, "system_config")
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ columns,
+ "kdocs_row_start",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 system_config.kdocs_row_start 字段",
+ )
+ _add_column_if_missing(
+ cursor,
+ "system_config",
+ columns,
+ "kdocs_row_end",
+ "INTEGER DEFAULT 0",
+ ok_message=" [OK] 添加 system_config.kdocs_row_end 字段",
+ )
conn.commit()
diff --git a/db/schedules.py b/db/schedules.py
index 4294169..d769804 100644
--- a/db/schedules.py
+++ b/db/schedules.py
@@ -2,12 +2,93 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
-from datetime import datetime
+import json
+from datetime import datetime, timedelta
import db_pool
from services.schedule_utils import compute_next_run_at, format_cst
from services.time_utils import get_beijing_now
+_SCHEDULE_DEFAULT_TIME = "08:00"
+_SCHEDULE_DEFAULT_WEEKDAYS = "1,2,3,4,5"
+
+_ALLOWED_SCHEDULE_UPDATE_FIELDS = (
+ "name",
+ "enabled",
+ "schedule_time",
+ "weekdays",
+ "browse_type",
+ "enable_screenshot",
+ "random_delay",
+ "account_ids",
+)
+
+_ALLOWED_EXEC_LOG_UPDATE_FIELDS = (
+ "total_accounts",
+ "success_accounts",
+ "failed_accounts",
+ "total_items",
+ "total_attachments",
+ "total_screenshots",
+ "duration_seconds",
+ "status",
+ "error_message",
+)
+
+
+def _normalize_limit(limit, default: int, *, minimum: int = 1) -> int:
+ try:
+ parsed = int(limit)
+ except Exception:
+ parsed = default
+ if parsed < minimum:
+ return minimum
+ return parsed
+
+
+def _to_int(value, default: int = 0) -> int:
+ try:
+ return int(value)
+ except Exception:
+ return default
+
+
+def _format_optional_datetime(dt: datetime | None) -> str | None:
+ if dt is None:
+ return None
+ return format_cst(dt)
+
+
+def _serialize_account_ids(account_ids) -> str:
+ return json.dumps(account_ids) if account_ids else "[]"
+
+
+def _compute_schedule_next_run_str(
+ *,
+ now_dt,
+ schedule_time,
+ weekdays,
+ random_delay,
+ last_run_at,
+) -> str:
+ next_dt = compute_next_run_at(
+ now=now_dt,
+ schedule_time=str(schedule_time or _SCHEDULE_DEFAULT_TIME),
+ weekdays=str(weekdays or _SCHEDULE_DEFAULT_WEEKDAYS),
+ random_delay=_to_int(random_delay, 0),
+ last_run_at=str(last_run_at or "") if last_run_at else None,
+ )
+ return format_cst(next_dt)
+
+
+def _map_schedule_log_row(row) -> dict:
+ log = dict(row)
+ log["created_at"] = log.get("execute_time")
+ log["success_count"] = log.get("success_accounts", 0)
+ log["failed_count"] = log.get("failed_accounts", 0)
+ log["duration"] = log.get("duration_seconds", 0)
+ return log
+
def get_user_schedules(user_id):
"""获取用户的所有定时任务"""
@@ -44,14 +125,10 @@ def create_user_schedule(
account_ids=None,
):
"""创建用户定时任务"""
- import json
-
with db_pool.get_db() as conn:
cursor = conn.cursor()
cst_time = format_cst(get_beijing_now())
- account_ids_str = json.dumps(account_ids) if account_ids else "[]"
-
cursor.execute(
"""
INSERT INTO user_schedules (
@@ -66,8 +143,8 @@ def create_user_schedule(
weekdays,
browse_type,
enable_screenshot,
- int(random_delay or 0),
- account_ids_str,
+ _to_int(random_delay, 0),
+ _serialize_account_ids(account_ids),
cst_time,
cst_time,
),
@@ -79,28 +156,11 @@ def create_user_schedule(
def update_user_schedule(schedule_id, **kwargs):
"""更新用户定时任务"""
- import json
-
with db_pool.get_db() as conn:
cursor = conn.cursor()
now_dt = get_beijing_now()
now_str = format_cst(now_dt)
- updates = []
- params = []
-
- allowed_fields = [
- "name",
- "enabled",
- "schedule_time",
- "weekdays",
- "browse_type",
- "enable_screenshot",
- "random_delay",
- "account_ids",
- ]
-
- # 读取旧值,用于决定是否需要重算 next_run_at
cursor.execute(
"""
SELECT enabled, schedule_time, weekdays, random_delay, last_run_at
@@ -112,10 +172,11 @@ def update_user_schedule(schedule_id, **kwargs):
current = cursor.fetchone()
if not current:
return False
- current_enabled = int(current[0] or 0)
+
+ current_enabled = _to_int(current[0], 0)
current_time = current[1]
current_weekdays = current[2]
- current_random_delay = int(current[3] or 0)
+ current_random_delay = _to_int(current[3], 0)
current_last_run_at = current[4]
will_enabled = current_enabled
@@ -123,21 +184,28 @@ def update_user_schedule(schedule_id, **kwargs):
next_weekdays = current_weekdays
next_random_delay = current_random_delay
- for field in allowed_fields:
- if field in kwargs:
- value = kwargs[field]
- if field == "account_ids" and isinstance(value, list):
- value = json.dumps(value)
- if field == "enabled":
- will_enabled = 1 if value else 0
- if field == "schedule_time":
- next_time = value
- if field == "weekdays":
- next_weekdays = value
- if field == "random_delay":
- next_random_delay = int(value or 0)
- updates.append(f"{field} = ?")
- params.append(value)
+ updates = []
+ params = []
+
+ for field in _ALLOWED_SCHEDULE_UPDATE_FIELDS:
+ if field not in kwargs:
+ continue
+
+ value = kwargs[field]
+ if field == "account_ids" and isinstance(value, list):
+ value = json.dumps(value)
+
+ if field == "enabled":
+ will_enabled = 1 if value else 0
+ if field == "schedule_time":
+ next_time = value
+ if field == "weekdays":
+ next_weekdays = value
+ if field == "random_delay":
+ next_random_delay = int(value or 0)
+
+ updates.append(f"{field} = ?")
+ params.append(value)
if not updates:
return False
@@ -145,30 +213,26 @@ def update_user_schedule(schedule_id, **kwargs):
updates.append("updated_at = ?")
params.append(now_str)
- # 关键字段变更后重算 next_run_at,确保索引驱动不会跑偏
- #
- # 需求:当用户修改“执行时间/执行日期/随机±15分钟”后,即使今天已经执行过,也允许按新配置在今天再次触发。
- # 做法:这些关键字段发生变更时,重算 next_run_at 时忽略 last_run_at 的“同日仅一次”限制。
- config_changed = any(key in kwargs for key in ["schedule_time", "weekdays", "random_delay"])
+ config_changed = any(key in kwargs for key in ("schedule_time", "weekdays", "random_delay"))
enabled_toggled = "enabled" in kwargs
should_recompute_next = config_changed or (enabled_toggled and will_enabled == 1)
+
if should_recompute_next:
- next_dt = compute_next_run_at(
- now=now_dt,
- schedule_time=str(next_time or "08:00"),
- weekdays=str(next_weekdays or "1,2,3,4,5"),
- random_delay=int(next_random_delay or 0),
- last_run_at=None if config_changed else (str(current_last_run_at or "") if current_last_run_at else None),
+ next_run_at = _compute_schedule_next_run_str(
+ now_dt=now_dt,
+ schedule_time=next_time,
+ weekdays=next_weekdays,
+ random_delay=next_random_delay,
+ last_run_at=None if config_changed else current_last_run_at,
)
updates.append("next_run_at = ?")
- params.append(format_cst(next_dt))
+ params.append(next_run_at)
- # 若本次显式禁用任务,则 next_run_at 清空(与 toggle 行为保持一致)
if enabled_toggled and will_enabled == 0:
updates.append("next_run_at = ?")
params.append(None)
- params.append(schedule_id)
+ params.append(schedule_id)
sql = f"UPDATE user_schedules SET {', '.join(updates)} WHERE id = ?"
cursor.execute(sql, params)
conn.commit()
@@ -203,28 +267,19 @@ def toggle_user_schedule(schedule_id, enabled):
)
row = cursor.fetchone()
if row:
- schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = (
- row[0],
- row[1],
- row[2],
- row[3],
- row[4],
- )
-
+ schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = row
existing_next_run_at = str(existing_next_run_at or "").strip() or None
- # 若 next_run_at 已经被“修改配置”逻辑预先计算好且仍在未来,则优先沿用,
- # 避免 last_run_at 的“同日仅一次”限制阻塞用户把任务调整到今天再次触发。
+
if existing_next_run_at and existing_next_run_at > now_str:
next_run_at = existing_next_run_at
else:
- next_dt = compute_next_run_at(
- now=now_dt,
- schedule_time=str(schedule_time or "08:00"),
- weekdays=str(weekdays or "1,2,3,4,5"),
- random_delay=int(random_delay or 0),
- last_run_at=str(last_run_at or "") if last_run_at else None,
+ next_run_at = _compute_schedule_next_run_str(
+ now_dt=now_dt,
+ schedule_time=schedule_time,
+ weekdays=weekdays,
+ random_delay=random_delay,
+ last_run_at=last_run_at,
)
- next_run_at = format_cst(next_dt)
cursor.execute(
"""
@@ -272,16 +327,15 @@ def update_schedule_last_run(schedule_id):
row = cursor.fetchone()
if not row:
return False
- schedule_time, weekdays, random_delay = row[0], row[1], row[2]
- next_dt = compute_next_run_at(
- now=now_dt,
- schedule_time=str(schedule_time or "08:00"),
- weekdays=str(weekdays or "1,2,3,4,5"),
- random_delay=int(random_delay or 0),
+ schedule_time, weekdays, random_delay = row
+ next_run_at = _compute_schedule_next_run_str(
+ now_dt=now_dt,
+ schedule_time=schedule_time,
+ weekdays=weekdays,
+ random_delay=random_delay,
last_run_at=now_str,
)
- next_run_at = format_cst(next_dt)
cursor.execute(
"""
@@ -305,7 +359,11 @@ def update_schedule_next_run(schedule_id: int, next_run_at: str) -> bool:
SET next_run_at = ?, updated_at = ?
WHERE id = ?
""",
- (str(next_run_at or "").strip() or None, format_cst(get_beijing_now()), int(schedule_id)),
+ (
+ str(next_run_at or "").strip() or None,
+ format_cst(get_beijing_now()),
+ int(schedule_id),
+ ),
)
conn.commit()
return cursor.rowcount > 0
@@ -328,15 +386,15 @@ def recompute_schedule_next_run(schedule_id: int, *, now_dt=None) -> bool:
if not row:
return False
- schedule_time, weekdays, random_delay, last_run_at = row[0], row[1], row[2], row[3]
- next_dt = compute_next_run_at(
- now=now_dt,
- schedule_time=str(schedule_time or "08:00"),
- weekdays=str(weekdays or "1,2,3,4,5"),
- random_delay=int(random_delay or 0),
- last_run_at=str(last_run_at or "") if last_run_at else None,
+ schedule_time, weekdays, random_delay, last_run_at = row
+ next_run_at = _compute_schedule_next_run_str(
+ now_dt=now_dt,
+ schedule_time=schedule_time,
+ weekdays=weekdays,
+ random_delay=random_delay,
+ last_run_at=last_run_at,
)
- return update_schedule_next_run(int(schedule_id), format_cst(next_dt))
+ return update_schedule_next_run(int(schedule_id), next_run_at)
def get_due_user_schedules(now_cst: str, limit: int = 50):
@@ -345,6 +403,8 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
if not now_cst:
now_cst = format_cst(get_beijing_now())
+ safe_limit = _normalize_limit(limit, 50, minimum=1)
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -358,7 +418,7 @@ def get_due_user_schedules(now_cst: str, limit: int = 50):
ORDER BY us.next_run_at ASC
LIMIT ?
""",
- (now_cst, int(limit)),
+ (now_cst, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -370,15 +430,13 @@ def create_schedule_execution_log(schedule_id, user_id, schedule_name):
"""创建定时任务执行日志"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- execute_time = format_cst(get_beijing_now())
-
cursor.execute(
"""
INSERT INTO schedule_execution_logs (
schedule_id, user_id, schedule_name, execute_time, status
) VALUES (?, ?, ?, ?, 'running')
""",
- (schedule_id, user_id, schedule_name, execute_time),
+ (schedule_id, user_id, schedule_name, format_cst(get_beijing_now())),
)
conn.commit()
@@ -393,22 +451,11 @@ def update_schedule_execution_log(log_id, **kwargs):
updates = []
params = []
- allowed_fields = [
- "total_accounts",
- "success_accounts",
- "failed_accounts",
- "total_items",
- "total_attachments",
- "total_screenshots",
- "duration_seconds",
- "status",
- "error_message",
- ]
-
- for field in allowed_fields:
- if field in kwargs:
- updates.append(f"{field} = ?")
- params.append(kwargs[field])
+ for field in _ALLOWED_EXEC_LOG_UPDATE_FIELDS:
+ if field not in kwargs:
+ continue
+ updates.append(f"{field} = ?")
+ params.append(kwargs[field])
if not updates:
return False
@@ -424,6 +471,7 @@ def update_schedule_execution_log(log_id, **kwargs):
def get_schedule_execution_logs(schedule_id, limit=10):
"""获取定时任务执行日志"""
try:
+ safe_limit = _normalize_limit(limit, 10, minimum=1)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -433,24 +481,16 @@ def get_schedule_execution_logs(schedule_id, limit=10):
ORDER BY execute_time DESC
LIMIT ?
""",
- (schedule_id, limit),
+ (schedule_id, safe_limit),
)
logs = []
- rows = cursor.fetchall()
-
- for row in rows:
+ for row in cursor.fetchall():
try:
- log = dict(row)
- log["created_at"] = log.get("execute_time")
- log["success_count"] = log.get("success_accounts", 0)
- log["failed_count"] = log.get("failed_accounts", 0)
- log["duration"] = log.get("duration_seconds", 0)
- logs.append(log)
+ logs.append(_map_schedule_log_row(row))
except Exception as e:
print(f"[数据库] 处理日志行时出错: {e}")
continue
-
return logs
except Exception as e:
print(f"[数据库] 查询定时任务日志时出错: {e}")
@@ -462,6 +502,7 @@ def get_schedule_execution_logs(schedule_id, limit=10):
def get_user_all_schedule_logs(user_id, limit=50):
"""获取用户所有定时任务的执行日志"""
+ safe_limit = _normalize_limit(limit, 50, minimum=1)
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
@@ -471,7 +512,7 @@ def get_user_all_schedule_logs(user_id, limit=50):
ORDER BY execute_time DESC
LIMIT ?
""",
- (user_id, limit),
+ (user_id, safe_limit),
)
return [dict(row) for row in cursor.fetchall()]
@@ -493,14 +534,21 @@ def delete_schedule_logs(schedule_id, user_id):
def clean_old_schedule_logs(days=30):
"""清理指定天数前的定时任务执行日志"""
+ safe_days = _to_int(days, 30)
+ if safe_days < 0:
+ safe_days = 0
+
+ cutoff_dt = get_beijing_now() - timedelta(days=safe_days)
+ cutoff_str = format_cst(cutoff_dt)
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute(
"""
DELETE FROM schedule_execution_logs
- WHERE execute_time < datetime('now', 'localtime', '-' || ? || ' days')
+ WHERE execute_time < ?
""",
- (days,),
+ (cutoff_str,),
)
conn.commit()
return cursor.rowcount
diff --git a/db/schema.py b/db/schema.py
index 59108a7..14ec136 100644
--- a/db/schema.py
+++ b/db/schema.py
@@ -362,6 +362,8 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)")
@@ -391,6 +393,8 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source ON task_logs(source)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source_created_at ON task_logs(source, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)")
@@ -409,6 +413,9 @@ def ensure_schema(conn) -> None:
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_id ON schedule_execution_logs(schedule_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_id ON schedule_execution_logs(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_status ON schedule_execution_logs(status)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_execute_time ON schedule_execution_logs(execute_time)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_time ON schedule_execution_logs(schedule_id, execute_time)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_time ON schedule_execution_logs(user_id, execute_time)")
# 初始化VIP配置(幂等)
try:
diff --git a/db/security.py b/db/security.py
index 79ad0f3..677627f 100644
--- a/db/security.py
+++ b/db/security.py
@@ -3,13 +3,82 @@
from __future__ import annotations
from datetime import timedelta
-from typing import Any, Optional
-from typing import Dict
+from typing import Any, Dict, Optional
import db_pool
from db.utils import get_cst_now, get_cst_now_str
+_THREAT_EVENT_SELECT_COLUMNS = """
+ id,
+ threat_type,
+ score,
+ rule,
+ field_name,
+ matched,
+ value_preview,
+ ip,
+ user_id,
+ request_method,
+ request_path,
+ user_agent,
+ created_at
+"""
+
+
+def _normalize_page(page: int) -> int:
+ try:
+ page_i = int(page)
+ except Exception:
+ page_i = 1
+ return max(1, page_i)
+
+
+def _normalize_per_page(per_page: int, default: int = 20) -> int:
+ try:
+ value = int(per_page)
+ except Exception:
+ value = default
+ return max(1, min(200, value))
+
+
+def _normalize_limit(limit: int, default: int = 50) -> int:
+ try:
+ value = int(limit)
+ except Exception:
+ value = default
+ return max(1, min(200, value))
+
+
+def _row_value(row, key: str, index: int = 0, default=None):
+ if row is None:
+ return default
+ try:
+ return row[key]
+ except Exception:
+ try:
+ return row[index]
+ except Exception:
+ return default
+
+
+def _fetch_threat_events_history(where_clause: str, params: tuple[Any, ...], limit_i: int) -> list[dict]:
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ f"""
+ SELECT
+ {_THREAT_EVENT_SELECT_COLUMNS}
+ FROM threat_events
+ WHERE {where_clause}
+ ORDER BY created_at DESC, id DESC
+ LIMIT ?
+ """,
+ tuple(params) + (limit_i,),
+ )
+ return [dict(r) for r in cursor.fetchall()]
+
+
def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]:
"""记录登录环境信息,返回是否新设备/新IP。"""
user_id = int(user_id)
@@ -36,7 +105,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
SET last_seen = ?, last_ip = ?
WHERE id = ?
""",
- (now_str, ip_text, row["id"] if isinstance(row, dict) else row[0]),
+ (now_str, ip_text, _row_value(row, "id", 0)),
)
else:
cursor.execute(
@@ -61,7 +130,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict
SET last_seen = ?
WHERE id = ?
""",
- (now_str, row["id"] if isinstance(row, dict) else row[0]),
+ (now_str, _row_value(row, "id", 0)),
)
else:
cursor.execute(
@@ -166,15 +235,8 @@ def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, lis
def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict:
"""分页获取威胁事件。"""
- try:
- page_i = max(1, int(page))
- except Exception:
- page_i = 1
- try:
- per_page_i = int(per_page)
- except Exception:
- per_page_i = 20
- per_page_i = max(1, min(200, per_page_i))
+ page_i = _normalize_page(page)
+ per_page_i = _normalize_per_page(per_page, default=20)
where_sql, params = _build_threat_events_where_clause(filters)
offset = (page_i - 1) * per_page_i
@@ -188,19 +250,7 @@ def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = N
cursor.execute(
f"""
SELECT
- id,
- threat_type,
- score,
- rule,
- field_name,
- matched,
- value_preview,
- ip,
- user_id,
- request_method,
- request_path,
- user_agent,
- created_at
+ {_THREAT_EVENT_SELECT_COLUMNS}
FROM threat_events
{where_sql}
ORDER BY created_at DESC, id DESC
@@ -218,75 +268,20 @@ def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]:
ip_text = str(ip or "").strip()[:64]
if not ip_text:
return []
- try:
- limit_i = max(1, min(200, int(limit)))
- except Exception:
- limit_i = 50
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cursor.execute(
- """
- SELECT
- id,
- threat_type,
- score,
- rule,
- field_name,
- matched,
- value_preview,
- ip,
- user_id,
- request_method,
- request_path,
- user_agent,
- created_at
- FROM threat_events
- WHERE ip = ?
- ORDER BY created_at DESC, id DESC
- LIMIT ?
- """,
- (ip_text, limit_i),
- )
- return [dict(r) for r in cursor.fetchall()]
+ limit_i = _normalize_limit(limit, default=50)
+ return _fetch_threat_events_history("ip = ?", (ip_text,), limit_i)
def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]:
"""获取用户的威胁历史(最近limit条)。"""
if user_id is None:
return []
+
try:
user_id_int = int(user_id)
except Exception:
return []
- try:
- limit_i = max(1, min(200, int(limit)))
- except Exception:
- limit_i = 50
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cursor.execute(
- """
- SELECT
- id,
- threat_type,
- score,
- rule,
- field_name,
- matched,
- value_preview,
- ip,
- user_id,
- request_method,
- request_path,
- user_agent,
- created_at
- FROM threat_events
- WHERE user_id = ?
- ORDER BY created_at DESC, id DESC
- LIMIT ?
- """,
- (user_id_int, limit_i),
- )
- return [dict(r) for r in cursor.fetchall()]
+ limit_i = _normalize_limit(limit, default=50)
+ return _fetch_threat_events_history("user_id = ?", (user_id_int,), limit_i)
diff --git a/db/tasks.py b/db/tasks.py
index 8b761cf..f28e280 100644
--- a/db/tasks.py
+++ b/db/tasks.py
@@ -2,12 +2,135 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
-from datetime import datetime
-
-import pytz
+from datetime import datetime, timedelta
import db_pool
-from db.utils import sanitize_sql_like_pattern
+from db.utils import get_cst_now, get_cst_now_str, sanitize_sql_like_pattern
+
+_TASK_STATS_SELECT_SQL = """
+ SELECT
+ COUNT(*) as total_tasks,
+ SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
+ SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
+ SUM(total_items) as total_items,
+ SUM(total_attachments) as total_attachments
+ FROM task_logs
+"""
+
+_USER_RUN_STATS_SELECT_SQL = """
+ SELECT
+ SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
+ SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
+ SUM(total_items) as total_items,
+ SUM(total_attachments) as total_attachments
+ FROM task_logs
+"""
+
+
+def _build_day_bounds(date_filter: str) -> tuple[str | None, str | None]:
+ """将 YYYY-MM-DD 转换为 [day_start, day_end) 区间。"""
+ try:
+ day_start = datetime.strptime(str(date_filter), "%Y-%m-%d")
+ except Exception:
+ return None, None
+
+ day_end = day_start + timedelta(days=1)
+ return day_start.strftime("%Y-%m-%d %H:%M:%S"), day_end.strftime("%Y-%m-%d %H:%M:%S")
+
+
+def _normalize_int(value, default: int, *, minimum: int | None = None) -> int:
+ try:
+ parsed = int(value)
+ except Exception:
+ parsed = default
+ if minimum is not None and parsed < minimum:
+ return minimum
+ return parsed
+
+
+def _stat_value(row, key: str) -> int:
+ try:
+ value = row[key] if row else 0
+ except Exception:
+ value = 0
+ return int(value or 0)
+
+
+def _build_task_logs_where_sql(
+ *,
+ date_filter=None,
+ status_filter=None,
+ source_filter=None,
+ user_id_filter=None,
+ account_filter=None,
+) -> tuple[str, list]:
+ where_clauses = ["1=1"]
+ params = []
+
+ if date_filter:
+ day_start, day_end = _build_day_bounds(date_filter)
+ if day_start and day_end:
+ where_clauses.append("tl.created_at >= ? AND tl.created_at < ?")
+ params.extend([day_start, day_end])
+ else:
+ where_clauses.append("date(tl.created_at) = ?")
+ params.append(date_filter)
+
+ if status_filter:
+ where_clauses.append("tl.status = ?")
+ params.append(status_filter)
+
+ if source_filter:
+ source_filter = str(source_filter or "").strip()
+ if source_filter == "user_scheduled":
+ where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
+ params.append("user_scheduled:%")
+ elif source_filter.endswith("*"):
+ prefix = source_filter[:-1]
+ safe_prefix = sanitize_sql_like_pattern(prefix)
+ where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
+ params.append(f"{safe_prefix}%")
+ else:
+ where_clauses.append("tl.source = ?")
+ params.append(source_filter)
+
+ if user_id_filter:
+ where_clauses.append("tl.user_id = ?")
+ params.append(user_id_filter)
+
+ if account_filter:
+ safe_filter = sanitize_sql_like_pattern(account_filter)
+ where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
+ params.append(f"%{safe_filter}%")
+
+ return " AND ".join(where_clauses), params
+
+
+def _fetch_task_stats_row(cursor, *, where_clause: str = "", params: tuple | list = ()) -> dict:
+ sql = _TASK_STATS_SELECT_SQL
+ if where_clause:
+ sql = f"{sql}\nWHERE {where_clause}"
+ cursor.execute(sql, params)
+ row = cursor.fetchone()
+ return {
+ "total_tasks": _stat_value(row, "total_tasks"),
+ "success_tasks": _stat_value(row, "success_tasks"),
+ "failed_tasks": _stat_value(row, "failed_tasks"),
+ "total_items": _stat_value(row, "total_items"),
+ "total_attachments": _stat_value(row, "total_attachments"),
+ }
+
+
+def _fetch_user_run_stats_row(cursor, *, where_clause: str, params: tuple | list) -> dict:
+ sql = f"{_USER_RUN_STATS_SELECT_SQL}\nWHERE {where_clause}"
+ cursor.execute(sql, params)
+ row = cursor.fetchone()
+ return {
+ "completed": _stat_value(row, "completed"),
+ "failed": _stat_value(row, "failed"),
+ "total_items": _stat_value(row, "total_items"),
+ "total_attachments": _stat_value(row, "total_attachments"),
+ }
def create_task_log(
@@ -25,8 +148,6 @@ def create_task_log(
"""创建任务日志记录"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cst_tz = pytz.timezone("Asia/Shanghai")
- cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S")
cursor.execute(
"""
@@ -45,7 +166,7 @@ def create_task_log(
total_attachments,
error_message,
duration,
- cst_time,
+ get_cst_now_str(),
source,
),
)
@@ -64,54 +185,27 @@ def get_task_logs(
account_filter=None,
):
"""获取任务日志列表(支持分页和多种筛选)"""
+ limit = _normalize_int(limit, 100, minimum=1)
+ offset = _normalize_int(offset, 0, minimum=0)
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
- where_clauses = ["1=1"]
- params = []
-
- if date_filter:
- where_clauses.append("date(tl.created_at) = ?")
- params.append(date_filter)
-
- if status_filter:
- where_clauses.append("tl.status = ?")
- params.append(status_filter)
-
- if source_filter:
- source_filter = str(source_filter or "").strip()
- # 兼容“虚拟来源”:用于筛选 user_scheduled:batch_xxx 这类动态值
- if source_filter == "user_scheduled":
- where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
- params.append("user_scheduled:%")
- elif source_filter.endswith("*"):
- prefix = source_filter[:-1]
- safe_prefix = sanitize_sql_like_pattern(prefix)
- where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'")
- params.append(f"{safe_prefix}%")
- else:
- where_clauses.append("tl.source = ?")
- params.append(source_filter)
-
- if user_id_filter:
- where_clauses.append("tl.user_id = ?")
- params.append(user_id_filter)
-
- if account_filter:
- safe_filter = sanitize_sql_like_pattern(account_filter)
- where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'")
- params.append(f"%{safe_filter}%")
-
- where_sql = " AND ".join(where_clauses)
+ where_sql, params = _build_task_logs_where_sql(
+ date_filter=date_filter,
+ status_filter=status_filter,
+ source_filter=source_filter,
+ user_id_filter=user_id_filter,
+ account_filter=account_filter,
+ )
count_sql = f"""
SELECT COUNT(*) as total
FROM task_logs tl
- LEFT JOIN users u ON tl.user_id = u.id
WHERE {where_sql}
"""
cursor.execute(count_sql, params)
- total = cursor.fetchone()["total"]
+ total = _stat_value(cursor.fetchone(), "total")
data_sql = f"""
SELECT
@@ -123,9 +217,10 @@ def get_task_logs(
ORDER BY tl.created_at DESC
LIMIT ? OFFSET ?
"""
- params.extend([limit, offset])
+ data_params = list(params)
+ data_params.extend([limit, offset])
- cursor.execute(data_sql, params)
+ cursor.execute(data_sql, data_params)
logs = [dict(row) for row in cursor.fetchall()]
return {"logs": logs, "total": total}
@@ -133,61 +228,39 @@ def get_task_logs(
def get_task_stats(date_filter=None):
"""获取任务统计信息"""
+ if date_filter is None:
+ date_filter = get_cst_now().strftime("%Y-%m-%d")
+
+ day_start, day_end = _build_day_bounds(date_filter)
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cst_tz = pytz.timezone("Asia/Shanghai")
- if date_filter is None:
- date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
+ if day_start and day_end:
+ today_stats = _fetch_task_stats_row(
+ cursor,
+ where_clause="created_at >= ? AND created_at < ?",
+ params=(day_start, day_end),
+ )
+ else:
+ today_stats = _fetch_task_stats_row(
+ cursor,
+ where_clause="date(created_at) = ?",
+ params=(date_filter,),
+ )
- cursor.execute(
- """
- SELECT
- COUNT(*) as total_tasks,
- SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
- SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
- SUM(total_items) as total_items,
- SUM(total_attachments) as total_attachments
- FROM task_logs
- WHERE date(created_at) = ?
- """,
- (date_filter,),
- )
- today_stats = cursor.fetchone()
+ total_stats = _fetch_task_stats_row(cursor)
- cursor.execute(
- """
- SELECT
- COUNT(*) as total_tasks,
- SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks,
- SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks,
- SUM(total_items) as total_items,
- SUM(total_attachments) as total_attachments
- FROM task_logs
- """
- )
- total_stats = cursor.fetchone()
-
- return {
- "today": {
- "total_tasks": today_stats["total_tasks"] or 0,
- "success_tasks": today_stats["success_tasks"] or 0,
- "failed_tasks": today_stats["failed_tasks"] or 0,
- "total_items": today_stats["total_items"] or 0,
- "total_attachments": today_stats["total_attachments"] or 0,
- },
- "total": {
- "total_tasks": total_stats["total_tasks"] or 0,
- "success_tasks": total_stats["success_tasks"] or 0,
- "failed_tasks": total_stats["failed_tasks"] or 0,
- "total_items": total_stats["total_items"] or 0,
- "total_attachments": total_stats["total_attachments"] or 0,
- },
- }
+ return {"today": today_stats, "total": total_stats}
def delete_old_task_logs(days=30, batch_size=1000):
"""删除N天前的任务日志(分批删除,避免长时间锁表)"""
+ days = _normalize_int(days, 30, minimum=0)
+ batch_size = _normalize_int(batch_size, 1000, minimum=1)
+
+ cutoff = (get_cst_now() - timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
+
total_deleted = 0
while True:
with db_pool.get_db() as conn:
@@ -197,16 +270,16 @@ def delete_old_task_logs(days=30, batch_size=1000):
DELETE FROM task_logs
WHERE rowid IN (
SELECT rowid FROM task_logs
- WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days')
+ WHERE created_at < ?
LIMIT ?
)
""",
- (days, batch_size),
+ (cutoff, batch_size),
)
deleted = cursor.rowcount
conn.commit()
- if deleted == 0:
+ if deleted <= 0:
break
total_deleted += deleted
@@ -215,31 +288,23 @@ def delete_old_task_logs(days=30, batch_size=1000):
def get_user_run_stats(user_id, date_filter=None):
"""获取用户的运行统计信息"""
+ if date_filter is None:
+ date_filter = get_cst_now().strftime("%Y-%m-%d")
+
+ day_start, day_end = _build_day_bounds(date_filter)
+
with db_pool.get_db() as conn:
- cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor()
- if date_filter is None:
- date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d")
+ if day_start and day_end:
+ return _fetch_user_run_stats_row(
+ cursor,
+ where_clause="user_id = ? AND created_at >= ? AND created_at < ?",
+ params=(user_id, day_start, day_end),
+ )
- cursor.execute(
- """
- SELECT
- SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed,
- SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed,
- SUM(total_items) as total_items,
- SUM(total_attachments) as total_attachments
- FROM task_logs
- WHERE user_id = ? AND date(created_at) = ?
- """,
- (user_id, date_filter),
+ return _fetch_user_run_stats_row(
+ cursor,
+ where_clause="user_id = ? AND date(created_at) = ?",
+ params=(user_id, date_filter),
)
-
- stats = cursor.fetchone()
-
- return {
- "completed": stats["completed"] or 0,
- "failed": stats["failed"] or 0,
- "total_items": stats["total_items"] or 0,
- "total_attachments": stats["total_attachments"] or 0,
- }
diff --git a/db/users.py b/db/users.py
index 42423a5..c5cf5d1 100644
--- a/db/users.py
+++ b/db/users.py
@@ -16,8 +16,41 @@ from password_utils import (
verify_password_bcrypt,
verify_password_sha256,
)
+
logger = get_logger(__name__)
+_CST_TZ = pytz.timezone("Asia/Shanghai")
+_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59"
+
+
+def _row_to_dict(row):
+ return dict(row) if row else None
+
+
+def _get_user_by_field(field_name: str, field_value):
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,))
+ return _row_to_dict(cursor.fetchone())
+
+
+def _parse_cst_datetime(datetime_str: str | None):
+ if not datetime_str:
+ return None
+ try:
+ naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S")
+ return _CST_TZ.localize(naive_dt)
+ except Exception:
+ return None
+
+
+def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str:
+ if int(days) == 999999:
+ return _PERMANENT_VIP_EXPIRE
+ if base_dt is None:
+ base_dt = datetime.now(_CST_TZ)
+ return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S")
+
def get_vip_config():
"""获取VIP配置"""
@@ -32,13 +65,12 @@ def set_default_vip_days(days):
"""设置默认VIP天数"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cst_time = get_cst_now_str()
cursor.execute(
"""
INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at)
VALUES (1, ?, ?)
""",
- (days, cst_time),
+ (days, get_cst_now_str()),
)
conn.commit()
return True
@@ -47,14 +79,8 @@ def set_default_vip_days(days):
def set_user_vip(user_id, days):
"""设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久"""
with db_pool.get_db() as conn:
- cst_tz = pytz.timezone("Asia/Shanghai")
cursor = conn.cursor()
-
- if days == 999999:
- expire_time = "2099-12-31 23:59:59"
- else:
- expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
-
+ expire_time = _format_vip_expire(days)
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id))
conn.commit()
return cursor.rowcount > 0
@@ -63,29 +89,26 @@ def set_user_vip(user_id, days):
def extend_user_vip(user_id, days):
"""延长用户VIP时间"""
user = get_user_by_id(user_id)
- cst_tz = pytz.timezone("Asia/Shanghai")
-
if not user:
return False
+ current_expire = user.get("vip_expire_time")
+ now_dt = datetime.now(_CST_TZ)
+
+ if current_expire and current_expire != _PERMANENT_VIP_EXPIRE:
+ expire_time = _parse_cst_datetime(current_expire)
+ if expire_time is not None:
+ if expire_time < now_dt:
+ expire_time = now_dt
+ new_expire = _format_vip_expire(days, base_dt=expire_time)
+ else:
+ logger.warning("解析VIP过期时间失败,使用当前时间")
+ new_expire = _format_vip_expire(days, base_dt=now_dt)
+ else:
+ new_expire = _format_vip_expire(days, base_dt=now_dt)
+
with db_pool.get_db() as conn:
cursor = conn.cursor()
- current_expire = user.get("vip_expire_time")
-
- if current_expire and current_expire != "2099-12-31 23:59:59":
- try:
- expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S")
- expire_time = cst_tz.localize(expire_time_naive)
- now = datetime.now(cst_tz)
- if expire_time < now:
- expire_time = now
- new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
- except (ValueError, AttributeError) as e:
- logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间")
- new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
- else:
- new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S")
-
cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id))
conn.commit()
return cursor.rowcount > 0
@@ -105,45 +128,49 @@ def is_user_vip(user_id):
注意:数据库中存储的时间统一使用CST(Asia/Shanghai)时区
"""
- cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
-
- if not user or not user.get("vip_expire_time"):
+ if not user:
return False
- try:
- expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S")
- expire_time = cst_tz.localize(expire_time_naive)
- now = datetime.now(cst_tz)
- return now < expire_time
- except (ValueError, AttributeError) as e:
- logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}")
+ vip_expire_time = user.get("vip_expire_time")
+ if not vip_expire_time:
return False
+ expire_time = _parse_cst_datetime(vip_expire_time)
+ if expire_time is None:
+ logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间")
+ return False
+
+ return datetime.now(_CST_TZ) < expire_time
+
def get_user_vip_info(user_id):
"""获取用户VIP信息"""
- cst_tz = pytz.timezone("Asia/Shanghai")
user = get_user_by_id(user_id)
-
if not user:
return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""}
vip_expire_time = user.get("vip_expire_time")
+ username = user.get("username", "")
+
if not vip_expire_time:
- return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
+ return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
- try:
- expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S")
- expire_time = cst_tz.localize(expire_time_naive)
- now = datetime.now(cst_tz)
- is_vip = now < expire_time
- days_left = (expire_time - now).days if is_vip else 0
+ expire_time = _parse_cst_datetime(vip_expire_time)
+ if expire_time is None:
+ logger.warning("VIP信息获取错误: 无法解析过期时间")
+ return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username}
- return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)}
- except Exception as e:
- logger.warning(f"VIP信息获取错误: {e}")
- return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")}
+ now_dt = datetime.now(_CST_TZ)
+ is_vip = now_dt < expire_time
+ days_left = (expire_time - now_dt).days if is_vip else 0
+
+ return {
+ "username": username,
+ "is_vip": is_vip,
+ "expire_time": vip_expire_time,
+ "days_left": max(0, days_left),
+ }
# ==================== 用户相关 ====================
@@ -151,8 +178,6 @@ def get_user_vip_info(user_id):
def create_user(username, password, email=""):
"""创建新用户(默认直接通过,赠送默认VIP)"""
- cst_tz = pytz.timezone("Asia/Shanghai")
-
with db_pool.get_db() as conn:
cursor = conn.cursor()
password_hash = hash_password_bcrypt(password)
@@ -160,12 +185,8 @@ def create_user(username, password, email=""):
default_vip_days = get_vip_config()["default_vip_days"]
vip_expire_time = None
-
- if default_vip_days > 0:
- if default_vip_days == 999999:
- vip_expire_time = "2099-12-31 23:59:59"
- else:
- vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S")
+ if int(default_vip_days or 0) > 0:
+ vip_expire_time = _format_vip_expire(int(default_vip_days))
try:
cursor.execute(
@@ -210,28 +231,28 @@ def verify_user(username, password):
def get_user_by_id(user_id):
"""根据ID获取用户"""
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
- user = cursor.fetchone()
- return dict(user) if user else None
+ return _get_user_by_field("id", user_id)
def get_user_kdocs_settings(user_id):
"""获取用户的金山文档配置"""
- user = get_user_by_id(user_id)
- if not user:
- return None
- return {
- "kdocs_unit": user.get("kdocs_unit") or "",
- "kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0,
- }
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,))
+ row = cursor.fetchone()
+ if not row:
+ return None
+ return {
+ "kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""),
+ "kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0,
+ }
def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool:
"""更新用户的金山文档配置"""
updates = []
params = []
+
if kdocs_unit is not None:
updates.append("kdocs_unit = ?")
params.append(kdocs_unit)
@@ -252,11 +273,7 @@ def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=No
def get_user_by_username(username):
"""根据用户名获取用户"""
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
- user = cursor.fetchone()
- return dict(user) if user else None
+ return _get_user_by_field("username", username)
def get_all_users():
@@ -279,14 +296,13 @@ def approve_user(user_id):
"""审核通过用户"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- cst_time = get_cst_now_str()
cursor.execute(
"""
UPDATE users
SET status = 'approved', approved_at = ?
WHERE id = ?
""",
- (cst_time, user_id),
+ (get_cst_now_str(), user_id),
)
conn.commit()
return cursor.rowcount > 0
@@ -315,5 +331,5 @@ def get_user_stats(user_id):
with db_pool.get_db() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,))
- account_count = cursor.fetchone()["count"]
- return {"account_count": account_count}
+ row = cursor.fetchone()
+ return {"account_count": int((row["count"] if row else 0) or 0)}
diff --git a/db_pool.py b/db_pool.py
index 3ce43fb..b5f4fe9 100755
--- a/db_pool.py
+++ b/db_pool.py
@@ -7,8 +7,12 @@
import sqlite3
import threading
-from queue import Queue, Empty
-import time
+from queue import Empty, Full, Queue
+
+from app_logger import get_logger
+
+
+logger = get_logger("database")
class ConnectionPool:
@@ -44,12 +48,55 @@ class ConnectionPool:
"""创建新的数据库连接"""
conn = sqlite3.connect(self.database, check_same_thread=False)
conn.row_factory = sqlite3.Row
+ # 启用外键约束,确保 ON DELETE CASCADE 等约束生效
+ conn.execute("PRAGMA foreign_keys=ON")
# 设置WAL模式提高并发性能
conn.execute("PRAGMA journal_mode=WAL")
+ # 在WAL模式下使用NORMAL同步,兼顾性能与可靠性
+ conn.execute("PRAGMA synchronous=NORMAL")
# 设置合理的超时时间
conn.execute("PRAGMA busy_timeout=5000")
return conn
+ def _close_connection(self, conn) -> None:
+ if conn is None:
+ return
+ try:
+ conn.close()
+ except Exception as e:
+ logger.warning(f"关闭连接失败: {e}")
+
+ def _is_connection_healthy(self, conn) -> bool:
+ if conn is None:
+ return False
+ try:
+ conn.rollback()
+ conn.execute("SELECT 1")
+ return True
+ except sqlite3.Error as e:
+ logger.warning(f"连接健康检查失败(数据库错误): {e}")
+ except Exception as e:
+ logger.warning(f"连接健康检查失败(未知错误): {e}")
+ return False
+
+ def _replenish_pool_if_needed(self) -> None:
+ with self._lock:
+ if self._pool.qsize() >= self.pool_size:
+ return
+
+ new_conn = None
+ try:
+ new_conn = self._create_connection()
+ self._pool.put(new_conn, block=False)
+ self._created_connections += 1
+ except Full:
+ if new_conn:
+ self._close_connection(new_conn)
+ except Exception as e:
+ if new_conn:
+ self._close_connection(new_conn)
+ logger.warning(f"重建连接失败: {e}")
+
def get_connection(self):
"""
从连接池获取连接
@@ -70,66 +117,20 @@ class ConnectionPool:
Args:
conn: 要归还的连接
"""
- import sqlite3
- from queue import Full
-
if conn is None:
return
- connection_healthy = False
- try:
- # 回滚任何未提交的事务
- conn.rollback()
- # 安全修复:验证连接是否健康,防止损坏的连接污染连接池
- conn.execute("SELECT 1")
- connection_healthy = True
- except sqlite3.Error as e:
- # 数据库相关错误,连接可能损坏
- print(f"连接健康检查失败(数据库错误): {e}")
- except Exception as e:
- print(f"连接健康检查失败(未知错误): {e}")
-
- if connection_healthy:
+ if self._is_connection_healthy(conn):
try:
self._pool.put(conn, block=False)
- return # 成功归还
+ return
except Full:
- # 队列已满(不应该发生,但处理它)
- print(f"警告: 连接池已满,关闭多余连接")
- connection_healthy = False # 标记为需要关闭
+ logger.warning("连接池已满,关闭多余连接")
+ self._close_connection(conn)
+ return
- # 连接不健康或队列已满,关闭它
- try:
- conn.close()
- except Exception as close_error:
- print(f"关闭连接失败: {close_error}")
-
- # 如果连接不健康,尝试创建新连接补充池
- if not connection_healthy:
- with self._lock:
- # 双重检查:确保池确实需要补充
- if self._pool.qsize() < self.pool_size:
- new_conn = None
- try:
- new_conn = self._create_connection()
- self._pool.put(new_conn, block=False)
- # 只有成功放入池后才增加计数
- self._created_connections += 1
- except Full:
- # 在获取锁期间池被填满了,关闭新建的连接
- if new_conn:
- try:
- new_conn.close()
- except Exception:
- pass
- except Exception as create_error:
- # 创建连接失败,确保关闭已创建的连接
- if new_conn:
- try:
- new_conn.close()
- except Exception:
- pass
- print(f"重建连接失败: {create_error}")
+ self._close_connection(conn)
+ self._replenish_pool_if_needed()
def close_all(self):
"""关闭所有连接"""
@@ -138,7 +139,7 @@ class ConnectionPool:
conn = self._pool.get(block=False)
conn.close()
except Exception as e:
- print(f"关闭连接失败: {e}")
+ logger.warning(f"关闭连接失败: {e}")
def get_stats(self):
"""获取连接池统计信息"""
@@ -175,14 +176,14 @@ class PooledConnection:
if exc_type is not None:
# 发生异常,回滚事务
self._conn.rollback()
- print(f"数据库事务已回滚: {exc_type.__name__}")
+ logger.warning(f"数据库事务已回滚: {exc_type.__name__}")
# 注意: 不自动commit,要求用户显式调用conn.commit()
if self._cursor:
self._cursor.close()
self._cursor = None
except Exception as e:
- print(f"关闭游标失败: {e}")
+ logger.warning(f"关闭游标失败: {e}")
finally:
# 归还连接
self._pool.return_connection(self._conn)
@@ -254,7 +255,7 @@ def init_pool(database, pool_size=5):
with _pool_lock:
if _pool is None:
_pool = ConnectionPool(database, pool_size)
- print(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
+ logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})")
def get_db():
diff --git a/email_service.py b/email_service.py
index c02629c..32fcdb8 100644
--- a/email_service.py
+++ b/email_service.py
@@ -34,15 +34,14 @@ def get_beijing_now_str():
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.base import MIMEBase
-from email.mime.image import MIMEImage
from email import encoders
from email.header import Header
from email.utils import formataddr
-from typing import Optional, List, Dict, Any, Callable
+from typing import Optional, List, Dict, Any, Callable, Tuple
from io import BytesIO
import db_pool
-from crypto_utils import encrypt_password, decrypt_password, is_encrypted
+from crypto_utils import encrypt_password, decrypt_password
from app_logger import get_logger
logger = get_logger("email_service")
@@ -102,6 +101,333 @@ QUEUE_MAX_SIZE = int(os.environ.get('EMAIL_QUEUE_MAX_SIZE', '100'))
MAX_ATTACHMENT_SIZE = int(os.environ.get('EMAIL_MAX_ATTACHMENT_SIZE', str(10 * 1024 * 1024))) # 10MB
+def _resolve_base_url(base_url: Optional[str] = None) -> str:
+ """解析系统基础URL,按参数 -> 邮件设置 -> 配置文件 -> 默认值顺序。"""
+ if base_url:
+ return str(base_url).strip() or 'http://localhost:51233'
+
+ try:
+ settings = get_email_settings()
+ configured_url = (settings or {}).get('base_url', '')
+ if configured_url:
+ return configured_url
+ except Exception:
+ pass
+
+ try:
+ from app_config import Config
+ configured_url = getattr(Config, 'BASE_URL', '')
+ if configured_url:
+ return configured_url
+ except Exception:
+ pass
+
+ return 'http://localhost:51233'
+
+
+def _load_email_template(template_filename: str, fallback_html: str) -> str:
+ """读取邮件模板,读取失败时回退到内置模板。"""
+ template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', template_filename)
+ try:
+ with open(template_path, 'r', encoding='utf-8') as f:
+ return f.read()
+ except FileNotFoundError:
+ return fallback_html
+
+
+def _render_template(template: str, values: Dict[str, Any]) -> str:
+ """替换模板变量,兼容 {{ key }} 和 {{key}} 两种写法。"""
+ rendered = template
+ for key, value in values.items():
+ value_text = '' if value is None else str(value)
+ rendered = rendered.replace(f'{{{{ {key} }}}}', value_text)
+ rendered = rendered.replace(f'{{{{{key}}}}}', value_text)
+ return rendered
+
+
+def _mark_unused_tokens_as_used(user_id: int, token_type: str) -> None:
+ """使指定类型的旧 token 失效。"""
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE email_tokens SET used = 1
+ WHERE user_id = ? AND token_type = ? AND used = 0
+ """,
+ (user_id, token_type),
+ )
+ conn.commit()
+
+
+def _send_token_email(
+ *,
+ email: str,
+ username: str,
+ user_id: int,
+ base_url: Optional[str],
+ token_type: str,
+ rate_limit_error: str,
+ subject: str,
+ template_filename: str,
+ fallback_html: str,
+ url_path: str,
+ url_template_key: str,
+ text_template: str,
+ invalidate_existing_tokens: bool = False,
+) -> Dict[str, Any]:
+ """通用 token 邮件发送流程(注册/重置/绑定)。"""
+ if not check_rate_limit(email, token_type):
+ return {'success': False, 'error': rate_limit_error, 'token': None}
+
+ if invalidate_existing_tokens:
+ _mark_unused_tokens_as_used(user_id, token_type)
+
+ token = generate_email_token(email, token_type, user_id)
+ resolved_base_url = _resolve_base_url(base_url)
+ normalized_path = url_path if url_path.startswith('/') else f'/{url_path}'
+ action_url = f"{resolved_base_url.rstrip('/')}{normalized_path.format(token=token)}"
+
+ html_template = _load_email_template(template_filename, fallback_html)
+ html_body = _render_template(
+ html_template,
+ {
+ 'username': username,
+ url_template_key: action_url,
+ },
+ )
+
+ text_body = text_template.format(username=username, action_url=action_url)
+
+ result = send_email(
+ to_email=email,
+ subject=subject,
+ body=text_body,
+ html_body=html_body,
+ email_type=token_type,
+ user_id=user_id,
+ )
+ if result.get('success'):
+ return {'success': True, 'error': '', 'token': token}
+ return {'success': False, 'error': result.get('error', '发送失败'), 'token': None}
+
+
+def _task_notify_precheck(email: str, *, require_screenshots: bool = False, screenshots: Optional[List] = None) -> str:
+ """任务通知发送前置检查,返回空字符串表示通过。"""
+ settings = get_email_settings()
+ if not settings.get('enabled', False):
+ return '邮件功能未启用'
+ if not settings.get('task_notify_enabled', False):
+ return '任务通知功能未启用'
+ if not email:
+ return '用户未设置邮箱'
+ if require_screenshots and not screenshots:
+ return '没有截图需要发送'
+ return ''
+
+
+def _load_screenshot_data(screenshot_path: Optional[str], log_callback: Optional[Callable] = None) -> Tuple[Optional[bytes], Optional[str]]:
+ """读取截图数据,失败时仅记录日志并继续。"""
+ if not screenshot_path or not os.path.exists(screenshot_path):
+ return None, None
+ try:
+ with open(screenshot_path, 'rb') as f:
+ return f.read(), os.path.basename(screenshot_path)
+ except Exception as e:
+ if log_callback:
+ log_callback(f"[邮件] 读取截图文件失败: {e}")
+ return None, None
+
+
+def _render_task_complete_bodies(
+ *,
+ html_template: str,
+ username: str,
+ account_name: str,
+ browse_type: str,
+ total_items: int,
+ total_attachments: int,
+ complete_time: str,
+ batch_info: str,
+ screenshot_text: str,
+) -> Tuple[str, str]:
+ """渲染任务完成通知邮件内容。"""
+ html_body = _render_template(
+ html_template,
+ {
+ 'username': username,
+ 'account_name': account_name,
+ 'browse_type': browse_type,
+ 'total_items': total_items,
+ 'total_attachments': total_attachments,
+ 'complete_time': complete_time,
+ 'batch_info': batch_info,
+ },
+ )
+
+ text_body = f"""
+您好,{username}!
+
+您的浏览任务已完成。
+
+账号:{account_name}
+浏览类型:{browse_type}
+浏览条目:{total_items} 条
+附件数量:{total_attachments} 个
+完成时间:{complete_time}
+
+{screenshot_text}
+"""
+ return html_body, text_body
+
+
+def _build_batch_accounts_html_rows(screenshots: List[Dict[str, Any]]) -> str:
+ """构建批次任务邮件中的账号详情表格。"""
+ rows = []
+ for item in screenshots:
+ rows.append(
+ f"""
+
+ | {item.get('account_name', '未知')} |
+ {item.get('items', 0)} |
+ {item.get('attachments', 0)} |
+
+ """
+ )
+ return ''.join(rows)
+
+
+def _collect_existing_screenshot_paths(screenshots: List[Dict[str, Any]]) -> List[Tuple[str, str]]:
+ """收集存在的截图路径与压缩包内文件名。"""
+ screenshot_paths: List[Tuple[str, str]] = []
+ for item in screenshots:
+ path = item.get('path')
+ if path and os.path.exists(path):
+ arcname = f"{item.get('account_name', 'screenshot')}_{os.path.basename(path)}"
+ screenshot_paths.append((path, arcname))
+ return screenshot_paths
+
+
+def _build_zip_attachment_from_paths(screenshot_paths: List[Tuple[str, str]]) -> Tuple[Optional[bytes], Optional[str], str]:
+ """从路径列表构建 ZIP 附件,返回 (zip_data, zip_filename, note)。"""
+ if not screenshot_paths:
+ return None, None, '本次无可用截图文件(可能截图失败或未启用截图)。'
+
+ import tempfile
+
+ zip_path = None
+ try:
+ with tempfile.NamedTemporaryFile(prefix='screenshots_', suffix='.zip', delete=False) as tmp:
+ zip_path = tmp.name
+
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
+ for file_path, arcname in screenshot_paths:
+ try:
+ zf.write(file_path, arcname=arcname)
+ except Exception as e:
+ logger.warning(f"[邮件] 写入ZIP失败: {e}")
+
+ zip_size = os.path.getsize(zip_path) if zip_path and os.path.exists(zip_path) else 0
+ if zip_size <= 0:
+ return None, None, '本次无可用截图文件(可能截图失败或文件不存在)。'
+ if zip_size > MAX_ATTACHMENT_SIZE:
+ return None, None, f'截图打包文件过大({zip_size} bytes),本次不附加附件。'
+
+ with open(zip_path, 'rb') as f:
+ zip_data = f.read()
+ zip_filename = f"screenshots_{datetime.now(BEIJING_TZ).strftime('%Y%m%d_%H%M%S')}.zip"
+ return zip_data, zip_filename, '截图已打包为ZIP附件,请查收。'
+ except Exception as e:
+ logger.warning(f"[邮件] 打包截图失败: {e}")
+ return None, None, '截图打包失败,本次不附加附件。'
+ finally:
+ if zip_path and os.path.exists(zip_path):
+ try:
+ os.remove(zip_path)
+ except Exception:
+ pass
+
+
+def _reset_smtp_daily_quota_if_needed(cursor, today: Optional[str] = None) -> int:
+ """在日期切换后重置 SMTP 每日计数,返回更新记录数。"""
+ reset_day = today or get_beijing_today()
+ cursor.execute(
+ """
+ UPDATE smtp_configs
+ SET daily_sent = 0, daily_reset_date = ?
+ WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = ''
+ """,
+ (reset_day, reset_day),
+ )
+ return int(cursor.rowcount or 0)
+
+
+def _build_smtp_admin_config(row, include_password: bool = False) -> Dict[str, Any]:
+ """将 smtp_configs 全字段查询行转换为管理接口配置字典。"""
+ encrypted_password = row[8]
+ password_text = ''
+ if encrypted_password:
+ if include_password:
+ password_text = decrypt_password(encrypted_password)
+ else:
+ password_text = '******'
+
+ config = {
+ 'id': row[0],
+ 'name': row[1],
+ 'enabled': bool(row[2]),
+ 'is_primary': bool(row[3]),
+ 'priority': row[4],
+ 'host': row[5],
+ 'port': row[6],
+ 'username': row[7],
+ 'password': password_text,
+ 'has_password': bool(encrypted_password),
+ 'use_ssl': bool(row[9]),
+ 'use_tls': bool(row[10]),
+ 'sender_name': row[11],
+ 'sender_email': row[12],
+ 'daily_limit': row[13],
+ 'daily_sent': row[14],
+ 'daily_reset_date': row[15],
+ 'last_success_at': row[16],
+ 'last_error': row[17],
+ 'success_count': row[18],
+ 'fail_count': row[19],
+ 'created_at': row[20],
+ 'updated_at': row[21],
+ }
+ total = (config['success_count'] or 0) + (config['fail_count'] or 0)
+ config['success_rate'] = round((config['success_count'] or 0) / total * 100, 1) if total > 0 else 0
+ return config
+
+
+def _build_smtp_runtime_config(row) -> Dict[str, Any]:
+ """将 smtp_configs 运行时查询行转换为发送字典(含每日限额字段)。"""
+ return {
+ 'id': row[0],
+ 'name': row[1],
+ 'host': row[2],
+ 'port': row[3],
+ 'username': row[4],
+ 'password': decrypt_password(row[5]) if row[5] else '',
+ 'use_ssl': bool(row[6]),
+ 'use_tls': bool(row[7]),
+ 'sender_name': row[8],
+ 'sender_email': row[9],
+ 'daily_limit': row[10] or 0,
+ 'daily_sent': row[11] or 0,
+ 'is_primary': bool(row[12]),
+ }
+
+
+def _strip_smtp_quota_fields(runtime_config: Dict[str, Any]) -> Dict[str, Any]:
+ """移除内部限额字段,返回 send_email 使用的 SMTP 配置字典。"""
+ config = dict(runtime_config)
+ config.pop('daily_limit', None)
+ config.pop('daily_sent', None)
+ return config
+
+
# ============ 数据库操作 ============
def init_email_tables():
@@ -317,12 +643,7 @@ def get_smtp_configs(include_password: bool = False) -> List[Dict[str, Any]]:
with db_pool.get_db() as conn:
cursor = conn.cursor()
# 确保每天的配额在日期切换后能及时重置(即使当天没有触发邮件发送)
- today = get_beijing_today()
- cursor.execute("""
- UPDATE smtp_configs
- SET daily_sent = 0, daily_reset_date = ?
- WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = ''
- """, (today, today))
+ _reset_smtp_daily_quota_if_needed(cursor)
conn.commit()
cursor.execute("""
@@ -335,53 +656,14 @@ def get_smtp_configs(include_password: bool = False) -> List[Dict[str, Any]]:
ORDER BY is_primary DESC, priority ASC, id ASC
""")
- configs = []
- for row in cursor.fetchall():
- config = {
- 'id': row[0],
- 'name': row[1],
- 'enabled': bool(row[2]),
- 'is_primary': bool(row[3]),
- 'priority': row[4],
- 'host': row[5],
- 'port': row[6],
- 'username': row[7],
- 'password': '******' if row[8] and not include_password else (decrypt_password(row[8]) if include_password and row[8] else ''),
- 'has_password': bool(row[8]),
- 'use_ssl': bool(row[9]),
- 'use_tls': bool(row[10]),
- 'sender_name': row[11],
- 'sender_email': row[12],
- 'daily_limit': row[13],
- 'daily_sent': row[14],
- 'daily_reset_date': row[15],
- 'last_success_at': row[16],
- 'last_error': row[17],
- 'success_count': row[18],
- 'fail_count': row[19],
- 'created_at': row[20],
- 'updated_at': row[21]
- }
-
- # 计算成功率
- total = config['success_count'] + config['fail_count']
- config['success_rate'] = round(config['success_count'] / total * 100, 1) if total > 0 else 0
-
- configs.append(config)
-
- return configs
+ return [_build_smtp_admin_config(row, include_password=include_password) for row in cursor.fetchall()]
def get_smtp_config(config_id: int, include_password: bool = False) -> Optional[Dict[str, Any]]:
"""获取单个SMTP配置"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
- today = get_beijing_today()
- cursor.execute("""
- UPDATE smtp_configs
- SET daily_sent = 0, daily_reset_date = ?
- WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = ''
- """, (today, today))
+ _reset_smtp_daily_quota_if_needed(cursor)
conn.commit()
cursor.execute("""
@@ -397,31 +679,7 @@ def get_smtp_config(config_id: int, include_password: bool = False) -> Optional[
if not row:
return None
- return {
- 'id': row[0],
- 'name': row[1],
- 'enabled': bool(row[2]),
- 'is_primary': bool(row[3]),
- 'priority': row[4],
- 'host': row[5],
- 'port': row[6],
- 'username': row[7],
- 'password': '******' if row[8] and not include_password else (decrypt_password(row[8]) if include_password and row[8] else ''),
- 'has_password': bool(row[8]),
- 'use_ssl': bool(row[9]),
- 'use_tls': bool(row[10]),
- 'sender_name': row[11],
- 'sender_email': row[12],
- 'daily_limit': row[13],
- 'daily_sent': row[14],
- 'daily_reset_date': row[15],
- 'last_success_at': row[16],
- 'last_error': row[17],
- 'success_count': row[18],
- 'fail_count': row[19],
- 'created_at': row[20],
- 'updated_at': row[21]
- }
+ return _build_smtp_admin_config(row, include_password=include_password)
def create_smtp_config(data: Dict[str, Any]) -> int:
@@ -556,61 +814,50 @@ def clear_primary_smtp_config() -> bool:
return True
-def _get_available_smtp_config(failover: bool = True) -> Optional[Dict[str, Any]]:
- """
- 获取可用的SMTP配置
- 优先级: 主配置 > 按priority排序的启用配置
- """
- today = get_beijing_today()
+def _fetch_candidate_smtp_configs(cursor, *, exclude_ids: Optional[List[int]] = None) -> List[Dict[str, Any]]:
+ """获取可用 SMTP 候选配置(已过滤不可用/超额配置)。"""
+ excluded = [int(i) for i in (exclude_ids or []) if i is not None]
+ sql = """
+ SELECT id, name, host, port, username, password, use_ssl, use_tls,
+ sender_name, sender_email, daily_limit, daily_sent, is_primary
+ FROM smtp_configs
+ WHERE enabled = 1
+ """
+ params: List[Any] = []
+ if excluded:
+ placeholders = ",".join(["?" for _ in excluded])
+ sql += f" AND id NOT IN ({placeholders})"
+ params.extend(excluded)
+
+ sql += " ORDER BY is_primary DESC, priority ASC, id ASC"
+ cursor.execute(sql, params)
+
+ candidates: List[Dict[str, Any]] = []
+ for row in cursor.fetchall() or []:
+ runtime_config = _build_smtp_runtime_config(row)
+ daily_limit = int(runtime_config.get('daily_limit') or 0)
+ daily_sent = int(runtime_config.get('daily_sent') or 0)
+ if daily_limit > 0 and daily_sent >= daily_limit:
+ continue
+ candidates.append(_strip_smtp_quota_fields(runtime_config))
+
+ return candidates
+
+
+def _get_smtp_candidates(*, exclude_ids: Optional[List[int]] = None) -> List[Dict[str, Any]]:
+ """读取并返回当前可发送的 SMTP 候选配置列表。"""
with db_pool.get_db() as conn:
cursor = conn.cursor()
-
- # 先重置过期的每日计数
- cursor.execute("""
- UPDATE smtp_configs
- SET daily_sent = 0, daily_reset_date = ?
- WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = ''
- """, (today, today))
+ _reset_smtp_daily_quota_if_needed(cursor)
conn.commit()
+ return _fetch_candidate_smtp_configs(cursor, exclude_ids=exclude_ids)
- # 获取所有启用的配置,按优先级排序
- cursor.execute("""
- SELECT id, name, host, port, username, password, use_ssl, use_tls,
- sender_name, sender_email, daily_limit, daily_sent, is_primary
- FROM smtp_configs
- WHERE enabled = 1
- ORDER BY is_primary DESC, priority ASC, id ASC
- """)
- configs = cursor.fetchall()
-
- for row in configs:
- config_id, name, host, port, username, password, use_ssl, use_tls, \
- sender_name, sender_email, daily_limit, daily_sent, is_primary = row
-
- # 检查每日限额
- if daily_limit > 0 and daily_sent >= daily_limit:
- continue # 超过限额,跳过此配置
-
- # 解密密码
- decrypted_password = decrypt_password(password) if password else ''
-
- return {
- 'id': config_id,
- 'name': name,
- 'host': host,
- 'port': port,
- 'username': username,
- 'password': decrypted_password,
- 'use_ssl': bool(use_ssl),
- 'use_tls': bool(use_tls),
- 'sender_name': sender_name,
- 'sender_email': sender_email,
- 'is_primary': bool(is_primary)
- }
-
- return None
+def _get_available_smtp_config(failover: bool = True) -> Optional[Dict[str, Any]]:
+ """获取首个可用 SMTP 配置。"""
+ candidates = _get_smtp_candidates()
+ return candidates[0] if candidates else None
def _update_smtp_stats(config_id: int, success: bool, error: str = ''):
@@ -751,6 +998,44 @@ class EmailSender:
raise
+def _create_email_sender_from_config(config: Dict[str, Any]) -> EmailSender:
+ """从 SMTP 配置创建发送器。"""
+ return EmailSender(
+ {
+ 'name': config.get('name', ''),
+ 'host': config.get('host', ''),
+ 'port': config.get('port', 465),
+ 'username': config.get('username', ''),
+ 'password': config.get('password', ''),
+ 'use_ssl': bool(config.get('use_ssl', True)),
+ 'use_tls': bool(config.get('use_tls', False)),
+ 'sender_name': config.get('sender_name', '自动化学习'),
+ 'sender_email': config.get('sender_email', ''),
+ }
+ )
+
+
+def _send_with_smtp_config(
+ *,
+ config: Dict[str, Any],
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: str = None,
+ attachments: Optional[List[Dict[str, Any]]] = None,
+) -> Tuple[bool, str]:
+ """使用指定 SMTP 配置发送一封邮件。"""
+ sender = _create_email_sender_from_config(config)
+ try:
+ sender.connect()
+ sender.send(to_email, subject, body, html_body, attachments)
+ sender.disconnect()
+ return True, ''
+ except Exception as e:
+ sender.disconnect()
+ return False, str(e)
+
+
def send_email(
to_email: str,
subject: str,
@@ -767,103 +1052,59 @@ def send_email(
Returns:
{'success': bool, 'error': str, 'config_id': int}
"""
- # 检查全局开关
settings = get_email_settings()
if not settings['enabled']:
return {'success': False, 'error': '邮件功能未启用', 'config_id': None}
- # 获取可用配置
- config = _get_available_smtp_config(settings['failover_enabled'])
- if not config:
+ candidates = _get_smtp_candidates()
+ if not settings['failover_enabled']:
+ candidates = candidates[:1]
+ if not candidates:
return {'success': False, 'error': '没有可用的SMTP配置', 'config_id': None}
- tried_configs = []
+ tried_configs: List[int] = []
last_error = ''
- while config:
- tried_configs.append(config['id'])
- sender = EmailSender(config)
+ for config in candidates:
+ config_id = int(config.get('id') or 0)
+ tried_configs.append(config_id)
- try:
- sender.connect()
- sender.send(to_email, subject, body, html_body, attachments)
- sender.disconnect()
+ ok, error = _send_with_smtp_config(
+ config=config,
+ to_email=to_email,
+ subject=subject,
+ body=body,
+ html_body=html_body,
+ attachments=attachments,
+ )
- # 更新统计
- _update_smtp_stats(config['id'], True)
- _log_email_send(user_id, config['id'], to_email, email_type, subject, 'success', '', attachments)
+ if ok:
+ _update_smtp_stats(config_id, True)
+ _log_email_send(user_id, config_id, to_email, email_type, subject, 'success', '', attachments)
_update_email_stats(email_type, True)
if log_callback:
- log_callback(f"[邮件服务] 发送成功: {to_email} (使用: {config['name']})")
+ log_callback(f"[邮件服务] 发送成功: {to_email} (使用: {config.get('name', '')})")
- return {'success': True, 'error': '', 'config_id': config['id']}
+ return {'success': True, 'error': '', 'config_id': config_id}
- except Exception as e:
- last_error = str(e)
- sender.disconnect()
+ last_error = error
+ _update_smtp_stats(config_id, False, last_error)
- # 更新失败统计
- _update_smtp_stats(config['id'], False, last_error)
+ if log_callback:
+ log_callback(f"[邮件服务] 发送失败 [{config.get('name', '')}]: {last_error}")
- if log_callback:
- log_callback(f"[邮件服务] 发送失败 [{config['name']}]: {e}")
-
- # 故障转移:尝试下一个配置
- if settings['failover_enabled']:
- config = _get_next_available_smtp_config(tried_configs)
- else:
- config = None
-
- # 所有配置都失败
- _log_email_send(user_id, tried_configs[0] if tried_configs else None,
- to_email, email_type, subject, 'failed', last_error, attachments)
+ first_config_id = tried_configs[0] if tried_configs else None
+ _log_email_send(user_id, first_config_id, to_email, email_type, subject, 'failed', last_error, attachments)
_update_email_stats(email_type, False)
- return {'success': False, 'error': last_error, 'config_id': tried_configs[0] if tried_configs else None}
+ return {'success': False, 'error': last_error, 'config_id': first_config_id}
def _get_next_available_smtp_config(exclude_ids: List[int]) -> Optional[Dict[str, Any]]:
- """获取下一个可用的SMTP配置(排除已尝试的)"""
- today = get_beijing_today()
-
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
-
- placeholders = ','.join(['?' for _ in exclude_ids])
- cursor.execute(f"""
- SELECT id, name, host, port, username, password, use_ssl, use_tls,
- sender_name, sender_email, daily_limit, daily_sent, is_primary
- FROM smtp_configs
- WHERE enabled = 1 AND id NOT IN ({placeholders})
- ORDER BY is_primary DESC, priority ASC, id ASC
- LIMIT 1
- """, exclude_ids)
-
- row = cursor.fetchone()
- if not row:
- return None
-
- config_id, name, host, port, username, password, use_ssl, use_tls, \
- sender_name, sender_email, daily_limit, daily_sent, is_primary = row
-
- # 检查每日限额
- if daily_limit > 0 and daily_sent >= daily_limit:
- return _get_next_available_smtp_config(exclude_ids + [config_id])
-
- return {
- 'id': config_id,
- 'name': name,
- 'host': host,
- 'port': port,
- 'username': username,
- 'password': decrypt_password(password) if password else '',
- 'use_ssl': bool(use_ssl),
- 'use_tls': bool(use_tls),
- 'sender_name': sender_name,
- 'sender_email': sender_email,
- 'is_primary': bool(is_primary)
- }
+ """获取下一个可用的SMTP配置(排除已尝试的)。"""
+ candidates = _get_smtp_candidates(exclude_ids=exclude_ids)
+ return candidates[0] if candidates else None
def test_smtp_config(config_id: int, test_email: str) -> Dict[str, Any]:
@@ -872,28 +1113,14 @@ def test_smtp_config(config_id: int, test_email: str) -> Dict[str, Any]:
if not config:
return {'success': False, 'error': '配置不存在'}
- sender = EmailSender({
- 'name': config['name'],
- 'host': config['host'],
- 'port': config['port'],
- 'username': config['username'],
- 'password': config['password'],
- 'use_ssl': config['use_ssl'],
- 'use_tls': config['use_tls'],
- 'sender_name': config['sender_name'],
- 'sender_email': config['sender_email']
- })
+ ok, error = _send_with_smtp_config(
+ config=config,
+ to_email=test_email,
+ subject='自动化学习 - SMTP配置测试',
+ body=f'这是一封测试邮件。\n\n配置名称: {config["name"]}\nSMTP服务器: {config["host"]}:{config["port"]}\n\n如果您收到此邮件,说明SMTP配置正确。',
+ )
- try:
- sender.connect()
- sender.send(
- test_email,
- '自动化学习 - SMTP配置测试',
- f'这是一封测试邮件。\n\n配置名称: {config["name"]}\nSMTP服务器: {config["host"]}:{config["port"]}\n\n如果您收到此邮件,说明SMTP配置正确。',
- None,
- None
- )
- sender.disconnect()
+ if ok:
# 更新最后成功时间
with db_pool.get_db() as conn:
@@ -908,16 +1135,18 @@ def test_smtp_config(config_id: int, test_email: str) -> Dict[str, Any]:
return {'success': True, 'error': ''}
- except Exception as e:
- # 记录错误
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cursor.execute("""
- UPDATE smtp_configs SET last_error = ? WHERE id = ?
- """, (str(e)[:500], config_id))
- conn.commit()
+ # 记录错误
+ with db_pool.get_db() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ UPDATE smtp_configs SET last_error = ? WHERE id = ?
+ """,
+ (error[:500], config_id),
+ )
+ conn.commit()
- return {'success': False, 'error': str(e)}
+ return {'success': False, 'error': error}
# ============ 邮件日志 ============
@@ -1248,84 +1477,43 @@ def send_register_verification_email(
Returns:
{'success': bool, 'error': str, 'token': str}
"""
- # 检查发送频率限制
- if not check_rate_limit(email, EMAIL_TYPE_REGISTER):
- return {
- 'success': False,
- 'error': '发送太频繁,请稍后再试',
- 'token': None
- }
-
- # 生成验证Token
- token = generate_email_token(email, EMAIL_TYPE_REGISTER, user_id)
-
- # 获取base_url
- if not base_url:
- settings = get_email_settings()
- base_url = settings.get('base_url', '')
-
- if not base_url:
- # 尝试从配置获取
- try:
- from app_config import Config
- base_url = Config.BASE_URL
- except:
- base_url = 'http://localhost:51233'
-
- # 生成验证链接
- verify_url = f"{base_url.rstrip('/')}/api/verify-email/{token}"
-
- # 读取邮件模板
- template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', 'register.html')
- try:
- with open(template_path, 'r', encoding='utf-8') as f:
- html_template = f.read()
- except FileNotFoundError:
- # 使用简单的HTML模板
- html_template = """
-
-
- 邮箱验证
- 您好,{{ username }}!
- 请点击下面的链接验证您的邮箱地址:
- {{ verify_url }}
- 此链接24小时内有效。
-
-
- """
-
- # 替换模板变量
- html_body = html_template.replace('{{ username }}', username)
- html_body = html_body.replace('{{ verify_url }}', verify_url)
-
- # 纯文本版本
- text_body = f"""
+ fallback_html = """
+
+
+ 邮箱验证
+ 您好,{{ username }}!
+ 请点击下面的链接验证您的邮箱地址:
+ {{ verify_url }}
+ 此链接24小时内有效。
+
+
+ """
+ text_template = """
您好,{username}!
感谢您注册自动化学习。请点击下面的链接验证您的邮箱地址:
-{verify_url}
+{action_url}
此链接24小时内有效。
如果您没有注册过账号,请忽略此邮件。
"""
-
- # 发送邮件
- result = send_email(
- to_email=email,
+ return _send_token_email(
+ email=email,
+ username=username,
+ user_id=user_id,
+ base_url=base_url,
+ token_type=EMAIL_TYPE_REGISTER,
+ rate_limit_error='发送太频繁,请稍后再试',
subject='【自动化学习】邮箱验证',
- body=text_body,
- html_body=html_body,
- email_type=EMAIL_TYPE_REGISTER,
- user_id=user_id
+ template_filename='register.html',
+ fallback_html=fallback_html,
+ url_path='/api/verify-email/{token}',
+ url_template_key='verify_url',
+ text_template=text_template,
)
- if result['success']:
- return {'success': True, 'error': '', 'token': token}
- else:
- return {'success': False, 'error': result['error'], 'token': None}
-
def resend_register_verification_email(user_id: int, email: str, username: str) -> Dict[str, Any]:
"""
@@ -1340,14 +1528,7 @@ def resend_register_verification_email(user_id: int, email: str, username: str)
{'success': bool, 'error': str}
"""
# 检查是否有未过期的token
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- # 先使旧token失效
- cursor.execute("""
- UPDATE email_tokens SET used = 1
- WHERE user_id = ? AND token_type = ? AND used = 0
- """, (user_id, EMAIL_TYPE_REGISTER))
- conn.commit()
+ _mark_unused_tokens_as_used(user_id, EMAIL_TYPE_REGISTER)
# 发送新的验证邮件
return send_register_verification_email(email, username, user_id)
@@ -1373,91 +1554,44 @@ def send_password_reset_email(
Returns:
{'success': bool, 'error': str, 'token': str}
"""
- # 检查发送频率限制(密码重置限制5分钟)
- if not check_rate_limit(email, EMAIL_TYPE_RESET):
- return {
- 'success': False,
- 'error': '发送太频繁,请5分钟后再试',
- 'token': None
- }
-
- # 使旧的重置token失效
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cursor.execute("""
- UPDATE email_tokens SET used = 1
- WHERE user_id = ? AND token_type = ? AND used = 0
- """, (user_id, EMAIL_TYPE_RESET))
- conn.commit()
-
- # 生成新的验证Token
- token = generate_email_token(email, EMAIL_TYPE_RESET, user_id)
-
- # 获取base_url
- if not base_url:
- settings = get_email_settings()
- base_url = settings.get('base_url', '')
-
- if not base_url:
- try:
- from app_config import Config
- base_url = Config.BASE_URL
- except:
- base_url = 'http://localhost:51233'
-
- # 生成重置链接
- reset_url = f"{base_url.rstrip('/')}/reset-password/{token}"
-
- # 读取邮件模板
- template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', 'reset_password.html')
- try:
- with open(template_path, 'r', encoding='utf-8') as f:
- html_template = f.read()
- except FileNotFoundError:
- html_template = """
-
-
- 密码重置
- 您好,{{ username }}!
- 请点击下面的链接重置您的密码:
- {{ reset_url }}
- 此链接30分钟内有效。
-
-
- """
-
- # 替换模板变量
- html_body = html_template.replace('{{ username }}', username)
- html_body = html_body.replace('{{ reset_url }}', reset_url)
-
- # 纯文本版本
- text_body = f"""
+ fallback_html = """
+
+
+ 密码重置
+ 您好,{{ username }}!
+ 请点击下面的链接重置您的密码:
+ {{ reset_url }}
+ 此链接30分钟内有效。
+
+
+ """
+ text_template = """
您好,{username}!
我们收到了您的密码重置请求。请点击下面的链接重置您的密码:
-{reset_url}
+{action_url}
此链接30分钟内有效。
如果您没有申请过密码重置,请忽略此邮件。
"""
-
- # 发送邮件
- result = send_email(
- to_email=email,
+ return _send_token_email(
+ email=email,
+ username=username,
+ user_id=user_id,
+ base_url=base_url,
+ token_type=EMAIL_TYPE_RESET,
+ rate_limit_error='发送太频繁,请5分钟后再试',
subject='【自动化学习】密码重置',
- body=text_body,
- html_body=html_body,
- email_type=EMAIL_TYPE_RESET,
- user_id=user_id
+ template_filename='reset_password.html',
+ fallback_html=fallback_html,
+ url_path='/reset-password/{token}',
+ url_template_key='reset_url',
+ text_template=text_template,
+ invalidate_existing_tokens=True,
)
- if result['success']:
- return {'success': True, 'error': '', 'token': token}
- else:
- return {'success': False, 'error': result['error'], 'token': None}
-
def verify_password_reset_token(token: str) -> Optional[Dict[str, Any]]:
"""
@@ -1533,91 +1667,44 @@ def send_bind_email_verification(
Returns:
{'success': bool, 'error': str, 'token': str}
"""
- # 检查发送频率限制(绑定邮件限制1分钟)
- if not check_rate_limit(email, EMAIL_TYPE_BIND):
- return {
- 'success': False,
- 'error': '发送太频繁,请1分钟后再试',
- 'token': None
- }
-
- # 使旧的绑定token失效
- with db_pool.get_db() as conn:
- cursor = conn.cursor()
- cursor.execute("""
- UPDATE email_tokens SET used = 1
- WHERE user_id = ? AND token_type = ? AND used = 0
- """, (user_id, EMAIL_TYPE_BIND))
- conn.commit()
-
- # 生成新的验证Token
- token = generate_email_token(email, EMAIL_TYPE_BIND, user_id)
-
- # 获取base_url
- if not base_url:
- settings = get_email_settings()
- base_url = settings.get('base_url', '')
-
- if not base_url:
- try:
- from app_config import Config
- base_url = Config.BASE_URL
- except:
- base_url = 'http://localhost:51233'
-
- # 生成验证链接
- verify_url = f"{base_url.rstrip('/')}/api/verify-bind-email/{token}"
-
- # 读取邮件模板
- template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', 'bind_email.html')
- try:
- with open(template_path, 'r', encoding='utf-8') as f:
- html_template = f.read()
- except FileNotFoundError:
- html_template = """
-
-
- 邮箱绑定验证
- 您好,{{ username }}!
- 请点击下面的链接完成邮箱绑定:
- {{ verify_url }}
- 此链接1小时内有效。
-
-
- """
-
- # 替换模板变量
- html_body = html_template.replace('{{ username }}', username)
- html_body = html_body.replace('{{ verify_url }}', verify_url)
-
- # 纯文本版本
- text_body = f"""
+ fallback_html = """
+
+
+ 邮箱绑定验证
+ 您好,{{ username }}!
+ 请点击下面的链接完成邮箱绑定:
+ {{ verify_url }}
+ 此链接1小时内有效。
+
+
+ """
+ text_template = """
您好,{username}!
您正在绑定此邮箱到您的账号。请点击下面的链接完成验证:
-{verify_url}
+{action_url}
此链接1小时内有效。
如果这不是您的操作,请忽略此邮件。
"""
-
- # 发送邮件
- result = send_email(
- to_email=email,
+ return _send_token_email(
+ email=email,
+ username=username,
+ user_id=user_id,
+ base_url=base_url,
+ token_type=EMAIL_TYPE_BIND,
+ rate_limit_error='发送太频繁,请1分钟后再试',
subject='【自动化学习】邮箱绑定验证',
- body=text_body,
- html_body=html_body,
- email_type=EMAIL_TYPE_BIND,
- user_id=user_id
+ template_filename='bind_email.html',
+ fallback_html=fallback_html,
+ url_path='/api/verify-bind-email/{token}',
+ url_template_key='verify_url',
+ text_template=text_template,
+ invalidate_existing_tokens=True,
)
- if result['success']:
- return {'success': True, 'error': '', 'token': token}
- else:
- return {'success': False, 'error': result['error'], 'token': None}
-
def verify_bind_email_token(token: str) -> Optional[Dict[str, Any]]:
"""
@@ -1897,6 +1984,37 @@ def create_zip_attachment(files: List[Dict[str, Any]], zip_filename: str = 'scre
# ============ 任务完成通知邮件 ============
+def _send_email_and_collect(
+ *,
+ to_email: str,
+ subject: str,
+ body: str,
+ html_body: Optional[str] = None,
+ attachments: Optional[List[Dict[str, Any]]] = None,
+ email_type: str,
+ user_id: Optional[int],
+ log_callback: Optional[Callable] = None,
+ success_log: Optional[str] = None,
+) -> Tuple[bool, str]:
+ result = send_email(
+ to_email=to_email,
+ subject=subject,
+ body=body,
+ html_body=html_body,
+ attachments=attachments,
+ email_type=email_type,
+ user_id=user_id,
+ log_callback=log_callback,
+ )
+
+ if result.get('success'):
+ if success_log and log_callback:
+ log_callback(success_log)
+ return True, ''
+
+ return False, str(result.get('error') or '发送失败')
+
+
def send_task_complete_email(
user_id: int,
email: str,
@@ -1925,39 +2043,16 @@ def send_task_complete_email(
Returns:
{'success': bool, 'error': str, 'emails_sent': int}
"""
- # 检查邮件功能是否启用
- settings = get_email_settings()
- if not settings.get('enabled', False):
- return {'success': False, 'error': '邮件功能未启用', 'emails_sent': 0}
+ precheck_error = _task_notify_precheck(email)
+ if precheck_error:
+ return {'success': False, 'error': precheck_error, 'emails_sent': 0}
- if not settings.get('task_notify_enabled', False):
- return {'success': False, 'error': '任务通知功能未启用', 'emails_sent': 0}
-
- if not email:
- return {'success': False, 'error': '用户未设置邮箱', 'emails_sent': 0}
-
- # 获取完成时间
complete_time = get_beijing_now_str()
+ screenshot_data, screenshot_filename = _load_screenshot_data(screenshot_path, log_callback)
- # 读取截图文件
- screenshot_data = None
- screenshot_filename = None
- if screenshot_path and os.path.exists(screenshot_path):
- try:
- with open(screenshot_path, 'rb') as f:
- screenshot_data = f.read()
- screenshot_filename = os.path.basename(screenshot_path)
- except Exception as e:
- if log_callback:
- log_callback(f"[邮件] 读取截图文件失败: {e}")
-
- # 读取邮件模板
- template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', 'task_complete.html')
- try:
- with open(template_path, 'r', encoding='utf-8') as f:
- html_template = f.read()
- except FileNotFoundError:
- html_template = """
+ html_template = _load_email_template(
+ 'task_complete.html',
+ """
任务完成通知
@@ -1970,108 +2065,72 @@ def send_task_complete_email(
完成时间:{{ complete_time }}
- """
+ """,
+ )
- # 准备发送
emails_sent = 0
last_error = ''
+ is_oversized_attachment = bool(screenshot_data and len(screenshot_data) > MAX_ATTACHMENT_SIZE)
- # 检查附件大小,决定是否需要分批发送
- if screenshot_data and len(screenshot_data) > MAX_ATTACHMENT_SIZE:
- # 附件超过限制,不压缩直接发送(大图片压缩效果不大)
- # 这种情况很少见,但作为容错处理
- batch_info = '(由于文件较大,截图将单独发送)'
+ if is_oversized_attachment:
+ html_body, text_body = _render_task_complete_bodies(
+ html_template=html_template,
+ username=username,
+ account_name=account_name,
+ browse_type=browse_type,
+ total_items=total_items,
+ total_attachments=total_attachments,
+ complete_time=complete_time,
+ batch_info='(由于文件较大,截图将单独发送)',
+ screenshot_text='截图将在下一封邮件中发送。',
+ )
- # 先发送不带附件的通知邮件
- html_body = html_template.replace('{{ username }}', username)
- html_body = html_body.replace('{{ account_name }}', account_name)
- html_body = html_body.replace('{{ browse_type }}', browse_type)
- html_body = html_body.replace('{{ total_items }}', str(total_items))
- html_body = html_body.replace('{{ total_attachments }}', str(total_attachments))
- html_body = html_body.replace('{{ complete_time }}', complete_time)
- html_body = html_body.replace('{{ batch_info }}', batch_info)
-
- text_body = f"""
-您好,{username}!
-
-您的浏览任务已完成。
-
-账号:{account_name}
-浏览类型:{browse_type}
-浏览条目:{total_items} 条
-附件数量:{total_attachments} 个
-完成时间:{complete_time}
-
-截图将在下一封邮件中发送。
-"""
-
- result = send_email(
+ ok, error = _send_email_and_collect(
to_email=email,
subject=f'【自动化学习】任务完成 - {account_name}',
body=text_body,
html_body=html_body,
email_type=EMAIL_TYPE_TASK_COMPLETE,
user_id=user_id,
- log_callback=log_callback
+ log_callback=log_callback,
+ success_log='[邮件] 任务通知已发送',
)
-
- if result['success']:
+ if ok:
emails_sent += 1
- if log_callback:
- log_callback(f"[邮件] 任务通知已发送")
else:
- last_error = result['error']
+ last_error = error
- # 单独发送截图附件
attachment = [{'filename': screenshot_filename, 'data': screenshot_data}]
- result2 = send_email(
+ ok, error = _send_email_and_collect(
to_email=email,
subject=f'【自动化学习】任务截图 - {account_name}',
body=f'这是 {account_name} 的任务截图。',
attachments=attachment,
email_type=EMAIL_TYPE_TASK_COMPLETE,
user_id=user_id,
- log_callback=log_callback
+ log_callback=log_callback,
+ success_log='[邮件] 截图附件已发送',
)
-
- if result2['success']:
+ if ok:
emails_sent += 1
- if log_callback:
- log_callback(f"[邮件] 截图附件已发送")
else:
- last_error = result2['error']
+ last_error = error
else:
- # 正常情况:附件大小在限制内,一次性发送
- batch_info = ''
- attachments = None
+ attachments = [{'filename': screenshot_filename, 'data': screenshot_data}] if screenshot_data else None
+ html_body, text_body = _render_task_complete_bodies(
+ html_template=html_template,
+ username=username,
+ account_name=account_name,
+ browse_type=browse_type,
+ total_items=total_items,
+ total_attachments=total_attachments,
+ complete_time=complete_time,
+ batch_info='',
+ screenshot_text='截图已附在邮件中。' if screenshot_data else '',
+ )
- if screenshot_data:
- attachments = [{'filename': screenshot_filename, 'data': screenshot_data}]
-
- html_body = html_template.replace('{{ username }}', username)
- html_body = html_body.replace('{{ account_name }}', account_name)
- html_body = html_body.replace('{{ browse_type }}', browse_type)
- html_body = html_body.replace('{{ total_items }}', str(total_items))
- html_body = html_body.replace('{{ total_attachments }}', str(total_attachments))
- html_body = html_body.replace('{{ complete_time }}', complete_time)
- html_body = html_body.replace('{{ batch_info }}', batch_info)
-
- text_body = f"""
-您好,{username}!
-
-您的浏览任务已完成。
-
-账号:{account_name}
-浏览类型:{browse_type}
-浏览条目:{total_items} 条
-附件数量:{total_attachments} 个
-完成时间:{complete_time}
-
-{'截图已附在邮件中。' if screenshot_data else ''}
-"""
-
- result = send_email(
+ ok, error = _send_email_and_collect(
to_email=email,
subject=f'【自动化学习】任务完成 - {account_name}',
body=text_body,
@@ -2079,15 +2138,13 @@ def send_task_complete_email(
attachments=attachments,
email_type=EMAIL_TYPE_TASK_COMPLETE,
user_id=user_id,
- log_callback=log_callback
+ log_callback=log_callback,
+ success_log='[邮件] 任务通知已发送',
)
-
- if result['success']:
+ if ok:
emails_sent += 1
- if log_callback:
- log_callback(f"[邮件] 任务通知已发送")
else:
- last_error = result['error']
+ last_error = error
return {
'success': emails_sent > 0,
@@ -2118,63 +2175,25 @@ def send_task_complete_email_async(
log_callback("[邮件] 邮件队列已满,任务通知未发送")
-def send_batch_task_complete_email(
- user_id: int,
- email: str,
+def _summarize_batch_screenshots(screenshots: List[Dict[str, Any]]) -> Tuple[int, int, int]:
+ total_items_sum = sum(int(s.get('items', 0) or 0) for s in screenshots)
+ total_attachments_sum = sum(int(s.get('attachments', 0) or 0) for s in screenshots)
+ account_count = len(screenshots)
+ return total_items_sum, total_attachments_sum, account_count
+
+
+def _render_batch_task_complete_html(
+ *,
username: str,
schedule_name: str,
browse_type: str,
- screenshots: List[Dict[str, Any]]
-) -> Dict[str, Any]:
- """
- 发送批次任务完成通知邮件(多账号截图打包)
-
- Args:
- user_id: 用户ID
- email: 收件人邮箱
- username: 用户名
- schedule_name: 定时任务名称
- browse_type: 浏览类型
- screenshots: 截图列表 [{'account_name': x, 'path': y, 'items': n, 'attachments': m}, ...]
-
- Returns:
- {'success': bool, 'error': str}
- """
- # 检查邮件功能是否启用
- settings = get_email_settings()
- if not settings.get('enabled', False):
- return {'success': False, 'error': '邮件功能未启用'}
-
- if not settings.get('task_notify_enabled', False):
- return {'success': False, 'error': '任务通知功能未启用'}
-
- if not email:
- return {'success': False, 'error': '用户未设置邮箱'}
-
- if not screenshots:
- return {'success': False, 'error': '没有截图需要发送'}
-
- # 获取完成时间
- complete_time = get_beijing_now_str()
-
- # 统计信息
- total_items_sum = sum(s.get('items', 0) for s in screenshots)
- total_attachments_sum = sum(s.get('attachments', 0) for s in screenshots)
- account_count = len(screenshots)
-
- # 构建账号详情HTML
- accounts_html = ""
- for s in screenshots:
- accounts_html += f"""
-
- | {s.get('account_name', '未知')} |
- {s.get('items', 0)} |
- {s.get('attachments', 0)} |
-
- """
-
- # 构建HTML邮件内容
- html_content = f"""
+ complete_time: str,
+ account_count: int,
+ total_items_sum: int,
+ total_attachments_sum: int,
+ accounts_html: str,
+) -> str:
+ return f"""
@@ -2208,52 +2227,50 @@ def send_batch_task_complete_email(
"""
- # 收集可用截图文件路径(避免把所有图片一次性读入内存)
- screenshot_paths = []
- for s in screenshots:
- path = s.get('path')
- if path and os.path.exists(path):
- arcname = f"{s.get('account_name', 'screenshot')}_{os.path.basename(path)}"
- screenshot_paths.append((path, arcname))
- # 如果有截图,优先落盘打包ZIP,再按大小决定是否附加(降低内存峰值)
- zip_data = None
- zip_filename = None
- attachment_note = ""
- if screenshot_paths:
- import tempfile
- zip_path = None
- try:
- with tempfile.NamedTemporaryFile(prefix="screenshots_", suffix=".zip", delete=False) as tmp:
- zip_path = tmp.name
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
- for file_path, arcname in screenshot_paths:
- try:
- zf.write(file_path, arcname=arcname)
- except Exception as e:
- logger.warning(f"[邮件] 写入ZIP失败: {e}")
+def send_batch_task_complete_email(
+ user_id: int,
+ email: str,
+ username: str,
+ schedule_name: str,
+ browse_type: str,
+ screenshots: List[Dict[str, Any]]
+) -> Dict[str, Any]:
+ """
+ 发送批次任务完成通知邮件(多账号截图打包)
- zip_size = os.path.getsize(zip_path) if zip_path and os.path.exists(zip_path) else 0
- if zip_size <= 0:
- attachment_note = "本次无可用截图文件(可能截图失败或文件不存在)。"
- elif zip_size > MAX_ATTACHMENT_SIZE:
- attachment_note = f"截图打包文件过大({zip_size} bytes),本次不附加附件。"
- else:
- with open(zip_path, 'rb') as f:
- zip_data = f.read()
- zip_filename = f"screenshots_{datetime.now(BEIJING_TZ).strftime('%Y%m%d_%H%M%S')}.zip"
- attachment_note = "截图已打包为ZIP附件,请查收。"
- except Exception as e:
- logger.warning(f"[邮件] 打包截图失败: {e}")
- attachment_note = "截图打包失败,本次不附加附件。"
- finally:
- if zip_path and os.path.exists(zip_path):
- try:
- os.remove(zip_path)
- except Exception:
- pass
- else:
- attachment_note = "本次无可用截图文件(可能截图失败或未启用截图)。"
+ Args:
+ user_id: 用户ID
+ email: 收件人邮箱
+ username: 用户名
+ schedule_name: 定时任务名称
+ browse_type: 浏览类型
+ screenshots: 截图列表 [{'account_name': x, 'path': y, 'items': n, 'attachments': m}, ...]
+
+ Returns:
+ {'success': bool, 'error': str}
+ """
+ precheck_error = _task_notify_precheck(email, require_screenshots=True, screenshots=screenshots)
+ if precheck_error:
+ return {'success': False, 'error': precheck_error}
+
+ complete_time = get_beijing_now_str()
+ total_items_sum, total_attachments_sum, account_count = _summarize_batch_screenshots(screenshots)
+ accounts_html = _build_batch_accounts_html_rows(screenshots)
+
+ html_content = _render_batch_task_complete_html(
+ username=username,
+ schedule_name=schedule_name,
+ browse_type=browse_type,
+ complete_time=complete_time,
+ account_count=account_count,
+ total_items_sum=total_items_sum,
+ total_attachments_sum=total_attachments_sum,
+ accounts_html=accounts_html,
+ )
+
+ screenshot_paths = _collect_existing_screenshot_paths(screenshots)
+ zip_data, zip_filename, attachment_note = _build_zip_attachment_from_paths(screenshot_paths)
# 将附件说明写入邮件内容
html_content = html_content.replace("截图已打包为ZIP附件,请查收。", attachment_note)
@@ -2267,7 +2284,7 @@ def send_batch_task_complete_email(
'mime_type': 'application/zip'
})
- result = send_email(
+ ok, error = _send_email_and_collect(
to_email=email,
subject=f'【自动化学习】定时任务完成 - {schedule_name}',
body='',
@@ -2277,10 +2294,9 @@ def send_batch_task_complete_email(
user_id=user_id,
)
- if result['success']:
+ if ok:
return {'success': True}
- else:
- return {'success': False, 'error': result.get('error', '发送失败')}
+ return {'success': False, 'error': error or '发送失败'}
def send_batch_task_complete_email_async(
diff --git a/realtime/status_push.py b/realtime/status_push.py
index 665dfdf..481d672 100644
--- a/realtime/status_push.py
+++ b/realtime/status_push.py
@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
+import json
import os
import time
@@ -9,8 +10,40 @@ from services.runtime import get_logger, get_socketio
from services.state import safe_get_account, safe_iter_task_status_items
+def _to_int(value, default: int = 0) -> int:
+ try:
+ return int(value)
+ except Exception:
+ return int(default)
+
+
+def _payload_signature(payload: dict) -> str:
+ try:
+ return json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"), default=str)
+ except Exception:
+ return repr(payload)
+
+
+def _should_emit(
+ *,
+ last_sig: str | None,
+ last_ts: float,
+ new_sig: str,
+ now_ts: float,
+ min_interval: float,
+ force_interval: float,
+) -> bool:
+ if last_sig is None:
+ return True
+ if (now_ts - last_ts) >= force_interval:
+ return True
+ if new_sig != last_sig and (now_ts - last_ts) >= min_interval:
+ return True
+ return False
+
+
def status_push_worker() -> None:
- """后台线程:按间隔推送排队/运行中任务的状态更新(可节流)。"""
+ """后台线程:按间隔推送排队/运行中任务状态(变更驱动+心跳兜底)。"""
logger = get_logger()
try:
push_interval = float(os.environ.get("STATUS_PUSH_INTERVAL_SECONDS", "1"))
@@ -18,18 +51,41 @@ def status_push_worker() -> None:
push_interval = 1.0
push_interval = max(0.5, push_interval)
+ try:
+ queue_min_interval = float(os.environ.get("STATUS_PUSH_MIN_QUEUE_INTERVAL_SECONDS", str(push_interval)))
+ except Exception:
+ queue_min_interval = push_interval
+ queue_min_interval = max(push_interval, queue_min_interval)
+
+ try:
+ progress_min_interval = float(
+ os.environ.get("STATUS_PUSH_MIN_PROGRESS_INTERVAL_SECONDS", str(push_interval))
+ )
+ except Exception:
+ progress_min_interval = push_interval
+ progress_min_interval = max(push_interval, progress_min_interval)
+
+ try:
+ force_interval = float(os.environ.get("STATUS_PUSH_FORCE_INTERVAL_SECONDS", "10"))
+ except Exception:
+ force_interval = 10.0
+ force_interval = max(push_interval, force_interval)
+
socketio = get_socketio()
from services.tasks import get_task_scheduler
scheduler = get_task_scheduler()
+ emitted_state: dict[str, dict] = {}
while True:
try:
+ now_ts = time.time()
queue_snapshot = scheduler.get_queue_state_snapshot()
pending_total = int(queue_snapshot.get("pending_total", 0) or 0)
running_total = int(queue_snapshot.get("running_total", 0) or 0)
running_by_user = queue_snapshot.get("running_by_user") or {}
positions = queue_snapshot.get("positions") or {}
+ active_account_ids = set()
status_items = safe_iter_task_status_items()
for account_id, status_info in status_items:
@@ -39,11 +95,15 @@ def status_push_worker() -> None:
user_id = status_info.get("user_id")
if not user_id:
continue
+
+ active_account_ids.add(str(account_id))
account = safe_get_account(user_id, account_id)
if not account:
continue
+
+ user_id_int = _to_int(user_id)
account_data = account.to_dict()
- pos = positions.get(account_id) or {}
+ pos = positions.get(account_id) or positions.get(str(account_id)) or {}
account_data.update(
{
"queue_pending_total": pending_total,
@@ -51,10 +111,23 @@ def status_push_worker() -> None:
"queue_ahead": pos.get("queue_ahead"),
"queue_position": pos.get("queue_position"),
"queue_is_vip": pos.get("is_vip"),
- "queue_running_user": int(running_by_user.get(int(user_id), 0) or 0),
+ "queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))),
}
)
- socketio.emit("account_update", account_data, room=f"user_{user_id}")
+
+ cache_entry = emitted_state.setdefault(str(account_id), {})
+ account_sig = _payload_signature(account_data)
+ if _should_emit(
+ last_sig=cache_entry.get("account_sig"),
+ last_ts=float(cache_entry.get("account_ts", 0) or 0),
+ new_sig=account_sig,
+ now_ts=now_ts,
+ min_interval=queue_min_interval,
+ force_interval=force_interval,
+ ):
+ socketio.emit("account_update", account_data, room=f"user_{user_id}")
+ cache_entry["account_sig"] = account_sig
+ cache_entry["account_ts"] = now_ts
if status != "运行中":
continue
@@ -74,9 +147,26 @@ def status_push_worker() -> None:
"queue_running_total": running_total,
"queue_ahead": pos.get("queue_ahead"),
"queue_position": pos.get("queue_position"),
- "queue_running_user": int(running_by_user.get(int(user_id), 0) or 0),
+ "queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))),
}
- socketio.emit("task_progress", progress_data, room=f"user_{user_id}")
+
+ progress_sig = _payload_signature(progress_data)
+ if _should_emit(
+ last_sig=cache_entry.get("progress_sig"),
+ last_ts=float(cache_entry.get("progress_ts", 0) or 0),
+ new_sig=progress_sig,
+ now_ts=now_ts,
+ min_interval=progress_min_interval,
+ force_interval=force_interval,
+ ):
+ socketio.emit("task_progress", progress_data, room=f"user_{user_id}")
+ cache_entry["progress_sig"] = progress_sig
+ cache_entry["progress_ts"] = now_ts
+
+ if emitted_state:
+ stale_ids = [account_id for account_id in emitted_state.keys() if account_id not in active_account_ids]
+ for account_id in stale_ids:
+ emitted_state.pop(account_id, None)
time.sleep(push_interval)
except Exception as e:
diff --git a/routes/admin_api/__init__.py b/routes/admin_api/__init__.py
index ea19e72..32cd44b 100644
--- a/routes/admin_api/__init__.py
+++ b/routes/admin_api/__init__.py
@@ -8,6 +8,15 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api")
# Import side effects: register routes on blueprint
from routes.admin_api import core as _core # noqa: F401
+from routes.admin_api import system_config_api as _system_config_api # noqa: F401
+from routes.admin_api import operations_api as _operations_api # noqa: F401
+from routes.admin_api import announcements_api as _announcements_api # noqa: F401
+from routes.admin_api import users_api as _users_api # noqa: F401
+from routes.admin_api import account_api as _account_api # noqa: F401
+from routes.admin_api import feedback_api as _feedback_api # noqa: F401
+from routes.admin_api import infra_api as _infra_api # noqa: F401
+from routes.admin_api import tasks_api as _tasks_api # noqa: F401
+from routes.admin_api import email_api as _email_api # noqa: F401
# Export security blueprint for app registration
from routes.admin_api.security import security_bp # noqa: F401
diff --git a/routes/admin_api/account_api.py b/routes/admin_api/account_api.py
new file mode 100644
index 0000000..6a7d3ba
--- /dev/null
+++ b/routes/admin_api/account_api.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import database
+from app_security import validate_password
+from flask import jsonify, request, session
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+
+# ==================== 密码重置 / 反馈(管理员) ====================
+
+
+@admin_api_bp.route("/admin/password", methods=["PUT"])
+@admin_required
+def update_admin_password():
+ """修改管理员密码"""
+ data = request.json or {}
+ new_password = (data.get("new_password") or "").strip()
+
+ if not new_password:
+ return jsonify({"error": "密码不能为空"}), 400
+
+ username = session.get("admin_username")
+ if database.update_admin_password(username, new_password):
+ return jsonify({"success": True})
+ return jsonify({"error": "修改失败"}), 400
+
+
+@admin_api_bp.route("/admin/username", methods=["PUT"])
+@admin_required
+def update_admin_username():
+ """修改管理员用户名"""
+ data = request.json or {}
+ new_username = (data.get("new_username") or "").strip()
+
+ if not new_username:
+ return jsonify({"error": "用户名不能为空"}), 400
+
+ old_username = session.get("admin_username")
+ if database.update_admin_username(old_username, new_username):
+ session["admin_username"] = new_username
+ return jsonify({"success": True})
+ return jsonify({"error": "修改失败,用户名可能已存在"}), 400
+
+
+@admin_api_bp.route("/users//reset_password", methods=["POST"])
+@admin_required
+def admin_reset_password_route(user_id):
+ """管理员直接重置用户密码(无需审核)"""
+ data = request.json or {}
+ new_password = (data.get("new_password") or "").strip()
+
+ if not new_password:
+ return jsonify({"error": "新密码不能为空"}), 400
+
+ is_valid, error_msg = validate_password(new_password)
+ if not is_valid:
+ return jsonify({"error": error_msg}), 400
+
+ if database.admin_reset_user_password(user_id, new_password):
+ return jsonify({"message": "密码重置成功"})
+ return jsonify({"error": "重置失败,用户不存在"}), 400
diff --git a/routes/admin_api/announcements_api.py b/routes/admin_api/announcements_api.py
new file mode 100644
index 0000000..70947f4
--- /dev/null
+++ b/routes/admin_api/announcements_api.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import os
+import posixpath
+import secrets
+import time
+
+import database
+from app_config import get_config
+from app_logger import get_logger
+from app_security import is_safe_path, sanitize_filename
+from flask import current_app, jsonify, request, url_for
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+
+logger = get_logger("app")
+config = get_config()
+
+
+def _get_upload_dir():
+ rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements")
+ if not is_safe_path(current_app.root_path, rel_dir):
+ rel_dir = "static/announcements"
+ abs_dir = os.path.join(current_app.root_path, rel_dir)
+ os.makedirs(abs_dir, exist_ok=True)
+ return abs_dir, rel_dir
+
+
+def _get_file_size(file_storage):
+ try:
+ file_storage.stream.seek(0, os.SEEK_END)
+ size = file_storage.stream.tell()
+ file_storage.stream.seek(0)
+ return size
+ except Exception:
+ return None
+
+
+# ==================== 公告管理API(管理员) ====================
+
+
+@admin_api_bp.route("/announcements/upload_image", methods=["POST"])
+@admin_required
+def admin_upload_announcement_image():
+ """上传公告图片(返回可访问URL)"""
+ file = request.files.get("file")
+ if not file or not file.filename:
+ return jsonify({"error": "请选择图片"}), 400
+
+ filename = sanitize_filename(file.filename)
+ ext = os.path.splitext(filename)[1].lower()
+ allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"})
+ if not ext or ext not in allowed_exts:
+ return jsonify({"error": "不支持的图片格式"}), 400
+ if file.mimetype and not str(file.mimetype).startswith("image/"):
+ return jsonify({"error": "文件类型无效"}), 400
+
+ size = _get_file_size(file)
+ max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024))
+ if size is not None and size > max_size:
+ max_mb = max_size // 1024 // 1024
+ return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400
+
+ abs_dir, rel_dir = _get_upload_dir()
+ token = secrets.token_hex(6)
+ name = f"announcement_{int(time.time())}_{token}{ext}"
+ save_path = os.path.join(abs_dir, name)
+ file.save(save_path)
+
+ static_root = os.path.join(current_app.root_path, "static")
+ rel_to_static = os.path.relpath(abs_dir, static_root)
+ if rel_to_static.startswith(".."):
+ rel_to_static = "announcements"
+ url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name)
+ return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)})
+
+
+@admin_api_bp.route("/announcements", methods=["GET"])
+@admin_required
+def admin_get_announcements():
+ """获取公告列表"""
+ try:
+ limit = int(request.args.get("limit", 50))
+ offset = int(request.args.get("offset", 0))
+ except (TypeError, ValueError):
+ limit, offset = 50, 0
+
+ limit = max(1, min(200, limit))
+ offset = max(0, offset)
+ return jsonify(database.get_announcements(limit=limit, offset=offset))
+
+
+@admin_api_bp.route("/announcements", methods=["POST"])
+@admin_required
+def admin_create_announcement():
+ """创建公告(默认启用并替换旧公告)"""
+ data = request.json or {}
+ title = (data.get("title") or "").strip()
+ content = (data.get("content") or "").strip()
+ image_url = (data.get("image_url") or "").strip()
+ is_active = bool(data.get("is_active", True))
+
+ if image_url and len(image_url) > 1000:
+ return jsonify({"error": "图片地址过长"}), 400
+
+ announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active)
+ if not announcement_id:
+ return jsonify({"error": "标题和内容不能为空"}), 400
+
+ return jsonify({"success": True, "id": announcement_id})
+
+
+@admin_api_bp.route("/announcements//activate", methods=["POST"])
+@admin_required
+def admin_activate_announcement(announcement_id):
+ """启用公告(会自动停用其他公告)"""
+ if not database.get_announcement_by_id(announcement_id):
+ return jsonify({"error": "公告不存在"}), 404
+ ok = database.set_announcement_active(announcement_id, True)
+ return jsonify({"success": ok})
+
+
+@admin_api_bp.route("/announcements//deactivate", methods=["POST"])
+@admin_required
+def admin_deactivate_announcement(announcement_id):
+ """停用公告"""
+ if not database.get_announcement_by_id(announcement_id):
+ return jsonify({"error": "公告不存在"}), 404
+ ok = database.set_announcement_active(announcement_id, False)
+ return jsonify({"success": ok})
+
+
+@admin_api_bp.route("/announcements/", methods=["DELETE"])
+@admin_required
+def admin_delete_announcement(announcement_id):
+ """删除公告"""
+ if not database.get_announcement_by_id(announcement_id):
+ return jsonify({"error": "公告不存在"}), 404
+ ok = database.delete_announcement(announcement_id)
+ return jsonify({"success": ok})
+
+
diff --git a/routes/admin_api/core.py b/routes/admin_api/core.py
index a744c4d..3449169 100644
--- a/routes/admin_api/core.py
+++ b/routes/admin_api/core.py
@@ -3,39 +3,19 @@
from __future__ import annotations
import os
-import posixpath
-import secrets
-import threading
import time
-from datetime import datetime
import database
-import email_service
-import requests
from app_config import get_config
from app_logger import get_logger
-from app_security import (
- get_rate_limit_ip,
- is_safe_outbound_url,
- is_safe_path,
- require_ip_not_locked,
- sanitize_filename,
- validate_email,
- validate_password,
-)
+from app_security import get_rate_limit_ip, require_ip_not_locked
from flask import current_app, jsonify, redirect, request, session, url_for
from routes.admin_api import admin_api_bp
from routes.decorators import admin_required
from services.accounts_service import load_user_accounts
-from services.browse_types import BROWSE_TYPE_SHOULD_READ, validate_browse_type
from services.checkpoints import get_checkpoint_mgr
-from services.scheduler import run_scheduled_task
from services.state import (
- safe_clear_user_logs,
- safe_get_account,
safe_get_user_accounts_snapshot,
- safe_iter_task_status_items,
- safe_remove_user_accounts,
safe_verify_and_consume_captcha,
check_login_ip_user_locked,
check_login_rate_limits,
@@ -46,42 +26,11 @@ from services.state import (
clear_login_failures,
record_login_failure,
)
-from services.tasks import get_task_scheduler, submit_account_task
-from services.time_utils import BEIJING_TZ, get_beijing_now
+from services.tasks import submit_account_task
logger = get_logger("app")
config = get_config()
-_server_cpu_percent_lock = threading.Lock()
-_server_cpu_percent_last: float | None = None
-_server_cpu_percent_last_ts = 0.0
-
-
-def _get_server_cpu_percent() -> float:
- import psutil
-
- global _server_cpu_percent_last, _server_cpu_percent_last_ts
-
- now = time.time()
- with _server_cpu_percent_lock:
- if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5:
- return _server_cpu_percent_last
-
- try:
- if _server_cpu_percent_last is None:
- cpu_percent = float(psutil.cpu_percent(interval=0.1))
- else:
- cpu_percent = float(psutil.cpu_percent(interval=None))
- except Exception:
- cpu_percent = float(_server_cpu_percent_last or 0.0)
-
- if cpu_percent < 0:
- cpu_percent = 0.0
-
- _server_cpu_percent_last = cpu_percent
- _server_cpu_percent_last_ts = now
- return cpu_percent
-
def _admin_reauth_required() -> bool:
try:
@@ -95,25 +44,6 @@ def _require_admin_reauth():
return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401
return None
-def _get_upload_dir():
- rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements")
- if not is_safe_path(current_app.root_path, rel_dir):
- rel_dir = "static/announcements"
- abs_dir = os.path.join(current_app.root_path, rel_dir)
- os.makedirs(abs_dir, exist_ok=True)
- return abs_dir, rel_dir
-
-
-def _get_file_size(file_storage):
- try:
- file_storage.stream.seek(0, os.SEEK_END)
- size = file_storage.stream.tell()
- file_storage.stream.seek(0)
- return size
- except Exception:
- return None
-
-
@admin_api_bp.route("/debug-config", methods=["GET"])
@admin_required
def debug_config():
@@ -247,948 +177,6 @@ def admin_reauth():
session.modified = True
return jsonify({"success": True, "expires_in": int(config.ADMIN_REAUTH_WINDOW_SECONDS)})
-
-# ==================== 公告管理API(管理员) ====================
-
-
-@admin_api_bp.route("/announcements/upload_image", methods=["POST"])
-@admin_required
-def admin_upload_announcement_image():
- """上传公告图片(返回可访问URL)"""
- file = request.files.get("file")
- if not file or not file.filename:
- return jsonify({"error": "请选择图片"}), 400
-
- filename = sanitize_filename(file.filename)
- ext = os.path.splitext(filename)[1].lower()
- allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"})
- if not ext or ext not in allowed_exts:
- return jsonify({"error": "不支持的图片格式"}), 400
- if file.mimetype and not str(file.mimetype).startswith("image/"):
- return jsonify({"error": "文件类型无效"}), 400
-
- size = _get_file_size(file)
- max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024))
- if size is not None and size > max_size:
- max_mb = max_size // 1024 // 1024
- return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400
-
- abs_dir, rel_dir = _get_upload_dir()
- token = secrets.token_hex(6)
- name = f"announcement_{int(time.time())}_{token}{ext}"
- save_path = os.path.join(abs_dir, name)
- file.save(save_path)
-
- static_root = os.path.join(current_app.root_path, "static")
- rel_to_static = os.path.relpath(abs_dir, static_root)
- if rel_to_static.startswith(".."):
- rel_to_static = "announcements"
- url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name)
- return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)})
-
-
-@admin_api_bp.route("/announcements", methods=["GET"])
-@admin_required
-def admin_get_announcements():
- """获取公告列表"""
- try:
- limit = int(request.args.get("limit", 50))
- offset = int(request.args.get("offset", 0))
- except (TypeError, ValueError):
- limit, offset = 50, 0
-
- limit = max(1, min(200, limit))
- offset = max(0, offset)
- return jsonify(database.get_announcements(limit=limit, offset=offset))
-
-
-@admin_api_bp.route("/announcements", methods=["POST"])
-@admin_required
-def admin_create_announcement():
- """创建公告(默认启用并替换旧公告)"""
- data = request.json or {}
- title = (data.get("title") or "").strip()
- content = (data.get("content") or "").strip()
- image_url = (data.get("image_url") or "").strip()
- is_active = bool(data.get("is_active", True))
-
- if image_url and len(image_url) > 1000:
- return jsonify({"error": "图片地址过长"}), 400
-
- announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active)
- if not announcement_id:
- return jsonify({"error": "标题和内容不能为空"}), 400
-
- return jsonify({"success": True, "id": announcement_id})
-
-
-@admin_api_bp.route("/announcements//activate", methods=["POST"])
-@admin_required
-def admin_activate_announcement(announcement_id):
- """启用公告(会自动停用其他公告)"""
- if not database.get_announcement_by_id(announcement_id):
- return jsonify({"error": "公告不存在"}), 404
- ok = database.set_announcement_active(announcement_id, True)
- return jsonify({"success": ok})
-
-
-@admin_api_bp.route("/announcements//deactivate", methods=["POST"])
-@admin_required
-def admin_deactivate_announcement(announcement_id):
- """停用公告"""
- if not database.get_announcement_by_id(announcement_id):
- return jsonify({"error": "公告不存在"}), 404
- ok = database.set_announcement_active(announcement_id, False)
- return jsonify({"success": ok})
-
-
-@admin_api_bp.route("/announcements/", methods=["DELETE"])
-@admin_required
-def admin_delete_announcement(announcement_id):
- """删除公告"""
- if not database.get_announcement_by_id(announcement_id):
- return jsonify({"error": "公告不存在"}), 404
- ok = database.delete_announcement(announcement_id)
- return jsonify({"success": ok})
-
-
-# ==================== 用户管理/统计(管理员) ====================
-
-
-@admin_api_bp.route("/users", methods=["GET"])
-@admin_required
-def get_all_users():
- """获取所有用户"""
- users = database.get_all_users()
- return jsonify(users)
-
-
-@admin_api_bp.route("/users/pending", methods=["GET"])
-@admin_required
-def get_pending_users():
- """获取待审核用户"""
- users = database.get_pending_users()
- return jsonify(users)
-
-
-@admin_api_bp.route("/users//approve", methods=["POST"])
-@admin_required
-def approve_user_route(user_id):
- """审核通过用户"""
- if database.approve_user(user_id):
- return jsonify({"success": True})
- return jsonify({"error": "审核失败"}), 400
-
-
-@admin_api_bp.route("/users//reject", methods=["POST"])
-@admin_required
-def reject_user_route(user_id):
- """拒绝用户"""
- if database.reject_user(user_id):
- return jsonify({"success": True})
- return jsonify({"error": "拒绝失败"}), 400
-
-
-@admin_api_bp.route("/users/", methods=["DELETE"])
-@admin_required
-def delete_user_route(user_id):
- """删除用户"""
- if database.delete_user(user_id):
- safe_remove_user_accounts(user_id)
- safe_clear_user_logs(user_id)
- return jsonify({"success": True})
- return jsonify({"error": "删除失败"}), 400
-
-
-@admin_api_bp.route("/stats", methods=["GET"])
-@admin_required
-def get_system_stats():
- """获取系统统计"""
- stats = database.get_system_stats()
- stats["admin_username"] = session.get("admin_username", "admin")
- return jsonify(stats)
-
-
-@admin_api_bp.route("/browser_pool/stats", methods=["GET"])
-@admin_required
-def get_browser_pool_stats():
- """获取截图线程池状态"""
- try:
- from browser_pool_worker import get_browser_worker_pool
-
- pool = get_browser_worker_pool()
- stats = pool.get_stats() or {}
-
- worker_details = []
- for w in stats.get("workers") or []:
- last_ts = float(w.get("last_active_ts") or 0)
- last_active_at = None
- if last_ts > 0:
- try:
- last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
- except Exception:
- last_active_at = None
-
- created_ts = w.get("browser_created_at")
- created_at = None
- if created_ts:
- try:
- created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
- except Exception:
- created_at = None
-
- worker_details.append(
- {
- "worker_id": w.get("worker_id"),
- "idle": bool(w.get("idle")),
- "has_browser": bool(w.get("has_browser")),
- "total_tasks": int(w.get("total_tasks") or 0),
- "failed_tasks": int(w.get("failed_tasks") or 0),
- "browser_use_count": int(w.get("browser_use_count") or 0),
- "browser_created_at": created_at,
- "browser_created_ts": created_ts,
- "last_active_at": last_active_at,
- "last_active_ts": last_ts,
- "thread_alive": bool(w.get("thread_alive")),
- }
- )
-
- total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0)
- return jsonify(
- {
- "total_workers": total_workers,
- "active_workers": int(stats.get("busy_workers") or 0),
- "idle_workers": int(stats.get("idle_workers") or 0),
- "queue_size": int(stats.get("queue_size") or 0),
- "workers": worker_details,
- "summary": {
- "total_tasks": int(stats.get("total_tasks") or 0),
- "failed_tasks": int(stats.get("failed_tasks") or 0),
- "success_rate": stats.get("success_rate"),
- },
- "server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
- }
- )
- except Exception as e:
- logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}")
- return jsonify({"error": "获取截图线程池状态失败"}), 500
-
-
-@admin_api_bp.route("/docker_stats", methods=["GET"])
-@admin_required
-def get_docker_stats():
- """获取Docker容器运行状态"""
- import subprocess
-
- docker_status = {
- "running": False,
- "container_name": "N/A",
- "uptime": "N/A",
- "memory_usage": "N/A",
- "memory_limit": "N/A",
- "memory_percent": "N/A",
- "cpu_percent": "N/A",
- "status": "Unknown",
- }
-
- try:
- if os.path.exists("/.dockerenv"):
- docker_status["running"] = True
-
- try:
- with open("/etc/hostname", "r") as f:
- docker_status["container_name"] = f.read().strip()
- except Exception as e:
- logger.debug(f"读取容器名称失败: {e}")
-
- try:
- if os.path.exists("/sys/fs/cgroup/memory.current"):
- with open("/sys/fs/cgroup/memory.current", "r") as f:
- mem_total = int(f.read().strip())
-
- cache = 0
- if os.path.exists("/sys/fs/cgroup/memory.stat"):
- with open("/sys/fs/cgroup/memory.stat", "r") as f:
- for line in f:
- if line.startswith("inactive_file "):
- cache = int(line.split()[1])
- break
-
- mem_bytes = mem_total - cache
- docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
-
- if os.path.exists("/sys/fs/cgroup/memory.max"):
- with open("/sys/fs/cgroup/memory.max", "r") as f:
- limit_str = f.read().strip()
- if limit_str != "max":
- limit_bytes = int(limit_str)
- docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
- docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
- elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"):
- with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f:
- mem_bytes = int(f.read().strip())
- docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
-
- with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f:
- limit_bytes = int(f.read().strip())
- if limit_bytes < 1e18:
- docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
- docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
- except Exception as e:
- logger.debug(f"读取内存信息失败: {e}")
-
- try:
- if os.path.exists("/sys/fs/cgroup/cpu.stat"):
- cpu_usage = 0
- with open("/sys/fs/cgroup/cpu.stat", "r") as f:
- for line in f:
- if line.startswith("usage_usec"):
- cpu_usage = int(line.split()[1])
- break
-
- time.sleep(0.1)
- cpu_usage2 = 0
- with open("/sys/fs/cgroup/cpu.stat", "r") as f:
- for line in f:
- if line.startswith("usage_usec"):
- cpu_usage2 = int(line.split()[1])
- break
-
- cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e6 * 100
- docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
- elif os.path.exists("/sys/fs/cgroup/cpu/cpuacct.usage"):
- with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
- cpu_usage = int(f.read().strip())
-
- time.sleep(0.1)
- with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
- cpu_usage2 = int(f.read().strip())
-
- cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e9 * 100
- docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
- except Exception as e:
- logger.debug(f"读取CPU信息失败: {e}")
-
- try:
- # 读取系统运行时间
- with open('/proc/uptime', 'r') as f:
- system_uptime = float(f.read().split()[0])
-
- # 读取 PID 1 的启动时间 (jiffies)
- with open('/proc/1/stat', 'r') as f:
- stat = f.read().split()
- starttime_jiffies = int(stat[21])
-
- # 获取 CLK_TCK (通常是 100)
- clk_tck = os.sysconf(os.sysconf_names['SC_CLK_TCK'])
-
- # 计算容器运行时长(秒)
- container_uptime_seconds = system_uptime - (starttime_jiffies / clk_tck)
-
- # 格式化为可读字符串
- days = int(container_uptime_seconds // 86400)
- hours = int((container_uptime_seconds % 86400) // 3600)
- minutes = int((container_uptime_seconds % 3600) // 60)
-
- if days > 0:
- docker_status["uptime"] = f"{days}天{hours}小时{minutes}分钟"
- elif hours > 0:
- docker_status["uptime"] = f"{hours}小时{minutes}分钟"
- else:
- docker_status["uptime"] = f"{minutes}分钟"
- except Exception as e:
- logger.debug(f"获取容器运行时间失败: {e}")
-
- docker_status["status"] = "Running"
-
- else:
- docker_status["status"] = "Not in Docker"
- except Exception as e:
- docker_status["status"] = f"Error: {str(e)}"
-
- return jsonify(docker_status)
-
-
-# ==================== VIP 管理(管理员) ====================
-
-
-@admin_api_bp.route("/vip/config", methods=["GET"])
-@admin_required
-def get_vip_config_api():
- """获取VIP配置"""
- config = database.get_vip_config()
- return jsonify(config)
-
-
-@admin_api_bp.route("/vip/config", methods=["POST"])
-@admin_required
-def set_vip_config_api():
- """设置默认VIP天数"""
- data = request.json or {}
- days = data.get("default_vip_days", 0)
-
- if not isinstance(days, int) or days < 0:
- return jsonify({"error": "VIP天数必须是非负整数"}), 400
-
- database.set_default_vip_days(days)
- return jsonify({"message": "VIP配置已更新", "default_vip_days": days})
-
-
-@admin_api_bp.route("/users//vip", methods=["POST"])
-@admin_required
-def set_user_vip_api(user_id):
- """设置用户VIP"""
- data = request.json or {}
- days = data.get("days", 30)
-
- valid_days = [7, 30, 365, 999999]
- if days not in valid_days:
- return jsonify({"error": "VIP天数必须是 7/30/365/999999 之一"}), 400
-
- if database.set_user_vip(user_id, days):
- vip_type = {7: "一周", 30: "一个月", 365: "一年", 999999: "永久"}[days]
- return jsonify({"message": f"VIP设置成功: {vip_type}"})
- return jsonify({"error": "设置失败,用户不存在"}), 400
-
-
-@admin_api_bp.route("/users//vip", methods=["DELETE"])
-@admin_required
-def remove_user_vip_api(user_id):
- """移除用户VIP"""
- if database.remove_user_vip(user_id):
- return jsonify({"message": "VIP已移除"})
- return jsonify({"error": "移除失败"}), 400
-
-
-@admin_api_bp.route("/users//vip", methods=["GET"])
-@admin_required
-def get_user_vip_info_api(user_id):
- """获取用户VIP信息(管理员)"""
- vip_info = database.get_user_vip_info(user_id)
- return jsonify(vip_info)
-
-
-# ==================== 系统配置 / 定时 / 代理(管理员) ====================
-
-
-@admin_api_bp.route("/system/config", methods=["GET"])
-@admin_required
-def get_system_config_api():
- """获取系统配置"""
- return jsonify(database.get_system_config())
-
-
-@admin_api_bp.route("/system/config", methods=["POST"])
-@admin_required
-def update_system_config_api():
- """更新系统配置"""
- data = request.json or {}
-
- max_concurrent = data.get("max_concurrent_global")
- schedule_enabled = data.get("schedule_enabled")
- schedule_time = data.get("schedule_time")
- schedule_browse_type = data.get("schedule_browse_type")
- schedule_weekdays = data.get("schedule_weekdays")
- new_max_concurrent_per_account = data.get("max_concurrent_per_account")
- new_max_screenshot_concurrent = data.get("max_screenshot_concurrent")
- enable_screenshot = data.get("enable_screenshot")
- auto_approve_enabled = data.get("auto_approve_enabled")
- auto_approve_hourly_limit = data.get("auto_approve_hourly_limit")
- auto_approve_vip_days = data.get("auto_approve_vip_days")
- kdocs_enabled = data.get("kdocs_enabled")
- kdocs_doc_url = data.get("kdocs_doc_url")
- kdocs_default_unit = data.get("kdocs_default_unit")
- kdocs_sheet_name = data.get("kdocs_sheet_name")
- kdocs_sheet_index = data.get("kdocs_sheet_index")
- kdocs_unit_column = data.get("kdocs_unit_column")
- kdocs_image_column = data.get("kdocs_image_column")
- kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled")
- kdocs_admin_notify_email = data.get("kdocs_admin_notify_email")
- kdocs_row_start = data.get("kdocs_row_start")
- kdocs_row_end = data.get("kdocs_row_end")
-
- if max_concurrent is not None:
- if not isinstance(max_concurrent, int) or max_concurrent < 1:
- return jsonify({"error": "全局并发数必须大于0(建议:小型服务器2-5,中型5-10,大型10-20)"}), 400
-
- if new_max_concurrent_per_account is not None:
- if not isinstance(new_max_concurrent_per_account, int) or new_max_concurrent_per_account < 1:
- return jsonify({"error": "单账号并发数必须大于0(建议设为1,避免同一用户任务相互影响)"}), 400
-
- if new_max_screenshot_concurrent is not None:
- if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1:
- return jsonify({"error": "截图并发数必须大于0(建议根据服务器配置设置,wkhtmltoimage 资源占用较低)"}), 400
-
- if enable_screenshot is not None:
- if isinstance(enable_screenshot, bool):
- enable_screenshot = 1 if enable_screenshot else 0
- if enable_screenshot not in (0, 1):
- return jsonify({"error": "截图开关必须是0或1"}), 400
-
- if schedule_time is not None:
- import re
-
- if not re.match(r"^([01]\\d|2[0-3]):([0-5]\\d)$", schedule_time):
- return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400
-
- if schedule_browse_type is not None:
- normalized = validate_browse_type(schedule_browse_type, default=BROWSE_TYPE_SHOULD_READ)
- if not normalized:
- return jsonify({"error": "浏览类型无效"}), 400
- schedule_browse_type = normalized
-
- if schedule_weekdays is not None:
- try:
- days = [int(d.strip()) for d in schedule_weekdays.split(",") if d.strip()]
- if not all(1 <= d <= 7 for d in days):
- return jsonify({"error": "星期数字必须在1-7之间"}), 400
- except (ValueError, AttributeError):
- return jsonify({"error": "星期格式错误"}), 400
-
- if auto_approve_hourly_limit is not None:
- if not isinstance(auto_approve_hourly_limit, int) or auto_approve_hourly_limit < 1:
- return jsonify({"error": "每小时注册限制必须大于0"}), 400
-
- if auto_approve_vip_days is not None:
- if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0:
- return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400
-
- if kdocs_enabled is not None:
- if isinstance(kdocs_enabled, bool):
- kdocs_enabled = 1 if kdocs_enabled else 0
- if kdocs_enabled not in (0, 1):
- return jsonify({"error": "表格上传开关必须是0或1"}), 400
-
- if kdocs_doc_url is not None:
- kdocs_doc_url = str(kdocs_doc_url or "").strip()
- if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url):
- return jsonify({"error": "文档链接格式不正确"}), 400
-
- if kdocs_default_unit is not None:
- kdocs_default_unit = str(kdocs_default_unit or "").strip()
- if len(kdocs_default_unit) > 50:
- return jsonify({"error": "默认县区长度不能超过50"}), 400
-
- if kdocs_sheet_name is not None:
- kdocs_sheet_name = str(kdocs_sheet_name or "").strip()
- if len(kdocs_sheet_name) > 50:
- return jsonify({"error": "Sheet名称长度不能超过50"}), 400
-
- if kdocs_sheet_index is not None:
- try:
- kdocs_sheet_index = int(kdocs_sheet_index)
- except Exception:
- return jsonify({"error": "Sheet序号必须是数字"}), 400
- if kdocs_sheet_index < 0:
- return jsonify({"error": "Sheet序号不能为负数"}), 400
-
- if kdocs_unit_column is not None:
- kdocs_unit_column = str(kdocs_unit_column or "").strip().upper()
- if not kdocs_unit_column:
- return jsonify({"error": "县区列不能为空"}), 400
- import re
-
- if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column):
- return jsonify({"error": "县区列格式错误"}), 400
-
- if kdocs_image_column is not None:
- kdocs_image_column = str(kdocs_image_column or "").strip().upper()
- if not kdocs_image_column:
- return jsonify({"error": "图片列不能为空"}), 400
- import re
-
- if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column):
- return jsonify({"error": "图片列格式错误"}), 400
-
- if kdocs_admin_notify_enabled is not None:
- if isinstance(kdocs_admin_notify_enabled, bool):
- kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0
- if kdocs_admin_notify_enabled not in (0, 1):
- return jsonify({"error": "管理员通知开关必须是0或1"}), 400
-
- if kdocs_admin_notify_email is not None:
- kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip()
- if kdocs_admin_notify_email:
- is_valid, error_msg = validate_email(kdocs_admin_notify_email)
- if not is_valid:
- return jsonify({"error": error_msg}), 400
-
- if kdocs_row_start is not None:
- try:
- kdocs_row_start = int(kdocs_row_start)
- except (ValueError, TypeError):
- return jsonify({"error": "起始行必须是数字"}), 400
- if kdocs_row_start < 0:
- return jsonify({"error": "起始行不能为负数"}), 400
-
- if kdocs_row_end is not None:
- try:
- kdocs_row_end = int(kdocs_row_end)
- except (ValueError, TypeError):
- return jsonify({"error": "结束行必须是数字"}), 400
- if kdocs_row_end < 0:
- return jsonify({"error": "结束行不能为负数"}), 400
-
- old_config = database.get_system_config() or {}
-
- if not database.update_system_config(
- max_concurrent=max_concurrent,
- schedule_enabled=schedule_enabled,
- schedule_time=schedule_time,
- schedule_browse_type=schedule_browse_type,
- schedule_weekdays=schedule_weekdays,
- max_concurrent_per_account=new_max_concurrent_per_account,
- max_screenshot_concurrent=new_max_screenshot_concurrent,
- enable_screenshot=enable_screenshot,
- auto_approve_enabled=auto_approve_enabled,
- auto_approve_hourly_limit=auto_approve_hourly_limit,
- auto_approve_vip_days=auto_approve_vip_days,
- kdocs_enabled=kdocs_enabled,
- kdocs_doc_url=kdocs_doc_url,
- kdocs_default_unit=kdocs_default_unit,
- kdocs_sheet_name=kdocs_sheet_name,
- kdocs_sheet_index=kdocs_sheet_index,
- kdocs_unit_column=kdocs_unit_column,
- kdocs_image_column=kdocs_image_column,
- kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
- kdocs_admin_notify_email=kdocs_admin_notify_email,
- kdocs_row_start=kdocs_row_start,
- kdocs_row_end=kdocs_row_end,
- ):
- return jsonify({"error": "更新失败"}), 400
-
- try:
- new_config = database.get_system_config() or {}
- scheduler = get_task_scheduler()
- scheduler.update_limits(
- max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))),
- max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))),
- )
- if new_max_screenshot_concurrent is not None:
- try:
- from browser_pool_worker import resize_browser_worker_pool
-
- if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))):
- logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}")
- except Exception as pool_error:
- logger.warning(f"截图线程池并发更新失败: {pool_error}")
- except Exception:
- pass
-
- if max_concurrent is not None and max_concurrent != old_config.get("max_concurrent_global"):
- logger.info(f"全局并发数已更新为: {max_concurrent}")
- if new_max_concurrent_per_account is not None and new_max_concurrent_per_account != old_config.get("max_concurrent_per_account"):
- logger.info(f"单用户并发数已更新为: {new_max_concurrent_per_account}")
- if new_max_screenshot_concurrent is not None:
- logger.info(f"截图并发数已更新为: {new_max_screenshot_concurrent}")
-
- return jsonify({"message": "系统配置已更新"})
-
-
-@admin_api_bp.route("/kdocs/status", methods=["GET"])
-@admin_required
-def get_kdocs_status_api():
- """获取金山文档上传状态"""
- try:
- from services.kdocs_uploader import get_kdocs_uploader
-
- uploader = get_kdocs_uploader()
- status = uploader.get_status()
- live = str(request.args.get("live", "")).lower() in ("1", "true", "yes")
- if live:
- live_status = uploader.refresh_login_status()
- if live_status.get("success"):
- logged_in = bool(live_status.get("logged_in"))
- status["logged_in"] = logged_in
- status["last_login_ok"] = logged_in
- status["login_required"] = not logged_in
- if live_status.get("error"):
- status["last_error"] = live_status.get("error")
- else:
- status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None
- if status.get("last_login_ok") is True and status.get("last_error") == "操作超时":
- status["last_error"] = None
- return jsonify(status)
- except Exception as e:
- return jsonify({"error": f"获取状态失败: {e}"}), 500
-
-
-@admin_api_bp.route("/kdocs/qr", methods=["POST"])
-@admin_required
-def get_kdocs_qr_api():
- """获取金山文档登录二维码"""
- try:
- from services.kdocs_uploader import get_kdocs_uploader
-
- uploader = get_kdocs_uploader()
- data = request.get_json(silent=True) or {}
- force = bool(data.get("force"))
- if not force:
- force = str(request.args.get("force", "")).lower() in ("1", "true", "yes")
- result = uploader.request_qr(force=force)
- if not result.get("success"):
- return jsonify({"error": result.get("error", "获取二维码失败")}), 400
- return jsonify(result)
- except Exception as e:
- return jsonify({"error": f"获取二维码失败: {e}"}), 500
-
-
-@admin_api_bp.route("/kdocs/clear-login", methods=["POST"])
-@admin_required
-def clear_kdocs_login_api():
- """清除金山文档登录态"""
- try:
- from services.kdocs_uploader import get_kdocs_uploader
-
- uploader = get_kdocs_uploader()
- result = uploader.clear_login()
- if not result.get("success"):
- return jsonify({"error": result.get("error", "清除失败")}), 400
- return jsonify({"success": True})
- except Exception as e:
- return jsonify({"error": f"清除失败: {e}"}), 500
-
-
-@admin_api_bp.route("/schedule/execute", methods=["POST"])
-@admin_required
-def execute_schedule_now():
- """立即执行定时任务(无视定时时间和星期限制)"""
- try:
- threading.Thread(target=run_scheduled_task, args=(True,), daemon=True).start()
- logger.info("[立即执行定时任务] 管理员手动触发定时任务执行(跳过星期检查)")
- return jsonify({"message": "定时任务已开始执行,请查看任务列表获取进度"})
- except Exception as e:
- logger.error(f"[立即执行定时任务] 启动失败: {str(e)}")
- return jsonify({"error": f"启动失败: {str(e)}"}), 500
-
-
-@admin_api_bp.route("/proxy/config", methods=["GET"])
-@admin_required
-def get_proxy_config_api():
- """获取代理配置"""
- config_data = database.get_system_config()
- return jsonify(
- {
- "proxy_enabled": config_data.get("proxy_enabled", 0),
- "proxy_api_url": config_data.get("proxy_api_url", ""),
- "proxy_expire_minutes": config_data.get("proxy_expire_minutes", 3),
- }
- )
-
-
-@admin_api_bp.route("/proxy/config", methods=["POST"])
-@admin_required
-def update_proxy_config_api():
- """更新代理配置"""
- data = request.json or {}
- proxy_enabled = data.get("proxy_enabled")
- proxy_api_url = (data.get("proxy_api_url", "") or "").strip()
- proxy_expire_minutes = data.get("proxy_expire_minutes")
-
- if proxy_enabled is not None and proxy_enabled not in [0, 1]:
- return jsonify({"error": "proxy_enabled必须是0或1"}), 400
-
- if proxy_expire_minutes is not None:
- if not isinstance(proxy_expire_minutes, int) or proxy_expire_minutes < 1:
- return jsonify({"error": "代理有效期必须是大于0的整数"}), 400
-
- if database.update_system_config(
- proxy_enabled=proxy_enabled,
- proxy_api_url=proxy_api_url,
- proxy_expire_minutes=proxy_expire_minutes,
- ):
- return jsonify({"message": "代理配置已更新"})
- return jsonify({"error": "更新失败"}), 400
-
-
-@admin_api_bp.route("/proxy/test", methods=["POST"])
-@admin_required
-def test_proxy_api():
- """测试代理连接"""
- data = request.json or {}
- api_url = (data.get("api_url") or "").strip()
-
- if not api_url:
- return jsonify({"error": "请提供API地址"}), 400
-
- if not is_safe_outbound_url(api_url):
- return jsonify({"error": "API地址不可用或不安全"}), 400
-
- try:
- response = requests.get(api_url, timeout=10)
- if response.status_code == 200:
- ip_port = response.text.strip()
- if ip_port and ":" in ip_port:
- return jsonify({"success": True, "proxy": ip_port, "message": f"代理获取成功: {ip_port}"})
- return jsonify({"success": False, "message": f"代理格式错误: {ip_port}"}), 400
- return jsonify({"success": False, "message": f"HTTP错误: {response.status_code}"}), 400
- except Exception as e:
- return jsonify({"success": False, "message": f"连接失败: {str(e)}"}), 500
-
-
-@admin_api_bp.route("/server/info", methods=["GET"])
-@admin_required
-def get_server_info_api():
- """获取服务器信息"""
- import psutil
-
- cpu_percent = _get_server_cpu_percent()
-
- memory = psutil.virtual_memory()
- memory_total = f"{memory.total / (1024**3):.1f}GB"
- memory_used = f"{memory.used / (1024**3):.1f}GB"
- memory_percent = memory.percent
-
- disk = psutil.disk_usage("/")
- disk_total = f"{disk.total / (1024**3):.1f}GB"
- disk_used = f"{disk.used / (1024**3):.1f}GB"
- disk_percent = disk.percent
-
- boot_time = datetime.fromtimestamp(psutil.boot_time(), tz=BEIJING_TZ)
- uptime_delta = get_beijing_now() - boot_time
- days = uptime_delta.days
- hours = uptime_delta.seconds // 3600
- uptime = f"{days}天{hours}小时"
-
- return jsonify(
- {
- "cpu_percent": cpu_percent,
- "memory_total": memory_total,
- "memory_used": memory_used,
- "memory_percent": memory_percent,
- "disk_total": disk_total,
- "disk_used": disk_used,
- "disk_percent": disk_percent,
- "uptime": uptime,
- }
- )
-
-
-# ==================== 任务统计与日志(管理员) ====================
-
-
-@admin_api_bp.route("/task/stats", methods=["GET"])
-@admin_required
-def get_task_stats_api():
- """获取任务统计数据"""
- date_filter = request.args.get("date")
- stats = database.get_task_stats(date_filter)
- return jsonify(stats)
-
-
-@admin_api_bp.route("/task/running", methods=["GET"])
-@admin_required
-def get_running_tasks_api():
- """获取当前运行中和排队中的任务"""
- import time as time_mod
-
- current_time = time_mod.time()
- running = []
- queuing = []
-
- for account_id, info in safe_iter_task_status_items():
- elapsed = int(current_time - info.get("start_time", current_time))
-
- user = database.get_user_by_id(info.get("user_id"))
- user_username = user["username"] if user else "N/A"
-
- progress = info.get("progress", {"items": 0, "attachments": 0})
- task_info = {
- "account_id": account_id,
- "user_id": info.get("user_id"),
- "user_username": user_username,
- "username": info.get("username"),
- "browse_type": info.get("browse_type"),
- "source": info.get("source", "manual"),
- "detail_status": info.get("detail_status", "未知"),
- "progress_items": progress.get("items", 0),
- "progress_attachments": progress.get("attachments", 0),
- "elapsed_seconds": elapsed,
- "elapsed_display": f"{elapsed // 60}分{elapsed % 60}秒" if elapsed >= 60 else f"{elapsed}秒",
- }
-
- if info.get("status") == "运行中":
- running.append(task_info)
- else:
- queuing.append(task_info)
-
- running.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
- queuing.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
-
- try:
- max_concurrent = int(get_task_scheduler().max_global)
- except Exception:
- max_concurrent = int((database.get_system_config() or {}).get("max_concurrent_global", 2))
-
- return jsonify(
- {
- "running": running,
- "queuing": queuing,
- "running_count": len(running),
- "queuing_count": len(queuing),
- "max_concurrent": max_concurrent,
- }
- )
-
-
-@admin_api_bp.route("/task/logs", methods=["GET"])
-@admin_required
-def get_task_logs_api():
- """获取任务日志列表(支持分页和多种筛选)"""
- try:
- limit = int(request.args.get("limit", 20))
- limit = max(1, min(limit, 200)) # 限制 1-200 条
- except (ValueError, TypeError):
- limit = 20
-
- try:
- offset = int(request.args.get("offset", 0))
- offset = max(0, offset)
- except (ValueError, TypeError):
- offset = 0
-
- date_filter = request.args.get("date")
- status_filter = request.args.get("status")
- source_filter = request.args.get("source")
- user_id_filter = request.args.get("user_id")
- account_filter = (request.args.get("account") or "").strip()
-
- if user_id_filter:
- try:
- user_id_filter = int(user_id_filter)
- except (ValueError, TypeError):
- user_id_filter = None
-
- try:
- result = database.get_task_logs(
- limit=limit,
- offset=offset,
- date_filter=date_filter,
- status_filter=status_filter,
- source_filter=source_filter,
- user_id_filter=user_id_filter,
- account_filter=account_filter if account_filter else None,
- )
- return jsonify(result)
- except Exception as e:
- logger.error(f"获取任务日志失败: {e}")
- return jsonify({"logs": [], "total": 0, "error": "查询失败"})
-
-
-@admin_api_bp.route("/task/logs/clear", methods=["POST"])
-@admin_required
-def clear_old_task_logs_api():
- """清理旧的任务日志"""
- data = request.json or {}
- days = data.get("days", 30)
-
- if not isinstance(days, int) or days < 1:
- return jsonify({"error": "天数必须是大于0的整数"}), 400
-
- deleted_count = database.delete_old_task_logs(days)
- return jsonify({"message": f"已删除{days}天前的{deleted_count}条日志"})
-
-
@admin_api_bp.route("/docker/restart", methods=["POST"])
@admin_required
def restart_docker_container():
@@ -1216,7 +204,7 @@ os._exit(0)
f.write(restart_script)
subprocess.Popen(
- ["python", "/tmp/restart_container.py"],
+ ["python3", "/tmp/restart_container.py"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
@@ -1228,112 +216,6 @@ os._exit(0)
return jsonify({"error": f"重启失败: {str(e)}"}), 500
-# ==================== 密码重置 / 反馈(管理员) ====================
-
-
-@admin_api_bp.route("/admin/password", methods=["PUT"])
-@admin_required
-def update_admin_password():
- """修改管理员密码"""
- data = request.json or {}
- new_password = (data.get("new_password") or "").strip()
-
- if not new_password:
- return jsonify({"error": "密码不能为空"}), 400
-
- username = session.get("admin_username")
- if database.update_admin_password(username, new_password):
- return jsonify({"success": True})
- return jsonify({"error": "修改失败"}), 400
-
-
-@admin_api_bp.route("/admin/username", methods=["PUT"])
-@admin_required
-def update_admin_username():
- """修改管理员用户名"""
- data = request.json or {}
- new_username = (data.get("new_username") or "").strip()
-
- if not new_username:
- return jsonify({"error": "用户名不能为空"}), 400
-
- old_username = session.get("admin_username")
- if database.update_admin_username(old_username, new_username):
- session["admin_username"] = new_username
- return jsonify({"success": True})
- return jsonify({"error": "修改失败,用户名可能已存在"}), 400
-
-
-@admin_api_bp.route("/users//reset_password", methods=["POST"])
-@admin_required
-def admin_reset_password_route(user_id):
- """管理员直接重置用户密码(无需审核)"""
- data = request.json or {}
- new_password = (data.get("new_password") or "").strip()
-
- if not new_password:
- return jsonify({"error": "新密码不能为空"}), 400
-
- is_valid, error_msg = validate_password(new_password)
- if not is_valid:
- return jsonify({"error": error_msg}), 400
-
- if database.admin_reset_user_password(user_id, new_password):
- return jsonify({"message": "密码重置成功"})
- return jsonify({"error": "重置失败,用户不存在"}), 400
-
-
-@admin_api_bp.route("/feedbacks", methods=["GET"])
-@admin_required
-def get_all_feedbacks():
- """管理员获取所有反馈"""
- status = request.args.get("status")
- try:
- limit = int(request.args.get("limit", 100))
- offset = int(request.args.get("offset", 0))
- limit = min(max(1, limit), 1000)
- offset = max(0, offset)
- except (ValueError, TypeError):
- return jsonify({"error": "无效的分页参数"}), 400
-
- feedbacks = database.get_bug_feedbacks(limit=limit, offset=offset, status_filter=status)
- stats = database.get_feedback_stats()
- return jsonify({"feedbacks": feedbacks, "stats": stats})
-
-
-@admin_api_bp.route("/feedbacks//reply", methods=["POST"])
-@admin_required
-def reply_to_feedback(feedback_id):
- """管理员回复反馈"""
- data = request.get_json() or {}
- reply = (data.get("reply") or "").strip()
-
- if not reply:
- return jsonify({"error": "回复内容不能为空"}), 400
-
- if database.reply_feedback(feedback_id, reply):
- return jsonify({"message": "回复成功"})
- return jsonify({"error": "反馈不存在"}), 404
-
-
-@admin_api_bp.route("/feedbacks//close", methods=["POST"])
-@admin_required
-def close_feedback_api(feedback_id):
- """管理员关闭反馈"""
- if database.close_feedback(feedback_id):
- return jsonify({"message": "已关闭"})
- return jsonify({"error": "反馈不存在"}), 404
-
-
-@admin_api_bp.route("/feedbacks/", methods=["DELETE"])
-@admin_required
-def delete_feedback_api(feedback_id):
- """管理员删除反馈"""
- if database.delete_feedback(feedback_id):
- return jsonify({"message": "已删除"})
- return jsonify({"error": "反馈不存在"}), 404
-
-
# ==================== 断点续传(管理员) ====================
@@ -1388,208 +270,3 @@ def checkpoint_abandon(task_id):
return jsonify({"success": False}), 404
except Exception as e:
return jsonify({"success": False, "message": str(e)}), 500
-
-
-# ==================== 邮件服务(管理员) ====================
-
-
-@admin_api_bp.route("/email/settings", methods=["GET"])
-@admin_required
-def get_email_settings_api():
- """获取全局邮件设置"""
- try:
- settings = email_service.get_email_settings()
- return jsonify(settings)
- except Exception as e:
- logger.error(f"获取邮件设置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/email/settings", methods=["POST"])
-@admin_required
-def update_email_settings_api():
- """更新全局邮件设置"""
- try:
- data = request.json or {}
- enabled = data.get("enabled", False)
- failover_enabled = data.get("failover_enabled", True)
- register_verify_enabled = data.get("register_verify_enabled")
- login_alert_enabled = data.get("login_alert_enabled")
- base_url = data.get("base_url")
- task_notify_enabled = data.get("task_notify_enabled")
-
- email_service.update_email_settings(
- enabled=enabled,
- failover_enabled=failover_enabled,
- register_verify_enabled=register_verify_enabled,
- login_alert_enabled=login_alert_enabled,
- base_url=base_url,
- task_notify_enabled=task_notify_enabled,
- )
- return jsonify({"success": True})
- except Exception as e:
- logger.error(f"更新邮件设置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/smtp/configs", methods=["GET"])
-@admin_required
-def get_smtp_configs_api():
- """获取所有SMTP配置列表"""
- try:
- configs = email_service.get_smtp_configs(include_password=False)
- return jsonify(configs)
- except Exception as e:
- logger.error(f"获取SMTP配置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/smtp/configs", methods=["POST"])
-@admin_required
-def create_smtp_config_api():
- """创建SMTP配置"""
- try:
- data = request.json or {}
- if not data.get("host"):
- return jsonify({"error": "SMTP服务器地址不能为空"}), 400
- if not data.get("username"):
- return jsonify({"error": "SMTP用户名不能为空"}), 400
-
- config_id = email_service.create_smtp_config(data)
- return jsonify({"success": True, "id": config_id})
- except Exception as e:
- logger.error(f"创建SMTP配置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/smtp/configs/", methods=["GET"])
-@admin_required
-def get_smtp_config_api(config_id):
- """获取单个SMTP配置详情"""
- try:
- config_data = email_service.get_smtp_config(config_id, include_password=False)
- if not config_data:
- return jsonify({"error": "配置不存在"}), 404
- return jsonify(config_data)
- except Exception as e:
- logger.error(f"获取SMTP配置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/smtp/configs/", methods=["PUT"])
-@admin_required
-def update_smtp_config_api(config_id):
- """更新SMTP配置"""
- try:
- data = request.json or {}
- if email_service.update_smtp_config(config_id, data):
- return jsonify({"success": True})
- return jsonify({"error": "更新失败"}), 400
- except Exception as e:
- logger.error(f"更新SMTP配置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/smtp/configs/", methods=["DELETE"])
-@admin_required
-def delete_smtp_config_api(config_id):
- """删除SMTP配置"""
- try:
- if email_service.delete_smtp_config(config_id):
- return jsonify({"success": True})
- return jsonify({"error": "删除失败"}), 400
- except Exception as e:
- logger.error(f"删除SMTP配置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/smtp/configs//test", methods=["POST"])
-@admin_required
-def test_smtp_config_api(config_id):
- """测试SMTP配置"""
- try:
- data = request.json or {}
- test_email = str(data.get("email", "") or "").strip()
- if not test_email:
- return jsonify({"error": "请提供测试邮箱"}), 400
-
- is_valid, error_msg = validate_email(test_email)
- if not is_valid:
- return jsonify({"error": error_msg}), 400
-
- result = email_service.test_smtp_config(config_id, test_email)
- return jsonify(result)
- except Exception as e:
- logger.error(f"测试SMTP配置失败: {e}")
- return jsonify({"success": False, "error": str(e)}), 500
-
-
-@admin_api_bp.route("/smtp/configs//primary", methods=["POST"])
-@admin_required
-def set_primary_smtp_config_api(config_id):
- """设置主SMTP配置"""
- try:
- if email_service.set_primary_smtp_config(config_id):
- return jsonify({"success": True})
- return jsonify({"error": "设置失败"}), 400
- except Exception as e:
- logger.error(f"设置主SMTP配置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"])
-@admin_required
-def clear_primary_smtp_config_api():
- """取消主SMTP配置"""
- try:
- email_service.clear_primary_smtp_config()
- return jsonify({"success": True})
- except Exception as e:
- logger.error(f"取消主SMTP配置失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/email/stats", methods=["GET"])
-@admin_required
-def get_email_stats_api():
- """获取邮件发送统计"""
- try:
- stats = email_service.get_email_stats()
- return jsonify(stats)
- except Exception as e:
- logger.error(f"获取邮件统计失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/email/logs", methods=["GET"])
-@admin_required
-def get_email_logs_api():
- """获取邮件发送日志"""
- try:
- page = request.args.get("page", 1, type=int)
- page_size = request.args.get("page_size", 20, type=int)
- email_type = request.args.get("type", None)
- status = request.args.get("status", None)
-
- page_size = min(max(page_size, 10), 100)
- result = email_service.get_email_logs(page, page_size, email_type, status)
- return jsonify(result)
- except Exception as e:
- logger.error(f"获取邮件日志失败: {e}")
- return jsonify({"error": str(e)}), 500
-
-
-@admin_api_bp.route("/email/logs/cleanup", methods=["POST"])
-@admin_required
-def cleanup_email_logs_api():
- """清理过期邮件日志"""
- try:
- data = request.json or {}
- days = data.get("days", 30)
- days = min(max(days, 7), 365)
-
- deleted = email_service.cleanup_email_logs(days)
- return jsonify({"success": True, "deleted": deleted})
- except Exception as e:
- logger.error(f"清理邮件日志失败: {e}")
- return jsonify({"error": str(e)}), 500
diff --git a/routes/admin_api/email_api.py b/routes/admin_api/email_api.py
new file mode 100644
index 0000000..d4aaa5f
--- /dev/null
+++ b/routes/admin_api/email_api.py
@@ -0,0 +1,214 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import email_service
+from app_logger import get_logger
+from app_security import validate_email
+from flask import jsonify, request
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+
+logger = get_logger("app")
+
+
+@admin_api_bp.route("/email/settings", methods=["GET"])
+@admin_required
+def get_email_settings_api():
+ """获取全局邮件设置"""
+ try:
+ settings = email_service.get_email_settings()
+ return jsonify(settings)
+ except Exception as e:
+ logger.error(f"获取邮件设置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/email/settings", methods=["POST"])
+@admin_required
+def update_email_settings_api():
+ """更新全局邮件设置"""
+ try:
+ data = request.json or {}
+ enabled = data.get("enabled", False)
+ failover_enabled = data.get("failover_enabled", True)
+ register_verify_enabled = data.get("register_verify_enabled")
+ login_alert_enabled = data.get("login_alert_enabled")
+ base_url = data.get("base_url")
+ task_notify_enabled = data.get("task_notify_enabled")
+
+ email_service.update_email_settings(
+ enabled=enabled,
+ failover_enabled=failover_enabled,
+ register_verify_enabled=register_verify_enabled,
+ login_alert_enabled=login_alert_enabled,
+ base_url=base_url,
+ task_notify_enabled=task_notify_enabled,
+ )
+ return jsonify({"success": True})
+ except Exception as e:
+ logger.error(f"更新邮件设置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/smtp/configs", methods=["GET"])
+@admin_required
+def get_smtp_configs_api():
+ """获取所有SMTP配置列表"""
+ try:
+ configs = email_service.get_smtp_configs(include_password=False)
+ return jsonify(configs)
+ except Exception as e:
+ logger.error(f"获取SMTP配置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/smtp/configs", methods=["POST"])
+@admin_required
+def create_smtp_config_api():
+ """创建SMTP配置"""
+ try:
+ data = request.json or {}
+ if not data.get("host"):
+ return jsonify({"error": "SMTP服务器地址不能为空"}), 400
+ if not data.get("username"):
+ return jsonify({"error": "SMTP用户名不能为空"}), 400
+
+ config_id = email_service.create_smtp_config(data)
+ return jsonify({"success": True, "id": config_id})
+ except Exception as e:
+ logger.error(f"创建SMTP配置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/smtp/configs/", methods=["GET"])
+@admin_required
+def get_smtp_config_api(config_id):
+ """获取单个SMTP配置详情"""
+ try:
+ config_data = email_service.get_smtp_config(config_id, include_password=False)
+ if not config_data:
+ return jsonify({"error": "配置不存在"}), 404
+ return jsonify(config_data)
+ except Exception as e:
+ logger.error(f"获取SMTP配置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/smtp/configs/", methods=["PUT"])
+@admin_required
+def update_smtp_config_api(config_id):
+ """更新SMTP配置"""
+ try:
+ data = request.json or {}
+ if email_service.update_smtp_config(config_id, data):
+ return jsonify({"success": True})
+ return jsonify({"error": "更新失败"}), 400
+ except Exception as e:
+ logger.error(f"更新SMTP配置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/smtp/configs/", methods=["DELETE"])
+@admin_required
+def delete_smtp_config_api(config_id):
+ """删除SMTP配置"""
+ try:
+ if email_service.delete_smtp_config(config_id):
+ return jsonify({"success": True})
+ return jsonify({"error": "删除失败"}), 400
+ except Exception as e:
+ logger.error(f"删除SMTP配置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/smtp/configs//test", methods=["POST"])
+@admin_required
+def test_smtp_config_api(config_id):
+ """测试SMTP配置"""
+ try:
+ data = request.json or {}
+ test_email = str(data.get("email", "") or "").strip()
+ if not test_email:
+ return jsonify({"error": "请提供测试邮箱"}), 400
+
+ is_valid, error_msg = validate_email(test_email)
+ if not is_valid:
+ return jsonify({"error": error_msg}), 400
+
+ result = email_service.test_smtp_config(config_id, test_email)
+ return jsonify(result)
+ except Exception as e:
+ logger.error(f"测试SMTP配置失败: {e}")
+ return jsonify({"success": False, "error": str(e)}), 500
+
+
+@admin_api_bp.route("/smtp/configs//primary", methods=["POST"])
+@admin_required
+def set_primary_smtp_config_api(config_id):
+ """设置主SMTP配置"""
+ try:
+ if email_service.set_primary_smtp_config(config_id):
+ return jsonify({"success": True})
+ return jsonify({"error": "设置失败"}), 400
+ except Exception as e:
+ logger.error(f"设置主SMTP配置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"])
+@admin_required
+def clear_primary_smtp_config_api():
+ """取消主SMTP配置"""
+ try:
+ email_service.clear_primary_smtp_config()
+ return jsonify({"success": True})
+ except Exception as e:
+ logger.error(f"取消主SMTP配置失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/email/stats", methods=["GET"])
+@admin_required
+def get_email_stats_api():
+ """获取邮件发送统计"""
+ try:
+ stats = email_service.get_email_stats()
+ return jsonify(stats)
+ except Exception as e:
+ logger.error(f"获取邮件统计失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/email/logs", methods=["GET"])
+@admin_required
+def get_email_logs_api():
+ """获取邮件发送日志"""
+ try:
+ page = request.args.get("page", 1, type=int)
+ page_size = request.args.get("page_size", 20, type=int)
+ email_type = request.args.get("type", None)
+ status = request.args.get("status", None)
+
+ page_size = min(max(page_size, 10), 100)
+ result = email_service.get_email_logs(page, page_size, email_type, status)
+ return jsonify(result)
+ except Exception as e:
+ logger.error(f"获取邮件日志失败: {e}")
+ return jsonify({"error": str(e)}), 500
+
+
+@admin_api_bp.route("/email/logs/cleanup", methods=["POST"])
+@admin_required
+def cleanup_email_logs_api():
+ """清理过期邮件日志"""
+ try:
+ data = request.json or {}
+ days = data.get("days", 30)
+ days = min(max(days, 7), 365)
+
+ deleted = email_service.cleanup_email_logs(days)
+ return jsonify({"success": True, "deleted": deleted})
+ except Exception as e:
+ logger.error(f"清理邮件日志失败: {e}")
+ return jsonify({"error": str(e)}), 500
diff --git a/routes/admin_api/feedback_api.py b/routes/admin_api/feedback_api.py
new file mode 100644
index 0000000..e921563
--- /dev/null
+++ b/routes/admin_api/feedback_api.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import database
+from flask import jsonify, request
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+
+@admin_api_bp.route("/feedbacks", methods=["GET"])
+@admin_required
+def get_all_feedbacks():
+ """管理员获取所有反馈"""
+ status = request.args.get("status")
+ try:
+ limit = int(request.args.get("limit", 100))
+ offset = int(request.args.get("offset", 0))
+ limit = min(max(1, limit), 1000)
+ offset = max(0, offset)
+ except (ValueError, TypeError):
+ return jsonify({"error": "无效的分页参数"}), 400
+
+ feedbacks = database.get_bug_feedbacks(limit=limit, offset=offset, status_filter=status)
+ stats = database.get_feedback_stats()
+ return jsonify({"feedbacks": feedbacks, "stats": stats})
+
+
+@admin_api_bp.route("/feedbacks//reply", methods=["POST"])
+@admin_required
+def reply_to_feedback(feedback_id):
+ """管理员回复反馈"""
+ data = request.get_json() or {}
+ reply = (data.get("reply") or "").strip()
+
+ if not reply:
+ return jsonify({"error": "回复内容不能为空"}), 400
+
+ if database.reply_feedback(feedback_id, reply):
+ return jsonify({"message": "回复成功"})
+ return jsonify({"error": "反馈不存在"}), 404
+
+
+@admin_api_bp.route("/feedbacks//close", methods=["POST"])
+@admin_required
+def close_feedback_api(feedback_id):
+ """管理员关闭反馈"""
+ if database.close_feedback(feedback_id):
+ return jsonify({"message": "已关闭"})
+ return jsonify({"error": "反馈不存在"}), 404
+
+
+@admin_api_bp.route("/feedbacks/", methods=["DELETE"])
+@admin_required
+def delete_feedback_api(feedback_id):
+ """管理员删除反馈"""
+ if database.delete_feedback(feedback_id):
+ return jsonify({"message": "已删除"})
+ return jsonify({"error": "反馈不存在"}), 404
diff --git a/routes/admin_api/infra_api.py b/routes/admin_api/infra_api.py
new file mode 100644
index 0000000..fd8371b
--- /dev/null
+++ b/routes/admin_api/infra_api.py
@@ -0,0 +1,226 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import os
+import time
+from datetime import datetime
+
+import database
+from app_logger import get_logger
+from flask import jsonify, session
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+from services.time_utils import BEIJING_TZ, get_beijing_now
+
+logger = get_logger("app")
+
+
+@admin_api_bp.route("/stats", methods=["GET"])
+@admin_required
+def get_system_stats():
+ """获取系统统计"""
+ stats = database.get_system_stats()
+ stats["admin_username"] = session.get("admin_username", "admin")
+ return jsonify(stats)
+
+
+@admin_api_bp.route("/browser_pool/stats", methods=["GET"])
+@admin_required
+def get_browser_pool_stats():
+ """获取截图线程池状态"""
+ try:
+ from browser_pool_worker import get_browser_worker_pool
+
+ pool = get_browser_worker_pool()
+ stats = pool.get_stats() or {}
+
+ worker_details = []
+ for w in stats.get("workers") or []:
+ last_ts = float(w.get("last_active_ts") or 0)
+ last_active_at = None
+ if last_ts > 0:
+ try:
+ last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
+ except Exception:
+ last_active_at = None
+
+ created_ts = w.get("browser_created_at")
+ created_at = None
+ if created_ts:
+ try:
+ created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S")
+ except Exception:
+ created_at = None
+
+ worker_details.append(
+ {
+ "worker_id": w.get("worker_id"),
+ "idle": bool(w.get("idle")),
+ "has_browser": bool(w.get("has_browser")),
+ "total_tasks": int(w.get("total_tasks") or 0),
+ "failed_tasks": int(w.get("failed_tasks") or 0),
+ "browser_use_count": int(w.get("browser_use_count") or 0),
+ "browser_created_at": created_at,
+ "browser_created_ts": created_ts,
+ "last_active_at": last_active_at,
+ "last_active_ts": last_ts,
+ "thread_alive": bool(w.get("thread_alive")),
+ }
+ )
+
+ total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0)
+ return jsonify(
+ {
+ "total_workers": total_workers,
+ "active_workers": int(stats.get("busy_workers") or 0),
+ "idle_workers": int(stats.get("idle_workers") or 0),
+ "queue_size": int(stats.get("queue_size") or 0),
+ "workers": worker_details,
+ "summary": {
+ "total_tasks": int(stats.get("total_tasks") or 0),
+ "failed_tasks": int(stats.get("failed_tasks") or 0),
+ "success_rate": stats.get("success_rate"),
+ },
+ "server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
+ }
+ )
+ except Exception as e:
+ logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}")
+ return jsonify({"error": "获取截图线程池状态失败"}), 500
+
+
+@admin_api_bp.route("/docker_stats", methods=["GET"])
+@admin_required
+def get_docker_stats():
+ """获取Docker容器运行状态"""
+ import subprocess
+
+ docker_status = {
+ "running": False,
+ "container_name": "N/A",
+ "uptime": "N/A",
+ "memory_usage": "N/A",
+ "memory_limit": "N/A",
+ "memory_percent": "N/A",
+ "cpu_percent": "N/A",
+ "status": "Unknown",
+ }
+
+ try:
+ if os.path.exists("/.dockerenv"):
+ docker_status["running"] = True
+
+ try:
+ with open("/etc/hostname", "r") as f:
+ docker_status["container_name"] = f.read().strip()
+ except Exception as e:
+ logger.debug(f"读取容器名称失败: {e}")
+
+ try:
+ if os.path.exists("/sys/fs/cgroup/memory.current"):
+ with open("/sys/fs/cgroup/memory.current", "r") as f:
+ mem_total = int(f.read().strip())
+
+ cache = 0
+ if os.path.exists("/sys/fs/cgroup/memory.stat"):
+ with open("/sys/fs/cgroup/memory.stat", "r") as f:
+ for line in f:
+ if line.startswith("inactive_file "):
+ cache = int(line.split()[1])
+ break
+
+ mem_bytes = mem_total - cache
+ docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
+
+ if os.path.exists("/sys/fs/cgroup/memory.max"):
+ with open("/sys/fs/cgroup/memory.max", "r") as f:
+ limit_str = f.read().strip()
+ if limit_str != "max":
+ limit_bytes = int(limit_str)
+ docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
+ docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
+ elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"):
+ with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f:
+ mem_bytes = int(f.read().strip())
+ docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024)
+
+ with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f:
+ limit_bytes = int(f.read().strip())
+ if limit_bytes < 1e18:
+ docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024)
+ docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100)
+ except Exception as e:
+ logger.debug(f"读取内存信息失败: {e}")
+
+ try:
+ if os.path.exists("/sys/fs/cgroup/cpu.stat"):
+ cpu_usage = 0
+ with open("/sys/fs/cgroup/cpu.stat", "r") as f:
+ for line in f:
+ if line.startswith("usage_usec"):
+ cpu_usage = int(line.split()[1])
+ break
+
+ time.sleep(0.1)
+ cpu_usage2 = 0
+ with open("/sys/fs/cgroup/cpu.stat", "r") as f:
+ for line in f:
+ if line.startswith("usage_usec"):
+ cpu_usage2 = int(line.split()[1])
+ break
+
+ cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e6 * 100
+ docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
+ elif os.path.exists("/sys/fs/cgroup/cpu/cpuacct.usage"):
+ with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
+ cpu_usage = int(f.read().strip())
+
+ time.sleep(0.1)
+ with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f:
+ cpu_usage2 = int(f.read().strip())
+
+ cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e9 * 100
+ docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent)
+ except Exception as e:
+ logger.debug(f"读取CPU信息失败: {e}")
+
+ try:
+ # 读取系统运行时间
+ with open('/proc/uptime', 'r') as f:
+ system_uptime = float(f.read().split()[0])
+
+ # 读取 PID 1 的启动时间 (jiffies)
+ with open('/proc/1/stat', 'r') as f:
+ stat = f.read().split()
+ starttime_jiffies = int(stat[21])
+
+ # 获取 CLK_TCK (通常是 100)
+ clk_tck = os.sysconf(os.sysconf_names['SC_CLK_TCK'])
+
+ # 计算容器运行时长(秒)
+ container_uptime_seconds = system_uptime - (starttime_jiffies / clk_tck)
+
+ # 格式化为可读字符串
+ days = int(container_uptime_seconds // 86400)
+ hours = int((container_uptime_seconds % 86400) // 3600)
+ minutes = int((container_uptime_seconds % 3600) // 60)
+
+ if days > 0:
+ docker_status["uptime"] = f"{days}天{hours}小时{minutes}分钟"
+ elif hours > 0:
+ docker_status["uptime"] = f"{hours}小时{minutes}分钟"
+ else:
+ docker_status["uptime"] = f"{minutes}分钟"
+ except Exception as e:
+ logger.debug(f"获取容器运行时间失败: {e}")
+
+ docker_status["status"] = "Running"
+
+ else:
+ docker_status["status"] = "Not in Docker"
+ except Exception as e:
+ docker_status["status"] = f"Error: {str(e)}"
+
+ return jsonify(docker_status)
+
diff --git a/routes/admin_api/operations_api.py b/routes/admin_api/operations_api.py
new file mode 100644
index 0000000..7f35d8c
--- /dev/null
+++ b/routes/admin_api/operations_api.py
@@ -0,0 +1,228 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import threading
+import time
+from datetime import datetime
+
+import database
+import requests
+from app_logger import get_logger
+from app_security import is_safe_outbound_url
+from flask import jsonify, request
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+from services.scheduler import run_scheduled_task
+from services.time_utils import BEIJING_TZ, get_beijing_now
+
+logger = get_logger("app")
+
+_server_cpu_percent_lock = threading.Lock()
+_server_cpu_percent_last: float | None = None
+_server_cpu_percent_last_ts = 0.0
+
+
+def _get_server_cpu_percent() -> float:
+ import psutil
+
+ global _server_cpu_percent_last, _server_cpu_percent_last_ts
+
+ now = time.time()
+ with _server_cpu_percent_lock:
+ if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5:
+ return _server_cpu_percent_last
+
+ try:
+ if _server_cpu_percent_last is None:
+ cpu_percent = float(psutil.cpu_percent(interval=0.1))
+ else:
+ cpu_percent = float(psutil.cpu_percent(interval=None))
+ except Exception:
+ cpu_percent = float(_server_cpu_percent_last or 0.0)
+
+ if cpu_percent < 0:
+ cpu_percent = 0.0
+
+ _server_cpu_percent_last = cpu_percent
+ _server_cpu_percent_last_ts = now
+ return cpu_percent
+
+
+@admin_api_bp.route("/kdocs/status", methods=["GET"])
+@admin_required
+def get_kdocs_status_api():
+ """获取金山文档上传状态"""
+ try:
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ uploader = get_kdocs_uploader()
+ status = uploader.get_status()
+ live = str(request.args.get("live", "")).lower() in ("1", "true", "yes")
+ if live:
+ live_status = uploader.refresh_login_status()
+ if live_status.get("success"):
+ logged_in = bool(live_status.get("logged_in"))
+ status["logged_in"] = logged_in
+ status["last_login_ok"] = logged_in
+ status["login_required"] = not logged_in
+ if live_status.get("error"):
+ status["last_error"] = live_status.get("error")
+ else:
+ status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None
+ if status.get("last_login_ok") is True and status.get("last_error") == "操作超时":
+ status["last_error"] = None
+ return jsonify(status)
+ except Exception as e:
+ return jsonify({"error": f"获取状态失败: {e}"}), 500
+
+
+@admin_api_bp.route("/kdocs/qr", methods=["POST"])
+@admin_required
+def get_kdocs_qr_api():
+ """获取金山文档登录二维码"""
+ try:
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ uploader = get_kdocs_uploader()
+ data = request.get_json(silent=True) or {}
+ force = bool(data.get("force"))
+ if not force:
+ force = str(request.args.get("force", "")).lower() in ("1", "true", "yes")
+ result = uploader.request_qr(force=force)
+ if not result.get("success"):
+ return jsonify({"error": result.get("error", "获取二维码失败")}), 400
+ return jsonify(result)
+ except Exception as e:
+ return jsonify({"error": f"获取二维码失败: {e}"}), 500
+
+
+@admin_api_bp.route("/kdocs/clear-login", methods=["POST"])
+@admin_required
+def clear_kdocs_login_api():
+ """清除金山文档登录态"""
+ try:
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ uploader = get_kdocs_uploader()
+ result = uploader.clear_login()
+ if not result.get("success"):
+ return jsonify({"error": result.get("error", "清除失败")}), 400
+ return jsonify({"success": True})
+ except Exception as e:
+ return jsonify({"error": f"清除失败: {e}"}), 500
+
+
+@admin_api_bp.route("/schedule/execute", methods=["POST"])
+@admin_required
+def execute_schedule_now():
+ """立即执行定时任务(无视定时时间和星期限制)"""
+ try:
+ threading.Thread(target=run_scheduled_task, args=(True,), daemon=True).start()
+ logger.info("[立即执行定时任务] 管理员手动触发定时任务执行(跳过星期检查)")
+ return jsonify({"message": "定时任务已开始执行,请查看任务列表获取进度"})
+ except Exception as e:
+ logger.error(f"[立即执行定时任务] 启动失败: {str(e)}")
+ return jsonify({"error": f"启动失败: {str(e)}"}), 500
+
+
+@admin_api_bp.route("/proxy/config", methods=["GET"])
+@admin_required
+def get_proxy_config_api():
+ """获取代理配置"""
+ config_data = database.get_system_config()
+ return jsonify(
+ {
+ "proxy_enabled": config_data.get("proxy_enabled", 0),
+ "proxy_api_url": config_data.get("proxy_api_url", ""),
+ "proxy_expire_minutes": config_data.get("proxy_expire_minutes", 3),
+ }
+ )
+
+
+@admin_api_bp.route("/proxy/config", methods=["POST"])
+@admin_required
+def update_proxy_config_api():
+ """更新代理配置"""
+ data = request.json or {}
+ proxy_enabled = data.get("proxy_enabled")
+ proxy_api_url = (data.get("proxy_api_url", "") or "").strip()
+ proxy_expire_minutes = data.get("proxy_expire_minutes")
+
+ if proxy_enabled is not None and proxy_enabled not in [0, 1]:
+ return jsonify({"error": "proxy_enabled必须是0或1"}), 400
+
+ if proxy_expire_minutes is not None:
+ if not isinstance(proxy_expire_minutes, int) or proxy_expire_minutes < 1:
+ return jsonify({"error": "代理有效期必须是大于0的整数"}), 400
+
+ if database.update_system_config(
+ proxy_enabled=proxy_enabled,
+ proxy_api_url=proxy_api_url,
+ proxy_expire_minutes=proxy_expire_minutes,
+ ):
+ return jsonify({"message": "代理配置已更新"})
+ return jsonify({"error": "更新失败"}), 400
+
+
+@admin_api_bp.route("/proxy/test", methods=["POST"])
+@admin_required
+def test_proxy_api():
+ """测试代理连接"""
+ data = request.json or {}
+ api_url = (data.get("api_url") or "").strip()
+
+ if not api_url:
+ return jsonify({"error": "请提供API地址"}), 400
+
+ if not is_safe_outbound_url(api_url):
+ return jsonify({"error": "API地址不可用或不安全"}), 400
+
+ try:
+ response = requests.get(api_url, timeout=10)
+ if response.status_code == 200:
+ ip_port = response.text.strip()
+ if ip_port and ":" in ip_port:
+ return jsonify({"success": True, "proxy": ip_port, "message": f"代理获取成功: {ip_port}"})
+ return jsonify({"success": False, "message": f"代理格式错误: {ip_port}"}), 400
+ return jsonify({"success": False, "message": f"HTTP错误: {response.status_code}"}), 400
+ except Exception as e:
+ return jsonify({"success": False, "message": f"连接失败: {str(e)}"}), 500
+
+
+@admin_api_bp.route("/server/info", methods=["GET"])
+@admin_required
+def get_server_info_api():
+ """获取服务器信息"""
+ import psutil
+
+ cpu_percent = _get_server_cpu_percent()
+
+ memory = psutil.virtual_memory()
+ memory_total = f"{memory.total / (1024**3):.1f}GB"
+ memory_used = f"{memory.used / (1024**3):.1f}GB"
+ memory_percent = memory.percent
+
+ disk = psutil.disk_usage("/")
+ disk_total = f"{disk.total / (1024**3):.1f}GB"
+ disk_used = f"{disk.used / (1024**3):.1f}GB"
+ disk_percent = disk.percent
+
+ boot_time = datetime.fromtimestamp(psutil.boot_time(), tz=BEIJING_TZ)
+ uptime_delta = get_beijing_now() - boot_time
+ days = uptime_delta.days
+ hours = uptime_delta.seconds // 3600
+ uptime = f"{days}天{hours}小时"
+
+ return jsonify(
+ {
+ "cpu_percent": cpu_percent,
+ "memory_total": memory_total,
+ "memory_used": memory_used,
+ "memory_percent": memory_percent,
+ "disk_total": disk_total,
+ "disk_used": disk_used,
+ "disk_percent": disk_percent,
+ "uptime": uptime,
+ }
+ )
diff --git a/routes/admin_api/security.py b/routes/admin_api/security.py
index f6a4caa..9399f7c 100644
--- a/routes/admin_api/security.py
+++ b/routes/admin_api/security.py
@@ -62,6 +62,19 @@ def _parse_bool(value: Any) -> bool:
return text in {"1", "true", "yes", "y", "on"}
+def _parse_int(value: Any, *, default: int | None = None, min_value: int | None = None) -> int | None:
+ try:
+ parsed = int(value)
+ except Exception:
+ parsed = default
+
+ if parsed is None:
+ return None
+ if min_value is not None:
+ parsed = max(int(min_value), parsed)
+ return parsed
+
+
def _sanitize_threat_event(event: dict) -> dict:
return {
"id": event.get("id"),
@@ -199,10 +212,7 @@ def ban_ip():
if not reason:
return jsonify({"error": "reason不能为空"}), 400
- try:
- duration_hours = max(1, int(duration_hours_raw))
- except Exception:
- duration_hours = 24
+ duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24
ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
@@ -235,20 +245,14 @@ def ban_user():
duration_hours_raw = data.get("duration_hours", 24)
permanent = _parse_bool(data.get("permanent", False))
- try:
- user_id = int(user_id_raw)
- except Exception:
- user_id = None
+ user_id = _parse_int(user_id_raw)
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
if not reason:
return jsonify({"error": "reason不能为空"}), 400
- try:
- duration_hours = max(1, int(duration_hours_raw))
- except Exception:
- duration_hours = 24
+ duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24
ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent)
if not ok:
@@ -262,10 +266,7 @@ def unban_user():
"""解除用户封禁"""
data = _parse_json()
user_id_raw = data.get("user_id")
- try:
- user_id = int(user_id_raw)
- except Exception:
- user_id = None
+ user_id = _parse_int(user_id_raw)
if user_id is None:
return jsonify({"error": "user_id不能为空"}), 400
diff --git a/routes/admin_api/system_config_api.py b/routes/admin_api/system_config_api.py
new file mode 100644
index 0000000..478b950
--- /dev/null
+++ b/routes/admin_api/system_config_api.py
@@ -0,0 +1,228 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import database
+from app_logger import get_logger
+from app_security import is_safe_outbound_url, validate_email
+from flask import jsonify, request
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+from services.browse_types import BROWSE_TYPE_SHOULD_READ, validate_browse_type
+from services.tasks import get_task_scheduler
+
+logger = get_logger("app")
+
+
+@admin_api_bp.route("/system/config", methods=["GET"])
+@admin_required
+def get_system_config_api():
+ """获取系统配置"""
+ return jsonify(database.get_system_config())
+
+
+@admin_api_bp.route("/system/config", methods=["POST"])
+@admin_required
+def update_system_config_api():
+ """更新系统配置"""
+ data = request.json or {}
+
+ max_concurrent = data.get("max_concurrent_global")
+ schedule_enabled = data.get("schedule_enabled")
+ schedule_time = data.get("schedule_time")
+ schedule_browse_type = data.get("schedule_browse_type")
+ schedule_weekdays = data.get("schedule_weekdays")
+ new_max_concurrent_per_account = data.get("max_concurrent_per_account")
+ new_max_screenshot_concurrent = data.get("max_screenshot_concurrent")
+ enable_screenshot = data.get("enable_screenshot")
+ auto_approve_enabled = data.get("auto_approve_enabled")
+ auto_approve_hourly_limit = data.get("auto_approve_hourly_limit")
+ auto_approve_vip_days = data.get("auto_approve_vip_days")
+ kdocs_enabled = data.get("kdocs_enabled")
+ kdocs_doc_url = data.get("kdocs_doc_url")
+ kdocs_default_unit = data.get("kdocs_default_unit")
+ kdocs_sheet_name = data.get("kdocs_sheet_name")
+ kdocs_sheet_index = data.get("kdocs_sheet_index")
+ kdocs_unit_column = data.get("kdocs_unit_column")
+ kdocs_image_column = data.get("kdocs_image_column")
+ kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled")
+ kdocs_admin_notify_email = data.get("kdocs_admin_notify_email")
+ kdocs_row_start = data.get("kdocs_row_start")
+ kdocs_row_end = data.get("kdocs_row_end")
+
+ if max_concurrent is not None:
+ if not isinstance(max_concurrent, int) or max_concurrent < 1:
+ return jsonify({"error": "全局并发数必须大于0(建议:小型服务器2-5,中型5-10,大型10-20)"}), 400
+
+ if new_max_concurrent_per_account is not None:
+ if not isinstance(new_max_concurrent_per_account, int) or new_max_concurrent_per_account < 1:
+ return jsonify({"error": "单账号并发数必须大于0(建议设为1,避免同一用户任务相互影响)"}), 400
+
+ if new_max_screenshot_concurrent is not None:
+ if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1:
+ return jsonify({"error": "截图并发数必须大于0(建议根据服务器配置设置,wkhtmltoimage 资源占用较低)"}), 400
+
+ if enable_screenshot is not None:
+ if isinstance(enable_screenshot, bool):
+ enable_screenshot = 1 if enable_screenshot else 0
+ if enable_screenshot not in (0, 1):
+ return jsonify({"error": "截图开关必须是0或1"}), 400
+
+ if schedule_time is not None:
+ import re
+
+ if not re.match(r"^([01]\\d|2[0-3]):([0-5]\\d)$", schedule_time):
+ return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400
+
+ if schedule_browse_type is not None:
+ normalized = validate_browse_type(schedule_browse_type, default=BROWSE_TYPE_SHOULD_READ)
+ if not normalized:
+ return jsonify({"error": "浏览类型无效"}), 400
+ schedule_browse_type = normalized
+
+ if schedule_weekdays is not None:
+ try:
+ days = [int(d.strip()) for d in schedule_weekdays.split(",") if d.strip()]
+ if not all(1 <= d <= 7 for d in days):
+ return jsonify({"error": "星期数字必须在1-7之间"}), 400
+ except (ValueError, AttributeError):
+ return jsonify({"error": "星期格式错误"}), 400
+
+ if auto_approve_hourly_limit is not None:
+ if not isinstance(auto_approve_hourly_limit, int) or auto_approve_hourly_limit < 1:
+ return jsonify({"error": "每小时注册限制必须大于0"}), 400
+
+ if auto_approve_vip_days is not None:
+ if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0:
+ return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400
+
+ if kdocs_enabled is not None:
+ if isinstance(kdocs_enabled, bool):
+ kdocs_enabled = 1 if kdocs_enabled else 0
+ if kdocs_enabled not in (0, 1):
+ return jsonify({"error": "表格上传开关必须是0或1"}), 400
+
+ if kdocs_doc_url is not None:
+ kdocs_doc_url = str(kdocs_doc_url or "").strip()
+ if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url):
+ return jsonify({"error": "文档链接格式不正确"}), 400
+
+ if kdocs_default_unit is not None:
+ kdocs_default_unit = str(kdocs_default_unit or "").strip()
+ if len(kdocs_default_unit) > 50:
+ return jsonify({"error": "默认县区长度不能超过50"}), 400
+
+ if kdocs_sheet_name is not None:
+ kdocs_sheet_name = str(kdocs_sheet_name or "").strip()
+ if len(kdocs_sheet_name) > 50:
+ return jsonify({"error": "Sheet名称长度不能超过50"}), 400
+
+ if kdocs_sheet_index is not None:
+ try:
+ kdocs_sheet_index = int(kdocs_sheet_index)
+ except Exception:
+ return jsonify({"error": "Sheet序号必须是数字"}), 400
+ if kdocs_sheet_index < 0:
+ return jsonify({"error": "Sheet序号不能为负数"}), 400
+
+ if kdocs_unit_column is not None:
+ kdocs_unit_column = str(kdocs_unit_column or "").strip().upper()
+ if not kdocs_unit_column:
+ return jsonify({"error": "县区列不能为空"}), 400
+ import re
+
+ if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column):
+ return jsonify({"error": "县区列格式错误"}), 400
+
+ if kdocs_image_column is not None:
+ kdocs_image_column = str(kdocs_image_column or "").strip().upper()
+ if not kdocs_image_column:
+ return jsonify({"error": "图片列不能为空"}), 400
+ import re
+
+ if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column):
+ return jsonify({"error": "图片列格式错误"}), 400
+
+ if kdocs_admin_notify_enabled is not None:
+ if isinstance(kdocs_admin_notify_enabled, bool):
+ kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0
+ if kdocs_admin_notify_enabled not in (0, 1):
+ return jsonify({"error": "管理员通知开关必须是0或1"}), 400
+
+ if kdocs_admin_notify_email is not None:
+ kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip()
+ if kdocs_admin_notify_email:
+ is_valid, error_msg = validate_email(kdocs_admin_notify_email)
+ if not is_valid:
+ return jsonify({"error": error_msg}), 400
+
+ if kdocs_row_start is not None:
+ try:
+ kdocs_row_start = int(kdocs_row_start)
+ except (ValueError, TypeError):
+ return jsonify({"error": "起始行必须是数字"}), 400
+ if kdocs_row_start < 0:
+ return jsonify({"error": "起始行不能为负数"}), 400
+
+ if kdocs_row_end is not None:
+ try:
+ kdocs_row_end = int(kdocs_row_end)
+ except (ValueError, TypeError):
+ return jsonify({"error": "结束行必须是数字"}), 400
+ if kdocs_row_end < 0:
+ return jsonify({"error": "结束行不能为负数"}), 400
+
+ old_config = database.get_system_config() or {}
+
+ if not database.update_system_config(
+ max_concurrent=max_concurrent,
+ schedule_enabled=schedule_enabled,
+ schedule_time=schedule_time,
+ schedule_browse_type=schedule_browse_type,
+ schedule_weekdays=schedule_weekdays,
+ max_concurrent_per_account=new_max_concurrent_per_account,
+ max_screenshot_concurrent=new_max_screenshot_concurrent,
+ enable_screenshot=enable_screenshot,
+ auto_approve_enabled=auto_approve_enabled,
+ auto_approve_hourly_limit=auto_approve_hourly_limit,
+ auto_approve_vip_days=auto_approve_vip_days,
+ kdocs_enabled=kdocs_enabled,
+ kdocs_doc_url=kdocs_doc_url,
+ kdocs_default_unit=kdocs_default_unit,
+ kdocs_sheet_name=kdocs_sheet_name,
+ kdocs_sheet_index=kdocs_sheet_index,
+ kdocs_unit_column=kdocs_unit_column,
+ kdocs_image_column=kdocs_image_column,
+ kdocs_admin_notify_enabled=kdocs_admin_notify_enabled,
+ kdocs_admin_notify_email=kdocs_admin_notify_email,
+ kdocs_row_start=kdocs_row_start,
+ kdocs_row_end=kdocs_row_end,
+ ):
+ return jsonify({"error": "更新失败"}), 400
+
+ try:
+ new_config = database.get_system_config() or {}
+ scheduler = get_task_scheduler()
+ scheduler.update_limits(
+ max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))),
+ max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))),
+ )
+ if new_max_screenshot_concurrent is not None:
+ try:
+ from browser_pool_worker import resize_browser_worker_pool
+
+ if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))):
+ logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}")
+ except Exception as pool_error:
+ logger.warning(f"截图线程池并发更新失败: {pool_error}")
+ except Exception:
+ pass
+
+ if max_concurrent is not None and max_concurrent != old_config.get("max_concurrent_global"):
+ logger.info(f"全局并发数已更新为: {max_concurrent}")
+ if new_max_concurrent_per_account is not None and new_max_concurrent_per_account != old_config.get("max_concurrent_per_account"):
+ logger.info(f"单用户并发数已更新为: {new_max_concurrent_per_account}")
+ if new_max_screenshot_concurrent is not None:
+ logger.info(f"截图并发数已更新为: {new_max_screenshot_concurrent}")
+
+ return jsonify({"message": "系统配置已更新"})
diff --git a/routes/admin_api/tasks_api.py b/routes/admin_api/tasks_api.py
new file mode 100644
index 0000000..3814717
--- /dev/null
+++ b/routes/admin_api/tasks_api.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import database
+from app_logger import get_logger
+from flask import jsonify, request
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+from services.state import safe_iter_task_status_items
+from services.tasks import get_task_scheduler
+
+logger = get_logger("app")
+
+
+def _parse_page_int(name: str, default: int, *, minimum: int, maximum: int) -> int:
+ try:
+ value = int(request.args.get(name, default))
+ return max(minimum, min(value, maximum))
+ except (ValueError, TypeError):
+ return default
+
+
+@admin_api_bp.route("/task/stats", methods=["GET"])
+@admin_required
+def get_task_stats_api():
+ """获取任务统计数据"""
+ date_filter = request.args.get("date")
+ stats = database.get_task_stats(date_filter)
+ return jsonify(stats)
+
+
+@admin_api_bp.route("/task/running", methods=["GET"])
+@admin_required
+def get_running_tasks_api():
+ """获取当前运行中和排队中的任务"""
+ import time as time_mod
+
+ current_time = time_mod.time()
+ running = []
+ queuing = []
+ user_cache = {}
+
+ for account_id, info in safe_iter_task_status_items():
+ elapsed = int(current_time - info.get("start_time", current_time))
+
+ info_user_id = info.get("user_id")
+ if info_user_id not in user_cache:
+ user_cache[info_user_id] = database.get_user_by_id(info_user_id)
+ user = user_cache.get(info_user_id)
+ user_username = user["username"] if user else "N/A"
+
+ progress = info.get("progress", {"items": 0, "attachments": 0})
+ task_info = {
+ "account_id": account_id,
+ "user_id": info.get("user_id"),
+ "user_username": user_username,
+ "username": info.get("username"),
+ "browse_type": info.get("browse_type"),
+ "source": info.get("source", "manual"),
+ "detail_status": info.get("detail_status", "未知"),
+ "progress_items": progress.get("items", 0),
+ "progress_attachments": progress.get("attachments", 0),
+ "elapsed_seconds": elapsed,
+ "elapsed_display": f"{elapsed // 60}分{elapsed % 60}秒" if elapsed >= 60 else f"{elapsed}秒",
+ }
+
+ if info.get("status") == "运行中":
+ running.append(task_info)
+ else:
+ queuing.append(task_info)
+
+ running.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
+ queuing.sort(key=lambda x: x["elapsed_seconds"], reverse=True)
+
+ try:
+ max_concurrent = int(get_task_scheduler().max_global)
+ except Exception:
+ max_concurrent = int((database.get_system_config() or {}).get("max_concurrent_global", 2))
+
+ return jsonify(
+ {
+ "running": running,
+ "queuing": queuing,
+ "running_count": len(running),
+ "queuing_count": len(queuing),
+ "max_concurrent": max_concurrent,
+ }
+ )
+
+
+@admin_api_bp.route("/task/logs", methods=["GET"])
+@admin_required
+def get_task_logs_api():
+ """获取任务日志列表(支持分页和多种筛选)"""
+ limit = _parse_page_int("limit", 20, minimum=1, maximum=200)
+ offset = _parse_page_int("offset", 0, minimum=0, maximum=10**9)
+
+ date_filter = request.args.get("date")
+ status_filter = request.args.get("status")
+ source_filter = request.args.get("source")
+ user_id_filter = request.args.get("user_id")
+ account_filter = (request.args.get("account") or "").strip()
+
+ if user_id_filter:
+ try:
+ user_id_filter = int(user_id_filter)
+ except (ValueError, TypeError):
+ user_id_filter = None
+
+ try:
+ result = database.get_task_logs(
+ limit=limit,
+ offset=offset,
+ date_filter=date_filter,
+ status_filter=status_filter,
+ source_filter=source_filter,
+ user_id_filter=user_id_filter,
+ account_filter=account_filter if account_filter else None,
+ )
+ return jsonify(result)
+ except Exception as e:
+ logger.error(f"获取任务日志失败: {e}")
+ return jsonify({"logs": [], "total": 0, "error": "查询失败"})
+
+
+@admin_api_bp.route("/task/logs/clear", methods=["POST"])
+@admin_required
+def clear_old_task_logs_api():
+ """清理旧的任务日志"""
+ data = request.json or {}
+ days = data.get("days", 30)
+
+ if not isinstance(days, int) or days < 1:
+ return jsonify({"error": "天数必须是大于0的整数"}), 400
+
+ deleted_count = database.delete_old_task_logs(days)
+ return jsonify({"message": f"已删除{days}天前的{deleted_count}条日志"})
diff --git a/routes/admin_api/users_api.py b/routes/admin_api/users_api.py
new file mode 100644
index 0000000..8357d43
--- /dev/null
+++ b/routes/admin_api/users_api.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+
+import database
+from flask import jsonify, request
+from routes.admin_api import admin_api_bp
+from routes.decorators import admin_required
+from services.state import safe_clear_user_logs, safe_remove_user_accounts
+
+
+# ==================== 用户管理/统计(管理员) ====================
+
+
+@admin_api_bp.route("/users", methods=["GET"])
+@admin_required
+def get_all_users():
+ """获取所有用户"""
+ users = database.get_all_users()
+ return jsonify(users)
+
+
+@admin_api_bp.route("/users/pending", methods=["GET"])
+@admin_required
+def get_pending_users():
+ """获取待审核用户"""
+ users = database.get_pending_users()
+ return jsonify(users)
+
+
+@admin_api_bp.route("/users//approve", methods=["POST"])
+@admin_required
+def approve_user_route(user_id):
+ """审核通过用户"""
+ if database.approve_user(user_id):
+ return jsonify({"success": True})
+ return jsonify({"error": "审核失败"}), 400
+
+
+@admin_api_bp.route("/users//reject", methods=["POST"])
+@admin_required
+def reject_user_route(user_id):
+ """拒绝用户"""
+ if database.reject_user(user_id):
+ return jsonify({"success": True})
+ return jsonify({"error": "拒绝失败"}), 400
+
+
+@admin_api_bp.route("/users/", methods=["DELETE"])
+@admin_required
+def delete_user_route(user_id):
+ """删除用户"""
+ if database.delete_user(user_id):
+ safe_remove_user_accounts(user_id)
+ safe_clear_user_logs(user_id)
+ return jsonify({"success": True})
+ return jsonify({"error": "删除失败"}), 400
+
+
+# ==================== VIP 管理(管理员) ====================
+
+
+@admin_api_bp.route("/vip/config", methods=["GET"])
+@admin_required
+def get_vip_config_api():
+ """获取VIP配置"""
+ config = database.get_vip_config()
+ return jsonify(config)
+
+
+@admin_api_bp.route("/vip/config", methods=["POST"])
+@admin_required
+def set_vip_config_api():
+ """设置默认VIP天数"""
+ data = request.json or {}
+ days = data.get("default_vip_days", 0)
+
+ if not isinstance(days, int) or days < 0:
+ return jsonify({"error": "VIP天数必须是非负整数"}), 400
+
+ database.set_default_vip_days(days)
+ return jsonify({"message": "VIP配置已更新", "default_vip_days": days})
+
+
+@admin_api_bp.route("/users//vip", methods=["POST"])
+@admin_required
+def set_user_vip_api(user_id):
+ """设置用户VIP"""
+ data = request.json or {}
+ days = data.get("days", 30)
+
+ valid_days = [7, 30, 365, 999999]
+ if days not in valid_days:
+ return jsonify({"error": "VIP天数必须是 7/30/365/999999 之一"}), 400
+
+ if database.set_user_vip(user_id, days):
+ vip_type = {7: "一周", 30: "一个月", 365: "一年", 999999: "永久"}[days]
+ return jsonify({"message": f"VIP设置成功: {vip_type}"})
+ return jsonify({"error": "设置失败,用户不存在"}), 400
+
+
+@admin_api_bp.route("/users//vip", methods=["DELETE"])
+@admin_required
+def remove_user_vip_api(user_id):
+ """移除用户VIP"""
+ if database.remove_user_vip(user_id):
+ return jsonify({"message": "VIP已移除"})
+ return jsonify({"error": "移除失败"}), 400
+
+
+@admin_api_bp.route("/users//vip", methods=["GET"])
+@admin_required
+def get_user_vip_info_api(user_id):
+ """获取用户VIP信息(管理员)"""
+ vip_info = database.get_user_vip_info(user_id)
+ return jsonify(vip_info)
diff --git a/routes/api_accounts.py b/routes/api_accounts.py
index 2bb5402..0c733d9 100644
--- a/routes/api_accounts.py
+++ b/routes/api_accounts.py
@@ -40,6 +40,48 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
pass
+def _emit_account_update(user_id: int, account) -> None:
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+
+
+def _request_json(default=None):
+ if default is None:
+ default = {}
+ data = request.get_json(silent=True)
+ return data if isinstance(data, dict) else default
+
+
+def _ensure_accounts_loaded(user_id: int) -> dict:
+ accounts = safe_get_user_accounts_snapshot(user_id)
+ if accounts:
+ return accounts
+ load_user_accounts(user_id)
+ return safe_get_user_accounts_snapshot(user_id)
+
+
+def _get_user_account(user_id: int, account_id: str, *, refresh_if_missing: bool = False):
+ account = safe_get_account(user_id, account_id)
+ if account or (not refresh_if_missing):
+ return account
+ load_user_accounts(user_id)
+ return safe_get_account(user_id, account_id)
+
+
+def _validate_browse_type_input(raw_browse_type, *, default=BROWSE_TYPE_SHOULD_READ):
+ browse_type = validate_browse_type(raw_browse_type, default=default)
+ if not browse_type:
+ return None, (jsonify({"error": "浏览类型无效"}), 400)
+ return browse_type, None
+
+
+def _cancel_pending_account_task(user_id: int, account_id: str) -> bool:
+ try:
+ scheduler = get_task_scheduler()
+ return bool(scheduler.cancel_pending_task(user_id=user_id, account_id=account_id))
+ except Exception:
+ return False
+
+
@api_accounts_bp.route("/api/accounts", methods=["GET"])
@login_required
def get_accounts():
@@ -49,8 +91,7 @@ def get_accounts():
accounts = safe_get_user_accounts_snapshot(user_id)
if refresh or not accounts:
- load_user_accounts(user_id)
- accounts = safe_get_user_accounts_snapshot(user_id)
+ accounts = _ensure_accounts_loaded(user_id)
return jsonify([acc.to_dict() for acc in accounts.values()])
@@ -63,20 +104,18 @@ def add_account():
current_count = len(database.get_user_accounts(user_id))
is_vip = database.is_user_vip(user_id)
- if not is_vip and current_count >= 3:
+ if (not is_vip) and current_count >= 3:
return jsonify({"error": "普通用户最多添加3个账号,升级VIP可无限添加"}), 403
- data = request.json
- username = data.get("username", "").strip()
- password = data.get("password", "").strip()
- remark = data.get("remark", "").strip()[:200]
+
+ data = _request_json()
+ username = str(data.get("username", "")).strip()
+ password = str(data.get("password", "")).strip()
+ remark = str(data.get("remark", "")).strip()[:200]
if not username or not password:
return jsonify({"error": "用户名和密码不能为空"}), 400
- accounts = safe_get_user_accounts_snapshot(user_id)
- if not accounts:
- load_user_accounts(user_id)
- accounts = safe_get_user_accounts_snapshot(user_id)
+ accounts = _ensure_accounts_loaded(user_id)
for acc in accounts.values():
if acc.username == username:
return jsonify({"error": f"账号 '{username}' 已存在"}), 400
@@ -92,7 +131,7 @@ def add_account():
safe_set_account(user_id, account_id, account)
log_to_client(f"添加账号: {username}", user_id)
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ _emit_account_update(user_id, account)
return jsonify(account.to_dict())
@@ -103,15 +142,15 @@ def update_account(account_id):
"""更新账号信息(密码等)"""
user_id = current_user.id
- account = safe_get_account(user_id, account_id)
+ account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
if account.is_running:
return jsonify({"error": "账号正在运行中,请先停止"}), 400
- data = request.json
- new_password = data.get("password", "").strip()
+ data = _request_json()
+ new_password = str(data.get("password", "")).strip()
new_remember = data.get("remember", account.remember)
if not new_password:
@@ -147,7 +186,7 @@ def delete_account(account_id):
"""删除账号"""
user_id = current_user.id
- account = safe_get_account(user_id, account_id)
+ account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
@@ -159,7 +198,6 @@ def delete_account(account_id):
username = account.username
database.delete_account(account_id)
-
safe_remove_account(user_id, account_id)
log_to_client(f"删除账号: {username}", user_id)
@@ -196,12 +234,12 @@ def update_remark(account_id):
"""更新备注"""
user_id = current_user.id
- account = safe_get_account(user_id, account_id)
+ account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
- data = request.json
- remark = data.get("remark", "").strip()[:200]
+ data = _request_json()
+ remark = str(data.get("remark", "")).strip()[:200]
database.update_account_remark(account_id, remark)
@@ -217,17 +255,18 @@ def start_account(account_id):
"""启动账号任务"""
user_id = current_user.id
- account = safe_get_account(user_id, account_id)
+ account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
if account.is_running:
return jsonify({"error": "任务已在运行中"}), 400
- data = request.json or {}
- browse_type = validate_browse_type(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
- if not browse_type:
- return jsonify({"error": "浏览类型无效"}), 400
+ data = _request_json()
+ browse_type, browse_error = _validate_browse_type_input(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
+ if browse_error:
+ return browse_error
+
enable_screenshot = data.get("enable_screenshot", True)
ok, message = submit_account_task(
user_id=user_id,
@@ -249,7 +288,7 @@ def stop_account(account_id):
"""停止账号任务"""
user_id = current_user.id
- account = safe_get_account(user_id, account_id)
+ account = _get_user_account(user_id, account_id)
if not account:
return jsonify({"error": "账号不存在"}), 404
@@ -259,20 +298,16 @@ def stop_account(account_id):
account.should_stop = True
account.status = "正在停止"
- try:
- scheduler = get_task_scheduler()
- if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id):
- account.status = "已停止"
- account.is_running = False
- safe_remove_task_status(account_id)
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
- log_to_client(f"任务已取消: {account.username}", user_id)
- return jsonify({"success": True, "canceled": True})
- except Exception:
- pass
+ if _cancel_pending_account_task(user_id, account_id):
+ account.status = "已停止"
+ account.is_running = False
+ safe_remove_task_status(account_id)
+ _emit_account_update(user_id, account)
+ log_to_client(f"任务已取消: {account.username}", user_id)
+ return jsonify({"success": True, "canceled": True})
log_to_client(f"停止任务: {account.username}", user_id)
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ _emit_account_update(user_id, account)
return jsonify({"success": True})
@@ -283,23 +318,20 @@ def manual_screenshot(account_id):
"""手动为指定账号截图"""
user_id = current_user.id
- account = safe_get_account(user_id, account_id)
- if not account:
- load_user_accounts(user_id)
- account = safe_get_account(user_id, account_id)
+ account = _get_user_account(user_id, account_id, refresh_if_missing=True)
if not account:
return jsonify({"error": "账号不存在"}), 404
if account.is_running:
return jsonify({"error": "任务运行中,无法截图"}), 400
- data = request.json or {}
+ data = _request_json()
requested_browse_type = data.get("browse_type", None)
if requested_browse_type is None:
browse_type = normalize_browse_type(account.last_browse_type)
else:
- browse_type = validate_browse_type(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ)
- if not browse_type:
- return jsonify({"error": "浏览类型无效"}), 400
+ browse_type, browse_error = _validate_browse_type_input(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ)
+ if browse_error:
+ return browse_error
account.last_browse_type = browse_type
@@ -317,12 +349,16 @@ def manual_screenshot(account_id):
def batch_start_accounts():
"""批量启动账号"""
user_id = current_user.id
- data = request.json or {}
+ data = _request_json()
account_ids = data.get("account_ids", [])
- browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ)
- if not browse_type:
- return jsonify({"error": "浏览类型无效"}), 400
+ browse_type, browse_error = _validate_browse_type_input(
+ data.get("browse_type", BROWSE_TYPE_SHOULD_READ),
+ default=BROWSE_TYPE_SHOULD_READ,
+ )
+ if browse_error:
+ return browse_error
+
enable_screenshot = data.get("enable_screenshot", True)
if not account_ids:
@@ -331,11 +367,10 @@ def batch_start_accounts():
started = []
failed = []
- if not safe_get_user_accounts_snapshot(user_id):
- load_user_accounts(user_id)
+ _ensure_accounts_loaded(user_id)
for account_id in account_ids:
- account = safe_get_account(user_id, account_id)
+ account = _get_user_account(user_id, account_id)
if not account:
failed.append({"id": account_id, "reason": "账号不存在"})
continue
@@ -357,7 +392,13 @@ def batch_start_accounts():
failed.append({"id": account_id, "reason": msg})
return jsonify(
- {"success": True, "started_count": len(started), "failed_count": len(failed), "started": started, "failed": failed}
+ {
+ "success": True,
+ "started_count": len(started),
+ "failed_count": len(failed),
+ "started": started,
+ "failed": failed,
+ }
)
@@ -366,39 +407,29 @@ def batch_start_accounts():
def batch_stop_accounts():
"""批量停止账号"""
user_id = current_user.id
- data = request.json
+ data = _request_json()
account_ids = data.get("account_ids", [])
-
if not account_ids:
return jsonify({"error": "请选择要停止的账号"}), 400
stopped = []
-
- if not safe_get_user_accounts_snapshot(user_id):
- load_user_accounts(user_id)
+ _ensure_accounts_loaded(user_id)
for account_id in account_ids:
- account = safe_get_account(user_id, account_id)
- if not account:
- continue
-
- if not account.is_running:
+ account = _get_user_account(user_id, account_id)
+ if (not account) or (not account.is_running):
continue
account.should_stop = True
account.status = "正在停止"
stopped.append(account_id)
- try:
- scheduler = get_task_scheduler()
- if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id):
- account.status = "已停止"
- account.is_running = False
- safe_remove_task_status(account_id)
- except Exception:
- pass
+ if _cancel_pending_account_task(user_id, account_id):
+ account.status = "已停止"
+ account.is_running = False
+ safe_remove_task_status(account_id)
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ _emit_account_update(user_id, account)
return jsonify({"success": True, "stopped_count": len(stopped), "stopped": stopped})
diff --git a/routes/api_auth.py b/routes/api_auth.py
index 3621892..2e65443 100644
--- a/routes/api_auth.py
+++ b/routes/api_auth.py
@@ -2,16 +2,20 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
+import base64
import random
import secrets
+import threading
import time
+import uuid
+from io import BytesIO
import database
import email_service
from app_config import get_config
from app_logger import get_logger
from app_security import get_rate_limit_ip, require_ip_not_locked, validate_email, validate_password, validate_username
-from flask import Blueprint, jsonify, redirect, render_template, request, url_for
+from flask import Blueprint, jsonify, request
from flask_login import login_required, login_user, logout_user
from routes.pages import render_app_spa_or_legacy
from services.accounts_service import load_user_accounts
@@ -39,12 +43,162 @@ config = get_config()
api_auth_bp = Blueprint("api_auth", __name__)
+_CAPTCHA_FONT_PATHS = [
+ "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
+ "/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
+]
+_CAPTCHA_FONT = None
+_CAPTCHA_FONT_LOCK = threading.Lock()
+
+
+def _get_json_payload() -> dict:
+ data = request.get_json(silent=True)
+ return data if isinstance(data, dict) else {}
+
+
+def _load_captcha_font(image_font_module):
+ global _CAPTCHA_FONT
+
+ if _CAPTCHA_FONT is not None:
+ return _CAPTCHA_FONT
+
+ with _CAPTCHA_FONT_LOCK:
+ if _CAPTCHA_FONT is not None:
+ return _CAPTCHA_FONT
+
+ for font_path in _CAPTCHA_FONT_PATHS:
+ try:
+ _CAPTCHA_FONT = image_font_module.truetype(font_path, 42)
+ break
+ except Exception:
+ continue
+
+ if _CAPTCHA_FONT is None:
+ _CAPTCHA_FONT = image_font_module.load_default()
+
+ return _CAPTCHA_FONT
+
+
+def _generate_captcha_image_data_uri(code: str) -> str:
+ from PIL import Image, ImageDraw, ImageFont
+
+ width, height = 160, 60
+ image = Image.new("RGB", (width, height), color=(255, 255, 255))
+ draw = ImageDraw.Draw(image)
+
+ for _ in range(6):
+ x1 = random.randint(0, width)
+ y1 = random.randint(0, height)
+ x2 = random.randint(0, width)
+ y2 = random.randint(0, height)
+ draw.line(
+ [(x1, y1), (x2, y2)],
+ fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)),
+ width=1,
+ )
+
+ for _ in range(80):
+ x = random.randint(0, width)
+ y = random.randint(0, height)
+ draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)))
+
+ font = _load_captcha_font(ImageFont)
+ for i, char in enumerate(code):
+ x = 12 + i * 35 + random.randint(-3, 3)
+ y = random.randint(5, 12)
+ color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150))
+ draw.text((x, y), char, font=font, fill=color)
+
+ buffer = BytesIO()
+ image.save(buffer, format="PNG")
+ img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
+ return f"data:image/png;base64,{img_base64}"
+
+
+def _with_vip_suffix(message: str, auto_approve_enabled: bool, auto_approve_vip_days: int) -> str:
+ if auto_approve_enabled and auto_approve_vip_days > 0:
+ return f"{message},赠送{auto_approve_vip_days}天VIP"
+ return message
+
+
+def _verify_common_captcha(client_ip: str, captcha_session: str, captcha_code: str):
+ success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
+ if success:
+ return True, None
+
+ is_locked = record_failed_captcha(client_ip)
+ if is_locked:
+ return False, (jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429)
+ return False, (jsonify({"error": message}), 400)
+
+
+def _verify_login_captcha_if_needed(
+ *,
+ captcha_required: bool,
+ captcha_session: str,
+ captcha_code: str,
+ client_ip: str,
+ username_key: str,
+):
+ if not captcha_required:
+ return True, None
+
+ if not captcha_session or not captcha_code:
+ return False, (jsonify({"error": "请填写验证码", "need_captcha": True}), 400)
+
+ success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
+ if success:
+ return True, None
+
+ record_login_failure(client_ip, username_key)
+ return False, (jsonify({"error": message, "need_captcha": True}), 400)
+
+
+def _send_password_reset_email_if_possible(email: str, username: str, user_id: int) -> None:
+ result = email_service.send_password_reset_email(email=email, username=username, user_id=user_id)
+ if not result["success"]:
+ logger.error(f"密码重置邮件发送失败: {result['error']}")
+
+
+def _send_login_security_alert_if_needed(user: dict, username: str, client_ip: str) -> None:
+ try:
+ user_agent = request.headers.get("User-Agent", "")
+ context = database.record_login_context(user["id"], client_ip, user_agent)
+ if not context or (not context.get("new_ip") and not context.get("new_device")):
+ return
+
+ if not config.LOGIN_ALERT_ENABLED:
+ return
+ if not should_send_login_alert(user["id"], client_ip):
+ return
+ if not email_service.get_email_settings().get("login_alert_enabled", True):
+ return
+
+ user_info = database.get_user_by_id(user["id"]) or {}
+ if (not user_info.get("email")) or (not user_info.get("email_verified")):
+ return
+ if not database.get_user_email_notify(user["id"]):
+ return
+
+ email_service.send_security_alert_email(
+ email=user_info.get("email"),
+ username=user_info.get("username") or username,
+ ip_address=client_ip,
+ user_agent=user_agent,
+ new_ip=context.get("new_ip", False),
+ new_device=context.get("new_device", False),
+ user_id=user["id"],
+ )
+ except Exception:
+ pass
+
@api_auth_bp.route("/api/register", methods=["POST"])
@require_ip_not_locked
def register():
"""用户注册"""
- data = request.json or {}
+ data = _get_json_payload()
username = data.get("username", "").strip()
password = data.get("password", "").strip()
email = data.get("email", "").strip().lower()
@@ -67,12 +221,9 @@ def register():
if not allowed:
return jsonify({"error": error_msg}), 429
- success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
- if not success:
- is_locked = record_failed_captcha(client_ip)
- if is_locked:
- return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
- return jsonify({"error": message}), 400
+ captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
+ if not captcha_ok:
+ return captcha_error_response
email_settings = email_service.get_email_settings()
email_verify_enabled = email_settings.get("register_verify_enabled", False) and email_settings.get("enabled", False)
@@ -105,20 +256,22 @@ def register():
if email_verify_enabled and email:
result = email_service.send_register_verification_email(email=email, username=username, user_id=user_id)
if result["success"]:
- message = "注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)"
- if auto_approve_enabled and auto_approve_vip_days > 0:
- message += f",赠送{auto_approve_vip_days}天VIP"
+ message = _with_vip_suffix(
+ "注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)",
+ auto_approve_enabled,
+ auto_approve_vip_days,
+ )
return jsonify({"success": True, "message": message, "need_verify": True})
logger.error(f"注册验证邮件发送失败: {result['error']}")
- message = f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录"
- if auto_approve_enabled and auto_approve_vip_days > 0:
- message += f",赠送{auto_approve_vip_days}天VIP"
+ message = _with_vip_suffix(
+ f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录",
+ auto_approve_enabled,
+ auto_approve_vip_days,
+ )
return jsonify({"success": True, "message": message, "need_verify": True})
- message = "注册成功!可直接登录"
- if auto_approve_enabled and auto_approve_vip_days > 0:
- message += f",赠送{auto_approve_vip_days}天VIP"
+ message = _with_vip_suffix("注册成功!可直接登录", auto_approve_enabled, auto_approve_vip_days)
return jsonify({"success": True, "message": message})
return jsonify({"error": "用户名已存在"}), 400
@@ -175,7 +328,7 @@ def verify_email(token):
@require_ip_not_locked
def resend_verify_email():
"""重发验证邮件"""
- data = request.json or {}
+ data = _get_json_payload()
email = data.get("email", "").strip().lower()
captcha_session = data.get("captcha_session", "")
captcha_code = data.get("captcha", "").strip()
@@ -195,12 +348,9 @@ def resend_verify_email():
if not allowed:
return jsonify({"error": error_msg}), 429
- success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
- if not success:
- is_locked = record_failed_captcha(client_ip)
- if is_locked:
- return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
- return jsonify({"error": message}), 400
+ captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
+ if not captcha_ok:
+ return captcha_error_response
user = database.get_user_by_email(email)
if not user:
@@ -235,7 +385,7 @@ def get_email_verify_status():
@require_ip_not_locked
def forgot_password():
"""发送密码重置邮件"""
- data = request.json or {}
+ data = _get_json_payload()
email = data.get("email", "").strip().lower()
username = data.get("username", "").strip()
captcha_session = data.get("captcha_session", "")
@@ -263,12 +413,9 @@ def forgot_password():
if not allowed:
return jsonify({"error": error_msg}), 429
- success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
- if not success:
- is_locked = record_failed_captcha(client_ip)
- if is_locked:
- return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429
- return jsonify({"error": message}), 400
+ captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code)
+ if not captcha_ok:
+ return captcha_error_response
email_settings = email_service.get_email_settings()
if not email_settings.get("enabled", False):
@@ -293,20 +440,16 @@ def forgot_password():
if not allowed:
return jsonify({"error": error_msg}), 429
- result = email_service.send_password_reset_email(
+ _send_password_reset_email_if_possible(
email=bound_email,
username=user["username"],
user_id=user["id"],
)
- if not result["success"]:
- logger.error(f"密码重置邮件发送失败: {result['error']}")
return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"})
user = database.get_user_by_email(email)
if user and user.get("status") == "approved":
- result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"])
- if not result["success"]:
- logger.error(f"密码重置邮件发送失败: {result['error']}")
+ _send_password_reset_email_if_possible(email=email, username=user["username"], user_id=user["id"])
return jsonify({"success": True, "message": "如果该邮箱已注册,您将收到密码重置邮件"})
@@ -331,7 +474,7 @@ def reset_password_page(token):
@api_auth_bp.route("/api/reset-password-confirm", methods=["POST"])
def reset_password_confirm():
"""确认密码重置"""
- data = request.json or {}
+ data = _get_json_payload()
token = data.get("token", "").strip()
new_password = data.get("new_password", "").strip()
@@ -356,67 +499,15 @@ def reset_password_confirm():
@api_auth_bp.route("/api/generate_captcha", methods=["POST"])
def generate_captcha():
"""生成4位数字验证码图片"""
- import base64
- import uuid
- from io import BytesIO
-
session_id = str(uuid.uuid4())
-
- code = "".join([str(secrets.randbelow(10)) for _ in range(4)])
+ code = "".join(str(secrets.randbelow(10)) for _ in range(4))
safe_set_captcha(session_id, {"code": code, "expire_time": time.time() + 300, "failed_attempts": 0})
safe_cleanup_expired_captcha()
try:
- from PIL import Image, ImageDraw, ImageFont
- import io
-
- width, height = 160, 60
- image = Image.new("RGB", (width, height), color=(255, 255, 255))
- draw = ImageDraw.Draw(image)
-
- for _ in range(6):
- x1 = random.randint(0, width)
- y1 = random.randint(0, height)
- x2 = random.randint(0, width)
- y2 = random.randint(0, height)
- draw.line(
- [(x1, y1), (x2, y2)],
- fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)),
- width=1,
- )
-
- for _ in range(80):
- x = random.randint(0, width)
- y = random.randint(0, height)
- draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)))
-
- font = None
- font_paths = [
- "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
- "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
- "/usr/share/fonts/truetype/freefont/FreeSansBold.ttf",
- ]
- for font_path in font_paths:
- try:
- font = ImageFont.truetype(font_path, 42)
- break
- except Exception:
- continue
- if font is None:
- font = ImageFont.load_default()
-
- for i, char in enumerate(code):
- x = 12 + i * 35 + random.randint(-3, 3)
- y = random.randint(5, 12)
- color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150))
- draw.text((x, y), char, font=font, fill=color)
-
- buffer = io.BytesIO()
- image.save(buffer, format="PNG")
- img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
-
- return jsonify({"session_id": session_id, "captcha_image": f"data:image/png;base64,{img_base64}"})
+ captcha_image = _generate_captcha_image_data_uri(code)
+ return jsonify({"session_id": session_id, "captcha_image": captcha_image})
except ImportError as e:
logger.error(f"PIL库未安装,验证码功能不可用: {e}")
safe_delete_captcha(session_id)
@@ -427,7 +518,7 @@ def generate_captcha():
@require_ip_not_locked
def login():
"""用户登录"""
- data = request.json or {}
+ data = _get_json_payload()
username = data.get("username", "").strip()
password = data.get("password", "").strip()
captcha_session = data.get("captcha_session", "")
@@ -452,13 +543,15 @@ def login():
return jsonify({"error": error_msg, "need_captcha": True}), 429
captcha_required = check_login_captcha_required(client_ip, username_key) or scan_locked or bool(need_captcha)
- if captcha_required:
- if not captcha_session or not captcha_code:
- return jsonify({"error": "请填写验证码", "need_captcha": True}), 400
- success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code)
- if not success:
- record_login_failure(client_ip, username_key)
- return jsonify({"error": message, "need_captcha": True}), 400
+ captcha_ok, captcha_error_response = _verify_login_captcha_if_needed(
+ captcha_required=captcha_required,
+ captcha_session=captcha_session,
+ captcha_code=captcha_code,
+ client_ip=client_ip,
+ username_key=username_key,
+ )
+ if not captcha_ok:
+ return captcha_error_response
user = database.verify_user(username, password)
if not user:
@@ -476,29 +569,7 @@ def login():
login_user(user_obj)
load_user_accounts(user["id"])
- try:
- user_agent = request.headers.get("User-Agent", "")
- context = database.record_login_context(user["id"], client_ip, user_agent)
- if context and (context.get("new_ip") or context.get("new_device")):
- if (
- config.LOGIN_ALERT_ENABLED
- and should_send_login_alert(user["id"], client_ip)
- and email_service.get_email_settings().get("login_alert_enabled", True)
- ):
- user_info = database.get_user_by_id(user["id"]) or {}
- if user_info.get("email") and user_info.get("email_verified"):
- if database.get_user_email_notify(user["id"]):
- email_service.send_security_alert_email(
- email=user_info.get("email"),
- username=user_info.get("username") or username,
- ip_address=client_ip,
- user_agent=user_agent,
- new_ip=context.get("new_ip", False),
- new_device=context.get("new_device", False),
- user_id=user["id"],
- )
- except Exception:
- pass
+ _send_login_security_alert_if_needed(user=user, username=username, client_ip=client_ip)
return jsonify({"success": True})
diff --git a/routes/api_schedules.py b/routes/api_schedules.py
index 1d81026..14844e5 100644
--- a/routes/api_schedules.py
+++ b/routes/api_schedules.py
@@ -2,7 +2,11 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
+import json
import re
+import threading
+import time as time_mod
+import uuid
import database
from flask import Blueprint, jsonify, request
@@ -17,6 +21,13 @@ api_schedules_bp = Blueprint("api_schedules", __name__)
_HHMM_RE = re.compile(r"^(\d{1,2}):(\d{2})$")
+def _request_json(default=None):
+ if default is None:
+ default = {}
+ data = request.get_json(silent=True)
+ return data if isinstance(data, dict) else default
+
+
def _normalize_hhmm(value: object) -> str | None:
match = _HHMM_RE.match(str(value or "").strip())
if not match:
@@ -28,18 +39,53 @@ def _normalize_hhmm(value: object) -> str | None:
return f"{hour:02d}:{minute:02d}"
+def _normalize_random_delay(value) -> tuple[int | None, str | None]:
+ try:
+ normalized = int(value or 0)
+ except Exception:
+ return None, "random_delay必须是0或1"
+ if normalized not in (0, 1):
+ return None, "random_delay必须是0或1"
+ return normalized, None
+
+
+def _parse_schedule_account_ids(raw_value) -> list:
+ try:
+ parsed = json.loads(raw_value or "[]")
+ except (json.JSONDecodeError, TypeError):
+ return []
+ return parsed if isinstance(parsed, list) else []
+
+
+def _get_owned_schedule_or_error(schedule_id: int):
+ schedule = database.get_schedule_by_id(schedule_id)
+ if not schedule:
+ return None, (jsonify({"error": "定时任务不存在"}), 404)
+ if schedule.get("user_id") != current_user.id:
+ return None, (jsonify({"error": "无权访问"}), 403)
+ return schedule, None
+
+
+def _ensure_user_accounts_loaded(user_id: int) -> None:
+ if safe_get_user_accounts_snapshot(user_id):
+ return
+ load_user_accounts(user_id)
+
+
+def _parse_browse_type_or_error(raw_value, *, default=BROWSE_TYPE_SHOULD_READ):
+ browse_type = validate_browse_type(raw_value, default=default)
+ if not browse_type:
+ return None, (jsonify({"error": "浏览类型无效"}), 400)
+ return browse_type, None
+
+
@api_schedules_bp.route("/api/schedules", methods=["GET"])
@login_required
def get_user_schedules_api():
"""获取当前用户的所有定时任务"""
schedules = database.get_user_schedules(current_user.id)
- import json
-
- for s in schedules:
- try:
- s["account_ids"] = json.loads(s.get("account_ids", "[]") or "[]")
- except (json.JSONDecodeError, TypeError):
- s["account_ids"] = []
+ for schedule in schedules:
+ schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids"))
return jsonify(schedules)
@@ -47,23 +93,26 @@ def get_user_schedules_api():
@login_required
def create_user_schedule_api():
"""创建用户定时任务"""
- data = request.json or {}
+ data = _request_json()
name = data.get("name", "我的定时任务")
schedule_time = data.get("schedule_time", "08:00")
weekdays = data.get("weekdays", "1,2,3,4,5")
- browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ)
- if not browse_type:
- return jsonify({"error": "浏览类型无效"}), 400
+
+ browse_type, browse_error = _parse_browse_type_or_error(data.get("browse_type", BROWSE_TYPE_SHOULD_READ))
+ if browse_error:
+ return browse_error
+
enable_screenshot = data.get("enable_screenshot", 1)
- random_delay = int(data.get("random_delay", 0) or 0)
+ random_delay, delay_error = _normalize_random_delay(data.get("random_delay", 0))
+ if delay_error:
+ return jsonify({"error": delay_error}), 400
+
account_ids = data.get("account_ids", [])
normalized_time = _normalize_hhmm(schedule_time)
if not normalized_time:
return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400
- if random_delay not in (0, 1):
- return jsonify({"error": "random_delay必须是0或1"}), 400
schedule_id = database.create_user_schedule(
user_id=current_user.id,
@@ -85,18 +134,11 @@ def create_user_schedule_api():
@login_required
def get_schedule_detail_api(schedule_id):
"""获取定时任务详情"""
- schedule = database.get_schedule_by_id(schedule_id)
- if not schedule:
- return jsonify({"error": "定时任务不存在"}), 404
- if schedule["user_id"] != current_user.id:
- return jsonify({"error": "无权访问"}), 403
+ schedule, error_response = _get_owned_schedule_or_error(schedule_id)
+ if error_response:
+ return error_response
- import json
-
- try:
- schedule["account_ids"] = json.loads(schedule.get("account_ids", "[]") or "[]")
- except (json.JSONDecodeError, TypeError):
- schedule["account_ids"] = []
+ schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids"))
return jsonify(schedule)
@@ -104,14 +146,12 @@ def get_schedule_detail_api(schedule_id):
@login_required
def update_schedule_api(schedule_id):
"""更新定时任务"""
- schedule = database.get_schedule_by_id(schedule_id)
- if not schedule:
- return jsonify({"error": "定时任务不存在"}), 404
- if schedule["user_id"] != current_user.id:
- return jsonify({"error": "无权访问"}), 403
+ _, error_response = _get_owned_schedule_or_error(schedule_id)
+ if error_response:
+ return error_response
- data = request.json or {}
- allowed_fields = [
+ data = _request_json()
+ allowed_fields = {
"name",
"schedule_time",
"weekdays",
@@ -120,27 +160,26 @@ def update_schedule_api(schedule_id):
"random_delay",
"account_ids",
"enabled",
- ]
-
- update_data = {k: v for k, v in data.items() if k in allowed_fields}
+ }
+ update_data = {key: value for key, value in data.items() if key in allowed_fields}
if "schedule_time" in update_data:
normalized_time = _normalize_hhmm(update_data["schedule_time"])
if not normalized_time:
return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400
update_data["schedule_time"] = normalized_time
+
if "random_delay" in update_data:
- try:
- update_data["random_delay"] = int(update_data.get("random_delay") or 0)
- except Exception:
- return jsonify({"error": "random_delay必须是0或1"}), 400
- if update_data["random_delay"] not in (0, 1):
- return jsonify({"error": "random_delay必须是0或1"}), 400
+ random_delay, delay_error = _normalize_random_delay(update_data.get("random_delay"))
+ if delay_error:
+ return jsonify({"error": delay_error}), 400
+ update_data["random_delay"] = random_delay
+
if "browse_type" in update_data:
- normalized = validate_browse_type(update_data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ)
- if not normalized:
- return jsonify({"error": "浏览类型无效"}), 400
- update_data["browse_type"] = normalized
+ normalized_browse_type, browse_error = _parse_browse_type_or_error(update_data.get("browse_type"))
+ if browse_error:
+ return browse_error
+ update_data["browse_type"] = normalized_browse_type
success = database.update_user_schedule(schedule_id, **update_data)
if success:
@@ -152,11 +191,9 @@ def update_schedule_api(schedule_id):
@login_required
def delete_schedule_api(schedule_id):
"""删除定时任务"""
- schedule = database.get_schedule_by_id(schedule_id)
- if not schedule:
- return jsonify({"error": "定时任务不存在"}), 404
- if schedule["user_id"] != current_user.id:
- return jsonify({"error": "无权访问"}), 403
+ _, error_response = _get_owned_schedule_or_error(schedule_id)
+ if error_response:
+ return error_response
success = database.delete_user_schedule(schedule_id)
if success:
@@ -168,13 +205,11 @@ def delete_schedule_api(schedule_id):
@login_required
def toggle_schedule_api(schedule_id):
"""启用/禁用定时任务"""
- schedule = database.get_schedule_by_id(schedule_id)
- if not schedule:
- return jsonify({"error": "定时任务不存在"}), 404
- if schedule["user_id"] != current_user.id:
- return jsonify({"error": "无权访问"}), 403
+ schedule, error_response = _get_owned_schedule_or_error(schedule_id)
+ if error_response:
+ return error_response
- data = request.json
+ data = _request_json()
enabled = data.get("enabled", not schedule["enabled"])
success = database.toggle_user_schedule(schedule_id, enabled)
@@ -187,22 +222,11 @@ def toggle_schedule_api(schedule_id):
@login_required
def run_schedule_now_api(schedule_id):
"""立即执行定时任务"""
- import json
- import threading
- import time as time_mod
- import uuid
-
- schedule = database.get_schedule_by_id(schedule_id)
- if not schedule:
- return jsonify({"error": "定时任务不存在"}), 404
- if schedule["user_id"] != current_user.id:
- return jsonify({"error": "无权访问"}), 403
-
- try:
- account_ids = json.loads(schedule.get("account_ids", "[]") or "[]")
- except (json.JSONDecodeError, TypeError):
- account_ids = []
+ schedule, error_response = _get_owned_schedule_or_error(schedule_id)
+ if error_response:
+ return error_response
+ account_ids = _parse_schedule_account_ids(schedule.get("account_ids"))
if not account_ids:
return jsonify({"error": "没有配置账号"}), 400
@@ -210,8 +234,7 @@ def run_schedule_now_api(schedule_id):
browse_type = normalize_browse_type(schedule.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule["enable_screenshot"]
- if not safe_get_user_accounts_snapshot(user_id):
- load_user_accounts(user_id)
+ _ensure_user_accounts_loaded(user_id)
from services.state import safe_create_batch, safe_finalize_batch_after_dispatch
from services.task_batches import _send_batch_task_email_if_configured
@@ -250,6 +273,7 @@ def run_schedule_now_api(schedule_id):
if remaining["done"] or remaining["count"] > 0:
return
remaining["done"] = True
+
execution_duration = int(time_mod.time() - execution_start_time)
database.update_schedule_execution_log(
log_id,
@@ -260,19 +284,17 @@ def run_schedule_now_api(schedule_id):
status="completed",
)
+ task_source = f"user_scheduled:{batch_id}"
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
- if not account:
- skipped_count += 1
- continue
- if account.is_running:
+ if (not account) or account.is_running:
skipped_count += 1
continue
- task_source = f"user_scheduled:{batch_id}"
with completion_lock:
remaining["count"] += 1
- ok, msg = submit_account_task(
+
+ ok, _ = submit_account_task(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
diff --git a/routes/api_screenshots.py b/routes/api_screenshots.py
index 0df1ad3..008e67b 100644
--- a/routes/api_screenshots.py
+++ b/routes/api_screenshots.py
@@ -4,6 +4,7 @@ from __future__ import annotations
import os
from datetime import datetime
+from typing import Iterator
import database
from app_config import get_config
@@ -15,41 +16,67 @@ from services.time_utils import BEIJING_TZ
config = get_config()
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
+_IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")
api_screenshots_bp = Blueprint("api_screenshots", __name__)
+def _get_user_prefix(user_id: int) -> str:
+ user_info = database.get_user_by_id(user_id)
+ return user_info["username"] if user_info else f"user{user_id}"
+
+
+def _is_user_screenshot(filename: str, username_prefix: str) -> bool:
+ return filename.startswith(username_prefix + "_") and filename.lower().endswith(_IMAGE_EXTENSIONS)
+
+
+def _iter_user_screenshot_entries(username_prefix: str) -> Iterator[os.DirEntry]:
+ if not os.path.exists(SCREENSHOTS_DIR):
+ return
+
+ with os.scandir(SCREENSHOTS_DIR) as entries:
+ for entry in entries:
+ if (not entry.is_file()) or (not _is_user_screenshot(entry.name, username_prefix)):
+ continue
+ yield entry
+
+
+def _build_display_name(filename: str) -> str:
+ base_name, ext = filename.rsplit(".", 1)
+ parts = base_name.split("_", 1)
+ if len(parts) > 1:
+ return f"{parts[1]}.{ext}"
+ return filename
+
+
@api_screenshots_bp.route("/api/screenshots", methods=["GET"])
@login_required
def get_screenshots():
"""获取当前用户的截图列表"""
user_id = current_user.id
- user_info = database.get_user_by_id(user_id)
- username_prefix = user_info["username"] if user_info else f"user{user_id}"
+ username_prefix = _get_user_prefix(user_id)
try:
screenshots = []
- if os.path.exists(SCREENSHOTS_DIR):
- for filename in os.listdir(SCREENSHOTS_DIR):
- if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"):
- filepath = os.path.join(SCREENSHOTS_DIR, filename)
- stat = os.stat(filepath)
- created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ)
- parts = filename.rsplit(".", 1)[0].split("_", 1)
- if len(parts) > 1:
- display_name = parts[1] + "." + filename.rsplit(".", 1)[1]
- else:
- display_name = filename
+ for entry in _iter_user_screenshot_entries(username_prefix):
+ filename = entry.name
+ stat = entry.stat()
+ created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ)
+
+ screenshots.append(
+ {
+ "filename": filename,
+ "display_name": _build_display_name(filename),
+ "size": stat.st_size,
+ "created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
+ "_created_ts": stat.st_mtime,
+ }
+ )
+
+ screenshots.sort(key=lambda item: item.get("_created_ts", 0), reverse=True)
+ for item in screenshots:
+ item.pop("_created_ts", None)
- screenshots.append(
- {
- "filename": filename,
- "display_name": display_name,
- "size": stat.st_size,
- "created": created_time.strftime("%Y-%m-%d %H:%M:%S"),
- }
- )
- screenshots.sort(key=lambda x: x["created"], reverse=True)
return jsonify(screenshots)
except Exception as e:
return jsonify({"error": str(e)}), 500
@@ -60,10 +87,9 @@ def get_screenshots():
def serve_screenshot(filename):
"""提供截图文件访问"""
user_id = current_user.id
- user_info = database.get_user_by_id(user_id)
- username_prefix = user_info["username"] if user_info else f"user{user_id}"
+ username_prefix = _get_user_prefix(user_id)
- if not filename.startswith(username_prefix + "_"):
+ if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权访问"}), 403
if not is_safe_path(SCREENSHOTS_DIR, filename):
@@ -77,12 +103,14 @@ def serve_screenshot(filename):
def delete_screenshot(filename):
"""删除指定截图"""
user_id = current_user.id
- user_info = database.get_user_by_id(user_id)
- username_prefix = user_info["username"] if user_info else f"user{user_id}"
+ username_prefix = _get_user_prefix(user_id)
- if not filename.startswith(username_prefix + "_"):
+ if not _is_user_screenshot(filename, username_prefix):
return jsonify({"error": "无权删除"}), 403
+ if not is_safe_path(SCREENSHOTS_DIR, filename):
+ return jsonify({"error": "非法路径"}), 403
+
try:
filepath = os.path.join(SCREENSHOTS_DIR, filename)
if os.path.exists(filepath):
@@ -99,19 +127,15 @@ def delete_screenshot(filename):
def clear_all_screenshots():
"""清空当前用户的所有截图"""
user_id = current_user.id
- user_info = database.get_user_by_id(user_id)
- username_prefix = user_info["username"] if user_info else f"user{user_id}"
+ username_prefix = _get_user_prefix(user_id)
try:
deleted_count = 0
- if os.path.exists(SCREENSHOTS_DIR):
- for filename in os.listdir(SCREENSHOTS_DIR):
- if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"):
- filepath = os.path.join(SCREENSHOTS_DIR, filename)
- os.remove(filepath)
- deleted_count += 1
+ for entry in _iter_user_screenshot_entries(username_prefix):
+ os.remove(entry.path)
+ deleted_count += 1
+
log_to_client(f"清理了 {deleted_count} 个截图文件", user_id)
return jsonify({"success": True, "deleted": deleted_count})
except Exception as e:
return jsonify({"error": str(e)}), 500
-
diff --git a/routes/api_user.py b/routes/api_user.py
index cb40502..f5bcb2e 100644
--- a/routes/api_user.py
+++ b/routes/api_user.py
@@ -10,12 +10,96 @@ from flask import Blueprint, jsonify, request
from flask_login import current_user, login_required
from routes.pages import render_app_spa_or_legacy
from services.state import check_email_rate_limit, check_ip_request_rate, safe_iter_task_status_items
+from services.tasks import get_task_scheduler
logger = get_logger("app")
api_user_bp = Blueprint("api_user", __name__)
+def _get_current_user_record():
+ return database.get_user_by_id(current_user.id)
+
+
+def _get_current_user_or_404():
+ user = _get_current_user_record()
+ if user:
+ return user, None
+ return None, (jsonify({"error": "用户不存在"}), 404)
+
+
+def _get_current_username(*, fallback: str) -> str:
+ user = _get_current_user_record()
+ username = (user or {}).get("username", "")
+ return username if username else fallback
+
+
+def _coerce_binary_flag(value, *, field_label: str):
+ if isinstance(value, bool):
+ value = 1 if value else 0
+ try:
+ value = int(value)
+ except Exception:
+ return None, f"{field_label}必须是0或1"
+ if value not in (0, 1):
+ return None, f"{field_label}必须是0或1"
+ return value, None
+
+
+def _check_bind_email_rate_limits(email: str):
+ client_ip = get_rate_limit_ip()
+ allowed, error_msg = check_ip_request_rate(client_ip, "email")
+ if not allowed:
+ return False, error_msg, 429
+ allowed, error_msg = check_email_rate_limit(email, "bind_email")
+ if not allowed:
+ return False, error_msg, 429
+ return True, "", 200
+
+
+def _render_verify_bind_failed(*, title: str, error_message: str):
+ spa_initial_state = {
+ "page": "verify_result",
+ "success": False,
+ "title": title,
+ "error_message": error_message,
+ "primary_label": "返回登录",
+ "primary_url": "/login",
+ }
+ return render_app_spa_or_legacy(
+ "verify_failed.html",
+ legacy_context={"error_message": error_message},
+ spa_initial_state=spa_initial_state,
+ )
+
+
+def _render_verify_bind_success(email: str):
+ spa_initial_state = {
+ "page": "verify_result",
+ "success": True,
+ "title": "邮箱绑定成功",
+ "message": f"邮箱 {email} 已成功绑定到您的账号!",
+ "primary_label": "返回登录",
+ "primary_url": "/login",
+ "redirect_url": "/login",
+ "redirect_seconds": 5,
+ }
+ return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
+
+
+def _get_current_running_count(user_id: int) -> int:
+ try:
+ queue_snapshot = get_task_scheduler().get_queue_state_snapshot() or {}
+ running_by_user = queue_snapshot.get("running_by_user") or {}
+ return int(running_by_user.get(int(user_id), running_by_user.get(str(user_id), 0)) or 0)
+ except Exception:
+ current_running = 0
+ for _, info in safe_iter_task_status_items():
+ if info.get("user_id") == user_id and info.get("status") == "运行中":
+ current_running += 1
+ return current_running
+
+
@api_user_bp.route("/api/announcements/active", methods=["GET"])
@login_required
def get_active_announcement():
@@ -77,8 +161,7 @@ def submit_feedback():
if len(description) > 2000:
return jsonify({"error": "描述不能超过2000个字符"}), 400
- user_info = database.get_user_by_id(current_user.id)
- username = user_info["username"] if user_info else f"用户{current_user.id}"
+ username = _get_current_username(fallback=f"用户{current_user.id}")
feedback_id = database.create_bug_feedback(
user_id=current_user.id,
@@ -104,8 +187,7 @@ def get_my_feedbacks():
def get_current_user_vip():
"""获取当前用户VIP信息"""
vip_info = database.get_user_vip_info(current_user.id)
- user_info = database.get_user_by_id(current_user.id)
- vip_info["username"] = user_info["username"] if user_info else "Unknown"
+ vip_info["username"] = _get_current_username(fallback="Unknown")
return jsonify(vip_info)
@@ -124,9 +206,9 @@ def change_user_password():
if not is_valid:
return jsonify({"error": error_msg}), 400
- user = database.get_user_by_id(current_user.id)
- if not user:
- return jsonify({"error": "用户不存在"}), 404
+ user, error_response = _get_current_user_or_404()
+ if error_response:
+ return error_response
username = user.get("username", "")
if not username or not database.verify_user(username, current_password):
@@ -141,9 +223,9 @@ def change_user_password():
@login_required
def get_user_email():
"""获取当前用户的邮箱信息"""
- user = database.get_user_by_id(current_user.id)
- if not user:
- return jsonify({"error": "用户不存在"}), 404
+ user, error_response = _get_current_user_or_404()
+ if error_response:
+ return error_response
return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)})
@@ -172,14 +254,9 @@ def update_user_kdocs_settings():
return jsonify({"error": "县区长度不能超过50"}), 400
if kdocs_auto_upload is not None:
- if isinstance(kdocs_auto_upload, bool):
- kdocs_auto_upload = 1 if kdocs_auto_upload else 0
- try:
- kdocs_auto_upload = int(kdocs_auto_upload)
- except Exception:
- return jsonify({"error": "自动上传开关必须是0或1"}), 400
- if kdocs_auto_upload not in (0, 1):
- return jsonify({"error": "自动上传开关必须是0或1"}), 400
+ kdocs_auto_upload, parse_error = _coerce_binary_flag(kdocs_auto_upload, field_label="自动上传开关")
+ if parse_error:
+ return jsonify({"error": parse_error}), 400
if not database.update_user_kdocs_settings(
current_user.id,
@@ -207,13 +284,9 @@ def bind_user_email():
if not is_valid:
return jsonify({"error": error_msg}), 400
- client_ip = get_rate_limit_ip()
- allowed, error_msg = check_ip_request_rate(client_ip, "email")
+ allowed, error_msg, status_code = _check_bind_email_rate_limits(email)
if not allowed:
- return jsonify({"error": error_msg}), 429
- allowed, error_msg = check_email_rate_limit(email, "bind_email")
- if not allowed:
- return jsonify({"error": error_msg}), 429
+ return jsonify({"error": error_msg}), status_code
settings = email_service.get_email_settings()
if not settings.get("enabled", False):
@@ -223,9 +296,9 @@ def bind_user_email():
if existing_user and existing_user["id"] != current_user.id:
return jsonify({"error": "该邮箱已被其他用户绑定"}), 400
- user = database.get_user_by_id(current_user.id)
- if not user:
- return jsonify({"error": "用户不存在"}), 404
+ user, error_response = _get_current_user_or_404()
+ if error_response:
+ return error_response
if user.get("email") == email and user.get("email_verified"):
return jsonify({"error": "该邮箱已绑定并验证"}), 400
@@ -247,56 +320,20 @@ def verify_bind_email(token):
email = result["email"]
if database.update_user_email(user_id, email, verified=True):
- spa_initial_state = {
- "page": "verify_result",
- "success": True,
- "title": "邮箱绑定成功",
- "message": f"邮箱 {email} 已成功绑定到您的账号!",
- "primary_label": "返回登录",
- "primary_url": "/login",
- "redirect_url": "/login",
- "redirect_seconds": 5,
- }
- return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state)
+ return _render_verify_bind_success(email)
- error_message = "邮箱绑定失败,请重试"
- spa_initial_state = {
- "page": "verify_result",
- "success": False,
- "title": "绑定失败",
- "error_message": error_message,
- "primary_label": "返回登录",
- "primary_url": "/login",
- }
- return render_app_spa_or_legacy(
- "verify_failed.html",
- legacy_context={"error_message": error_message},
- spa_initial_state=spa_initial_state,
- )
+ return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试")
- error_message = "验证链接已过期或无效,请重新发送验证邮件"
- spa_initial_state = {
- "page": "verify_result",
- "success": False,
- "title": "链接无效",
- "error_message": error_message,
- "primary_label": "返回登录",
- "primary_url": "/login",
- }
- return render_app_spa_or_legacy(
- "verify_failed.html",
- legacy_context={"error_message": error_message},
- spa_initial_state=spa_initial_state,
- )
+ return _render_verify_bind_failed(title="链接无效", error_message="验证链接已过期或无效,请重新发送验证邮件")
@api_user_bp.route("/api/user/unbind-email", methods=["POST"])
@login_required
def unbind_user_email():
"""解绑用户邮箱"""
- user = database.get_user_by_id(current_user.id)
- if not user:
- return jsonify({"error": "用户不存在"}), 404
+ user, error_response = _get_current_user_or_404()
+ if error_response:
+ return error_response
if not user.get("email"):
return jsonify({"error": "当前未绑定邮箱"}), 400
@@ -334,10 +371,7 @@ def get_run_stats():
stats = database.get_user_run_stats(user_id)
- current_running = 0
- for _, info in safe_iter_task_status_items():
- if info.get("user_id") == user_id and info.get("status") == "运行中":
- current_running += 1
+ current_running = _get_current_running_count(user_id)
return jsonify(
{
diff --git a/routes/decorators.py b/routes/decorators.py
index f99f9c3..5267db4 100644
--- a/routes/decorators.py
+++ b/routes/decorators.py
@@ -31,7 +31,7 @@ def admin_required(f):
if is_api:
return jsonify({"error": "需要管理员权限"}), 403
return redirect(url_for("pages.admin_login_page"))
- logger.info(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}")
+ logger.debug(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}")
return f(*args, **kwargs)
return decorated_function
diff --git a/routes/health.py b/routes/health.py
index 57abdc4..c1dc202 100644
--- a/routes/health.py
+++ b/routes/health.py
@@ -2,12 +2,62 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
+import os
+import time
+
from flask import Blueprint, jsonify
import database
+import db_pool
from services.time_utils import get_beijing_now
health_bp = Blueprint("health", __name__)
+_PROCESS_START_TS = time.time()
+
+
+def _build_runtime_metrics() -> dict:
+ metrics = {
+ "uptime_seconds": max(0, int(time.time() - _PROCESS_START_TS)),
+ }
+
+ try:
+ pool_stats = db_pool.get_pool_stats() or {}
+ metrics["db_pool"] = {
+ "pool_size": int(pool_stats.get("pool_size", 0) or 0),
+ "available": int(pool_stats.get("available", 0) or 0),
+ "in_use": int(pool_stats.get("in_use", 0) or 0),
+ }
+ except Exception:
+ pass
+
+ try:
+ import psutil
+
+ proc = psutil.Process(os.getpid())
+ with proc.oneshot():
+ mem_info = proc.memory_info()
+ metrics["process"] = {
+ "rss_mb": round(float(mem_info.rss) / 1024 / 1024, 2),
+ "cpu_percent": round(float(proc.cpu_percent(interval=None)), 2),
+ "threads": int(proc.num_threads()),
+ }
+ except Exception:
+ pass
+
+ try:
+ from services import tasks as tasks_module
+
+ scheduler = getattr(tasks_module, "_task_scheduler", None)
+ if scheduler is not None:
+ queue_snapshot = scheduler.get_queue_state_snapshot() or {}
+ metrics["task_queue"] = {
+ "pending_total": int(queue_snapshot.get("pending_total", 0) or 0),
+ "running_total": int(queue_snapshot.get("running_total", 0) or 0),
+ }
+ except Exception:
+ pass
+
+ return metrics
@health_bp.route("/health", methods=["GET"])
@@ -26,6 +76,6 @@ def health_check():
"time": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"),
"db_ok": db_ok,
"db_error": db_error,
+ "metrics": _build_runtime_metrics(),
}
return jsonify(payload), (200 if db_ok else 500)
-
diff --git a/scripts/HEALTH_MONITOR_README.md b/scripts/HEALTH_MONITOR_README.md
new file mode 100644
index 0000000..e40fc26
--- /dev/null
+++ b/scripts/HEALTH_MONITOR_README.md
@@ -0,0 +1,60 @@
+# 健康监控(邮件版)
+
+本目录提供 `health_email_monitor.py`,通过调用 `/health` 接口并使用**容器内已有邮件配置**发告警邮件。
+
+## 1) 快速试跑
+
+```bash
+cd /root/zsglpt
+python3 scripts/health_email_monitor.py \
+ --to 你的告警邮箱@example.com \
+ --container knowledge-automation-multiuser \
+ --url http://127.0.0.1:51232/health \
+ --dry-run
+```
+
+去掉 `--dry-run` 即会实际发邮件。
+
+## 2) 建议 cron(每分钟)
+
+```bash
+* * * * * cd /root/zsglpt && /usr/bin/python3 scripts/health_email_monitor.py \
+ --to 你的告警邮箱@example.com \
+ --container knowledge-automation-multiuser \
+ --url http://127.0.0.1:51232/health \
+ >> /root/zsglpt/logs/health_monitor.log 2>&1
+```
+
+## 3) 支持的规则
+
+- `service_down`:健康接口请求失败(立即告警)
+- `health_fail`:返回 `ok/db_ok` 异常或 HTTP 5xx(立即告警)
+- `db_pool_exhausted`:连接池耗尽(默认连续 3 次才告警)
+- `queue_backlog_high`:任务堆积过高(默认 `pending_total >= 50` 且连续 5 次)
+
+脚本支持恢复通知(规则恢复正常会发“恢复”邮件)。
+
+## 4) 常用参数
+
+- `--to`:收件人(必填)
+- `--container`:Docker 容器名(默认 `knowledge-automation-multiuser`)
+- `--url`:健康地址(默认 `http://127.0.0.1:51232/health`)
+- `--state-file`:状态文件路径(默认 `/tmp/zsglpt_health_monitor_state.json`)
+- `--remind-seconds`:重复告警间隔(默认 3600 秒)
+- `--queue-threshold`:队列告警阈值(默认 50)
+- `--queue-streak`:队列连续次数阈值(默认 5)
+- `--db-pool-streak`:连接池连续次数阈值(默认 3)
+
+## 5) 环境变量方式(可选)
+
+也可不用命令行参数,改用环境变量:
+
+- `MONITOR_EMAIL_TO`
+- `MONITOR_DOCKER_CONTAINER`
+- `HEALTH_URL`
+- `MONITOR_STATE_FILE`
+- `MONITOR_REMIND_SECONDS`
+- `MONITOR_QUEUE_THRESHOLD`
+- `MONITOR_QUEUE_STREAK`
+- `MONITOR_DB_POOL_STREAK`
+
diff --git a/scripts/health_email_monitor.py b/scripts/health_email_monitor.py
new file mode 100644
index 0000000..0abfb37
--- /dev/null
+++ b/scripts/health_email_monitor.py
@@ -0,0 +1,348 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import subprocess
+import time
+from datetime import datetime
+from typing import Any, Dict, Tuple
+from urllib.error import HTTPError, URLError
+from urllib.request import Request, urlopen
+
+DEFAULT_STATE_FILE = "/tmp/zsglpt_health_monitor_state.json"
+
+
+def _now_text() -> str:
+ return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+
+
+def _safe_int(value: Any, default: int = 0) -> int:
+ try:
+ return int(value)
+ except Exception:
+ return int(default)
+
+
+def _safe_float(value: Any, default: float = 0.0) -> float:
+ try:
+ return float(value)
+ except Exception:
+ return float(default)
+
+
+def _load_state(path: str) -> Dict[str, Any]:
+ if not path or not os.path.exists(path):
+ return {
+ "version": 1,
+ "rules": {},
+ "counters": {},
+ }
+ try:
+ with open(path, "r", encoding="utf-8") as f:
+ raw = json.load(f)
+ if not isinstance(raw, dict):
+ raise ValueError("state is not dict")
+ raw.setdefault("version", 1)
+ raw.setdefault("rules", {})
+ raw.setdefault("counters", {})
+ return raw
+ except Exception:
+ return {
+ "version": 1,
+ "rules": {},
+ "counters": {},
+ }
+
+
+def _save_state(path: str, state: Dict[str, Any]) -> None:
+ if not path:
+ return
+ state_dir = os.path.dirname(path)
+ if state_dir:
+ os.makedirs(state_dir, exist_ok=True)
+ tmp_path = f"{path}.tmp"
+ with open(tmp_path, "w", encoding="utf-8") as f:
+ json.dump(state, f, ensure_ascii=False, indent=2)
+ os.replace(tmp_path, path)
+
+
+def _fetch_health(url: str, timeout: int) -> Tuple[int | None, Dict[str, Any], str | None]:
+ req = Request(
+ url,
+ headers={
+ "User-Agent": "zsglpt-health-email-monitor/1.0",
+ "Accept": "application/json",
+ },
+ method="GET",
+ )
+ try:
+ with urlopen(req, timeout=max(1, int(timeout))) as resp:
+ status = int(resp.getcode())
+ body = resp.read().decode("utf-8", errors="ignore")
+ except HTTPError as e:
+ status = int(getattr(e, "code", 0) or 0)
+ body = ""
+ try:
+ body = e.read().decode("utf-8", errors="ignore")
+ except Exception:
+ pass
+ data = {}
+ if body:
+ try:
+ data = json.loads(body)
+ if not isinstance(data, dict):
+ data = {}
+ except Exception:
+ data = {}
+ return status, data, f"HTTPError: {e}"
+ except URLError as e:
+ return None, {}, f"URLError: {e}"
+ except Exception as e:
+ return None, {}, f"RequestError: {e}"
+
+ data: Dict[str, Any] = {}
+ if body:
+ try:
+ loaded = json.loads(body)
+ if isinstance(loaded, dict):
+ data = loaded
+ except Exception:
+ data = {}
+
+ return status, data, None
+
+
+def _inc_streak(state: Dict[str, Any], key: str, bad: bool) -> int:
+ counters = state.setdefault("counters", {})
+ current = _safe_int(counters.get(key), 0)
+ current = (current + 1) if bad else 0
+ counters[key] = current
+ return current
+
+
+def _rule_transition(
+ state: Dict[str, Any],
+ *,
+ rule_name: str,
+ bad: bool,
+ streak: int,
+ threshold: int,
+ remind_seconds: int,
+ now_ts: float,
+) -> str | None:
+ rules = state.setdefault("rules", {})
+ rule_state = rules.setdefault(rule_name, {"active": False, "last_sent": 0})
+
+ is_active = bool(rule_state.get("active", False))
+ last_sent = _safe_float(rule_state.get("last_sent", 0), 0.0)
+ threshold = max(1, int(threshold))
+ remind_seconds = max(60, int(remind_seconds))
+
+ if bad and streak >= threshold:
+ if not is_active:
+ rule_state["active"] = True
+ rule_state["last_sent"] = now_ts
+ return "alert"
+ if (now_ts - last_sent) >= remind_seconds:
+ rule_state["last_sent"] = now_ts
+ return "alert"
+ return None
+
+ if is_active and (not bad):
+ rule_state["active"] = False
+ rule_state["last_sent"] = now_ts
+ return "recover"
+
+ return None
+
+
+def _send_email_via_container(
+ *,
+ container_name: str,
+ to_email: str,
+ subject: str,
+ body: str,
+ timeout_seconds: int = 45,
+) -> Tuple[bool, str]:
+ code = (
+ "import sys,email_service;"
+ "res=email_service.send_email(to_email=sys.argv[1],subject=sys.argv[2],body=sys.argv[3],email_type='health_monitor');"
+ "ok=bool(res.get('success'));"
+ "print('ok' if ok else ('error:'+str(res.get('error'))));"
+ "raise SystemExit(0 if ok else 2)"
+ )
+ cmd = [
+ "docker",
+ "exec",
+ container_name,
+ "python",
+ "-c",
+ code,
+ to_email,
+ subject,
+ body,
+ ]
+ try:
+ proc = subprocess.run(
+ cmd,
+ capture_output=True,
+ text=True,
+ timeout=max(5, int(timeout_seconds)),
+ check=False,
+ )
+ except Exception as e:
+ return False, str(e)
+
+ output = (proc.stdout or "") + (proc.stderr or "")
+ output = output.strip()
+ return proc.returncode == 0, output
+
+
+def _build_common_lines(status: int | None, data: Dict[str, Any], fetch_error: str | None) -> list[str]:
+ metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {}
+ db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {}
+ process = metrics.get("process") if isinstance(metrics.get("process"), dict) else {}
+ task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {}
+
+ lines = [
+ f"时间: {_now_text()}",
+ f"健康地址: {data.get('_monitor_url', '')}",
+ f"HTTP状态: {status if status is not None else '请求失败'}",
+ f"ok/db_ok: {data.get('ok')} / {data.get('db_ok')}",
+ ]
+ if fetch_error:
+ lines.append(f"请求错误: {fetch_error}")
+ lines.extend(
+ [
+ f"队列: pending={task_queue.get('pending_total', 'N/A')}, running={task_queue.get('running_total', 'N/A')}",
+ f"连接池: size={db_pool.get('pool_size', 'N/A')}, available={db_pool.get('available', 'N/A')}, in_use={db_pool.get('in_use', 'N/A')}",
+ f"进程: rss_mb={process.get('rss_mb', 'N/A')}, cpu%={process.get('cpu_percent', 'N/A')}, threads={process.get('threads', 'N/A')}",
+ f"运行时长: {metrics.get('uptime_seconds', 'N/A')} 秒",
+ ]
+ )
+ return lines
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(description="zsglpt 邮件健康监控(基于 /health)")
+ parser.add_argument("--url", default=os.environ.get("HEALTH_URL", "http://127.0.0.1:51232/health"))
+ parser.add_argument("--to", default=os.environ.get("MONITOR_EMAIL_TO", ""))
+ parser.add_argument(
+ "--container",
+ default=os.environ.get("MONITOR_DOCKER_CONTAINER", "knowledge-automation-multiuser"),
+ )
+ parser.add_argument("--state-file", default=os.environ.get("MONITOR_STATE_FILE", DEFAULT_STATE_FILE))
+ parser.add_argument("--timeout", type=int, default=_safe_int(os.environ.get("MONITOR_TIMEOUT", 8), 8))
+ parser.add_argument(
+ "--remind-seconds",
+ type=int,
+ default=_safe_int(os.environ.get("MONITOR_REMIND_SECONDS", 3600), 3600),
+ )
+ parser.add_argument(
+ "--queue-threshold",
+ type=int,
+ default=_safe_int(os.environ.get("MONITOR_QUEUE_THRESHOLD", 50), 50),
+ )
+ parser.add_argument(
+ "--queue-streak",
+ type=int,
+ default=_safe_int(os.environ.get("MONITOR_QUEUE_STREAK", 5), 5),
+ )
+ parser.add_argument(
+ "--db-pool-streak",
+ type=int,
+ default=_safe_int(os.environ.get("MONITOR_DB_POOL_STREAK", 3), 3),
+ )
+ parser.add_argument("--dry-run", action="store_true", help="仅打印,不实际发邮件")
+ args = parser.parse_args()
+
+ if not args.to:
+ print("[monitor] 缺少收件人,请设置 --to 或 MONITOR_EMAIL_TO", flush=True)
+ return 2
+
+ state = _load_state(args.state_file)
+ now_ts = time.time()
+
+ status, data, fetch_error = _fetch_health(args.url, args.timeout)
+ if not isinstance(data, dict):
+ data = {}
+ data["_monitor_url"] = args.url
+
+ metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {}
+ db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {}
+ task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {}
+
+ service_down = status is None
+ health_fail = bool(status is not None and (status >= 500 or (not data.get("ok", False)) or (not data.get("db_ok", False))))
+ db_pool_exhausted = (
+ _safe_int(db_pool.get("pool_size"), 0) > 0
+ and _safe_int(db_pool.get("available"), 0) <= 0
+ and _safe_int(db_pool.get("in_use"), 0) >= _safe_int(db_pool.get("pool_size"), 0)
+ )
+ queue_backlog_high = _safe_int(task_queue.get("pending_total"), 0) >= max(1, int(args.queue_threshold))
+
+ rule_defs = [
+ ("service_down", service_down, 1),
+ ("health_fail", health_fail, 1),
+ ("db_pool_exhausted", db_pool_exhausted, max(1, int(args.db_pool_streak))),
+ ("queue_backlog_high", queue_backlog_high, max(1, int(args.queue_streak))),
+ ]
+
+ pending_notifications: list[tuple[str, str]] = []
+ for rule_name, bad, threshold in rule_defs:
+ streak = _inc_streak(state, rule_name, bad)
+ action = _rule_transition(
+ state,
+ rule_name=rule_name,
+ bad=bad,
+ streak=streak,
+ threshold=threshold,
+ remind_seconds=args.remind_seconds,
+ now_ts=now_ts,
+ )
+ if action:
+ pending_notifications.append((rule_name, action))
+
+ _save_state(args.state_file, state)
+
+ if not pending_notifications:
+ print(f"[monitor] {_now_text()} 正常,无需发送邮件")
+ return 0
+
+ common_lines = _build_common_lines(status, data, fetch_error)
+
+ for rule_name, action in pending_notifications:
+ level = "告警" if action == "alert" else "恢复"
+ subject = f"[zsglpt健康监控][{level}] {rule_name}"
+ body_lines = [
+ f"规则: {rule_name}",
+ f"状态: {level}",
+ "",
+ *common_lines,
+ ]
+ body = "\n".join(body_lines)
+
+ if args.dry_run:
+ print(f"[monitor][dry-run] subject={subject}\n{body}\n")
+ continue
+
+ ok, msg = _send_email_via_container(
+ container_name=args.container,
+ to_email=args.to,
+ subject=subject,
+ body=body,
+ )
+ if ok:
+ print(f"[monitor] 邮件已发送: {subject}")
+ else:
+ print(f"[monitor] 邮件发送失败: {subject} | {msg}")
+
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
+
diff --git a/services/kdocs_uploader.py b/services/kdocs_uploader.py
index 88d7f87..4d25fc5 100644
--- a/services/kdocs_uploader.py
+++ b/services/kdocs_uploader.py
@@ -243,6 +243,35 @@ class KDocsUploader:
except queue.Empty:
return {"success": False, "error": "操作超时"}
+ def _put_task_response(self, task: Dict[str, Any], result: Dict[str, Any]) -> None:
+ response_queue = task.get("response")
+ if not response_queue:
+ return
+ try:
+ response_queue.put(result)
+ except Exception:
+ return
+
+ def _process_task(self, task: Dict[str, Any]) -> bool:
+ action = task.get("action")
+ payload = task.get("payload") or {}
+
+ if action == "shutdown":
+ return False
+ if action == "upload":
+ self._handle_upload(payload)
+ return True
+ if action == "qr":
+ self._put_task_response(task, self._handle_qr(payload))
+ return True
+ if action == "clear_login":
+ self._put_task_response(task, self._handle_clear_login())
+ return True
+ if action == "status":
+ self._put_task_response(task, self._handle_status_check())
+ return True
+ return True
+
def _run(self) -> None:
thread_id = self._thread_id
logger.info(f"[KDocs] 上传线程启动 (ID={thread_id})")
@@ -261,34 +290,17 @@ class KDocsUploader:
# 更新最后活动时间
self._last_activity = time.time()
- action = task.get("action")
- if action == "shutdown":
- break
-
try:
- if action == "upload":
- self._handle_upload(task.get("payload") or {})
- elif action == "qr":
- result = self._handle_qr(task.get("payload") or {})
- task.get("response").put(result)
- elif action == "clear_login":
- result = self._handle_clear_login()
- task.get("response").put(result)
- elif action == "status":
- result = self._handle_status_check()
- task.get("response").put(result)
+ should_continue = self._process_task(task)
+ if not should_continue:
+ break
# 任务处理完成后更新活动时间
self._last_activity = time.time()
except Exception as e:
logger.warning(f"[KDocs] 处理任务失败: {e}")
- # 如果有响应队列,返回错误
- if "response" in task and task.get("response"):
- try:
- task["response"].put({"success": False, "error": str(e)})
- except Exception:
- pass
+ self._put_task_response(task, {"success": False, "error": str(e)})
except Exception as e:
logger.warning(f"[KDocs] 线程主循环异常: {e}")
@@ -830,18 +842,180 @@ class KDocsUploader:
except Exception as e:
logger.warning(f"[KDocs] 保存登录态失败: {e}")
+ def _resolve_doc_url(self, cfg: Dict[str, Any]) -> str:
+ return (cfg.get("kdocs_doc_url") or "").strip()
+
+ def _ensure_doc_access(
+ self,
+ doc_url: str,
+ *,
+ fast: bool = False,
+ use_storage_state: bool = True,
+ ) -> Optional[str]:
+ if not self._ensure_playwright(use_storage_state=use_storage_state):
+ return self._last_error or "浏览器不可用"
+ if not self._open_document(doc_url, fast=fast):
+ return self._last_error or "打开文档失败"
+ return None
+
+ def _trigger_fast_login_dialog(self, timeout_ms: int) -> None:
+ self._ensure_login_dialog(
+ timeout_ms=timeout_ms,
+ frame_timeout_ms=timeout_ms,
+ quick=True,
+ )
+
+ def _capture_qr_with_retry(self, fast_login_timeout: int) -> Tuple[Optional[bytes], Optional[bytes]]:
+ qr_image = None
+ invalid_qr = None
+ for attempt in range(10):
+ if attempt in (3, 7):
+ self._trigger_fast_login_dialog(fast_login_timeout)
+ candidate = self._capture_qr_image()
+ if candidate and self._is_valid_qr_image(candidate):
+ qr_image = candidate
+ break
+ if candidate:
+ invalid_qr = candidate
+ time.sleep(0.8) # 优化: 1 -> 0.8
+ return qr_image, invalid_qr
+
+ def _save_qr_debug_artifacts(self, invalid_qr: Optional[bytes]) -> None:
+ try:
+ pages = self._iter_pages()
+ page_urls = [getattr(p, "url", "") for p in pages]
+ logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
+
+ ts = int(time.time())
+ saved = []
+ for idx, page in enumerate(pages[:3]):
+ try:
+ path = f"data/kdocs_debug_{ts}_{idx}.png"
+ page.screenshot(path=path, full_page=True)
+ saved.append(path)
+ except Exception:
+ continue
+
+ if saved:
+ logger.warning(f"[KDocs] 已保存调试截图: {saved}")
+
+ if invalid_qr:
+ try:
+ path = f"data/kdocs_invalid_qr_{ts}.png"
+ with open(path, "wb") as handle:
+ handle.write(invalid_qr)
+ logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ def _log_upload_failure(self, message: str, user_id: Any, account_id: Any) -> None:
+ try:
+ log_to_client(f"表格上传失败: {message}", user_id, account_id)
+ except Exception:
+ pass
+
+ def _mark_upload_tracking(self, user_id: Any, account_id: Any) -> Tuple[Any, Optional[str], bool]:
+ account = None
+ prev_status = None
+ status_tracked = False
+ try:
+ account = safe_get_account(user_id, account_id)
+ if account and self._should_mark_upload(account):
+ prev_status = getattr(account, "status", None)
+ account.status = "上传截图"
+ self._emit_account_update(user_id, account)
+ status_tracked = True
+ except Exception:
+ prev_status = None
+ return account, prev_status, status_tracked
+
+ def _parse_upload_payload(self, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ unit = (payload.get("unit") or "").strip()
+ name = (payload.get("name") or "").strip()
+ image_path = payload.get("image_path")
+
+ if not unit or not name:
+ return None
+ if not image_path or not os.path.exists(image_path):
+ return None
+
+ return {
+ "unit": unit,
+ "name": name,
+ "image_path": image_path,
+ "user_id": payload.get("user_id"),
+ "account_id": payload.get("account_id"),
+ }
+
+ def _resolve_upload_sheet_config(self, cfg: Dict[str, Any]) -> Dict[str, Any]:
+ return {
+ "sheet_name": (cfg.get("kdocs_sheet_name") or "").strip(),
+ "sheet_index": int(cfg.get("kdocs_sheet_index") or 0),
+ "unit_col": (cfg.get("kdocs_unit_column") or "A").strip().upper(),
+ "image_col": (cfg.get("kdocs_image_column") or "D").strip().upper(),
+ "row_start": int(cfg.get("kdocs_row_start") or 0),
+ "row_end": int(cfg.get("kdocs_row_end") or 0),
+ }
+
+ def _try_upload_to_sheet(self, cfg: Dict[str, Any], unit: str, name: str, image_path: str) -> Tuple[bool, str]:
+ sheet_cfg = self._resolve_upload_sheet_config(cfg)
+ success = False
+ error_msg = ""
+
+ for _ in range(2):
+ try:
+ if sheet_cfg["sheet_name"] or sheet_cfg["sheet_index"]:
+ self._select_sheet(sheet_cfg["sheet_name"], sheet_cfg["sheet_index"])
+
+ row_num = self._find_person_with_unit(
+ unit,
+ name,
+ sheet_cfg["unit_col"],
+ row_start=sheet_cfg["row_start"],
+ row_end=sheet_cfg["row_end"],
+ )
+ if row_num < 0:
+ error_msg = f"未找到人员: {unit}-{name}"
+ break
+
+ success = self._upload_image_to_cell(row_num, image_path, sheet_cfg["image_col"])
+ if success:
+ break
+ except Exception as e:
+ error_msg = str(e)
+
+ return success, error_msg
+
+ def _handle_upload_login_invalid(
+ self,
+ *,
+ unit: str,
+ name: str,
+ image_path: str,
+ user_id: Any,
+ account_id: Any,
+ ) -> None:
+ error_msg = "登录已失效,请管理员重新扫码登录"
+ self._login_required = True
+ self._last_login_ok = False
+ self._notify_admin(unit, name, image_path, error_msg)
+ self._log_upload_failure(error_msg, user_id, account_id)
+
def _handle_qr(self, payload: Dict[str, Any]) -> Dict[str, Any]:
cfg = self._load_system_config()
- doc_url = (cfg.get("kdocs_doc_url") or "").strip()
+ doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return {"success": False, "error": "未配置金山文档链接"}
+
force = bool(payload.get("force"))
if force:
self._handle_clear_login()
- if not self._ensure_playwright(use_storage_state=not force):
- return {"success": False, "error": self._last_error or "浏览器不可用"}
- if not self._open_document(doc_url, fast=True):
- return {"success": False, "error": self._last_error or "打开文档失败"}
+
+ doc_error = self._ensure_doc_access(doc_url, fast=True, use_storage_state=not force)
+ if doc_error:
+ return {"success": False, "error": doc_error}
if not force and self._has_saved_login_state() and self._is_logged_in():
self._login_required = False
@@ -850,54 +1024,12 @@ class KDocsUploader:
return {"success": True, "logged_in": True, "qr_image": ""}
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
- self._ensure_login_dialog(
- timeout_ms=fast_login_timeout,
- frame_timeout_ms=fast_login_timeout,
- quick=True,
- )
- qr_image = None
- invalid_qr = None
- for attempt in range(10):
- if attempt in (3, 7):
- self._ensure_login_dialog(
- timeout_ms=fast_login_timeout,
- frame_timeout_ms=fast_login_timeout,
- quick=True,
- )
- candidate = self._capture_qr_image()
- if candidate and self._is_valid_qr_image(candidate):
- qr_image = candidate
- break
- if candidate:
- invalid_qr = candidate
- time.sleep(0.8) # 优化: 1 -> 0.8
+ self._trigger_fast_login_dialog(fast_login_timeout)
+ qr_image, invalid_qr = self._capture_qr_with_retry(fast_login_timeout)
+
if not qr_image:
self._last_error = "二维码识别异常" if invalid_qr else "二维码获取失败"
- try:
- pages = self._iter_pages()
- page_urls = [getattr(p, "url", "") for p in pages]
- logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
- ts = int(time.time())
- saved = []
- for idx, page in enumerate(pages[:3]):
- try:
- path = f"data/kdocs_debug_{ts}_{idx}.png"
- page.screenshot(path=path, full_page=True)
- saved.append(path)
- except Exception:
- continue
- if saved:
- logger.warning(f"[KDocs] 已保存调试截图: {saved}")
- if invalid_qr:
- try:
- path = f"data/kdocs_invalid_qr_{ts}.png"
- with open(path, "wb") as handle:
- handle.write(invalid_qr)
- logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
- except Exception:
- pass
- except Exception:
- pass
+ self._save_qr_debug_artifacts(invalid_qr)
return {"success": False, "error": self._last_error}
try:
@@ -933,24 +1065,22 @@ class KDocsUploader:
def _handle_status_check(self) -> Dict[str, Any]:
cfg = self._load_system_config()
- doc_url = (cfg.get("kdocs_doc_url") or "").strip()
+ doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return {"success": True, "logged_in": False, "error": "未配置文档链接"}
- if not self._ensure_playwright():
- return {"success": False, "logged_in": False, "error": self._last_error or "浏览器不可用"}
- if not self._open_document(doc_url, fast=True):
- return {"success": False, "logged_in": False, "error": self._last_error or "打开文档失败"}
+
+ doc_error = self._ensure_doc_access(doc_url, fast=True)
+ if doc_error:
+ return {"success": False, "logged_in": False, "error": doc_error}
+
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
- self._ensure_login_dialog(
- timeout_ms=fast_login_timeout,
- frame_timeout_ms=fast_login_timeout,
- quick=True,
- )
+ self._trigger_fast_login_dialog(fast_login_timeout)
self._try_confirm_login(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
+
logged_in = self._is_logged_in()
self._last_login_ok = logged_in
self._login_required = not logged_in
@@ -962,79 +1092,43 @@ class KDocsUploader:
cfg = self._load_system_config()
if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
return
- doc_url = (cfg.get("kdocs_doc_url") or "").strip()
+
+ doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return
- unit = (payload.get("unit") or "").strip()
- name = (payload.get("name") or "").strip()
- image_path = payload.get("image_path")
- user_id = payload.get("user_id")
- account_id = payload.get("account_id")
-
- if not unit or not name:
- return
- if not image_path or not os.path.exists(image_path):
+ upload_data = self._parse_upload_payload(payload)
+ if not upload_data:
return
- account = None
- prev_status = None
- status_tracked = False
+ unit = upload_data["unit"]
+ name = upload_data["name"]
+ image_path = upload_data["image_path"]
+ user_id = upload_data["user_id"]
+ account_id = upload_data["account_id"]
+
+ account, prev_status, status_tracked = self._mark_upload_tracking(user_id, account_id)
try:
- try:
- account = safe_get_account(user_id, account_id)
- if account and self._should_mark_upload(account):
- prev_status = getattr(account, "status", None)
- account.status = "上传截图"
- self._emit_account_update(user_id, account)
- status_tracked = True
- except Exception:
- prev_status = None
-
- if not self._ensure_playwright():
- self._notify_admin(unit, name, image_path, self._last_error or "浏览器不可用")
- return
-
- if not self._open_document(doc_url):
- self._notify_admin(unit, name, image_path, self._last_error or "打开文档失败")
+ doc_error = self._ensure_doc_access(doc_url)
+ if doc_error:
+ self._notify_admin(unit, name, image_path, doc_error)
return
if not self._is_logged_in():
- self._login_required = True
- self._last_login_ok = False
- self._notify_admin(unit, name, image_path, "登录已失效,请管理员重新扫码登录")
- try:
- log_to_client("表格上传失败: 登录已失效,请管理员重新扫码登录", user_id, account_id)
- except Exception:
- pass
+ self._handle_upload_login_invalid(
+ unit=unit,
+ name=name,
+ image_path=image_path,
+ user_id=user_id,
+ account_id=account_id,
+ )
return
+
self._login_required = False
self._last_login_ok = True
- sheet_name = (cfg.get("kdocs_sheet_name") or "").strip()
- sheet_index = int(cfg.get("kdocs_sheet_index") or 0)
- unit_col = (cfg.get("kdocs_unit_column") or "A").strip().upper()
- image_col = (cfg.get("kdocs_image_column") or "D").strip().upper()
- row_start = int(cfg.get("kdocs_row_start") or 0)
- row_end = int(cfg.get("kdocs_row_end") or 0)
-
- success = False
- error_msg = ""
- for attempt in range(2):
- try:
- if sheet_name or sheet_index:
- self._select_sheet(sheet_name, sheet_index)
- row_num = self._find_person_with_unit(unit, name, unit_col, row_start=row_start, row_end=row_end)
- if row_num < 0:
- error_msg = f"未找到人员: {unit}-{name}"
- break
- success = self._upload_image_to_cell(row_num, image_path, image_col)
- if success:
- break
- except Exception as e:
- error_msg = str(e)
-
+ success, error_msg = self._try_upload_to_sheet(cfg, unit, name, image_path)
if success:
self._last_success_at = time.time()
self._last_error = None
@@ -1048,10 +1142,7 @@ class KDocsUploader:
error_msg = "上传失败"
self._last_error = error_msg
self._notify_admin(unit, name, image_path, error_msg)
- try:
- log_to_client(f"表格上传失败: {error_msg}", user_id, account_id)
- except Exception:
- pass
+ self._log_upload_failure(error_msg, user_id, account_id)
finally:
if status_tracked:
self._restore_account_status(user_id, account, prev_status)
diff --git a/services/maintenance.py b/services/maintenance.py
index 7dbfb90..fda72f1 100644
--- a/services/maintenance.py
+++ b/services/maintenance.py
@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
+import os
import threading
import time
from datetime import datetime
@@ -10,6 +11,8 @@ from app_config import get_config
from app_logger import get_logger
from services.state import (
cleanup_expired_ip_rate_limits,
+ cleanup_expired_ip_request_rates,
+ cleanup_expired_login_security_state,
safe_cleanup_expired_batches,
safe_cleanup_expired_captcha,
safe_cleanup_expired_pending_random,
@@ -31,6 +34,69 @@ PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECON
_kdocs_offline_notified: bool = False
+def _to_int(value, default: int = 0) -> int:
+ try:
+ return int(value)
+ except Exception:
+ return int(default)
+
+
+def _collect_active_user_ids() -> set[int]:
+ active_user_ids: set[int] = set()
+ for _, info in safe_iter_task_status_items():
+ user_id = info.get("user_id") if isinstance(info, dict) else None
+ if user_id is None:
+ continue
+ try:
+ active_user_ids.add(int(user_id))
+ except Exception:
+ continue
+ return active_user_ids
+
+
+def _find_expired_user_cache_ids(current_time: float, active_user_ids: set[int]) -> list[int]:
+ expired_users = []
+ for user_id, last_access in (safe_get_user_accounts_last_access_items() or []):
+ try:
+ user_id_int = int(user_id)
+ last_access_ts = float(last_access)
+ except Exception:
+ continue
+ if (current_time - last_access_ts) <= USER_ACCOUNTS_EXPIRE_SECONDS:
+ continue
+ if user_id_int in active_user_ids:
+ continue
+ if safe_has_user(user_id_int):
+ expired_users.append(user_id_int)
+ return expired_users
+
+
+def _find_completed_task_status_ids(current_time: float) -> list[str]:
+ completed_task_ids = []
+ for account_id, status_data in safe_iter_task_status_items():
+ status = status_data.get("status") if isinstance(status_data, dict) else None
+ if status not in ["已完成", "失败", "已停止"]:
+ continue
+
+ start_time = float(status_data.get("start_time", 0) or 0)
+ if (current_time - start_time) > 600: # 10分钟
+ completed_task_ids.append(account_id)
+ return completed_task_ids
+
+
+def _reap_zombie_processes() -> None:
+ while True:
+ try:
+ pid, _ = os.waitpid(-1, os.WNOHANG)
+ if pid == 0:
+ break
+ logger.debug(f"已回收僵尸进程: PID={pid}")
+ except ChildProcessError:
+ break
+ except Exception:
+ break
+
+
def cleanup_expired_data() -> None:
"""定期清理过期数据,防止内存泄漏(逻辑保持不变)。"""
current_time = time.time()
@@ -43,48 +109,36 @@ def cleanup_expired_data() -> None:
if deleted_ips:
logger.debug(f"已清理 {deleted_ips} 个过期IP限流记录")
- expired_users = []
- last_access_items = safe_get_user_accounts_last_access_items()
- if last_access_items:
- task_items = safe_iter_task_status_items()
- active_user_ids = {int(info.get("user_id")) for _, info in task_items if info.get("user_id")}
- for user_id, last_access in last_access_items:
- if (current_time - float(last_access)) <= USER_ACCOUNTS_EXPIRE_SECONDS:
- continue
- if int(user_id) in active_user_ids:
- continue
- if safe_has_user(user_id):
- expired_users.append(int(user_id))
+ deleted_ip_requests = cleanup_expired_ip_request_rates(current_time)
+ if deleted_ip_requests:
+ logger.debug(f"已清理 {deleted_ip_requests} 个过期IP请求频率记录")
+ login_cleanup_stats = cleanup_expired_login_security_state(current_time)
+ login_cleanup_total = sum(int(v or 0) for v in login_cleanup_stats.values())
+ if login_cleanup_total:
+ logger.debug(
+ "已清理登录风控缓存: "
+ f"失败计数={login_cleanup_stats.get('failures', 0)}, "
+ f"限流桶={login_cleanup_stats.get('rate_limits', 0)}, "
+ f"扫描状态={login_cleanup_stats.get('scan_states', 0)}, "
+ f"短时锁={login_cleanup_stats.get('ip_user_locks', 0)}, "
+ f"告警状态={login_cleanup_stats.get('alerts', 0)}"
+ )
+
+ active_user_ids = _collect_active_user_ids()
+ expired_users = _find_expired_user_cache_ids(current_time, active_user_ids)
for user_id in expired_users:
safe_remove_user_accounts(user_id)
if expired_users:
logger.debug(f"已清理 {len(expired_users)} 个过期用户账号缓存")
- completed_tasks = []
- for account_id, status_data in safe_iter_task_status_items():
- if status_data.get("status") in ["已完成", "失败", "已停止"]:
- start_time = float(status_data.get("start_time", 0) or 0)
- if (current_time - start_time) > 600: # 10分钟
- completed_tasks.append(account_id)
- for account_id in completed_tasks:
+ completed_task_ids = _find_completed_task_status_ids(current_time)
+ for account_id in completed_task_ids:
safe_remove_task_status(account_id)
- if completed_tasks:
- logger.debug(f"已清理 {len(completed_tasks)} 个已完成任务状态")
+ if completed_task_ids:
+ logger.debug(f"已清理 {len(completed_task_ids)} 个已完成任务状态")
- try:
- import os
-
- while True:
- try:
- pid, status = os.waitpid(-1, os.WNOHANG)
- if pid == 0:
- break
- logger.debug(f"已回收僵尸进程: PID={pid}")
- except ChildProcessError:
- break
- except Exception:
- pass
+ _reap_zombie_processes()
deleted_batches = safe_cleanup_expired_batches(BATCH_TASK_EXPIRE_SECONDS, current_time)
if deleted_batches:
@@ -95,52 +149,39 @@ def cleanup_expired_data() -> None:
logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务")
-def check_kdocs_online_status() -> None:
- """检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
- global _kdocs_offline_notified
+def _load_kdocs_monitor_config():
+ import database
+ cfg = database.get_system_config()
+ if not cfg:
+ return None
+
+ kdocs_enabled = _to_int(cfg.get("kdocs_enabled"), 0)
+ if not kdocs_enabled:
+ return None
+
+ admin_notify_enabled = _to_int(cfg.get("kdocs_admin_notify_enabled"), 0)
+ admin_notify_email = str(cfg.get("kdocs_admin_notify_email") or "").strip()
+ if (not admin_notify_enabled) or (not admin_notify_email):
+ return None
+
+ return admin_notify_email
+
+
+def _is_kdocs_offline(status: dict) -> tuple[bool, bool, bool | None]:
+ login_required = bool(status.get("login_required", False))
+ last_login_ok = status.get("last_login_ok")
+ is_offline = login_required or (last_login_ok is False)
+ return is_offline, login_required, last_login_ok
+
+
+def _send_kdocs_offline_alert(admin_notify_email: str, *, login_required: bool, last_login_ok) -> bool:
try:
- import database
- from services.kdocs_uploader import get_kdocs_uploader
+ import email_service
- # 获取系统配置
- cfg = database.get_system_config()
- if not cfg:
- return
-
- # 检查是否启用了金山文档功能
- kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
- if not kdocs_enabled:
- return
-
- # 检查是否启用了管理员通知
- admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0)
- admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
- if not admin_notify_enabled or not admin_notify_email:
- return
-
- # 获取金山文档状态
- kdocs = get_kdocs_uploader()
- status = kdocs.get_status()
- login_required = status.get("login_required", False)
- last_login_ok = status.get("last_login_ok")
-
- # 如果需要登录或最后登录状态不是成功
- is_offline = login_required or (last_login_ok is False)
-
- if is_offline:
- # 已经通知过了,不再重复通知
- if _kdocs_offline_notified:
- logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
- return
-
- # 发送邮件通知
- try:
- import email_service
-
- now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- subject = "【金山文档离线告警】需要重新登录"
- body = f"""
+ now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ subject = "【金山文档离线告警】需要重新登录"
+ body = f"""
您好,
系统检测到金山文档上传功能已离线,需要重新扫码登录。
@@ -155,58 +196,92 @@ def check_kdocs_online_status() -> None:
---
此邮件由系统自动发送,请勿直接回复。
"""
- email_service.send_email_async(
- to_email=admin_notify_email,
- subject=subject,
- body=body,
- email_type="kdocs_offline_alert",
- )
- _kdocs_offline_notified = True # 标记为已通知
- logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
- except Exception as e:
- logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
- else:
- # 恢复在线,重置通知状态
+ email_service.send_email_async(
+ to_email=admin_notify_email,
+ subject=subject,
+ body=body,
+ email_type="kdocs_offline_alert",
+ )
+ logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
+ return True
+ except Exception as e:
+ logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
+ return False
+
+
+def check_kdocs_online_status() -> None:
+ """检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
+ global _kdocs_offline_notified
+
+ try:
+ admin_notify_email = _load_kdocs_monitor_config()
+ if not admin_notify_email:
+ return
+
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ kdocs = get_kdocs_uploader()
+ status = kdocs.get_status() or {}
+ is_offline, login_required, last_login_ok = _is_kdocs_offline(status)
+
+ if is_offline:
if _kdocs_offline_notified:
- logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
- _kdocs_offline_notified = False
- logger.debug("[KDocs监控] 金山文档状态正常")
+ logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
+ return
+
+ if _send_kdocs_offline_alert(
+ admin_notify_email,
+ login_required=login_required,
+ last_login_ok=last_login_ok,
+ ):
+ _kdocs_offline_notified = True
+ return
+
+ if _kdocs_offline_notified:
+ logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
+ _kdocs_offline_notified = False
+ logger.debug("[KDocs监控] 金山文档状态正常")
except Exception as e:
logger.error(f"[KDocs监控] 检测失败: {e}")
-def start_cleanup_scheduler() -> None:
- """启动定期清理调度器"""
-
- def cleanup_loop():
+def _start_daemon_loop(name: str, *, startup_delay: float, interval_seconds: float, job, error_tag: str):
+ def loop():
+ if startup_delay > 0:
+ time.sleep(startup_delay)
while True:
try:
- time.sleep(300) # 每5分钟执行一次清理
- cleanup_expired_data()
+ job()
+ time.sleep(interval_seconds)
except Exception as e:
- logger.error(f"清理任务执行失败: {e}")
+ logger.error(f"{error_tag}: {e}")
+ time.sleep(min(60.0, max(1.0, interval_seconds / 5.0)))
- cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True, name="cleanup-scheduler")
- cleanup_thread.start()
+ thread = threading.Thread(target=loop, daemon=True, name=name)
+ thread.start()
+ return thread
+
+
+def start_cleanup_scheduler() -> None:
+ """启动定期清理调度器"""
+ _start_daemon_loop(
+ "cleanup-scheduler",
+ startup_delay=300,
+ interval_seconds=300,
+ job=cleanup_expired_data,
+ error_tag="清理任务执行失败",
+ )
logger.info("内存清理调度器已启动")
def start_kdocs_monitor() -> None:
"""启动金山文档状态监控"""
-
- def monitor_loop():
- # 启动后等待 60 秒再开始检测(给系统初始化的时间)
- time.sleep(60)
- while True:
- try:
- check_kdocs_online_status()
- time.sleep(300) # 每5分钟检测一次
- except Exception as e:
- logger.error(f"[KDocs监控] 监控任务执行失败: {e}")
- time.sleep(60)
-
- monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor")
- monitor_thread.start()
+ _start_daemon_loop(
+ "kdocs-monitor",
+ startup_delay=60,
+ interval_seconds=300,
+ job=check_kdocs_online_status,
+ error_tag="[KDocs监控] 监控任务执行失败",
+ )
logger.info("[KDocs监控] 金山文档状态监控已启动(每5分钟检测一次)")
-
diff --git a/services/scheduler.py b/services/scheduler.py
index c562bb1..892b705 100644
--- a/services/scheduler.py
+++ b/services/scheduler.py
@@ -27,6 +27,12 @@ from services.time_utils import get_beijing_now
logger = get_logger("app")
config = get_config()
+try:
+ _SCHEDULE_SUBMIT_DELAY_SECONDS = float(os.environ.get("SCHEDULE_SUBMIT_DELAY_SECONDS", "0.2"))
+except Exception:
+ _SCHEDULE_SUBMIT_DELAY_SECONDS = 0.2
+_SCHEDULE_SUBMIT_DELAY_SECONDS = max(0.0, _SCHEDULE_SUBMIT_DELAY_SECONDS)
+
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
@@ -55,6 +61,150 @@ def _normalize_hhmm(value: object, *, default: str) -> str:
return f"{hour:02d}:{minute:02d}"
+def _safe_recompute_schedule_next_run(schedule_id: int) -> None:
+ try:
+ database.recompute_schedule_next_run(schedule_id)
+ except Exception:
+ pass
+
+
+def _load_accounts_for_users(approved_users: list[dict]) -> tuple[dict[int, dict], list[str]]:
+ """批量加载用户账号快照。"""
+ user_accounts: dict[int, dict] = {}
+ account_ids: list[str] = []
+ for user in approved_users:
+ user_id = user["id"]
+ accounts = safe_get_user_accounts_snapshot(user_id)
+ if not accounts:
+ load_user_accounts(user_id)
+ accounts = safe_get_user_accounts_snapshot(user_id)
+ if accounts:
+ user_accounts[user_id] = accounts
+ account_ids.extend(list(accounts.keys()))
+ return user_accounts, account_ids
+
+
+def _should_skip_suspended_account(account_status_info, account, username: str) -> bool:
+ """判断是否应跳过暂停账号,并输出日志。"""
+ if not account_status_info:
+ return False
+
+ status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
+ if status != "suspended":
+ return False
+
+ fail_count = account_status_info["login_fail_count"] if "login_fail_count" in account_status_info.keys() else 0
+ logger.info(
+ f"[定时任务] 跳过暂停账号: {account.username} (用户:{username}) - 连续{fail_count}次密码错误,需修改密码"
+ )
+ return True
+
+
+def _parse_schedule_account_ids(schedule_config: dict, schedule_id: int):
+ import json
+
+ try:
+ account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
+ account_ids = json.loads(account_ids_raw)
+ except Exception as e:
+ logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
+ return []
+ if isinstance(account_ids, list):
+ return account_ids
+ return []
+
+
+def _create_user_schedule_batch(*, batch_id: str, user_id: int, browse_type: str, schedule_name: str, now_ts: float) -> None:
+ safe_create_batch(
+ batch_id,
+ {
+ "user_id": user_id,
+ "browse_type": browse_type,
+ "schedule_name": schedule_name,
+ "screenshots": [],
+ "total_accounts": 0,
+ "completed": 0,
+ "created_at": now_ts,
+ "updated_at": now_ts,
+ },
+ )
+
+
+def _build_user_schedule_done_callback(
+ *,
+ completion_lock: threading.Lock,
+ remaining: dict,
+ counters: dict,
+ execution_start_time: float,
+ log_id: int,
+ schedule_id: int,
+ total_accounts: int,
+):
+ def on_browse_done():
+ with completion_lock:
+ remaining["count"] -= 1
+ if remaining["done"] or remaining["count"] > 0:
+ return
+ remaining["done"] = True
+
+ execution_duration = int(time.time() - execution_start_time)
+ started_count = int(counters.get("started", 0) or 0)
+ database.update_schedule_execution_log(
+ log_id,
+ total_accounts=total_accounts,
+ success_accounts=started_count,
+ failed_accounts=total_accounts - started_count,
+ duration_seconds=execution_duration,
+ status="completed",
+ )
+ logger.info(f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件")
+
+ return on_browse_done
+
+
+def _submit_user_schedule_accounts(
+ *,
+ user_id: int,
+ account_ids: list,
+ browse_type: str,
+ enable_screenshot,
+ task_source: str,
+ done_callback,
+ completion_lock: threading.Lock,
+ remaining: dict,
+ counters: dict,
+) -> tuple[int, int]:
+ started_count = 0
+ skipped_count = 0
+
+ for account_id in account_ids:
+ account = safe_get_account(user_id, account_id)
+ if (not account) or account.is_running:
+ skipped_count += 1
+ continue
+
+ with completion_lock:
+ remaining["count"] += 1
+ ok, msg = submit_account_task(
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ enable_screenshot=enable_screenshot,
+ source=task_source,
+ done_callback=done_callback,
+ )
+ if ok:
+ started_count += 1
+ counters["started"] = started_count
+ else:
+ with completion_lock:
+ remaining["count"] -= 1
+ skipped_count += 1
+ logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
+
+ return started_count, skipped_count
+
+
def run_scheduled_task(skip_weekday_check: bool = False) -> None:
"""执行所有账号的浏览任务(可被手动调用,过滤重复账号)"""
try:
@@ -87,17 +237,7 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
cfg = database.get_system_config()
enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1
- user_accounts = {}
- account_ids = []
- for user in approved_users:
- user_id = user["id"]
- accounts = safe_get_user_accounts_snapshot(user_id)
- if not accounts:
- load_user_accounts(user_id)
- accounts = safe_get_user_accounts_snapshot(user_id)
- if accounts:
- user_accounts[user_id] = accounts
- account_ids.extend(list(accounts.keys()))
+ user_accounts, account_ids = _load_accounts_for_users(approved_users)
account_statuses = database.get_account_status_batch(account_ids)
@@ -113,18 +253,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
continue
account_status_info = account_statuses.get(str(account_id))
- if account_status_info:
- status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
- if status == "suspended":
- fail_count = (
- account_status_info["login_fail_count"]
- if "login_fail_count" in account_status_info.keys()
- else 0
- )
- logger.info(
- f"[定时任务] 跳过暂停账号: {account.username} (用户:{user['username']}) - 连续{fail_count}次密码错误,需修改密码"
- )
- continue
+ if _should_skip_suspended_account(account_status_info, account, user["username"]):
+ continue
if account.username in executed_usernames:
skipped_duplicates += 1
@@ -149,7 +279,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
else:
logger.warning(f"[定时任务] 启动失败({account.username}): {msg}")
- time.sleep(2)
+ if _SCHEDULE_SUBMIT_DELAY_SECONDS > 0:
+ time.sleep(_SCHEDULE_SUBMIT_DELAY_SECONDS)
logger.info(
f"[定时任务] 执行完成 - 总账号数:{total_accounts}, 已执行:{executed_accounts}, 跳过重复:{skipped_duplicates}"
@@ -198,15 +329,16 @@ def scheduled_task_worker() -> None:
deleted_screenshots = 0
if os.path.exists(SCREENSHOTS_DIR):
cutoff_time = time.time() - (7 * 24 * 60 * 60)
- for filename in os.listdir(SCREENSHOTS_DIR):
- if filename.lower().endswith((".png", ".jpg", ".jpeg")):
- filepath = os.path.join(SCREENSHOTS_DIR, filename)
+ with os.scandir(SCREENSHOTS_DIR) as entries:
+ for entry in entries:
+ if (not entry.is_file()) or (not entry.name.lower().endswith((".png", ".jpg", ".jpeg"))):
+ continue
try:
- if os.path.getmtime(filepath) < cutoff_time:
- os.remove(filepath)
+ if entry.stat().st_mtime < cutoff_time:
+ os.remove(entry.path)
deleted_screenshots += 1
except Exception as e:
- logger.warning(f"[定时清理] 删除截图失败 {filename}: {str(e)}")
+ logger.warning(f"[定时清理] 删除截图失败 {entry.name}: {str(e)}")
logger.info(f"[定时清理] 已删除 {deleted_screenshots} 个截图文件")
logger.info("[定时清理] 清理完成!")
@@ -214,10 +346,97 @@ def scheduled_task_worker() -> None:
except Exception as e:
logger.exception(f"[定时清理] 清理任务出错: {str(e)}")
+ def _parse_due_schedule_weekdays(schedule_config: dict, schedule_id: int):
+ weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
+ try:
+ return [int(d) for d in weekdays_str.split(",") if d.strip()]
+ except Exception as e:
+ logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
+ _safe_recompute_schedule_next_run(schedule_id)
+ return None
+
+ def _execute_due_user_schedule(schedule_config: dict) -> None:
+ schedule_name = schedule_config.get("name", "未命名任务")
+ schedule_id = schedule_config["id"]
+ user_id = schedule_config["user_id"]
+ browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
+ enable_screenshot = schedule_config.get("enable_screenshot", 1)
+
+ account_ids = _parse_schedule_account_ids(schedule_config, schedule_id)
+ if not account_ids:
+ _safe_recompute_schedule_next_run(schedule_id)
+ return
+
+ if not safe_get_user_accounts_snapshot(user_id):
+ load_user_accounts(user_id)
+
+ import uuid
+
+ execution_start_time = time.time()
+ log_id = database.create_schedule_execution_log(
+ schedule_id=schedule_id,
+ user_id=user_id,
+ schedule_name=schedule_name,
+ )
+
+ batch_id = f"batch_{uuid.uuid4().hex[:12]}"
+ now_ts = time.time()
+ _create_user_schedule_batch(
+ batch_id=batch_id,
+ user_id=user_id,
+ browse_type=browse_type,
+ schedule_name=schedule_name,
+ now_ts=now_ts,
+ )
+
+ completion_lock = threading.Lock()
+ remaining = {"count": 0, "done": False}
+ counters = {"started": 0}
+
+ on_browse_done = _build_user_schedule_done_callback(
+ completion_lock=completion_lock,
+ remaining=remaining,
+ counters=counters,
+ execution_start_time=execution_start_time,
+ log_id=log_id,
+ schedule_id=schedule_id,
+ total_accounts=len(account_ids),
+ )
+
+ task_source = f"user_scheduled:{batch_id}"
+ started_count, skipped_count = _submit_user_schedule_accounts(
+ user_id=user_id,
+ account_ids=account_ids,
+ browse_type=browse_type,
+ enable_screenshot=enable_screenshot,
+ task_source=task_source,
+ done_callback=on_browse_done,
+ completion_lock=completion_lock,
+ remaining=remaining,
+ counters=counters,
+ )
+
+ batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time.time())
+ if batch_info:
+ _send_batch_task_email_if_configured(batch_info)
+
+ database.update_schedule_last_run(schedule_id)
+
+ logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号,批次ID: {batch_id}")
+ if started_count <= 0:
+ database.update_schedule_execution_log(
+ log_id,
+ total_accounts=len(account_ids),
+ success_accounts=0,
+ failed_accounts=len(account_ids),
+ duration_seconds=0,
+ status="completed",
+ )
+ if started_count == 0 and len(account_ids) > 0:
+ logger.warning("[用户定时任务] ⚠️ 警告:所有账号都被跳过了!请检查user_accounts状态")
+
def check_user_schedules():
"""检查并执行用户定时任务(O-08:next_run_at 索引驱动)。"""
- import json
-
try:
now = get_beijing_now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
@@ -226,145 +445,22 @@ def scheduled_task_worker() -> None:
due_schedules = database.get_due_user_schedules(now_str, limit=50) or []
for schedule_config in due_schedules:
- schedule_name = schedule_config.get("name", "未命名任务")
schedule_id = schedule_config["id"]
+ schedule_name = schedule_config.get("name", "未命名任务")
- weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
- try:
- allowed_weekdays = [int(d) for d in weekdays_str.split(",") if d.strip()]
- except Exception as e:
- logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
- try:
- database.recompute_schedule_next_run(schedule_id)
- except Exception:
- pass
+ allowed_weekdays = _parse_due_schedule_weekdays(schedule_config, schedule_id)
+ if allowed_weekdays is None:
continue
if current_weekday not in allowed_weekdays:
- try:
- database.recompute_schedule_next_run(schedule_id)
- except Exception:
- pass
+ _safe_recompute_schedule_next_run(schedule_id)
continue
- logger.info(f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 (next_run_at={schedule_config.get('next_run_at')})")
-
- user_id = schedule_config["user_id"]
- schedule_id = schedule_config["id"]
- browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
- enable_screenshot = schedule_config.get("enable_screenshot", 1)
-
- try:
- account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
- account_ids = json.loads(account_ids_raw)
- except Exception as e:
- logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
- account_ids = []
-
- if not account_ids:
- try:
- database.recompute_schedule_next_run(schedule_id)
- except Exception:
- pass
- continue
-
- if not safe_get_user_accounts_snapshot(user_id):
- load_user_accounts(user_id)
-
- import time as time_mod
- import uuid
-
- execution_start_time = time_mod.time()
- log_id = database.create_schedule_execution_log(
- schedule_id=schedule_id, user_id=user_id, schedule_name=schedule_config.get("name", "未命名任务")
+ logger.info(
+ f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 "
+ f"(next_run_at={schedule_config.get('next_run_at')})"
)
-
- batch_id = f"batch_{uuid.uuid4().hex[:12]}"
- now_ts = time_mod.time()
- safe_create_batch(
- batch_id,
- {
- "user_id": user_id,
- "browse_type": browse_type,
- "schedule_name": schedule_config.get("name", "未命名任务"),
- "screenshots": [],
- "total_accounts": 0,
- "completed": 0,
- "created_at": now_ts,
- "updated_at": now_ts,
- },
- )
-
- started_count = 0
- skipped_count = 0
- completion_lock = threading.Lock()
- remaining = {"count": 0, "done": False}
-
- def on_browse_done():
- with completion_lock:
- remaining["count"] -= 1
- if remaining["done"] or remaining["count"] > 0:
- return
- remaining["done"] = True
- execution_duration = int(time_mod.time() - execution_start_time)
- database.update_schedule_execution_log(
- log_id,
- total_accounts=len(account_ids),
- success_accounts=started_count,
- failed_accounts=len(account_ids) - started_count,
- duration_seconds=execution_duration,
- status="completed",
- )
- logger.info(
- f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件"
- )
-
- for account_id in account_ids:
- account = safe_get_account(user_id, account_id)
- if not account:
- skipped_count += 1
- continue
- if account.is_running:
- skipped_count += 1
- continue
-
- task_source = f"user_scheduled:{batch_id}"
- with completion_lock:
- remaining["count"] += 1
- ok, msg = submit_account_task(
- user_id=user_id,
- account_id=account_id,
- browse_type=browse_type,
- enable_screenshot=enable_screenshot,
- source=task_source,
- done_callback=on_browse_done,
- )
- if ok:
- started_count += 1
- else:
- with completion_lock:
- remaining["count"] -= 1
- skipped_count += 1
- logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
-
- batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time_mod.time())
- if batch_info:
- _send_batch_task_email_if_configured(batch_info)
-
- database.update_schedule_last_run(schedule_id)
-
- logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号,批次ID: {batch_id}")
- if started_count <= 0:
- database.update_schedule_execution_log(
- log_id,
- total_accounts=len(account_ids),
- success_accounts=0,
- failed_accounts=len(account_ids),
- duration_seconds=0,
- status="completed",
- )
- if started_count == 0 and len(account_ids) > 0:
- logger.warning("[用户定时任务] ⚠️ 警告:所有账号都被跳过了!请检查user_accounts状态")
+ _execute_due_user_schedule(schedule_config)
except Exception as e:
logger.exception(f"[用户定时任务] 检查出错: {str(e)}")
diff --git a/services/screenshots.py b/services/screenshots.py
index 2b47b59..c870f4a 100644
--- a/services/screenshots.py
+++ b/services/screenshots.py
@@ -6,12 +6,14 @@ import os
import shutil
import subprocess
import time
+from urllib.parse import urlsplit
import database
import email_service
from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh
from app_config import get_config
from app_logger import get_logger
+from app_security import sanitize_filename
from browser_pool_worker import get_browser_worker_pool
from services.client_log import log_to_client
from services.runtime import get_socketio
@@ -194,6 +196,293 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
pass
+def _set_screenshot_running_status(user_id: int, account_id: str) -> None:
+ """更新账号状态为截图中。"""
+ acc = safe_get_account(user_id, account_id)
+ if not acc:
+ return
+ acc.status = "截图中"
+ safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
+ _emit("account_update", acc.to_dict(), room=f"user_{user_id}")
+
+
+def _get_worker_display_info(browser_instance) -> tuple[str, int]:
+ """获取截图 worker 的展示信息。"""
+ if isinstance(browser_instance, dict):
+ return str(browser_instance.get("worker_id", "?")), int(browser_instance.get("use_count", 0) or 0)
+ return "?", 0
+
+
+def _get_proxy_context(account) -> tuple[dict | None, str | None]:
+ """提取截图阶段代理配置。"""
+ proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
+ proxy_server = proxy_config.get("server") if proxy_config else None
+ return proxy_config, proxy_server
+
+
+def _build_screenshot_targets(browse_type: str) -> tuple[str, str, str]:
+ """构建截图目标 URL 与页面脚本。"""
+ parsed = urlsplit(config.ZSGL_LOGIN_URL)
+ base = f"{parsed.scheme}://{parsed.netloc}"
+ if "注册前" in str(browse_type):
+ bz = 0
+ else:
+ bz = 0
+
+ target_url = f"{base}/admin/center.aspx?bz={bz}"
+ index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
+ run_script = (
+ "(function(){"
+ "function done(){window.status='ready';}"
+ "function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
+ "function expandMenu(){"
+ "try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
+ "try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
+ "try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
+ "try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
+ "try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
+ "}"
+ "function navReady(){"
+ "try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
+ "}"
+ "function frameReady(){"
+ "try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
+ "}"
+ "function check(){"
+ "if(navReady() && frameReady()){done();return;}"
+ "setTimeout(check,300);"
+ "}"
+ "var f=document.getElementById('mainframe');"
+ "ensureNav();"
+ "expandMenu();"
+ "if(!f){done();return;}"
+ f"f.src='{target_url}';"
+ "f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
+ "setTimeout(check,5000);"
+ "})();"
+ )
+ return index_url, target_url, run_script
+
+
+def _build_screenshot_output_path(username_prefix: str, account, browse_type: str) -> tuple[str, str]:
+ """构建截图输出文件名与路径。"""
+ timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
+ login_account = account.remark if account.remark else account.username
+ raw_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
+ screenshot_filename = sanitize_filename(raw_filename)
+ return screenshot_filename, os.path.join(SCREENSHOTS_DIR, screenshot_filename)
+
+
+def _ensure_screenshot_login_state(
+ *,
+ account,
+ proxy_config,
+ cookie_path: str,
+ attempt: int,
+ max_retries: int,
+ user_id: int,
+ account_id: str,
+ custom_log,
+) -> str:
+ """确保截图前登录态有效。返回: ok/retry/fail。"""
+ should_refresh_login = not is_cookie_jar_fresh(cookie_path)
+ if not should_refresh_login:
+ return "ok"
+
+ log_to_client("正在刷新登录态...", user_id, account_id)
+ if _ensure_login_cookies(account, proxy_config, custom_log):
+ return "ok"
+
+ if attempt > 1:
+ log_to_client("截图登录失败", user_id, account_id)
+ if attempt < max_retries:
+ log_to_client("将重试...", user_id, account_id)
+ time.sleep(2)
+ return "retry"
+
+ log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
+ return "fail"
+
+
+def _take_screenshot_once(
+ *,
+ index_url: str,
+ target_url: str,
+ screenshot_path: str,
+ cookie_path: str,
+ proxy_server: str | None,
+ run_script: str,
+ log_callback,
+) -> str:
+ """执行一次截图尝试并验证输出文件。返回: success/invalid/failed。"""
+ cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
+
+ attempts = [
+ {
+ "url": index_url,
+ "run_script": run_script,
+ "window_status": "ready",
+ },
+ {
+ "url": target_url,
+ "run_script": None,
+ "window_status": None,
+ },
+ ]
+
+ ok = False
+ for shot in attempts:
+ ok = take_screenshot_wkhtmltoimage(
+ shot["url"],
+ screenshot_path,
+ cookies_path=cookies_for_shot,
+ proxy_server=proxy_server,
+ run_script=shot["run_script"],
+ window_status=shot["window_status"],
+ log_callback=log_callback,
+ )
+ if ok:
+ break
+
+ if not ok:
+ return "failed"
+
+ if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
+ return "success"
+
+ if os.path.exists(screenshot_path):
+ os.remove(screenshot_path)
+ return "invalid"
+
+
+def _get_result_screenshot_path(result) -> str | None:
+ """从截图结果中提取截图文件绝对路径。"""
+ if result and result.get("success") and result.get("filename"):
+ return os.path.join(SCREENSHOTS_DIR, result["filename"])
+ return None
+
+
+def _enqueue_kdocs_upload_if_needed(user_id: int, account_id: str, account, screenshot_path: str | None) -> None:
+ """按配置提交金山文档上传任务。"""
+ if not screenshot_path:
+ return
+
+ cfg = database.get_system_config() or {}
+ if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
+ return
+
+ doc_url = (cfg.get("kdocs_doc_url") or "").strip()
+ if not doc_url:
+ return
+
+ user_cfg = database.get_user_kdocs_settings(user_id) or {}
+ if int(user_cfg.get("kdocs_auto_upload", 0) or 0) != 1:
+ return
+
+ unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip()
+ name = (account.remark or "").strip()
+ if not unit:
+ log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
+ return
+ if not name:
+ log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
+ return
+
+ from services.kdocs_uploader import get_kdocs_uploader
+
+ ok = get_kdocs_uploader().enqueue_upload(
+ user_id=user_id,
+ account_id=account_id,
+ unit=unit,
+ name=name,
+ image_path=screenshot_path,
+ )
+ if not ok:
+ log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
+
+
+def _dispatch_screenshot_result(
+ *,
+ user_id: int,
+ account_id: str,
+ source: str,
+ browse_type: str,
+ browse_result: dict,
+ result,
+ account,
+ user_info,
+) -> None:
+ """将截图结果发送到批次统计/邮件通知链路。"""
+ batch_id = _get_batch_id_from_source(source)
+ screenshot_path = _get_result_screenshot_path(result)
+ account_name = account.remark if account.remark else account.username
+
+ try:
+ if result and result.get("success") and screenshot_path:
+ _enqueue_kdocs_upload_if_needed(user_id, account_id, account, screenshot_path)
+ except Exception as kdocs_error:
+ logger.warning(f"表格上传任务提交失败: {kdocs_error}")
+
+ if batch_id:
+ _batch_task_record_result(
+ batch_id=batch_id,
+ account_name=account_name,
+ screenshot_path=screenshot_path,
+ total_items=browse_result.get("total_items", 0),
+ total_attachments=browse_result.get("total_attachments", 0),
+ )
+ return
+
+ if source and source.startswith("user_scheduled"):
+ if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
+ email_service.send_task_complete_email_async(
+ user_id=user_id,
+ email=user_info["email"],
+ username=user_info["username"],
+ account_name=account_name,
+ browse_type=browse_type,
+ total_items=browse_result.get("total_items", 0),
+ total_attachments=browse_result.get("total_attachments", 0),
+ screenshot_path=screenshot_path,
+ log_callback=lambda msg: log_to_client(msg, user_id, account_id),
+ )
+
+
+def _finalize_screenshot_callback_state(user_id: int, account_id: str, account) -> None:
+ """截图回调的通用收尾状态变更。"""
+ account.is_running = False
+ account.status = "未开始"
+ safe_remove_task_status(account_id)
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+
+
+def _persist_browse_log_after_screenshot(
+ *,
+ user_id: int,
+ account_id: str,
+ account,
+ browse_type: str,
+ source: str,
+ task_start_time,
+ browse_result,
+) -> None:
+ """截图完成后写入任务日志(浏览完成日志)。"""
+ import time as time_module
+
+ total_elapsed = int(time_module.time() - task_start_time)
+ database.create_task_log(
+ user_id=user_id,
+ account_id=account_id,
+ username=account.username,
+ browse_type=browse_type,
+ status="success",
+ total_items=browse_result.get("total_items", 0),
+ total_attachments=browse_result.get("total_attachments", 0),
+ duration=total_elapsed,
+ source=source,
+ )
+
+
def take_screenshot_for_account(
user_id,
account_id,
@@ -213,21 +502,21 @@ def take_screenshot_for_account(
# 标记账号正在截图(防止重复提交截图任务)
account.is_running = True
+
+ user_info = database.get_user_by_id(user_id)
+ username_prefix = user_info["username"] if user_info else f"user{user_id}"
+
def screenshot_task(
browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result
):
"""在worker线程中执行的截图任务"""
# ✅ 获得worker后,立即更新状态为"截图中"
- acc = safe_get_account(user_id, account_id)
- if acc:
- acc.status = "截图中"
- safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
- _emit("account_update", acc.to_dict(), room=f"user_{user_id}")
+ _set_screenshot_running_status(user_id, account_id)
max_retries = 3
- proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
- proxy_server = proxy_config.get("server") if proxy_config else None
+ proxy_config, proxy_server = _get_proxy_context(account)
cookie_path = get_cookie_jar_path(account.username)
+ index_url, target_url, run_script = _build_screenshot_targets(browse_type)
for attempt in range(1, max_retries + 1):
try:
@@ -239,8 +528,7 @@ def take_screenshot_for_account(
if attempt > 1:
log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id)
- worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?"
- use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0
+ worker_id, use_count = _get_worker_display_info(browser_instance)
log_to_client(
f"使用Worker-{worker_id}执行截图(已执行{use_count}次)",
user_id,
@@ -250,99 +538,39 @@ def take_screenshot_for_account(
def custom_log(message: str):
log_to_client(message, user_id, account_id)
- # 智能登录状态检查:只在必要时才刷新登录
- should_refresh_login = not is_cookie_jar_fresh(cookie_path)
- if should_refresh_login and attempt > 1:
- # 重试时刷新登录(attempt > 1 表示第2次及以后的尝试)
- log_to_client("正在刷新登录态...", user_id, account_id)
- if not _ensure_login_cookies(account, proxy_config, custom_log):
- log_to_client("截图登录失败", user_id, account_id)
- if attempt < max_retries:
- log_to_client("将重试...", user_id, account_id)
- time.sleep(2)
- continue
- log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
- return {"success": False, "error": "登录失败"}
- elif should_refresh_login:
- # 首次尝试时快速检查登录状态
- log_to_client("正在刷新登录态...", user_id, account_id)
- if not _ensure_login_cookies(account, proxy_config, custom_log):
- log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
- return {"success": False, "error": "登录失败"}
+ login_state = _ensure_screenshot_login_state(
+ account=account,
+ proxy_config=proxy_config,
+ cookie_path=cookie_path,
+ attempt=attempt,
+ max_retries=max_retries,
+ user_id=user_id,
+ account_id=account_id,
+ custom_log=custom_log,
+ )
+ if login_state == "retry":
+ continue
+ if login_state == "fail":
+ return {"success": False, "error": "登录失败"}
log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id)
- from urllib.parse import urlsplit
-
- parsed = urlsplit(config.ZSGL_LOGIN_URL)
- base = f"{parsed.scheme}://{parsed.netloc}"
- if "注册前" in str(browse_type):
- bz = 0
- else:
- bz = 0 # 应读(网站更新后 bz=0 为应读)
- target_url = f"{base}/admin/center.aspx?bz={bz}"
- index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
- run_script = (
- "(function(){"
- "function done(){window.status='ready';}"
- "function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
- "function expandMenu(){"
- "try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
- "try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
- "try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
- "try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
- "try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
- "}"
- "function navReady(){"
- "try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
- "}"
- "function frameReady(){"
- "try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
- "}"
- "function check(){"
- "if(navReady() && frameReady()){done();return;}"
- "setTimeout(check,300);"
- "}"
- "var f=document.getElementById('mainframe');"
- "ensureNav();"
- "expandMenu();"
- "if(!f){done();return;}"
- f"f.src='{target_url}';"
- "f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
- "setTimeout(check,5000);"
- "})();"
- )
-
- timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
-
- user_info = database.get_user_by_id(user_id)
- username_prefix = user_info["username"] if user_info else f"user{user_id}"
- login_account = account.remark if account.remark else account.username
- screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
- screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename)
-
- cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
- if take_screenshot_wkhtmltoimage(
- index_url,
- screenshot_path,
- cookies_path=cookies_for_shot,
+ screenshot_filename, screenshot_path = _build_screenshot_output_path(username_prefix, account, browse_type)
+ shot_state = _take_screenshot_once(
+ index_url=index_url,
+ target_url=target_url,
+ screenshot_path=screenshot_path,
+ cookie_path=cookie_path,
proxy_server=proxy_server,
run_script=run_script,
- window_status="ready",
log_callback=custom_log,
- ) or take_screenshot_wkhtmltoimage(
- target_url,
- screenshot_path,
- cookies_path=cookies_for_shot,
- proxy_server=proxy_server,
- log_callback=custom_log,
- ):
- if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
- log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
- return {"success": True, "filename": screenshot_filename}
+ )
+ if shot_state == "success":
+ log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
+ return {"success": True, "filename": screenshot_filename}
+
+ if shot_state == "invalid":
log_to_client("截图文件异常,将重试", user_id, account_id)
- if os.path.exists(screenshot_path):
- os.remove(screenshot_path)
else:
log_to_client("截图保存失败", user_id, account_id)
@@ -361,12 +589,7 @@ def take_screenshot_for_account(
def screenshot_callback(result, error):
"""截图完成回调"""
try:
- account.is_running = False
- account.status = "未开始"
-
- safe_remove_task_status(account_id)
-
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ _finalize_screenshot_callback_state(user_id, account_id, account)
if error:
log_to_client(f"❌ 截图失败: {error}", user_id, account_id)
@@ -375,84 +598,27 @@ def take_screenshot_for_account(
log_to_client(f"❌ 截图失败: {error_msg}", user_id, account_id)
if task_start_time and browse_result:
- import time as time_module
-
- total_elapsed = int(time_module.time() - task_start_time)
- database.create_task_log(
+ _persist_browse_log_after_screenshot(
user_id=user_id,
account_id=account_id,
- username=account.username,
+ account=account,
browse_type=browse_type,
- status="success",
- total_items=browse_result.get("total_items", 0),
- total_attachments=browse_result.get("total_attachments", 0),
- duration=total_elapsed,
source=source,
+ task_start_time=task_start_time,
+ browse_result=browse_result,
)
try:
- batch_id = _get_batch_id_from_source(source)
-
- screenshot_path = None
- if result and result.get("success") and result.get("filename"):
- screenshot_path = os.path.join(SCREENSHOTS_DIR, result["filename"])
-
- account_name = account.remark if account.remark else account.username
-
- try:
- if screenshot_path and result and result.get("success"):
- cfg = database.get_system_config() or {}
- if int(cfg.get("kdocs_enabled", 0) or 0) == 1:
- doc_url = (cfg.get("kdocs_doc_url") or "").strip()
- if doc_url:
- user_cfg = database.get_user_kdocs_settings(user_id) or {}
- if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1:
- unit = (
- user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or ""
- ).strip()
- name = (account.remark or "").strip()
- if unit and name:
- from services.kdocs_uploader import get_kdocs_uploader
-
- ok = get_kdocs_uploader().enqueue_upload(
- user_id=user_id,
- account_id=account_id,
- unit=unit,
- name=name,
- image_path=screenshot_path,
- )
- if not ok:
- log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
- else:
- if not unit:
- log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
- if not name:
- log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
- except Exception as kdocs_error:
- logger.warning(f"表格上传任务提交失败: {kdocs_error}")
-
- if batch_id:
- _batch_task_record_result(
- batch_id=batch_id,
- account_name=account_name,
- screenshot_path=screenshot_path,
- total_items=browse_result.get("total_items", 0),
- total_attachments=browse_result.get("total_attachments", 0),
- )
- elif source and source.startswith("user_scheduled"):
- user_info = database.get_user_by_id(user_id)
- if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
- email_service.send_task_complete_email_async(
- user_id=user_id,
- email=user_info["email"],
- username=user_info["username"],
- account_name=account_name,
- browse_type=browse_type,
- total_items=browse_result.get("total_items", 0),
- total_attachments=browse_result.get("total_attachments", 0),
- screenshot_path=screenshot_path,
- log_callback=lambda msg: log_to_client(msg, user_id, account_id),
- )
+ _dispatch_screenshot_result(
+ user_id=user_id,
+ account_id=account_id,
+ source=source,
+ browse_type=browse_type,
+ browse_result=browse_result,
+ result=result,
+ account=account,
+ user_info=user_info,
+ )
except Exception as email_error:
logger.warning(f"发送任务完成邮件失败: {email_error}")
except Exception as e:
diff --git a/services/state.py b/services/state.py
index 3725026..c3968c2 100644
--- a/services/state.py
+++ b/services/state.py
@@ -13,7 +13,7 @@ from __future__ import annotations
import threading
import time
import random
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from app_config import get_config
@@ -161,6 +161,36 @@ _log_cache_lock = threading.RLock()
_log_cache_total_count = 0
+def _pop_oldest_log_for_user(uid: int) -> bool:
+ global _log_cache_total_count
+
+ logs = _log_cache.get(uid)
+ if not logs:
+ _log_cache.pop(uid, None)
+ return False
+
+ logs.pop(0)
+ _log_cache_total_count = max(0, _log_cache_total_count - 1)
+ if not logs:
+ _log_cache.pop(uid, None)
+ return True
+
+
+def _pop_oldest_log_from_largest_user() -> bool:
+ largest_uid = None
+ largest_size = 0
+
+ for uid, logs in _log_cache.items():
+ size = len(logs)
+ if size > largest_size:
+ largest_uid = uid
+ largest_size = size
+
+ if largest_uid is None or largest_size <= 0:
+ return False
+ return _pop_oldest_log_for_user(int(largest_uid))
+
+
def safe_add_log(
user_id: int,
log_entry: Dict[str, Any],
@@ -175,24 +205,17 @@ def safe_add_log(
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
with _log_cache_lock:
- if uid not in _log_cache:
- _log_cache[uid] = []
+ logs = _log_cache.setdefault(uid, [])
- if len(_log_cache[uid]) >= max_logs_per_user:
- _log_cache[uid].pop(0)
- _log_cache_total_count = max(0, _log_cache_total_count - 1)
+ if len(logs) >= max_logs_per_user:
+ _pop_oldest_log_for_user(uid)
+ logs = _log_cache.setdefault(uid, [])
- _log_cache[uid].append(dict(log_entry or {}))
+ logs.append(dict(log_entry or {}))
_log_cache_total_count += 1
while _log_cache_total_count > max_total_logs:
- if not _log_cache:
- break
- max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
- if _log_cache.get(max_user):
- _log_cache[max_user].pop(0)
- _log_cache_total_count -= 1
- else:
+ if not _pop_oldest_log_from_largest_user():
break
@@ -378,6 +401,34 @@ def _get_action_rate_limit(action: str) -> Tuple[int, int]:
return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS)
+def _format_wait_hint(remaining_seconds: int) -> str:
+ remaining = max(1, int(remaining_seconds or 0))
+ if remaining >= 60:
+ return f"{remaining // 60 + 1}分钟"
+ return f"{remaining}秒"
+
+
+def _check_and_increment_rate_bucket(
+ *,
+ buckets: Dict[str, Dict[str, Any]],
+ key: str,
+ now_ts: float,
+ max_requests: int,
+ window_seconds: int,
+) -> Tuple[bool, Optional[str]]:
+ if not key or int(max_requests) <= 0:
+ return True, None
+
+ data = _get_or_reset_bucket(buckets.get(key), now_ts, window_seconds)
+ if int(data.get("count", 0) or 0) >= int(max_requests):
+ remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
+ return False, f"请求过于频繁,请{_format_wait_hint(remaining)}后再试"
+
+ data["count"] = int(data.get("count", 0) or 0) + 1
+ buckets[key] = data
+ return True, None
+
+
def check_ip_request_rate(
ip_address: str,
action: str,
@@ -392,21 +443,13 @@ def check_ip_request_rate(
key = f"{action}:{ip_address}"
with _ip_request_rate_lock:
- data = _ip_request_rate.get(key)
- if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds:
- data = {"window_start": now_ts, "count": 0}
- _ip_request_rate[key] = data
-
- if int(data.get("count", 0) or 0) >= max_requests:
- remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
- if remaining >= 60:
- wait_hint = f"{remaining // 60 + 1}分钟"
- else:
- wait_hint = f"{remaining}秒"
- return False, f"请求过于频繁,请{wait_hint}后再试"
-
- data["count"] = int(data.get("count", 0) or 0) + 1
- return True, None
+ return _check_and_increment_rate_bucket(
+ buckets=_ip_request_rate,
+ key=key,
+ now_ts=now_ts,
+ max_requests=max_requests,
+ window_seconds=window_seconds,
+ )
def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
@@ -417,8 +460,7 @@ def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
data = _ip_request_rate.get(key) or {}
action = key.split(":", 1)[0]
_, window_seconds = _get_action_rate_limit(action)
- window_start = float(data.get("window_start", 0) or 0)
- if now_ts - window_start >= window_seconds:
+ if _is_bucket_expired(data, now_ts, window_seconds):
_ip_request_rate.pop(key, None)
removed += 1
return removed
@@ -487,6 +529,30 @@ def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_s
return data
+def _is_bucket_expired(
+ data: Optional[Dict[str, Any]],
+ now_ts: float,
+ window_seconds: int,
+ *,
+ time_field: str = "window_start",
+) -> bool:
+ start_ts = float((data or {}).get(time_field, 0) or 0)
+ return (now_ts - start_ts) >= max(1, int(window_seconds))
+
+
+def _cleanup_map_entries(
+ store: Dict[Any, Dict[str, Any]],
+ should_remove: Callable[[Dict[str, Any]], bool],
+) -> int:
+ removed = 0
+ for key, value in list(store.items()):
+ item = value if isinstance(value, dict) else {}
+ if should_remove(item):
+ store.pop(key, None)
+ removed += 1
+ return removed
+
+
def record_login_username_attempt(ip_address: str, username: str) -> bool:
now_ts = time.time()
threshold, window_seconds, cooldown_seconds = _get_login_scan_config()
@@ -527,26 +593,32 @@ def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optio
user_key = _normalize_login_key("user", "", username)
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
- def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]:
- if not key or max_requests <= 0:
- return True, None
- data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds)
- if int(data.get("count", 0) or 0) >= max_requests:
- remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
- wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}秒"
- return False, f"请求过于频繁,请{wait_hint}后再试"
- data["count"] = int(data.get("count", 0) or 0) + 1
- _login_rate_limits[key] = data
- return True, None
-
with _login_rate_limits_lock:
- allowed, msg = _check(ip_key, ip_max)
+ allowed, msg = _check_and_increment_rate_bucket(
+ buckets=_login_rate_limits,
+ key=ip_key,
+ now_ts=now_ts,
+ max_requests=ip_max,
+ window_seconds=window_seconds,
+ )
if not allowed:
return False, msg
- allowed, msg = _check(ip_user_key, ip_user_max)
+ allowed, msg = _check_and_increment_rate_bucket(
+ buckets=_login_rate_limits,
+ key=ip_user_key,
+ now_ts=now_ts,
+ max_requests=ip_user_max,
+ window_seconds=window_seconds,
+ )
if not allowed:
return False, msg
- allowed, msg = _check(user_key, user_max)
+ allowed, msg = _check_and_increment_rate_bucket(
+ buckets=_login_rate_limits,
+ key=user_key,
+ now_ts=now_ts,
+ max_requests=user_max,
+ window_seconds=window_seconds,
+ )
if not allowed:
return False, msg
@@ -622,15 +694,18 @@ def check_login_captcha_required(ip_address: str, username: Optional[str] = None
ip_key = _normalize_login_key("ip", ip_address)
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
+ def _is_over_threshold(data: Optional[Dict[str, Any]]) -> bool:
+ if not data:
+ return False
+ if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
+ return False
+ return int(data.get("count", 0) or 0) >= max_failures
+
with _login_failures_lock:
- ip_data = _login_failures.get(ip_key)
- if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds:
- if int(ip_data.get("count", 0) or 0) >= max_failures:
- return True
- ip_user_data = _login_failures.get(ip_user_key)
- if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds:
- if int(ip_user_data.get("count", 0) or 0) >= max_failures:
- return True
+ if _is_over_threshold(_login_failures.get(ip_key)):
+ return True
+ if _is_over_threshold(_login_failures.get(ip_user_key)):
+ return True
if is_login_scan_locked(ip_address):
return True
@@ -685,6 +760,56 @@ def should_send_login_alert(user_id: int, ip_address: str) -> bool:
return False
+def cleanup_expired_login_security_state(now_ts: Optional[float] = None) -> Dict[str, int]:
+ now_ts = float(now_ts if now_ts is not None else time.time())
+ _, captcha_window = _get_login_captcha_config()
+ _, _, _, rate_window = _get_login_rate_limit_config()
+ _, lock_window, _ = _get_login_lock_config()
+ _, scan_window, _ = _get_login_scan_config()
+ alert_expire_seconds = max(3600, int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS) * 3)
+
+ with _login_failures_lock:
+ failures_removed = _cleanup_map_entries(
+ _login_failures,
+ lambda data: (now_ts - float(data.get("first_failed", 0) or 0)) > max(captcha_window, lock_window),
+ )
+
+ with _login_rate_limits_lock:
+ rate_removed = _cleanup_map_entries(
+ _login_rate_limits,
+ lambda data: _is_bucket_expired(data, now_ts, rate_window),
+ )
+
+ with _login_scan_lock:
+ scan_removed = _cleanup_map_entries(
+ _login_scan_state,
+ lambda data: (
+ (now_ts - float(data.get("first_seen", 0) or 0)) > scan_window
+ and now_ts >= float(data.get("scan_until", 0) or 0)
+ ),
+ )
+
+ with _login_ip_user_lock:
+ ip_user_locks_removed = _cleanup_map_entries(
+ _login_ip_user_locks,
+ lambda data: now_ts >= float(data.get("lock_until", 0) or 0),
+ )
+
+ with _login_alert_lock:
+ alerts_removed = _cleanup_map_entries(
+ _login_alert_state,
+ lambda data: (now_ts - float(data.get("last_sent", 0) or 0)) > alert_expire_seconds,
+ )
+
+ return {
+ "failures": failures_removed,
+ "rate_limits": rate_removed,
+ "scan_states": scan_removed,
+ "ip_user_locks": ip_user_locks_removed,
+ "alerts": alerts_removed,
+ }
+
+
# ==================== 邮箱维度限流 ====================
_email_rate_limit: Dict[str, Dict[str, Any]] = {}
@@ -701,14 +826,13 @@ def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str]
key = f"{action}:{email_key}"
with _email_rate_limit_lock:
- data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds)
- if int(data.get("count", 0) or 0) >= max_requests:
- remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
- wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}秒"
- return False, f"请求过于频繁,请{wait_hint}后再试"
- data["count"] = int(data.get("count", 0) or 0) + 1
- _email_rate_limit[key] = data
- return True, None
+ return _check_and_increment_rate_bucket(
+ buckets=_email_rate_limit,
+ key=key,
+ now_ts=now_ts,
+ max_requests=max_requests,
+ window_seconds=window_seconds,
+ )
# ==================== Batch screenshots(批次任务截图收集) ====================
diff --git a/services/task_scheduler.py b/services/task_scheduler.py
new file mode 100644
index 0000000..cee84bd
--- /dev/null
+++ b/services/task_scheduler.py
@@ -0,0 +1,365 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+from __future__ import annotations
+
+import heapq
+import threading
+import time
+from concurrent.futures import ThreadPoolExecutor, wait
+from dataclasses import dataclass
+
+import database
+from app_logger import get_logger
+from services.state import safe_get_account, safe_get_task, safe_remove_task, safe_set_task
+from services.task_batches import _batch_task_record_result, _get_batch_id_from_source
+
+logger = get_logger("app")
+
+# VIP优先级队列(仅用于可视化/调试)
+vip_task_queue = [] # VIP用户任务队列
+normal_task_queue = [] # 普通用户任务队列
+task_queue_lock = threading.Lock()
+
+@dataclass
+class _TaskRequest:
+ user_id: int
+ account_id: str
+ browse_type: str
+ enable_screenshot: bool
+ source: str
+ retry_count: int
+ submitted_at: float
+ is_vip: bool
+ seq: int
+ canceled: bool = False
+ done_callback: object = None
+
+
+class TaskScheduler:
+ """全局任务调度器:队列排队,不为每个任务单独创建线程。"""
+
+ def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000, run_task_fn=None):
+ self.max_global = max(1, int(max_global))
+ self.max_per_user = max(1, int(max_per_user))
+ self.max_queue_size = max(1, int(max_queue_size))
+
+ self._cond = threading.Condition()
+ self._pending = [] # heap: (priority, submitted_at, seq, task)
+ self._pending_by_account = {} # {account_id: task}
+ self._seq = 0
+ self._known_account_ids = set()
+
+ self._running_global = 0
+ self._running_by_user = {} # {user_id: running_count}
+
+ self._executor_max_workers = self.max_global
+ self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker")
+
+ self._futures_lock = threading.Lock()
+ self._active_futures = set()
+
+ self._running = True
+ self._run_task_fn = run_task_fn
+ self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher")
+ self._dispatcher_thread.start()
+
+ def _track_future(self, future) -> None:
+ with self._futures_lock:
+ self._active_futures.add(future)
+ try:
+ future.add_done_callback(self._untrack_future)
+ except Exception:
+ pass
+
+ def _untrack_future(self, future) -> None:
+ with self._futures_lock:
+ self._active_futures.discard(future)
+
+ def shutdown(self, timeout: float = 5.0):
+ """停止调度器(用于进程退出清理)"""
+ with self._cond:
+ self._running = False
+ self._cond.notify_all()
+
+ try:
+ self._dispatcher_thread.join(timeout=timeout)
+ except Exception:
+ pass
+
+ # 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试
+ try:
+ deadline = time.time() + max(0.0, float(timeout or 0))
+ while True:
+ with self._futures_lock:
+ pending = [f for f in self._active_futures if not f.done()]
+ if not pending:
+ break
+ remaining = deadline - time.time()
+ if remaining <= 0:
+ break
+ wait(pending, timeout=remaining)
+ except Exception:
+ pass
+
+ try:
+ self._executor.shutdown(wait=False)
+ except Exception:
+ pass
+
+ # 最后兜底:清理本调度器提交过的 active_task,避免测试/重启时被“任务已在运行中”误拦截
+ try:
+ with self._cond:
+ known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys())
+ self._pending.clear()
+ self._pending_by_account.clear()
+ self._cond.notify_all()
+ for account_id in known_ids:
+ safe_remove_task(account_id)
+ except Exception:
+ pass
+
+ def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None):
+ """动态更新并发/队列上限(不影响已在运行的任务)"""
+ with self._cond:
+ if max_per_user is not None:
+ self.max_per_user = max(1, int(max_per_user))
+ if max_queue_size is not None:
+ self.max_queue_size = max(1, int(max_queue_size))
+
+ if max_global is not None:
+ new_max_global = max(1, int(max_global))
+ self.max_global = new_max_global
+ if new_max_global > self._executor_max_workers:
+ # 立即关闭旧线程池,防止资源泄漏
+ old_executor = self._executor
+ self._executor_max_workers = new_max_global
+ self._executor = ThreadPoolExecutor(
+ max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker"
+ )
+ # 立即关闭旧线程池
+ try:
+ old_executor.shutdown(wait=False)
+ logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}")
+ except Exception as e:
+ logger.warning(f"关闭旧线程池失败: {e}")
+
+ self._cond.notify_all()
+
+ def get_queue_state_snapshot(self) -> dict:
+ """获取调度器队列/运行状态快照(用于前端展示/监控)。"""
+ with self._cond:
+ pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled]
+ pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq))
+
+ positions = {}
+ for idx, t in enumerate(pending_tasks):
+ positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)}
+
+ return {
+ "pending_total": len(pending_tasks),
+ "running_total": int(self._running_global),
+ "running_by_user": dict(self._running_by_user),
+ "positions": positions,
+ }
+
+ def submit_task(
+ self,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ enable_screenshot: bool = True,
+ source: str = "manual",
+ retry_count: int = 0,
+ is_vip: bool = None,
+ done_callback=None,
+ ):
+ """提交任务进入队列(返回: (ok, message))"""
+ if not user_id or not account_id:
+ return False, "参数错误"
+
+ submitted_at = time.time()
+ if is_vip is None:
+ try:
+ is_vip = bool(database.is_user_vip(user_id))
+ except Exception:
+ is_vip = False
+ else:
+ is_vip = bool(is_vip)
+
+ with self._cond:
+ if not self._running:
+ return False, "调度器未运行"
+ if len(self._pending_by_account) >= self.max_queue_size:
+ return False, "任务队列已满,请稍后再试"
+ if account_id in self._pending_by_account:
+ return False, "任务已在队列中"
+ if safe_get_task(account_id) is not None:
+ return False, "任务已在运行中"
+
+ self._seq += 1
+ task = _TaskRequest(
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ enable_screenshot=bool(enable_screenshot),
+ source=source,
+ retry_count=int(retry_count or 0),
+ submitted_at=submitted_at,
+ is_vip=is_vip,
+ seq=self._seq,
+ done_callback=done_callback,
+ )
+ self._pending_by_account[account_id] = task
+ self._known_account_ids.add(account_id)
+ priority = 0 if is_vip else 1
+ heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task))
+ self._cond.notify_all()
+
+ # 用于可视化/调试:记录队列
+ with task_queue_lock:
+ if is_vip:
+ vip_task_queue.append(account_id)
+ else:
+ normal_task_queue.append(account_id)
+
+ return True, "已加入队列"
+
+ def cancel_pending_task(self, user_id: int, account_id: str) -> bool:
+ """取消尚未开始的排队任务(已运行的任务由 should_stop 控制)"""
+ canceled_task = None
+ with self._cond:
+ task = self._pending_by_account.pop(account_id, None)
+ if not task:
+ return False
+ task.canceled = True
+ canceled_task = task
+ self._cond.notify_all()
+
+ # 从可视化队列移除
+ with task_queue_lock:
+ if account_id in vip_task_queue:
+ vip_task_queue.remove(account_id)
+ if account_id in normal_task_queue:
+ normal_task_queue.remove(account_id)
+
+ # 批次任务:取消也要推进完成计数,避免批次缓存常驻
+ try:
+ batch_id = _get_batch_id_from_source(canceled_task.source)
+ if batch_id:
+ acc = safe_get_account(user_id, account_id)
+ if acc:
+ account_name = acc.remark if acc.remark else acc.username
+ else:
+ account_name = account_id
+ _batch_task_record_result(
+ batch_id=batch_id,
+ account_name=account_name,
+ screenshot_path=None,
+ total_items=0,
+ total_attachments=0,
+ )
+ except Exception:
+ pass
+
+ return True
+
+ def _dispatch_loop(self):
+ while True:
+ task = None
+ with self._cond:
+ if not self._running:
+ return
+
+ if not self._pending or self._running_global >= self.max_global:
+ self._cond.wait(timeout=0.5)
+ continue
+
+ task = self._pop_next_runnable_locked()
+ if task is None:
+ self._cond.wait(timeout=0.5)
+ continue
+
+ self._running_global += 1
+ self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1
+
+ # 从队列移除(可视化)
+ with task_queue_lock:
+ if task.account_id in vip_task_queue:
+ vip_task_queue.remove(task.account_id)
+ if task.account_id in normal_task_queue:
+ normal_task_queue.remove(task.account_id)
+
+ try:
+ future = self._executor.submit(self._run_task_wrapper, task)
+ self._track_future(future)
+ safe_set_task(task.account_id, future)
+ except Exception:
+ with self._cond:
+ self._running_global = max(0, self._running_global - 1)
+ # 使用默认值 0 与增加时保持一致
+ self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
+ if self._running_by_user.get(task.user_id) == 0:
+ self._running_by_user.pop(task.user_id, None)
+ self._cond.notify_all()
+
+ def _pop_next_runnable_locked(self):
+ """在锁内从优先队列取出“可运行”的任务,避免VIP任务占位阻塞普通任务。"""
+ if not self._pending:
+ return None
+
+ skipped = []
+ selected = None
+
+ while self._pending:
+ _, _, _, task = heapq.heappop(self._pending)
+
+ if task.canceled:
+ continue
+ if self._pending_by_account.get(task.account_id) is not task:
+ continue
+
+ running_for_user = self._running_by_user.get(task.user_id, 0)
+ if running_for_user >= self.max_per_user:
+ skipped.append(task)
+ continue
+
+ selected = task
+ break
+
+ for t in skipped:
+ priority = 0 if t.is_vip else 1
+ heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t))
+
+ if selected is None:
+ return None
+
+ self._pending_by_account.pop(selected.account_id, None)
+ return selected
+
+ def _run_task_wrapper(self, task: _TaskRequest):
+ try:
+ if callable(self._run_task_fn):
+ self._run_task_fn(
+ user_id=task.user_id,
+ account_id=task.account_id,
+ browse_type=task.browse_type,
+ enable_screenshot=task.enable_screenshot,
+ source=task.source,
+ retry_count=task.retry_count,
+ )
+ finally:
+ try:
+ if callable(task.done_callback):
+ task.done_callback()
+ except Exception:
+ pass
+ safe_remove_task(task.account_id)
+ with self._cond:
+ self._running_global = max(0, self._running_global - 1)
+ # 使用默认值 0 与增加时保持一致
+ self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
+ if self._running_by_user.get(task.user_id) == 0:
+ self._running_by_user.pop(task.user_id, None)
+ self._cond.notify_all()
+
+
diff --git a/services/tasks.py b/services/tasks.py
index bca2d0a..87255dc 100644
--- a/services/tasks.py
+++ b/services/tasks.py
@@ -2,12 +2,9 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
-import heapq
import os
import threading
import time
-from concurrent.futures import ThreadPoolExecutor, wait
-from dataclasses import dataclass
import database
import email_service
@@ -21,27 +18,22 @@ from services.runtime import get_socketio
from services.screenshots import take_screenshot_for_account
from services.state import (
safe_get_account,
- safe_get_task,
safe_remove_task,
safe_remove_task_status,
- safe_set_task,
safe_set_task_status,
safe_update_task_status,
)
from services.task_batches import _batch_task_record_result, _get_batch_id_from_source
+from services.task_scheduler import TaskScheduler
from task_checkpoint import TaskStage
logger = get_logger("app")
config = get_config()
-# VIP优先级队列(仅用于可视化/调试)
-vip_task_queue = [] # VIP用户任务队列
-normal_task_queue = [] # 普通用户任务队列
-task_queue_lock = threading.Lock()
-
# 并发默认值(启动后会由系统配置覆盖并调用 update_limits)
max_concurrent_per_account = config.MAX_CONCURRENT_PER_ACCOUNT
max_concurrent_global = config.MAX_CONCURRENT_GLOBAL
+_SOURCE_UNSET = object()
def _emit(event: str, data: object, *, room: str | None = None) -> None:
@@ -52,347 +44,6 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
pass
-@dataclass
-class _TaskRequest:
- user_id: int
- account_id: str
- browse_type: str
- enable_screenshot: bool
- source: str
- retry_count: int
- submitted_at: float
- is_vip: bool
- seq: int
- canceled: bool = False
- done_callback: object = None
-
-
-class TaskScheduler:
- """全局任务调度器:队列排队,不为每个任务单独创建线程。"""
-
- def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000):
- self.max_global = max(1, int(max_global))
- self.max_per_user = max(1, int(max_per_user))
- self.max_queue_size = max(1, int(max_queue_size))
-
- self._cond = threading.Condition()
- self._pending = [] # heap: (priority, submitted_at, seq, task)
- self._pending_by_account = {} # {account_id: task}
- self._seq = 0
- self._known_account_ids = set()
-
- self._running_global = 0
- self._running_by_user = {} # {user_id: running_count}
-
- self._executor_max_workers = self.max_global
- self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker")
-
- self._futures_lock = threading.Lock()
- self._active_futures = set()
-
- self._running = True
- self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher")
- self._dispatcher_thread.start()
-
- def _track_future(self, future) -> None:
- with self._futures_lock:
- self._active_futures.add(future)
- try:
- future.add_done_callback(self._untrack_future)
- except Exception:
- pass
-
- def _untrack_future(self, future) -> None:
- with self._futures_lock:
- self._active_futures.discard(future)
-
- def shutdown(self, timeout: float = 5.0):
- """停止调度器(用于进程退出清理)"""
- with self._cond:
- self._running = False
- self._cond.notify_all()
-
- try:
- self._dispatcher_thread.join(timeout=timeout)
- except Exception:
- pass
-
- # 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试
- try:
- deadline = time.time() + max(0.0, float(timeout or 0))
- while True:
- with self._futures_lock:
- pending = [f for f in self._active_futures if not f.done()]
- if not pending:
- break
- remaining = deadline - time.time()
- if remaining <= 0:
- break
- wait(pending, timeout=remaining)
- except Exception:
- pass
-
- try:
- self._executor.shutdown(wait=False)
- except Exception:
- pass
-
- # 最后兜底:清理本调度器提交过的 active_task,避免测试/重启时被“任务已在运行中”误拦截
- try:
- with self._cond:
- known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys())
- self._pending.clear()
- self._pending_by_account.clear()
- self._cond.notify_all()
- for account_id in known_ids:
- safe_remove_task(account_id)
- except Exception:
- pass
-
- def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None):
- """动态更新并发/队列上限(不影响已在运行的任务)"""
- with self._cond:
- if max_per_user is not None:
- self.max_per_user = max(1, int(max_per_user))
- if max_queue_size is not None:
- self.max_queue_size = max(1, int(max_queue_size))
-
- if max_global is not None:
- new_max_global = max(1, int(max_global))
- self.max_global = new_max_global
- if new_max_global > self._executor_max_workers:
- # 立即关闭旧线程池,防止资源泄漏
- old_executor = self._executor
- self._executor_max_workers = new_max_global
- self._executor = ThreadPoolExecutor(
- max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker"
- )
- # 立即关闭旧线程池
- try:
- old_executor.shutdown(wait=False)
- logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}")
- except Exception as e:
- logger.warning(f"关闭旧线程池失败: {e}")
-
- self._cond.notify_all()
-
- def get_queue_state_snapshot(self) -> dict:
- """获取调度器队列/运行状态快照(用于前端展示/监控)。"""
- with self._cond:
- pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled]
- pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq))
-
- positions = {}
- for idx, t in enumerate(pending_tasks):
- positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)}
-
- return {
- "pending_total": len(pending_tasks),
- "running_total": int(self._running_global),
- "running_by_user": dict(self._running_by_user),
- "positions": positions,
- }
-
- def submit_task(
- self,
- user_id: int,
- account_id: str,
- browse_type: str,
- enable_screenshot: bool = True,
- source: str = "manual",
- retry_count: int = 0,
- is_vip: bool = None,
- done_callback=None,
- ):
- """提交任务进入队列(返回: (ok, message))"""
- if not user_id or not account_id:
- return False, "参数错误"
-
- submitted_at = time.time()
- if is_vip is None:
- try:
- is_vip = bool(database.is_user_vip(user_id))
- except Exception:
- is_vip = False
- else:
- is_vip = bool(is_vip)
-
- with self._cond:
- if not self._running:
- return False, "调度器未运行"
- if len(self._pending_by_account) >= self.max_queue_size:
- return False, "任务队列已满,请稍后再试"
- if account_id in self._pending_by_account:
- return False, "任务已在队列中"
- if safe_get_task(account_id) is not None:
- return False, "任务已在运行中"
-
- self._seq += 1
- task = _TaskRequest(
- user_id=user_id,
- account_id=account_id,
- browse_type=browse_type,
- enable_screenshot=bool(enable_screenshot),
- source=source,
- retry_count=int(retry_count or 0),
- submitted_at=submitted_at,
- is_vip=is_vip,
- seq=self._seq,
- done_callback=done_callback,
- )
- self._pending_by_account[account_id] = task
- self._known_account_ids.add(account_id)
- priority = 0 if is_vip else 1
- heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task))
- self._cond.notify_all()
-
- # 用于可视化/调试:记录队列
- with task_queue_lock:
- if is_vip:
- vip_task_queue.append(account_id)
- else:
- normal_task_queue.append(account_id)
-
- return True, "已加入队列"
-
- def cancel_pending_task(self, user_id: int, account_id: str) -> bool:
- """取消尚未开始的排队任务(已运行的任务由 should_stop 控制)"""
- canceled_task = None
- with self._cond:
- task = self._pending_by_account.pop(account_id, None)
- if not task:
- return False
- task.canceled = True
- canceled_task = task
- self._cond.notify_all()
-
- # 从可视化队列移除
- with task_queue_lock:
- if account_id in vip_task_queue:
- vip_task_queue.remove(account_id)
- if account_id in normal_task_queue:
- normal_task_queue.remove(account_id)
-
- # 批次任务:取消也要推进完成计数,避免批次缓存常驻
- try:
- batch_id = _get_batch_id_from_source(canceled_task.source)
- if batch_id:
- acc = safe_get_account(user_id, account_id)
- if acc:
- account_name = acc.remark if acc.remark else acc.username
- else:
- account_name = account_id
- _batch_task_record_result(
- batch_id=batch_id,
- account_name=account_name,
- screenshot_path=None,
- total_items=0,
- total_attachments=0,
- )
- except Exception:
- pass
-
- return True
-
- def _dispatch_loop(self):
- while True:
- task = None
- with self._cond:
- if not self._running:
- return
-
- if not self._pending or self._running_global >= self.max_global:
- self._cond.wait(timeout=0.5)
- continue
-
- task = self._pop_next_runnable_locked()
- if task is None:
- self._cond.wait(timeout=0.5)
- continue
-
- self._running_global += 1
- self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1
-
- # 从队列移除(可视化)
- with task_queue_lock:
- if task.account_id in vip_task_queue:
- vip_task_queue.remove(task.account_id)
- if task.account_id in normal_task_queue:
- normal_task_queue.remove(task.account_id)
-
- try:
- future = self._executor.submit(self._run_task_wrapper, task)
- self._track_future(future)
- safe_set_task(task.account_id, future)
- except Exception:
- with self._cond:
- self._running_global = max(0, self._running_global - 1)
- # 使用默认值 0 与增加时保持一致
- self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
- if self._running_by_user.get(task.user_id) == 0:
- self._running_by_user.pop(task.user_id, None)
- self._cond.notify_all()
-
- def _pop_next_runnable_locked(self):
- """在锁内从优先队列取出“可运行”的任务,避免VIP任务占位阻塞普通任务。"""
- if not self._pending:
- return None
-
- skipped = []
- selected = None
-
- while self._pending:
- _, _, _, task = heapq.heappop(self._pending)
-
- if task.canceled:
- continue
- if self._pending_by_account.get(task.account_id) is not task:
- continue
-
- running_for_user = self._running_by_user.get(task.user_id, 0)
- if running_for_user >= self.max_per_user:
- skipped.append(task)
- continue
-
- selected = task
- break
-
- for t in skipped:
- priority = 0 if t.is_vip else 1
- heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t))
-
- if selected is None:
- return None
-
- self._pending_by_account.pop(selected.account_id, None)
- return selected
-
- def _run_task_wrapper(self, task: _TaskRequest):
- try:
- run_task(
- user_id=task.user_id,
- account_id=task.account_id,
- browse_type=task.browse_type,
- enable_screenshot=task.enable_screenshot,
- source=task.source,
- retry_count=task.retry_count,
- )
- finally:
- try:
- if callable(task.done_callback):
- task.done_callback()
- except Exception:
- pass
- safe_remove_task(task.account_id)
- with self._cond:
- self._running_global = max(0, self._running_global - 1)
- # 使用默认值 0 与增加时保持一致
- self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
- if self._running_by_user.get(task.user_id) == 0:
- self._running_by_user.pop(task.user_id, None)
- self._cond.notify_all()
-
-
_task_scheduler = None
_task_scheduler_lock = threading.Lock()
@@ -410,6 +61,7 @@ def get_task_scheduler() -> TaskScheduler:
max_global=max_concurrent_global,
max_per_user=max_concurrent_per_account,
max_queue_size=max_queue_size,
+ run_task_fn=run_task,
)
return _task_scheduler
@@ -477,6 +129,584 @@ def submit_account_task(
return True, message
+def _account_display_name(account) -> str:
+ return account.remark if account.remark else account.username
+
+
+def _record_batch_result(batch_id, account, total_items: int, total_attachments: int) -> None:
+ if not batch_id:
+ return
+ _batch_task_record_result(
+ batch_id=batch_id,
+ account_name=_account_display_name(account),
+ screenshot_path=None,
+ total_items=total_items,
+ total_attachments=total_attachments,
+ )
+
+
+def _close_account_automation(account, *, on_error=None) -> bool:
+ automation = getattr(account, "automation", None)
+ if not automation:
+ return False
+
+ closed = False
+ try:
+ automation.close()
+ closed = True
+ except Exception as e:
+ if on_error:
+ try:
+ on_error(e)
+ except Exception:
+ pass
+ finally:
+ account.automation = None
+
+ return closed
+
+
+def _create_task_log(
+ *,
+ user_id: int,
+ account_id: str,
+ account,
+ browse_type: str,
+ status: str,
+ total_items: int,
+ total_attachments: int,
+ error_message: str,
+ duration: int,
+ source=_SOURCE_UNSET,
+) -> None:
+ payload = {
+ "user_id": user_id,
+ "account_id": account_id,
+ "username": account.username,
+ "browse_type": browse_type,
+ "status": status,
+ "total_items": total_items,
+ "total_attachments": total_attachments,
+ "error_message": error_message,
+ "duration": duration,
+ }
+ if source is not _SOURCE_UNSET:
+ payload["source"] = source
+ database.create_task_log(**payload)
+
+
+def _handle_stop_requested(
+ *,
+ account,
+ user_id: int,
+ account_id: str,
+ batch_id,
+ remove_task_status: bool = False,
+ record_batch: bool = False,
+) -> bool:
+ if not account.should_stop:
+ return False
+
+ log_to_client("任务已取消", user_id, account_id)
+ account.status = "已停止"
+ account.is_running = False
+ if remove_task_status:
+ safe_remove_task_status(account_id)
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+
+ if record_batch and batch_id:
+ _record_batch_result(batch_id=batch_id, account=account, total_items=0, total_attachments=0)
+ return True
+
+
+def _resolve_proxy_config(user_id: int, account_id: str, account):
+ proxy_config = None
+ system_config = database.get_system_config()
+ if system_config.get("proxy_enabled") != 1:
+ return proxy_config
+
+ proxy_api_url = system_config.get("proxy_api_url", "").strip()
+ if not proxy_api_url:
+ log_to_client("⚠ 代理已启用但未配置API地址", user_id, account_id)
+ return proxy_config
+
+ log_to_client("正在获取代理IP...", user_id, account_id)
+ proxy_server = get_proxy_from_api(proxy_api_url, max_retries=3)
+ if proxy_server:
+ proxy_config = {"server": proxy_server}
+ log_to_client(f"[OK] 将使用代理: {proxy_server}", user_id, account_id)
+ account.proxy_config = proxy_config
+ else:
+ log_to_client("✗ 代理获取失败,将不使用代理继续", user_id, account_id)
+ return proxy_config
+
+
+def _refresh_account_remark(api_browser, account, user_id: int, account_id: str) -> None:
+ if account.remark:
+ return
+ try:
+ real_name = api_browser.get_real_name()
+ if not real_name:
+ return
+ account.remark = real_name
+ database.update_account_remark(account_id, real_name)
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ logger.info(f"[自动备注] 账号 {account.username} 自动设置备注为: {real_name}")
+ except Exception as e:
+ logger.warning(f"[自动备注] 获取姓名失败: {e}")
+
+
+def _handle_login_failed(
+ *,
+ account,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ source: str,
+ task_start_time: float,
+ task_id: str,
+ checkpoint_mgr,
+ time_module,
+) -> None:
+ error_message = "登录失败"
+ log_to_client(f"❌ {error_message}", user_id, account_id)
+
+ is_suspended = database.increment_account_login_fail(account_id, error_message)
+ if is_suspended:
+ log_to_client("⚠ 该账号连续3次密码错误,已自动暂停", user_id, account_id)
+ log_to_client("请在前台修改密码后才能继续使用", user_id, account_id)
+
+ retry_action = checkpoint_mgr.record_error(task_id, error_message)
+ if retry_action == "paused":
+ logger.warning(f"[断点] 任务 {task_id} 已暂停(登录失败)")
+
+ account.status = "登录失败"
+ account.is_running = False
+ _create_task_log(
+ user_id=user_id,
+ account_id=account_id,
+ account=account,
+ browse_type=browse_type,
+ status="failed",
+ total_items=0,
+ total_attachments=0,
+ error_message=error_message,
+ duration=int(time_module.time() - task_start_time),
+ source=source,
+ )
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+
+
+def _execute_single_attempt(
+ *,
+ account,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ source: str,
+ checkpoint_mgr,
+ task_id: str,
+ task_start_time: float,
+ time_module,
+):
+ proxy_config = _resolve_proxy_config(user_id=user_id, account_id=account_id, account=account)
+
+ checkpoint_mgr.update_stage(task_id, TaskStage.STARTING, progress_percent=10)
+
+ def custom_log(message: str):
+ log_to_client(message, user_id, account_id)
+
+ log_to_client("开始登录...", user_id, account_id)
+ safe_update_task_status(account_id, {"detail_status": "正在登录"})
+ checkpoint_mgr.update_stage(task_id, TaskStage.LOGGING_IN, progress_percent=25)
+
+ with APIBrowser(log_callback=custom_log, proxy_config=proxy_config) as api_browser:
+ if not api_browser.login(account.username, account.password):
+ _handle_login_failed(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ source=source,
+ task_start_time=task_start_time,
+ task_id=task_id,
+ checkpoint_mgr=checkpoint_mgr,
+ time_module=time_module,
+ )
+ return "login_failed", None
+
+ log_to_client("[OK] 首次登录成功,刷新登录时间...", user_id, account_id)
+
+ if api_browser.login(account.username, account.password):
+ log_to_client("[OK] 二次登录成功!", user_id, account_id)
+ else:
+ log_to_client("⚠ 二次登录失败,继续使用首次登录状态", user_id, account_id)
+
+ api_browser.save_cookies_for_screenshot(account.username)
+ database.reset_account_login_status(account_id)
+ _refresh_account_remark(api_browser, account, user_id, account_id)
+
+ safe_update_task_status(account_id, {"detail_status": "正在浏览"})
+ log_to_client(f"开始浏览 '{browse_type}' 内容...", user_id, account_id)
+ account.total_items = 0
+ safe_update_task_status(account_id, {"progress": {"items": 0, "attachments": 0}})
+
+ def should_stop():
+ return account.should_stop
+
+ def on_browse_progress(progress: dict):
+ try:
+ total_items = int(progress.get("total_items") or 0)
+ browsed_items = int(progress.get("browsed_items") or 0)
+ if total_items > 0:
+ account.total_items = total_items
+ safe_update_task_status(account_id, {"progress": {"items": browsed_items, "attachments": 0}})
+ except Exception:
+ pass
+
+ checkpoint_mgr.update_stage(task_id, TaskStage.BROWSING, progress_percent=50)
+ result = api_browser.browse_content(
+ browse_type=browse_type,
+ should_stop_callback=should_stop,
+ progress_callback=on_browse_progress,
+ )
+ return "ok", result
+
+
+def _record_success_without_screenshot(
+ *,
+ account,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ source: str,
+ task_start_time: float,
+ result,
+ batch_id,
+ time_module,
+) -> bool:
+ _create_task_log(
+ user_id=user_id,
+ account_id=account_id,
+ account=account,
+ browse_type=browse_type,
+ status="success",
+ total_items=result.total_items,
+ total_attachments=result.total_attachments,
+ error_message="",
+ duration=int(time_module.time() - task_start_time),
+ source=source,
+ )
+ if batch_id:
+ _record_batch_result(
+ batch_id=batch_id,
+ account=account,
+ total_items=result.total_items,
+ total_attachments=result.total_attachments,
+ )
+ return True
+
+ if source and source.startswith("user_scheduled"):
+ try:
+ user_info = database.get_user_by_id(user_id)
+ if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
+ email_service.send_task_complete_email_async(
+ user_id=user_id,
+ email=user_info["email"],
+ username=user_info["username"],
+ account_name=_account_display_name(account),
+ browse_type=browse_type,
+ total_items=result.total_items,
+ total_attachments=result.total_attachments,
+ screenshot_path=None,
+ log_callback=lambda msg: log_to_client(msg, user_id, account_id),
+ )
+ except Exception as email_error:
+ logger.warning(f"发送任务完成邮件失败: {email_error}")
+
+ return False
+
+
+def _is_timeout_error(error_msg: str) -> bool:
+ return "Timeout" in error_msg or "timeout" in error_msg
+
+
+def _record_failed_task_log(
+ *,
+ account,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ source: str,
+ total_items: int,
+ total_attachments: int,
+ error_message: str,
+ task_start_time: float,
+ time_module,
+) -> None:
+ account.status = "出错"
+ _create_task_log(
+ user_id=user_id,
+ account_id=account_id,
+ account=account,
+ browse_type=browse_type,
+ status="failed",
+ total_items=total_items,
+ total_attachments=total_attachments,
+ error_message=error_message,
+ duration=int(time_module.time() - task_start_time),
+ source=source,
+ )
+
+
+def _handle_failed_browse_result(
+ *,
+ account,
+ result,
+ error_msg: str,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ source: str,
+ attempt: int,
+ max_attempts: int,
+ task_start_time: float,
+ time_module,
+) -> str:
+ if _is_timeout_error(error_msg):
+ log_to_client(f"⚠ 检测到超时错误: {error_msg}", user_id, account_id)
+
+ close_ok = _close_account_automation(
+ account,
+ on_error=lambda e: logger.debug(f"关闭超时浏览器实例失败: {e}"),
+ )
+ if close_ok:
+ log_to_client("已关闭超时的浏览器实例", user_id, account_id)
+
+ if attempt < max_attempts:
+ log_to_client(f"⚠ 代理可能速度过慢,将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id)
+ time_module.sleep(2)
+ return "continue"
+
+ log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id)
+ _record_failed_task_log(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ source=source,
+ total_items=result.total_items,
+ total_attachments=result.total_attachments,
+ error_message=f"重试{max_attempts}次后仍失败: {error_msg}",
+ task_start_time=task_start_time,
+ time_module=time_module,
+ )
+ return "break"
+
+ log_to_client(f"浏览出错: {error_msg}", user_id, account_id)
+ _record_failed_task_log(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ source=source,
+ total_items=result.total_items,
+ total_attachments=result.total_attachments,
+ error_message=error_msg,
+ task_start_time=task_start_time,
+ time_module=time_module,
+ )
+ return "break"
+
+
+def _handle_attempt_exception(
+ *,
+ account,
+ error_msg: str,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ source: str,
+ attempt: int,
+ max_attempts: int,
+ task_start_time: float,
+ time_module,
+) -> str:
+ _close_account_automation(
+ account,
+ on_error=lambda e: logger.debug(f"关闭浏览器实例失败: {e}"),
+ )
+
+ if _is_timeout_error(error_msg):
+ log_to_client(f"⚠ 执行超时: {error_msg}", user_id, account_id)
+ if attempt < max_attempts:
+ log_to_client(f"⚠ 将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id)
+ time_module.sleep(2)
+ return "continue"
+
+ log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id)
+ _record_failed_task_log(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ source=source,
+ total_items=account.total_items,
+ total_attachments=account.total_attachments,
+ error_message=f"重试{max_attempts}次后仍失败: {error_msg}",
+ task_start_time=task_start_time,
+ time_module=time_module,
+ )
+ return "break"
+
+ log_to_client(f"任务执行异常: {error_msg}", user_id, account_id)
+ _record_failed_task_log(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ source=source,
+ total_items=account.total_items,
+ total_attachments=account.total_attachments,
+ error_message=error_msg,
+ task_start_time=task_start_time,
+ time_module=time_module,
+ )
+ return "break"
+
+
+def _schedule_auto_retry(
+ *,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ enable_screenshot: bool,
+ source: str,
+ retry_count: int,
+) -> None:
+ def delayed_retry_submit():
+ fresh_account = safe_get_account(user_id, account_id)
+ if not fresh_account:
+ log_to_client("自动重试取消: 账户不存在", user_id, account_id)
+ return
+ if fresh_account.should_stop:
+ log_to_client("自动重试取消: 任务已被停止", user_id, account_id)
+ return
+
+ log_to_client(f"🔄 开始第 {retry_count + 1} 次自动重试...", user_id, account_id)
+ ok, msg = submit_account_task(
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ enable_screenshot=enable_screenshot,
+ source=source,
+ retry_count=retry_count + 1,
+ )
+ if not ok:
+ log_to_client(f"自动重试提交失败: {msg}", user_id, account_id)
+
+ try:
+ threading.Timer(5, delayed_retry_submit).start()
+ except Exception:
+ delayed_retry_submit()
+
+
+def _finalize_task_run(
+ *,
+ account,
+ user_id: int,
+ account_id: str,
+ browse_type: str,
+ source: str,
+ enable_screenshot: bool,
+ retry_count: int,
+ max_auto_retry: int,
+ task_start_time: float,
+ result,
+ batch_id,
+ batch_recorded: bool,
+) -> None:
+ final_status = str(account.status or "")
+ account.is_running = False
+ screenshot_submitted = False
+
+ _close_account_automation(
+ account,
+ on_error=lambda e: log_to_client(f"关闭主任务浏览器时出错: {str(e)}", user_id, account_id),
+ )
+
+ safe_remove_task(account_id)
+ safe_remove_task_status(account_id)
+
+ if final_status == "已完成" and not account.should_stop:
+ account.status = "已完成"
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+
+ if enable_screenshot:
+ log_to_client("等待2秒后开始截图...", user_id, account_id)
+ account.status = "等待截图"
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+
+ safe_set_task_status(
+ account_id,
+ {
+ "user_id": user_id,
+ "username": account.username,
+ "status": "排队中",
+ "detail_status": "等待截图资源",
+ "browse_type": browse_type,
+ "start_time": time.time(),
+ "source": source,
+ "progress": {
+ "items": result.total_items if result else 0,
+ "attachments": result.total_attachments if result else 0,
+ },
+ },
+ )
+
+ browse_result_dict = {
+ "total_items": result.total_items if result else 0,
+ "total_attachments": result.total_attachments if result else 0,
+ }
+ screenshot_submitted = True
+ threading.Thread(
+ target=take_screenshot_for_account,
+ args=(user_id, account_id, browse_type, source, task_start_time, browse_result_dict),
+ daemon=True,
+ ).start()
+ else:
+ account.status = "未开始"
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ log_to_client("截图功能已禁用,跳过截图", user_id, account_id)
+ else:
+ if final_status == "出错" and retry_count < max_auto_retry:
+ log_to_client(f"⚠ 任务执行失败,5秒后自动重试 ({retry_count + 1}/{max_auto_retry})...", user_id, account_id)
+ account.status = "等待重试"
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ _schedule_auto_retry(
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ enable_screenshot=enable_screenshot,
+ source=source,
+ retry_count=retry_count,
+ )
+ elif final_status in ["登录失败", "出错"]:
+ account.status = final_status
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ else:
+ account.status = "未开始"
+ _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+
+ if batch_id and (not screenshot_submitted) and (not batch_recorded) and account.status != "等待重试":
+ _record_batch_result(
+ batch_id=batch_id,
+ account=account,
+ total_items=getattr(account, "total_items", 0) or 0,
+ total_attachments=getattr(account, "total_attachments", 0) or 0,
+ )
+
+
def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="manual", retry_count=0):
"""运行自动化任务
@@ -498,30 +728,27 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m
import time as time_module
+ task_start_time = time_module.time()
+ result = None
+
try:
- if account.should_stop:
- log_to_client("任务已取消", user_id, account_id)
- account.status = "已停止"
- account.is_running = False
- safe_remove_task_status(account_id)
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
- if batch_id:
- account_name = account.remark if account.remark else account.username
- _batch_task_record_result(
- batch_id=batch_id,
- account_name=account_name,
- screenshot_path=None,
- total_items=0,
- total_attachments=0,
- )
+ if _handle_stop_requested(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ batch_id=batch_id,
+ remove_task_status=True,
+ record_batch=True,
+ ):
return
try:
- if account.should_stop:
- log_to_client("任务已取消", user_id, account_id)
- account.status = "已停止"
- account.is_running = False
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
+ if _handle_stop_requested(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ batch_id=batch_id,
+ ):
return
task_id = checkpoint_mgr.create_checkpoint(
@@ -546,111 +773,19 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m
if attempt > 1:
log_to_client(f"🔄 第 {attempt} 次尝试(共{max_attempts}次)...", user_id, account_id)
- proxy_config = None
- config = database.get_system_config()
- if config.get("proxy_enabled") == 1:
- proxy_api_url = config.get("proxy_api_url", "").strip()
- if proxy_api_url:
- log_to_client("正在获取代理IP...", user_id, account_id)
- proxy_server = get_proxy_from_api(proxy_api_url, max_retries=3)
- if proxy_server:
- proxy_config = {"server": proxy_server}
- log_to_client(f"[OK] 将使用代理: {proxy_server}", user_id, account_id)
- account.proxy_config = proxy_config # 保存代理配置供截图使用
- else:
- log_to_client("✗ 代理获取失败,将不使用代理继续", user_id, account_id)
- else:
- log_to_client("⚠ 代理已启用但未配置API地址", user_id, account_id)
-
- checkpoint_mgr.update_stage(task_id, TaskStage.STARTING, progress_percent=10)
-
- def custom_log(message: str):
- log_to_client(message, user_id, account_id)
-
- log_to_client("开始登录...", user_id, account_id)
- safe_update_task_status(account_id, {"detail_status": "正在登录"})
- checkpoint_mgr.update_stage(task_id, TaskStage.LOGGING_IN, progress_percent=25)
-
- with APIBrowser(log_callback=custom_log, proxy_config=proxy_config) as api_browser:
- if api_browser.login(account.username, account.password):
- log_to_client("[OK] 首次登录成功,刷新登录时间...", user_id, account_id)
-
- # 二次登录:让"上次登录时间"变成刚才首次登录的时间
- # 这样截图时显示的"上次登录时间"就是几秒前而不是昨天
- if api_browser.login(account.username, account.password):
- log_to_client("[OK] 二次登录成功!", user_id, account_id)
- else:
- log_to_client("⚠ 二次登录失败,继续使用首次登录状态", user_id, account_id)
-
- api_browser.save_cookies_for_screenshot(account.username)
- database.reset_account_login_status(account_id)
-
- if not account.remark:
- try:
- real_name = api_browser.get_real_name()
- if real_name:
- account.remark = real_name
- database.update_account_remark(account_id, real_name)
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
- logger.info(f"[自动备注] 账号 {account.username} 自动设置备注为: {real_name}")
- except Exception as e:
- logger.warning(f"[自动备注] 获取姓名失败: {e}")
-
- safe_update_task_status(account_id, {"detail_status": "正在浏览"})
- log_to_client(f"开始浏览 '{browse_type}' 内容...", user_id, account_id)
- account.total_items = 0
- safe_update_task_status(account_id, {"progress": {"items": 0, "attachments": 0}})
-
- def should_stop():
- return account.should_stop
-
- def on_browse_progress(progress: dict):
- try:
- total_items = int(progress.get("total_items") or 0)
- browsed_items = int(progress.get("browsed_items") or 0)
- if total_items > 0:
- account.total_items = total_items
- safe_update_task_status(
- account_id, {"progress": {"items": browsed_items, "attachments": 0}}
- )
- except Exception:
- pass
-
- checkpoint_mgr.update_stage(task_id, TaskStage.BROWSING, progress_percent=50)
- result = api_browser.browse_content(
- browse_type=browse_type,
- should_stop_callback=should_stop,
- progress_callback=on_browse_progress,
- )
- else:
- error_message = "登录失败"
- log_to_client(f"❌ {error_message}", user_id, account_id)
-
- is_suspended = database.increment_account_login_fail(account_id, error_message)
- if is_suspended:
- log_to_client("⚠ 该账号连续3次密码错误,已自动暂停", user_id, account_id)
- log_to_client("请在前台修改密码后才能继续使用", user_id, account_id)
-
- retry_action = checkpoint_mgr.record_error(task_id, error_message)
- if retry_action == "paused":
- logger.warning(f"[断点] 任务 {task_id} 已暂停(登录失败)")
-
- account.status = "登录失败"
- account.is_running = False
- database.create_task_log(
- user_id=user_id,
- account_id=account_id,
- username=account.username,
- browse_type=browse_type,
- status="failed",
- total_items=0,
- total_attachments=0,
- error_message=error_message,
- duration=int(time_module.time() - task_start_time),
- source=source,
- )
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
- return
+ attempt_status, result = _execute_single_attempt(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ source=source,
+ checkpoint_mgr=checkpoint_mgr,
+ task_id=task_id,
+ task_start_time=task_start_time,
+ time_module=time_module,
+ )
+ if attempt_status == "login_failed":
+ return
account.total_items = result.total_items
account.total_attachments = result.total_attachments
@@ -674,152 +809,63 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m
logger.info(f"[断点] 任务 {task_id} 已完成")
if not enable_screenshot:
- database.create_task_log(
+ batch_recorded = _record_success_without_screenshot(
+ account=account,
user_id=user_id,
account_id=account_id,
- username=account.username,
browse_type=browse_type,
- status="success",
- total_items=result.total_items,
- total_attachments=result.total_attachments,
- error_message="",
- duration=int(time_module.time() - task_start_time),
source=source,
+ task_start_time=task_start_time,
+ result=result,
+ batch_id=batch_id,
+ time_module=time_module,
)
- if batch_id:
- account_name = account.remark if account.remark else account.username
- _batch_task_record_result(
- batch_id=batch_id,
- account_name=account_name,
- screenshot_path=None,
- total_items=result.total_items,
- total_attachments=result.total_attachments,
- )
- batch_recorded = True
- elif source and source.startswith("user_scheduled"):
- try:
- user_info = database.get_user_by_id(user_id)
- if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
- account_name = account.remark if account.remark else account.username
- email_service.send_task_complete_email_async(
- user_id=user_id,
- email=user_info["email"],
- username=user_info["username"],
- account_name=account_name,
- browse_type=browse_type,
- total_items=result.total_items,
- total_attachments=result.total_attachments,
- screenshot_path=None,
- log_callback=lambda msg: log_to_client(msg, user_id, account_id),
- )
- except Exception as email_error:
- logger.warning(f"发送任务完成邮件失败: {email_error}")
break
error_msg = result.error_message
- if "Timeout" in error_msg or "timeout" in error_msg:
- log_to_client(f"⚠ 检测到超时错误: {error_msg}", user_id, account_id)
-
- if account.automation:
- try:
- account.automation.close()
- log_to_client("已关闭超时的浏览器实例", user_id, account_id)
- except Exception as e:
- logger.debug(f"关闭超时浏览器实例失败: {e}")
- account.automation = None
-
- if attempt < max_attempts:
- log_to_client(
- f"⚠ 代理可能速度过慢,将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id
- )
- time_module.sleep(2)
- continue
- log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id)
- account.status = "出错"
- database.create_task_log(
- user_id=user_id,
- account_id=account_id,
- username=account.username,
- browse_type=browse_type,
- status="failed",
- total_items=result.total_items,
- total_attachments=result.total_attachments,
- error_message=f"重试{max_attempts}次后仍失败: {error_msg}",
- duration=int(time_module.time() - task_start_time),
- )
- break
-
- log_to_client(f"浏览出错: {error_msg}", user_id, account_id)
- account.status = "出错"
- database.create_task_log(
+ action = _handle_failed_browse_result(
+ account=account,
+ result=result,
+ error_msg=error_msg,
user_id=user_id,
account_id=account_id,
- username=account.username,
browse_type=browse_type,
- status="failed",
- total_items=result.total_items,
- total_attachments=result.total_attachments,
- error_message=error_msg,
- duration=int(time_module.time() - task_start_time),
source=source,
+ attempt=attempt,
+ max_attempts=max_attempts,
+ task_start_time=task_start_time,
+ time_module=time_module,
)
+ if action == "continue":
+ continue
break
except Exception as retry_error:
error_msg = str(retry_error)
- if account.automation:
- try:
- account.automation.close()
- except Exception as e:
- logger.debug(f"关闭浏览器实例失败: {e}")
- account.automation = None
-
- if "Timeout" in error_msg or "timeout" in error_msg:
- log_to_client(f"⚠ 执行超时: {error_msg}", user_id, account_id)
- if attempt < max_attempts:
- log_to_client(f"⚠ 将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id)
- time_module.sleep(2)
- continue
- log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id)
- account.status = "出错"
- database.create_task_log(
- user_id=user_id,
- account_id=account_id,
- username=account.username,
- browse_type=browse_type,
- status="failed",
- total_items=account.total_items,
- total_attachments=account.total_attachments,
- error_message=f"重试{max_attempts}次后仍失败: {error_msg}",
- duration=int(time_module.time() - task_start_time),
- source=source,
- )
- break
-
- log_to_client(f"任务执行异常: {error_msg}", user_id, account_id)
- account.status = "出错"
- database.create_task_log(
+ action = _handle_attempt_exception(
+ account=account,
+ error_msg=error_msg,
user_id=user_id,
account_id=account_id,
- username=account.username,
browse_type=browse_type,
- status="failed",
- total_items=account.total_items,
- total_attachments=account.total_attachments,
- error_message=error_msg,
- duration=int(time_module.time() - task_start_time),
source=source,
+ attempt=attempt,
+ max_attempts=max_attempts,
+ task_start_time=task_start_time,
+ time_module=time_module,
)
+ if action == "continue":
+ continue
break
except Exception as e:
error_msg = str(e)
log_to_client(f"任务执行出错: {error_msg}", user_id, account_id)
account.status = "出错"
- database.create_task_log(
+ _create_task_log(
user_id=user_id,
account_id=account_id,
- username=account.username,
+ account=account,
browse_type=browse_type,
status="failed",
total_items=account.total_items,
@@ -830,108 +876,20 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m
)
finally:
- account.is_running = False
- screenshot_submitted = False
- if account.status not in ["已完成"]:
- account.status = "未开始"
-
- if account.automation:
- try:
- account.automation.close()
- except Exception as e:
- log_to_client(f"关闭主任务浏览器时出错: {str(e)}", user_id, account_id)
- finally:
- account.automation = None
-
- safe_remove_task(account_id)
- safe_remove_task_status(account_id)
-
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
-
- if account.status == "已完成" and not account.should_stop:
- if enable_screenshot:
- log_to_client("等待2秒后开始截图...", user_id, account_id)
- account.status = "等待截图"
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
- import time as time_mod
-
- safe_set_task_status(
- account_id,
- {
- "user_id": user_id,
- "username": account.username,
- "status": "排队中",
- "detail_status": "等待截图资源",
- "browse_type": browse_type,
- "start_time": time_mod.time(),
- "source": source,
- "progress": {
- "items": result.total_items if result else 0,
- "attachments": result.total_attachments if result else 0,
- },
- },
- )
- browse_result_dict = {
- "total_items": result.total_items,
- "total_attachments": result.total_attachments,
- }
- screenshot_submitted = True
- threading.Thread(
- target=take_screenshot_for_account,
- args=(user_id, account_id, browse_type, source, task_start_time, browse_result_dict),
- daemon=True,
- ).start()
- else:
- account.status = "未开始"
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
- log_to_client("截图功能已禁用,跳过截图", user_id, account_id)
- else:
- if account.status not in ["登录失败", "出错"]:
- account.status = "未开始"
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
- elif account.status == "出错" and retry_count < MAX_AUTO_RETRY:
- log_to_client(
- f"⚠ 任务执行失败,5秒后自动重试 ({retry_count + 1}/{MAX_AUTO_RETRY})...", user_id, account_id
- )
- account.status = "等待重试"
- _emit("account_update", account.to_dict(), room=f"user_{user_id}")
-
- def delayed_retry_submit():
- # 重新获取最新的账户对象,避免使用闭包中的旧对象
- fresh_account = safe_get_account(user_id, account_id)
- if not fresh_account:
- log_to_client("自动重试取消: 账户不存在", user_id, account_id)
- return
- if fresh_account.should_stop:
- log_to_client("自动重试取消: 任务已被停止", user_id, account_id)
- return
- log_to_client(f"🔄 开始第 {retry_count + 1} 次自动重试...", user_id, account_id)
- ok, msg = submit_account_task(
- user_id=user_id,
- account_id=account_id,
- browse_type=browse_type,
- enable_screenshot=enable_screenshot,
- source=source,
- retry_count=retry_count + 1,
- )
- if not ok:
- log_to_client(f"自动重试提交失败: {msg}", user_id, account_id)
-
- try:
- threading.Timer(5, delayed_retry_submit).start()
- except Exception:
- delayed_retry_submit()
-
- if batch_id and (not screenshot_submitted) and (not batch_recorded) and account.status != "等待重试":
- account_name = account.remark if account.remark else account.username
- _batch_task_record_result(
- batch_id=batch_id,
- account_name=account_name,
- screenshot_path=None,
- total_items=getattr(account, "total_items", 0) or 0,
- total_attachments=getattr(account, "total_attachments", 0) or 0,
- )
- batch_recorded = True
+ _finalize_task_run(
+ account=account,
+ user_id=user_id,
+ account_id=account_id,
+ browse_type=browse_type,
+ source=source,
+ enable_screenshot=enable_screenshot,
+ retry_count=retry_count,
+ max_auto_retry=MAX_AUTO_RETRY,
+ task_start_time=task_start_time,
+ result=result,
+ batch_id=batch_id,
+ batch_recorded=batch_recorded,
+ )
finally:
pass