diff --git a/app.py b/app.py index 94dad79..4fcfee3 100644 --- a/app.py +++ b/app.py @@ -173,10 +173,28 @@ def serve_static(filename): if not is_safe_path("static", filename): return jsonify({"error": "非法路径"}), 403 - response = send_from_directory("static", filename) - response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" - response.headers["Pragma"] = "no-cache" - response.headers["Expires"] = "0" + cache_ttl = 3600 + lowered = filename.lower() + if "/assets/" in lowered or lowered.endswith((".js", ".css", ".woff", ".woff2", ".ttf", ".svg")): + cache_ttl = 604800 # 7天 + + if request.args.get("v"): + cache_ttl = max(cache_ttl, 604800) + + response = send_from_directory("static", filename, max_age=cache_ttl, conditional=True) + + # 协商缓存:确保存在 ETag,并基于 If-None-Match/If-Modified-Since 返回 304 + try: + response.add_etag(overwrite=False) + except Exception: + pass + try: + response.make_conditional(request) + except Exception: + pass + + response.headers.setdefault("Vary", "Accept-Encoding") + response.headers["Cache-Control"] = f"public, max-age={cache_ttl}" return response @@ -232,6 +250,93 @@ def _signal_handler(sig, frame): sys.exit(0) +def _cleanup_stale_task_state() -> None: + logger.info("清理遗留任务状态...") + try: + from services.state import safe_get_active_task_ids, safe_remove_task, safe_remove_task_status + + for _, accounts in safe_iter_user_accounts_items(): + for acc in accounts.values(): + if not getattr(acc, "is_running", False): + continue + acc.is_running = False + acc.should_stop = False + acc.status = "未开始" + + for account_id in list(safe_get_active_task_ids()): + safe_remove_task(account_id) + safe_remove_task_status(account_id) + + logger.info("[OK] 遗留任务状态已清理") + except Exception as e: + logger.warning(f"清理遗留任务状态失败: {e}") + + +def _init_optional_email_service() -> None: + try: + email_service.init_email_service() + logger.info("[OK] 邮件服务已初始化") + except Exception as e: + logger.warning(f"警告: 邮件服务初始化失败: {e}") + + +def _load_and_apply_scheduler_limits() -> None: + try: + system_config = database.get_system_config() or {} + max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL)) + max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT)) + get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account) + logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}") + except Exception as e: + logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}") + + +def _start_background_workers() -> None: + logger.info("启动定时任务调度器...") + threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start() + logger.info("[OK] 定时任务调度器已启动") + + logger.info("[OK] 状态推送线程已启动(默认2秒/次)") + threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start() + + +def _init_screenshot_worker_pool() -> None: + try: + pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3)) + except Exception: + pool_size = 3 + + try: + logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...") + init_browser_worker_pool(pool_size=pool_size) + logger.info("[OK] 截图线程池初始化完成") + except Exception as e: + logger.warning(f"警告: 截图线程池初始化失败: {e}") + + +def _warmup_api_connection() -> None: + logger.info("预热 API 连接...") + try: + from api_browser import warmup_api_connection + + threading.Thread( + target=warmup_api_connection, + kwargs={"log_callback": lambda msg: logger.info(msg)}, + daemon=True, + name="api-warmup", + ).start() + except Exception as e: + logger.warning(f"API 预热失败: {e}") + + +def _log_startup_urls() -> None: + logger.info("服务器启动中...") + logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}") + logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx") + logger.info("默认管理员: admin (首次运行随机密码见日志)") + logger.info("=" * 60) + + if __name__ == "__main__": atexit.register(cleanup_on_exit) signal.signal(signal.SIGINT, _signal_handler) @@ -245,81 +350,17 @@ if __name__ == "__main__": init_checkpoint_manager() logger.info("[OK] 任务断点管理器已初始化") - # 【新增】容器重启时清理遗留的任务状态 - logger.info("清理遗留任务状态...") - try: - from services.state import safe_remove_task, safe_get_active_task_ids, safe_remove_task_status - # 重置所有账号的运行状态 - for _, accounts in safe_iter_user_accounts_items(): - for acc in accounts.values(): - if getattr(acc, "is_running", False): - acc.is_running = False - acc.should_stop = False - acc.status = "未开始" - # 清理活跃任务句柄 - for account_id in list(safe_get_active_task_ids()): - safe_remove_task(account_id) - safe_remove_task_status(account_id) - logger.info("[OK] 遗留任务状态已清理") - except Exception as e: - logger.warning(f"清理遗留任务状态失败: {e}") - - try: - email_service.init_email_service() - logger.info("[OK] 邮件服务已初始化") - except Exception as e: - logger.warning(f"警告: 邮件服务初始化失败: {e}") + _cleanup_stale_task_state() + _init_optional_email_service() start_cleanup_scheduler() start_kdocs_monitor() - try: - system_config = database.get_system_config() or {} - max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL)) - max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT)) - get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account) - logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}") - except Exception as e: - logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}") - - logger.info("启动定时任务调度器...") - threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start() - logger.info("[OK] 定时任务调度器已启动") - - logger.info("[OK] 状态推送线程已启动(默认2秒/次)") - threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start() - - logger.info("服务器启动中...") - logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}") - logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx") - logger.info("默认管理员: admin (首次运行随机密码见日志)") - logger.info("=" * 60) - - try: - pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3)) - except Exception: - pool_size = 3 - try: - logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...") - init_browser_worker_pool(pool_size=pool_size) - logger.info("[OK] 截图线程池初始化完成") - except Exception as e: - logger.warning(f"警告: 截图线程池初始化失败: {e}") - - # 预热 API 连接(后台进行,不阻塞启动) - logger.info("预热 API 连接...") - try: - from api_browser import warmup_api_connection - import threading - - threading.Thread( - target=warmup_api_connection, - kwargs={"log_callback": lambda msg: logger.info(msg)}, - daemon=True, - name="api-warmup", - ).start() - except Exception as e: - logger.warning(f"API 预热失败: {e}") + _load_and_apply_scheduler_limits() + _start_background_workers() + _log_startup_urls() + _init_screenshot_worker_pool() + _warmup_api_connection() socketio.run( app, diff --git a/database.py b/database.py index d6a132a..56121e1 100644 --- a/database.py +++ b/database.py @@ -120,7 +120,7 @@ config = get_config() DB_FILE = config.DB_FILE # 数据库版本 (用于迁移管理) -DB_VERSION = 17 +DB_VERSION = 18 # ==================== 系统配置缓存(P1 / O-03) ==================== @@ -142,6 +142,37 @@ def invalidate_system_config_cache() -> None: _system_config_cache_loaded_at = 0.0 +def _normalize_system_config_value(value) -> dict: + try: + return dict(value or {}) + except Exception: + return {} + + +def _is_system_config_cache_valid(now_ts: float) -> bool: + if _system_config_cache_value is None: + return False + if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0: + return True + return (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS + + +def _read_system_config_cache(now_ts: float, *, ignore_ttl: bool = False) -> Optional[dict]: + with _system_config_cache_lock: + if _system_config_cache_value is None: + return None + if (not ignore_ttl) and (not _is_system_config_cache_valid(now_ts)): + return None + return dict(_system_config_cache_value) + + +def _write_system_config_cache(value: dict, now_ts: float) -> None: + global _system_config_cache_value, _system_config_cache_loaded_at + with _system_config_cache_lock: + _system_config_cache_value = dict(value) + _system_config_cache_loaded_at = now_ts + + def init_database(): """初始化数据库表结构 + 迁移(入口统一)。""" db_pool.init_pool(DB_FILE, pool_size=config.DB_POOL_SIZE) @@ -165,19 +196,21 @@ def migrate_database(): def get_system_config(): """获取系统配置(带进程内缓存)。""" - global _system_config_cache_value, _system_config_cache_loaded_at - now_ts = time.time() - with _system_config_cache_lock: - if _system_config_cache_value is not None: - if _SYSTEM_CONFIG_CACHE_TTL_SECONDS <= 0 or (now_ts - _system_config_cache_loaded_at) < _SYSTEM_CONFIG_CACHE_TTL_SECONDS: - return dict(_system_config_cache_value) - value = _get_system_config_raw() + cached_value = _read_system_config_cache(now_ts) + if cached_value is not None: + return cached_value - with _system_config_cache_lock: - _system_config_cache_value = dict(value) - _system_config_cache_loaded_at = now_ts + try: + value = _normalize_system_config_value(_get_system_config_raw()) + except Exception: + fallback_value = _read_system_config_cache(now_ts, ignore_ttl=True) + if fallback_value is not None: + return fallback_value + raise + + _write_system_config_cache(value, now_ts) return dict(value) diff --git a/db/accounts.py b/db/accounts.py index 85e4132..fb98789 100644 --- a/db/accounts.py +++ b/db/accounts.py @@ -6,19 +6,51 @@ import db_pool from crypto_utils import decrypt_password, encrypt_password from db.utils import get_cst_now_str +_ACCOUNT_STATUS_QUERY_SQL = """ + SELECT status, login_fail_count, last_login_error + FROM accounts + WHERE id = ? +""" + + +def _decode_account_password(account_dict: dict) -> dict: + account_dict["password"] = decrypt_password(account_dict.get("password", "")) + return account_dict + + +def _normalize_account_ids(account_ids) -> list[str]: + normalized = [] + seen = set() + for account_id in account_ids or []: + if not account_id: + continue + account_key = str(account_id) + if account_key in seen: + continue + seen.add(account_key) + normalized.append(account_key) + return normalized + def create_account(user_id, account_id, username, password, remember=True, remark=""): """创建账号(密码加密存储)""" with db_pool.get_db() as conn: cursor = conn.cursor() - cst_time = get_cst_now_str() encrypted_password = encrypt_password(password) cursor.execute( """ INSERT INTO accounts (id, user_id, username, password, remember, remark, created_at) VALUES (?, ?, ?, ?, ?, ?, ?) """, - (account_id, user_id, username, encrypted_password, 1 if remember else 0, remark, cst_time), + ( + account_id, + user_id, + username, + encrypted_password, + 1 if remember else 0, + remark, + get_cst_now_str(), + ), ) conn.commit() return cursor.lastrowid @@ -29,12 +61,7 @@ def get_user_accounts(user_id): with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM accounts WHERE user_id = ? ORDER BY created_at DESC", (user_id,)) - accounts = [] - for row in cursor.fetchall(): - account = dict(row) - account["password"] = decrypt_password(account.get("password", "")) - accounts.append(account) - return accounts + return [_decode_account_password(dict(row)) for row in cursor.fetchall()] def get_account(account_id): @@ -43,11 +70,9 @@ def get_account(account_id): cursor = conn.cursor() cursor.execute("SELECT * FROM accounts WHERE id = ?", (account_id,)) row = cursor.fetchone() - if row: - account = dict(row) - account["password"] = decrypt_password(account.get("password", "")) - return account - return None + if not row: + return None + return _decode_account_password(dict(row)) def update_account_remark(account_id, remark): @@ -78,33 +103,21 @@ def increment_account_login_fail(account_id, error_message): if not row: return False - fail_count = (row["login_fail_count"] or 0) + 1 - - if fail_count >= 3: - cursor.execute( - """ - UPDATE accounts - SET login_fail_count = ?, - last_login_error = ?, - status = 'suspended' - WHERE id = ? - """, - (fail_count, error_message, account_id), - ) - conn.commit() - return True + fail_count = int(row["login_fail_count"] or 0) + 1 + is_suspended = fail_count >= 3 cursor.execute( """ UPDATE accounts SET login_fail_count = ?, - last_login_error = ? + last_login_error = ?, + status = CASE WHEN ? = 1 THEN 'suspended' ELSE status END WHERE id = ? """, - (fail_count, error_message, account_id), + (fail_count, error_message, 1 if is_suspended else 0, account_id), ) conn.commit() - return False + return is_suspended def reset_account_login_status(account_id): @@ -129,29 +142,22 @@ def get_account_status(account_id): """获取账号状态信息""" with db_pool.get_db() as conn: cursor = conn.cursor() - cursor.execute( - """ - SELECT status, login_fail_count, last_login_error - FROM accounts - WHERE id = ? - """, - (account_id,), - ) + cursor.execute(_ACCOUNT_STATUS_QUERY_SQL, (account_id,)) return cursor.fetchone() def get_account_status_batch(account_ids): """批量获取账号状态信息""" - account_ids = [str(account_id) for account_id in (account_ids or []) if account_id] - if not account_ids: + normalized_ids = _normalize_account_ids(account_ids) + if not normalized_ids: return {} results = {} chunk_size = 900 # 避免触发 SQLite 绑定参数上限 with db_pool.get_db() as conn: cursor = conn.cursor() - for idx in range(0, len(account_ids), chunk_size): - chunk = account_ids[idx : idx + chunk_size] + for idx in range(0, len(normalized_ids), chunk_size): + chunk = normalized_ids[idx : idx + chunk_size] placeholders = ",".join("?" for _ in chunk) cursor.execute( f""" diff --git a/db/admin.py b/db/admin.py index b087ad3..f3be7ca 100644 --- a/db/admin.py +++ b/db/admin.py @@ -3,9 +3,6 @@ from __future__ import annotations import sqlite3 -from datetime import datetime, timedelta - -import pytz import db_pool from db.utils import get_cst_now_str @@ -16,6 +13,99 @@ from password_utils import ( verify_password_sha256, ) +_DEFAULT_SYSTEM_CONFIG = { + "max_concurrent_global": 2, + "max_concurrent_per_account": 1, + "max_screenshot_concurrent": 3, + "schedule_enabled": 0, + "schedule_time": "02:00", + "schedule_browse_type": "应读", + "schedule_weekdays": "1,2,3,4,5,6,7", + "proxy_enabled": 0, + "proxy_api_url": "", + "proxy_expire_minutes": 3, + "enable_screenshot": 1, + "auto_approve_enabled": 0, + "auto_approve_hourly_limit": 10, + "auto_approve_vip_days": 7, + "kdocs_enabled": 0, + "kdocs_doc_url": "", + "kdocs_default_unit": "", + "kdocs_sheet_name": "", + "kdocs_sheet_index": 0, + "kdocs_unit_column": "A", + "kdocs_image_column": "D", + "kdocs_admin_notify_enabled": 0, + "kdocs_admin_notify_email": "", + "kdocs_row_start": 0, + "kdocs_row_end": 0, +} + +_SYSTEM_CONFIG_UPDATERS = ( + ("max_concurrent_global", "max_concurrent"), + ("schedule_enabled", "schedule_enabled"), + ("schedule_time", "schedule_time"), + ("schedule_browse_type", "schedule_browse_type"), + ("schedule_weekdays", "schedule_weekdays"), + ("max_concurrent_per_account", "max_concurrent_per_account"), + ("max_screenshot_concurrent", "max_screenshot_concurrent"), + ("enable_screenshot", "enable_screenshot"), + ("proxy_enabled", "proxy_enabled"), + ("proxy_api_url", "proxy_api_url"), + ("proxy_expire_minutes", "proxy_expire_minutes"), + ("auto_approve_enabled", "auto_approve_enabled"), + ("auto_approve_hourly_limit", "auto_approve_hourly_limit"), + ("auto_approve_vip_days", "auto_approve_vip_days"), + ("kdocs_enabled", "kdocs_enabled"), + ("kdocs_doc_url", "kdocs_doc_url"), + ("kdocs_default_unit", "kdocs_default_unit"), + ("kdocs_sheet_name", "kdocs_sheet_name"), + ("kdocs_sheet_index", "kdocs_sheet_index"), + ("kdocs_unit_column", "kdocs_unit_column"), + ("kdocs_image_column", "kdocs_image_column"), + ("kdocs_admin_notify_enabled", "kdocs_admin_notify_enabled"), + ("kdocs_admin_notify_email", "kdocs_admin_notify_email"), + ("kdocs_row_start", "kdocs_row_start"), + ("kdocs_row_end", "kdocs_row_end"), +) + + +def _count_scalar(cursor, sql: str, params=()) -> int: + cursor.execute(sql, params) + row = cursor.fetchone() + if not row: + return 0 + try: + if "count" in row.keys(): + return int(row["count"] or 0) + except Exception: + pass + try: + return int(row[0] or 0) + except Exception: + return 0 + + +def _table_exists(cursor, table_name: str) -> bool: + cursor.execute( + """ + SELECT name FROM sqlite_master + WHERE type='table' AND name=? + """, + (table_name,), + ) + return bool(cursor.fetchone()) + + +def _normalize_days(days, default: int = 30) -> int: + try: + value = int(days) + except Exception: + value = default + if value < 0: + return 0 + return value + def ensure_default_admin() -> bool: """确保存在默认管理员账号(行为保持不变)。""" @@ -24,10 +114,9 @@ def ensure_default_admin() -> bool: with db_pool.get_db() as conn: cursor = conn.cursor() - cursor.execute("SELECT COUNT(*) as count FROM admins") - result = cursor.fetchone() + count = _count_scalar(cursor, "SELECT COUNT(*) as count FROM admins") - if result["count"] == 0: + if count == 0: alphabet = string.ascii_letters + string.digits random_password = "".join(secrets.choice(alphabet) for _ in range(12)) @@ -101,41 +190,33 @@ def get_system_stats() -> dict: with db_pool.get_db() as conn: cursor = conn.cursor() - cursor.execute("SELECT COUNT(*) as count FROM users") - total_users = cursor.fetchone()["count"] - - cursor.execute("SELECT COUNT(*) as count FROM users WHERE status = 'approved'") - approved_users = cursor.fetchone()["count"] - - cursor.execute( + total_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users") + approved_users = _count_scalar(cursor, "SELECT COUNT(*) as count FROM users WHERE status = 'approved'") + new_users_today = _count_scalar( + cursor, """ SELECT COUNT(*) as count FROM users WHERE date(created_at) = date('now', 'localtime') - """ + """, ) - new_users_today = cursor.fetchone()["count"] - - cursor.execute( + new_users_7d = _count_scalar( + cursor, """ SELECT COUNT(*) as count FROM users WHERE datetime(created_at) >= datetime('now', 'localtime', '-7 days') - """ + """, ) - new_users_7d = cursor.fetchone()["count"] - - cursor.execute("SELECT COUNT(*) as count FROM accounts") - total_accounts = cursor.fetchone()["count"] - - cursor.execute( + total_accounts = _count_scalar(cursor, "SELECT COUNT(*) as count FROM accounts") + vip_users = _count_scalar( + cursor, """ SELECT COUNT(*) as count FROM users WHERE vip_expire_time IS NOT NULL AND datetime(vip_expire_time) > datetime('now', 'localtime') - """ + """, ) - vip_users = cursor.fetchone()["count"] return { "total_users": total_users, @@ -153,37 +234,9 @@ def get_system_config_raw() -> dict: cursor = conn.cursor() cursor.execute("SELECT * FROM system_config WHERE id = 1") row = cursor.fetchone() - if row: return dict(row) - - return { - "max_concurrent_global": 2, - "max_concurrent_per_account": 1, - "max_screenshot_concurrent": 3, - "schedule_enabled": 0, - "schedule_time": "02:00", - "schedule_browse_type": "应读", - "schedule_weekdays": "1,2,3,4,5,6,7", - "proxy_enabled": 0, - "proxy_api_url": "", - "proxy_expire_minutes": 3, - "enable_screenshot": 1, - "auto_approve_enabled": 0, - "auto_approve_hourly_limit": 10, - "auto_approve_vip_days": 7, - "kdocs_enabled": 0, - "kdocs_doc_url": "", - "kdocs_default_unit": "", - "kdocs_sheet_name": "", - "kdocs_sheet_index": 0, - "kdocs_unit_column": "A", - "kdocs_image_column": "D", - "kdocs_admin_notify_enabled": 0, - "kdocs_admin_notify_email": "", - "kdocs_row_start": 0, - "kdocs_row_end": 0, - } + return dict(_DEFAULT_SYSTEM_CONFIG) def update_system_config( @@ -215,127 +268,51 @@ def update_system_config( kdocs_row_end=None, ) -> bool: """更新系统配置(仅更新DB,不做缓存处理)。""" - allowed_fields = { - "max_concurrent_global", - "schedule_enabled", - "schedule_time", - "schedule_browse_type", - "schedule_weekdays", - "max_concurrent_per_account", - "max_screenshot_concurrent", - "enable_screenshot", - "proxy_enabled", - "proxy_api_url", - "proxy_expire_minutes", - "auto_approve_enabled", - "auto_approve_hourly_limit", - "auto_approve_vip_days", - "kdocs_enabled", - "kdocs_doc_url", - "kdocs_default_unit", - "kdocs_sheet_name", - "kdocs_sheet_index", - "kdocs_unit_column", - "kdocs_image_column", - "kdocs_admin_notify_enabled", - "kdocs_admin_notify_email", - "kdocs_row_start", - "kdocs_row_end", - "updated_at", + arg_values = { + "max_concurrent": max_concurrent, + "schedule_enabled": schedule_enabled, + "schedule_time": schedule_time, + "schedule_browse_type": schedule_browse_type, + "schedule_weekdays": schedule_weekdays, + "max_concurrent_per_account": max_concurrent_per_account, + "max_screenshot_concurrent": max_screenshot_concurrent, + "enable_screenshot": enable_screenshot, + "proxy_enabled": proxy_enabled, + "proxy_api_url": proxy_api_url, + "proxy_expire_minutes": proxy_expire_minutes, + "auto_approve_enabled": auto_approve_enabled, + "auto_approve_hourly_limit": auto_approve_hourly_limit, + "auto_approve_vip_days": auto_approve_vip_days, + "kdocs_enabled": kdocs_enabled, + "kdocs_doc_url": kdocs_doc_url, + "kdocs_default_unit": kdocs_default_unit, + "kdocs_sheet_name": kdocs_sheet_name, + "kdocs_sheet_index": kdocs_sheet_index, + "kdocs_unit_column": kdocs_unit_column, + "kdocs_image_column": kdocs_image_column, + "kdocs_admin_notify_enabled": kdocs_admin_notify_enabled, + "kdocs_admin_notify_email": kdocs_admin_notify_email, + "kdocs_row_start": kdocs_row_start, + "kdocs_row_end": kdocs_row_end, } + updates = [] + params = [] + for db_field, arg_name in _SYSTEM_CONFIG_UPDATERS: + value = arg_values.get(arg_name) + if value is None: + continue + updates.append(f"{db_field} = ?") + params.append(value) + + if not updates: + return False + + updates.append("updated_at = ?") + params.append(get_cst_now_str()) + with db_pool.get_db() as conn: cursor = conn.cursor() - updates = [] - params = [] - - if max_concurrent is not None: - updates.append("max_concurrent_global = ?") - params.append(max_concurrent) - if schedule_enabled is not None: - updates.append("schedule_enabled = ?") - params.append(schedule_enabled) - if schedule_time is not None: - updates.append("schedule_time = ?") - params.append(schedule_time) - if schedule_browse_type is not None: - updates.append("schedule_browse_type = ?") - params.append(schedule_browse_type) - if max_concurrent_per_account is not None: - updates.append("max_concurrent_per_account = ?") - params.append(max_concurrent_per_account) - if max_screenshot_concurrent is not None: - updates.append("max_screenshot_concurrent = ?") - params.append(max_screenshot_concurrent) - if enable_screenshot is not None: - updates.append("enable_screenshot = ?") - params.append(enable_screenshot) - if schedule_weekdays is not None: - updates.append("schedule_weekdays = ?") - params.append(schedule_weekdays) - if proxy_enabled is not None: - updates.append("proxy_enabled = ?") - params.append(proxy_enabled) - if proxy_api_url is not None: - updates.append("proxy_api_url = ?") - params.append(proxy_api_url) - if proxy_expire_minutes is not None: - updates.append("proxy_expire_minutes = ?") - params.append(proxy_expire_minutes) - if auto_approve_enabled is not None: - updates.append("auto_approve_enabled = ?") - params.append(auto_approve_enabled) - if auto_approve_hourly_limit is not None: - updates.append("auto_approve_hourly_limit = ?") - params.append(auto_approve_hourly_limit) - if auto_approve_vip_days is not None: - updates.append("auto_approve_vip_days = ?") - params.append(auto_approve_vip_days) - if kdocs_enabled is not None: - updates.append("kdocs_enabled = ?") - params.append(kdocs_enabled) - if kdocs_doc_url is not None: - updates.append("kdocs_doc_url = ?") - params.append(kdocs_doc_url) - if kdocs_default_unit is not None: - updates.append("kdocs_default_unit = ?") - params.append(kdocs_default_unit) - if kdocs_sheet_name is not None: - updates.append("kdocs_sheet_name = ?") - params.append(kdocs_sheet_name) - if kdocs_sheet_index is not None: - updates.append("kdocs_sheet_index = ?") - params.append(kdocs_sheet_index) - if kdocs_unit_column is not None: - updates.append("kdocs_unit_column = ?") - params.append(kdocs_unit_column) - if kdocs_image_column is not None: - updates.append("kdocs_image_column = ?") - params.append(kdocs_image_column) - if kdocs_admin_notify_enabled is not None: - updates.append("kdocs_admin_notify_enabled = ?") - params.append(kdocs_admin_notify_enabled) - if kdocs_admin_notify_email is not None: - updates.append("kdocs_admin_notify_email = ?") - params.append(kdocs_admin_notify_email) - if kdocs_row_start is not None: - updates.append("kdocs_row_start = ?") - params.append(kdocs_row_start) - if kdocs_row_end is not None: - updates.append("kdocs_row_end = ?") - params.append(kdocs_row_end) - - if not updates: - return False - - updates.append("updated_at = ?") - params.append(get_cst_now_str()) - - for update_clause in updates: - field_name = update_clause.split("=")[0].strip() - if field_name not in allowed_fields: - raise ValueError(f"非法字段名: {field_name}") - sql = f"UPDATE system_config SET {', '.join(updates)} WHERE id = 1" cursor.execute(sql, params) conn.commit() @@ -346,13 +323,13 @@ def get_hourly_registration_count() -> int: """获取最近一小时内的注册用户数""" with db_pool.get_db() as conn: cursor = conn.cursor() - cursor.execute( + return _count_scalar( + cursor, """ - SELECT COUNT(*) FROM users + SELECT COUNT(*) as count FROM users WHERE created_at >= datetime('now', 'localtime', '-1 hour') - """ + """, ) - return cursor.fetchone()[0] # ==================== 密码重置(管理员) ==================== @@ -374,17 +351,12 @@ def admin_reset_user_password(user_id: int, new_password: str) -> bool: def clean_old_operation_logs(days: int = 30) -> int: """清理指定天数前的操作日志(如果存在operation_logs表)""" + safe_days = _normalize_days(days, default=30) + with db_pool.get_db() as conn: cursor = conn.cursor() - cursor.execute( - """ - SELECT name FROM sqlite_master - WHERE type='table' AND name='operation_logs' - """ - ) - - if not cursor.fetchone(): + if not _table_exists(cursor, "operation_logs"): return 0 try: @@ -393,11 +365,11 @@ def clean_old_operation_logs(days: int = 30) -> int: DELETE FROM operation_logs WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days') """, - (days,), + (safe_days,), ) deleted_count = cursor.rowcount conn.commit() - print(f"已清理 {deleted_count} 条旧操作日志 (>{days}天)") + print(f"已清理 {deleted_count} 条旧操作日志 (>{safe_days}天)") return deleted_count except Exception as e: print(f"清理旧操作日志失败: {e}") diff --git a/db/announcements.py b/db/announcements.py index c680816..937d84a 100644 --- a/db/announcements.py +++ b/db/announcements.py @@ -6,12 +6,38 @@ import db_pool from db.utils import get_cst_now_str +def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + parsed = max(minimum, parsed) + parsed = min(maximum, parsed) + return parsed + + +def _normalize_offset(value, default: int = 0) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + return max(0, parsed) + + +def _normalize_announcement_payload(title, content, image_url): + normalized_title = str(title or "").strip() + normalized_content = str(content or "").strip() + normalized_image = str(image_url or "").strip() or None + return normalized_title, normalized_content, normalized_image + + +def _deactivate_all_active_announcements(cursor, cst_time: str) -> None: + cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,)) + + def create_announcement(title, content, image_url=None, is_active=True): """创建公告(默认启用;启用时会自动停用其他公告)""" - title = (title or "").strip() - content = (content or "").strip() - image_url = (image_url or "").strip() - image_url = image_url or None + title, content, image_url = _normalize_announcement_payload(title, content, image_url) if not title or not content: return None @@ -20,7 +46,7 @@ def create_announcement(title, content, image_url=None, is_active=True): cst_time = get_cst_now_str() if is_active: - cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,)) + _deactivate_all_active_announcements(cursor, cst_time) cursor.execute( """ @@ -44,6 +70,9 @@ def get_announcement_by_id(announcement_id): def get_announcements(limit=50, offset=0): """获取公告列表(管理员用)""" + safe_limit = _normalize_limit(limit, 50) + safe_offset = _normalize_offset(offset, 0) + with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( @@ -52,7 +81,7 @@ def get_announcements(limit=50, offset=0): ORDER BY created_at DESC, id DESC LIMIT ? OFFSET ? """, - (limit, offset), + (safe_limit, safe_offset), ) return [dict(row) for row in cursor.fetchall()] @@ -64,7 +93,7 @@ def set_announcement_active(announcement_id, is_active): cst_time = get_cst_now_str() if is_active: - cursor.execute("UPDATE announcements SET is_active = 0, updated_at = ? WHERE is_active = 1", (cst_time,)) + _deactivate_all_active_announcements(cursor, cst_time) cursor.execute( """ UPDATE announcements @@ -121,13 +150,12 @@ def dismiss_announcement_for_user(user_id, announcement_id): """用户永久关闭某条公告(幂等)""" with db_pool.get_db() as conn: cursor = conn.cursor() - cst_time = get_cst_now_str() cursor.execute( """ INSERT OR IGNORE INTO announcement_dismissals (user_id, announcement_id, dismissed_at) VALUES (?, ?, ?) """, - (user_id, announcement_id, cst_time), + (user_id, announcement_id, get_cst_now_str()), ) conn.commit() return cursor.rowcount >= 0 diff --git a/db/email.py b/db/email.py index 2716e50..b0efc4f 100644 --- a/db/email.py +++ b/db/email.py @@ -5,6 +5,27 @@ from __future__ import annotations import db_pool +def _to_bool_with_default(value, default: bool = True) -> bool: + if value is None: + return default + try: + return bool(int(value)) + except Exception: + try: + return bool(value) + except Exception: + return default + + +def _normalize_notify_enabled(enabled) -> int: + if isinstance(enabled, bool): + return 1 if enabled else 0 + try: + return 1 if int(enabled) else 0 + except Exception: + return 1 + + def get_user_by_email(email): """根据邮箱获取用户""" with db_pool.get_db() as conn: @@ -25,7 +46,7 @@ def update_user_email(user_id, email, verified=False): SET email = ?, email_verified = ? WHERE id = ? """, - (email, int(verified), user_id), + (email, 1 if verified else 0, user_id), ) conn.commit() return cursor.rowcount > 0 @@ -42,7 +63,7 @@ def update_user_email_notify(user_id, enabled): SET email_notify_enabled = ? WHERE id = ? """, - (int(enabled), user_id), + (_normalize_notify_enabled(enabled), user_id), ) conn.commit() return cursor.rowcount > 0 @@ -57,6 +78,6 @@ def get_user_email_notify(user_id): row = cursor.fetchone() if row is None: return True - return bool(row[0]) if row[0] is not None else True + return _to_bool_with_default(row[0], default=True) except Exception: return True diff --git a/db/feedbacks.py b/db/feedbacks.py index 828f96a..2245ff7 100644 --- a/db/feedbacks.py +++ b/db/feedbacks.py @@ -2,32 +2,73 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from datetime import datetime - -import pytz - import db_pool -from db.utils import escape_html +from db.utils import escape_html, get_cst_now_str + + +def _normalize_limit(value, default: int, *, minimum: int = 1, maximum: int = 500) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + parsed = max(minimum, parsed) + parsed = min(maximum, parsed) + return parsed + + +def _normalize_offset(value, default: int = 0) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + return max(0, parsed) + + +def _safe_text(value) -> str: + if value is None: + return "" + text = str(value) + return escape_html(text) if text else "" + + +def _build_feedback_filter_sql(status_filter=None) -> tuple[str, list]: + where_clauses = ["1=1"] + params = [] + + if status_filter: + where_clauses.append("status = ?") + params.append(status_filter) + + return " AND ".join(where_clauses), params + + +def _normalize_feedback_stats_row(row) -> dict: + row_dict = dict(row) if row else {} + return { + "total": int(row_dict.get("total") or 0), + "pending": int(row_dict.get("pending") or 0), + "replied": int(row_dict.get("replied") or 0), + "closed": int(row_dict.get("closed") or 0), + } def create_bug_feedback(user_id, username, title, description, contact=""): """创建Bug反馈(带XSS防护)""" with db_pool.get_db() as conn: cursor = conn.cursor() - cst_tz = pytz.timezone("Asia/Shanghai") - cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S") - - safe_title = escape_html(title) if title else "" - safe_description = escape_html(description) if description else "" - safe_contact = escape_html(contact) if contact else "" - safe_username = escape_html(username) if username else "" - cursor.execute( """ INSERT INTO bug_feedbacks (user_id, username, title, description, contact, created_at) VALUES (?, ?, ?, ?, ?, ?) """, - (user_id, safe_username, safe_title, safe_description, safe_contact, cst_time), + ( + user_id, + _safe_text(username), + _safe_text(title), + _safe_text(description), + _safe_text(contact), + get_cst_now_str(), + ), ) conn.commit() @@ -36,25 +77,25 @@ def create_bug_feedback(user_id, username, title, description, contact=""): def get_bug_feedbacks(limit=100, offset=0, status_filter=None): """获取Bug反馈列表(管理员用)""" + safe_limit = _normalize_limit(limit, 100, minimum=1, maximum=1000) + safe_offset = _normalize_offset(offset, 0) + with db_pool.get_db() as conn: cursor = conn.cursor() - - sql = "SELECT * FROM bug_feedbacks WHERE 1=1" - params = [] - - if status_filter: - sql += " AND status = ?" - params.append(status_filter) - - sql += " ORDER BY created_at DESC LIMIT ? OFFSET ?" - params.extend([limit, offset]) - - cursor.execute(sql, params) + where_sql, params = _build_feedback_filter_sql(status_filter=status_filter) + sql = f""" + SELECT * FROM bug_feedbacks + WHERE {where_sql} + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """ + cursor.execute(sql, params + [safe_limit, safe_offset]) return [dict(row) for row in cursor.fetchall()] def get_user_feedbacks(user_id, limit=50): """获取用户自己的反馈列表""" + safe_limit = _normalize_limit(limit, 50, minimum=1, maximum=1000) with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( @@ -64,7 +105,7 @@ def get_user_feedbacks(user_id, limit=50): ORDER BY created_at DESC LIMIT ? """, - (user_id, limit), + (user_id, safe_limit), ) return [dict(row) for row in cursor.fetchall()] @@ -82,18 +123,13 @@ def reply_feedback(feedback_id, admin_reply): """管理员回复反馈(带XSS防护)""" with db_pool.get_db() as conn: cursor = conn.cursor() - cst_tz = pytz.timezone("Asia/Shanghai") - cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S") - - safe_reply = escape_html(admin_reply) if admin_reply else "" - cursor.execute( """ UPDATE bug_feedbacks SET admin_reply = ?, status = 'replied', replied_at = ? WHERE id = ? """, - (safe_reply, cst_time, feedback_id), + (_safe_text(admin_reply), get_cst_now_str(), feedback_id), ) conn.commit() @@ -139,6 +175,4 @@ def get_feedback_stats(): FROM bug_feedbacks """ ) - row = cursor.fetchone() - return dict(row) if row else {"total": 0, "pending": 0, "replied": 0, "closed": 0} - + return _normalize_feedback_stats_row(cursor.fetchone()) diff --git a/db/migrations.py b/db/migrations.py index 7d9347a..c1e88db 100644 --- a/db/migrations.py +++ b/db/migrations.py @@ -28,105 +28,136 @@ def set_current_version(conn, version: int) -> None: conn.commit() +def _table_exists(cursor, table_name: str) -> bool: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (str(table_name),)) + return cursor.fetchone() is not None + + +def _get_table_columns(cursor, table_name: str) -> set[str]: + cursor.execute(f"PRAGMA table_info({table_name})") + return {col[1] for col in cursor.fetchall()} + + +def _add_column_if_missing(cursor, table_name: str, columns: set[str], column_name: str, column_ddl: str, *, ok_message: str) -> bool: + if column_name in columns: + return False + cursor.execute(f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_ddl}") + columns.add(column_name) + print(ok_message) + return True + + +def _read_row_value(row, key: str, index: int): + if isinstance(row, sqlite3.Row): + return row[key] + return row[index] + + +def _get_migration_steps(): + return [ + (1, _migrate_to_v1), + (2, _migrate_to_v2), + (3, _migrate_to_v3), + (4, _migrate_to_v4), + (5, _migrate_to_v5), + (6, _migrate_to_v6), + (7, _migrate_to_v7), + (8, _migrate_to_v8), + (9, _migrate_to_v9), + (10, _migrate_to_v10), + (11, _migrate_to_v11), + (12, _migrate_to_v12), + (13, _migrate_to_v13), + (14, _migrate_to_v14), + (15, _migrate_to_v15), + (16, _migrate_to_v16), + (17, _migrate_to_v17), + (18, _migrate_to_v18), + ] + + def migrate_database(conn, target_version: int) -> None: """数据库迁移:按版本增量升级(向前兼容)。""" cursor = conn.cursor() cursor.execute("INSERT OR IGNORE INTO db_version (id, version, updated_at) VALUES (1, 0, ?)", (get_cst_now_str(),)) conn.commit() + target_version = int(target_version) current_version = get_current_version(conn) - if current_version < 1: - _migrate_to_v1(conn) - current_version = 1 - if current_version < 2: - _migrate_to_v2(conn) - current_version = 2 - if current_version < 3: - _migrate_to_v3(conn) - current_version = 3 - if current_version < 4: - _migrate_to_v4(conn) - current_version = 4 - if current_version < 5: - _migrate_to_v5(conn) - current_version = 5 - if current_version < 6: - _migrate_to_v6(conn) - current_version = 6 - if current_version < 7: - _migrate_to_v7(conn) - current_version = 7 - if current_version < 8: - _migrate_to_v8(conn) - current_version = 8 - if current_version < 9: - _migrate_to_v9(conn) - current_version = 9 - if current_version < 10: - _migrate_to_v10(conn) - current_version = 10 - if current_version < 11: - _migrate_to_v11(conn) - current_version = 11 - if current_version < 12: - _migrate_to_v12(conn) - current_version = 12 - if current_version < 13: - _migrate_to_v13(conn) - current_version = 13 - if current_version < 14: - _migrate_to_v14(conn) - current_version = 14 - if current_version < 15: - _migrate_to_v15(conn) - current_version = 15 - if current_version < 16: - _migrate_to_v16(conn) - current_version = 16 - if current_version < 17: - _migrate_to_v17(conn) - current_version = 17 - if current_version < 18: - _migrate_to_v18(conn) - current_version = 18 + for version, migrate_fn in _get_migration_steps(): + if version > target_version or current_version >= version: + continue + migrate_fn(conn) + current_version = version - if current_version != int(target_version): - set_current_version(conn, int(target_version)) + if current_version != target_version: + set_current_version(conn, target_version) def _migrate_to_v1(conn): """迁移到版本1 - 添加缺失字段""" cursor = conn.cursor() - cursor.execute("PRAGMA table_info(system_config)") - columns = [col[1] for col in cursor.fetchall()] + system_columns = _get_table_columns(cursor, "system_config") + _add_column_if_missing( + cursor, + "system_config", + system_columns, + "schedule_weekdays", + 'TEXT DEFAULT "1,2,3,4,5,6,7"', + ok_message=" [OK] 添加 schedule_weekdays 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + system_columns, + "max_screenshot_concurrent", + "INTEGER DEFAULT 3", + ok_message=" [OK] 添加 max_screenshot_concurrent 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + system_columns, + "max_concurrent_per_account", + "INTEGER DEFAULT 1", + ok_message=" [OK] 添加 max_concurrent_per_account 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + system_columns, + "auto_approve_enabled", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 auto_approve_enabled 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + system_columns, + "auto_approve_hourly_limit", + "INTEGER DEFAULT 10", + ok_message=" [OK] 添加 auto_approve_hourly_limit 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + system_columns, + "auto_approve_vip_days", + "INTEGER DEFAULT 7", + ok_message=" [OK] 添加 auto_approve_vip_days 字段", + ) - if "schedule_weekdays" not in columns: - cursor.execute('ALTER TABLE system_config ADD COLUMN schedule_weekdays TEXT DEFAULT "1,2,3,4,5,6,7"') - print(" [OK] 添加 schedule_weekdays 字段") - - if "max_screenshot_concurrent" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN max_screenshot_concurrent INTEGER DEFAULT 3") - print(" [OK] 添加 max_screenshot_concurrent 字段") - if "max_concurrent_per_account" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN max_concurrent_per_account INTEGER DEFAULT 1") - print(" [OK] 添加 max_concurrent_per_account 字段") - if "auto_approve_enabled" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_enabled INTEGER DEFAULT 0") - print(" [OK] 添加 auto_approve_enabled 字段") - if "auto_approve_hourly_limit" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_hourly_limit INTEGER DEFAULT 10") - print(" [OK] 添加 auto_approve_hourly_limit 字段") - if "auto_approve_vip_days" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN auto_approve_vip_days INTEGER DEFAULT 7") - print(" [OK] 添加 auto_approve_vip_days 字段") - - cursor.execute("PRAGMA table_info(task_logs)") - columns = [col[1] for col in cursor.fetchall()] - if "duration" not in columns: - cursor.execute("ALTER TABLE task_logs ADD COLUMN duration INTEGER") - print(" [OK] 添加 duration 字段到 task_logs") + task_log_columns = _get_table_columns(cursor, "task_logs") + _add_column_if_missing( + cursor, + "task_logs", + task_log_columns, + "duration", + "INTEGER", + ok_message=" [OK] 添加 duration 字段到 task_logs", + ) conn.commit() @@ -135,24 +166,39 @@ def _migrate_to_v2(conn): """迁移到版本2 - 添加代理配置字段""" cursor = conn.cursor() - cursor.execute("PRAGMA table_info(system_config)") - columns = [col[1] for col in cursor.fetchall()] - - if "proxy_enabled" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_enabled INTEGER DEFAULT 0") - print(" [OK] 添加 proxy_enabled 字段") - - if "proxy_api_url" not in columns: - cursor.execute('ALTER TABLE system_config ADD COLUMN proxy_api_url TEXT DEFAULT ""') - print(" [OK] 添加 proxy_api_url 字段") - - if "proxy_expire_minutes" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN proxy_expire_minutes INTEGER DEFAULT 3") - print(" [OK] 添加 proxy_expire_minutes 字段") - - if "enable_screenshot" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN enable_screenshot INTEGER DEFAULT 1") - print(" [OK] 添加 enable_screenshot 字段") + columns = _get_table_columns(cursor, "system_config") + _add_column_if_missing( + cursor, + "system_config", + columns, + "proxy_enabled", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 proxy_enabled 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + columns, + "proxy_api_url", + 'TEXT DEFAULT ""', + ok_message=" [OK] 添加 proxy_api_url 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + columns, + "proxy_expire_minutes", + "INTEGER DEFAULT 3", + ok_message=" [OK] 添加 proxy_expire_minutes 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + columns, + "enable_screenshot", + "INTEGER DEFAULT 1", + ok_message=" [OK] 添加 enable_screenshot 字段", + ) conn.commit() @@ -161,20 +207,31 @@ def _migrate_to_v3(conn): """迁移到版本3 - 添加账号状态和登录失败计数字段""" cursor = conn.cursor() - cursor.execute("PRAGMA table_info(accounts)") - columns = [col[1] for col in cursor.fetchall()] - - if "status" not in columns: - cursor.execute('ALTER TABLE accounts ADD COLUMN status TEXT DEFAULT "active"') - print(" [OK] 添加 accounts.status 字段 (账号状态)") - - if "login_fail_count" not in columns: - cursor.execute("ALTER TABLE accounts ADD COLUMN login_fail_count INTEGER DEFAULT 0") - print(" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)") - - if "last_login_error" not in columns: - cursor.execute("ALTER TABLE accounts ADD COLUMN last_login_error TEXT") - print(" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)") + columns = _get_table_columns(cursor, "accounts") + _add_column_if_missing( + cursor, + "accounts", + columns, + "status", + 'TEXT DEFAULT "active"', + ok_message=" [OK] 添加 accounts.status 字段 (账号状态)", + ) + _add_column_if_missing( + cursor, + "accounts", + columns, + "login_fail_count", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 accounts.login_fail_count 字段 (登录失败计数)", + ) + _add_column_if_missing( + cursor, + "accounts", + columns, + "last_login_error", + "TEXT", + ok_message=" [OK] 添加 accounts.last_login_error 字段 (最后登录错误)", + ) conn.commit() @@ -183,12 +240,15 @@ def _migrate_to_v4(conn): """迁移到版本4 - 添加任务来源字段""" cursor = conn.cursor() - cursor.execute("PRAGMA table_info(task_logs)") - columns = [col[1] for col in cursor.fetchall()] - - if "source" not in columns: - cursor.execute('ALTER TABLE task_logs ADD COLUMN source TEXT DEFAULT "manual"') - print(" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)") + columns = _get_table_columns(cursor, "task_logs") + _add_column_if_missing( + cursor, + "task_logs", + columns, + "source", + 'TEXT DEFAULT "manual"', + ok_message=" [OK] 添加 task_logs.source 字段 (任务来源: manual/scheduled/immediate)", + ) conn.commit() @@ -300,20 +360,17 @@ def _migrate_to_v6(conn): def _migrate_to_v7(conn): """迁移到版本7 - 统一存储北京时间(将历史UTC时间字段整体+8小时)""" cursor = conn.cursor() - - def table_exists(table_name: str) -> bool: - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) - return cursor.fetchone() is not None - - def column_exists(table_name: str, column_name: str) -> bool: - cursor.execute(f"PRAGMA table_info({table_name})") - return any(row[1] == column_name for row in cursor.fetchall()) + columns_cache: dict[str, set[str]] = {} def shift_utc_to_cst(table_name: str, column_name: str) -> None: - if not table_exists(table_name): + if not _table_exists(cursor, table_name): return - if not column_exists(table_name, column_name): + + if table_name not in columns_cache: + columns_cache[table_name] = _get_table_columns(cursor, table_name) + if column_name not in columns_cache[table_name]: return + cursor.execute( f""" UPDATE {table_name} @@ -329,10 +386,6 @@ def _migrate_to_v7(conn): ("accounts", "created_at"), ("password_reset_requests", "created_at"), ("password_reset_requests", "processed_at"), - ]: - shift_utc_to_cst(table, col) - - for table, col in [ ("smtp_configs", "created_at"), ("smtp_configs", "updated_at"), ("smtp_configs", "last_success_at"), @@ -340,10 +393,6 @@ def _migrate_to_v7(conn): ("email_tokens", "created_at"), ("email_logs", "created_at"), ("email_stats", "last_updated"), - ]: - shift_utc_to_cst(table, col) - - for table, col in [ ("task_checkpoints", "created_at"), ("task_checkpoints", "updated_at"), ("task_checkpoints", "completed_at"), @@ -359,15 +408,23 @@ def _migrate_to_v8(conn): cursor = conn.cursor() # 1) 增量字段:random_delay(旧库可能不存在) - cursor.execute("PRAGMA table_info(user_schedules)") - columns = [col[1] for col in cursor.fetchall()] - if "random_delay" not in columns: - cursor.execute("ALTER TABLE user_schedules ADD COLUMN random_delay INTEGER DEFAULT 0") - print(" [OK] 添加 user_schedules.random_delay 字段") - - if "next_run_at" not in columns: - cursor.execute("ALTER TABLE user_schedules ADD COLUMN next_run_at TIMESTAMP") - print(" [OK] 添加 user_schedules.next_run_at 字段") + columns = _get_table_columns(cursor, "user_schedules") + _add_column_if_missing( + cursor, + "user_schedules", + columns, + "random_delay", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 user_schedules.random_delay 字段", + ) + _add_column_if_missing( + cursor, + "user_schedules", + columns, + "next_run_at", + "TIMESTAMP", + ok_message=" [OK] 添加 user_schedules.next_run_at 字段", + ) cursor.execute("CREATE INDEX IF NOT EXISTS idx_user_schedules_next_run ON user_schedules(next_run_at)") conn.commit() @@ -392,12 +449,12 @@ def _migrate_to_v8(conn): fixed = 0 for row in rows: try: - schedule_id = row["id"] if isinstance(row, sqlite3.Row) else row[0] - schedule_time = row["schedule_time"] if isinstance(row, sqlite3.Row) else row[1] - weekdays = row["weekdays"] if isinstance(row, sqlite3.Row) else row[2] - random_delay = row["random_delay"] if isinstance(row, sqlite3.Row) else row[3] - last_run_at = row["last_run_at"] if isinstance(row, sqlite3.Row) else row[4] - next_run_at = row["next_run_at"] if isinstance(row, sqlite3.Row) else row[5] + schedule_id = _read_row_value(row, "id", 0) + schedule_time = _read_row_value(row, "schedule_time", 1) + weekdays = _read_row_value(row, "weekdays", 2) + random_delay = _read_row_value(row, "random_delay", 3) + last_run_at = _read_row_value(row, "last_run_at", 4) + next_run_at = _read_row_value(row, "next_run_at", 5) except Exception: continue @@ -430,27 +487,46 @@ def _migrate_to_v9(conn): """迁移到版本9 - 邮件设置字段迁移(清理 email_service scattered ALTER TABLE)""" cursor = conn.cursor() - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'") - if not cursor.fetchone(): + if not _table_exists(cursor, "email_settings"): # 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移 return - cursor.execute("PRAGMA table_info(email_settings)") - columns = [col[1] for col in cursor.fetchall()] + columns = _get_table_columns(cursor, "email_settings") changed = False - if "register_verify_enabled" not in columns: - cursor.execute("ALTER TABLE email_settings ADD COLUMN register_verify_enabled INTEGER DEFAULT 0") - print(" [OK] 添加 email_settings.register_verify_enabled 字段") - changed = True - if "base_url" not in columns: - cursor.execute("ALTER TABLE email_settings ADD COLUMN base_url TEXT DEFAULT ''") - print(" [OK] 添加 email_settings.base_url 字段") - changed = True - if "task_notify_enabled" not in columns: - cursor.execute("ALTER TABLE email_settings ADD COLUMN task_notify_enabled INTEGER DEFAULT 0") - print(" [OK] 添加 email_settings.task_notify_enabled 字段") - changed = True + changed = ( + _add_column_if_missing( + cursor, + "email_settings", + columns, + "register_verify_enabled", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 email_settings.register_verify_enabled 字段", + ) + or changed + ) + changed = ( + _add_column_if_missing( + cursor, + "email_settings", + columns, + "base_url", + "TEXT DEFAULT ''", + ok_message=" [OK] 添加 email_settings.base_url 字段", + ) + or changed + ) + changed = ( + _add_column_if_missing( + cursor, + "email_settings", + columns, + "task_notify_enabled", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 email_settings.task_notify_enabled 字段", + ) + or changed + ) if changed: conn.commit() @@ -459,18 +535,31 @@ def _migrate_to_v9(conn): def _migrate_to_v10(conn): """迁移到版本10 - users 邮箱字段迁移(避免运行时 ALTER TABLE)""" cursor = conn.cursor() - cursor.execute("PRAGMA table_info(users)") - columns = [col[1] for col in cursor.fetchall()] + columns = _get_table_columns(cursor, "users") changed = False - if "email_verified" not in columns: - cursor.execute("ALTER TABLE users ADD COLUMN email_verified INTEGER DEFAULT 0") - print(" [OK] 添加 users.email_verified 字段") - changed = True - if "email_notify_enabled" not in columns: - cursor.execute("ALTER TABLE users ADD COLUMN email_notify_enabled INTEGER DEFAULT 1") - print(" [OK] 添加 users.email_notify_enabled 字段") - changed = True + changed = ( + _add_column_if_missing( + cursor, + "users", + columns, + "email_verified", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 users.email_verified 字段", + ) + or changed + ) + changed = ( + _add_column_if_missing( + cursor, + "users", + columns, + "email_notify_enabled", + "INTEGER DEFAULT 1", + ok_message=" [OK] 添加 users.email_notify_enabled 字段", + ) + or changed + ) if changed: conn.commit() @@ -657,19 +746,24 @@ def _migrate_to_v15(conn): """迁移到版本15 - 邮件设置:新设备登录提醒全局开关""" cursor = conn.cursor() - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='email_settings'") - if not cursor.fetchone(): + if not _table_exists(cursor, "email_settings"): # 邮件表由 email_service.init_email_tables 创建;此处仅做增量字段迁移 return - cursor.execute("PRAGMA table_info(email_settings)") - columns = [col[1] for col in cursor.fetchall()] + columns = _get_table_columns(cursor, "email_settings") changed = False - if "login_alert_enabled" not in columns: - cursor.execute("ALTER TABLE email_settings ADD COLUMN login_alert_enabled INTEGER DEFAULT 1") - print(" [OK] 添加 email_settings.login_alert_enabled 字段") - changed = True + changed = ( + _add_column_if_missing( + cursor, + "email_settings", + columns, + "login_alert_enabled", + "INTEGER DEFAULT 1", + ok_message=" [OK] 添加 email_settings.login_alert_enabled 字段", + ) + or changed + ) try: cursor.execute("UPDATE email_settings SET login_alert_enabled = 1 WHERE login_alert_enabled IS NULL") @@ -686,22 +780,24 @@ def _migrate_to_v15(conn): def _migrate_to_v16(conn): """迁移到版本16 - 公告支持图片字段""" cursor = conn.cursor() - cursor.execute("PRAGMA table_info(announcements)") - columns = [col[1] for col in cursor.fetchall()] + columns = _get_table_columns(cursor, "announcements") - if "image_url" not in columns: - cursor.execute("ALTER TABLE announcements ADD COLUMN image_url TEXT") + if _add_column_if_missing( + cursor, + "announcements", + columns, + "image_url", + "TEXT", + ok_message=" [OK] 添加 announcements.image_url 字段", + ): conn.commit() - print(" [OK] 添加 announcements.image_url 字段") def _migrate_to_v17(conn): """迁移到版本17 - 金山文档上传配置与用户开关""" cursor = conn.cursor() - cursor.execute("PRAGMA table_info(system_config)") - columns = [col[1] for col in cursor.fetchall()] - + system_columns = _get_table_columns(cursor, "system_config") system_fields = [ ("kdocs_enabled", "INTEGER DEFAULT 0"), ("kdocs_doc_url", "TEXT DEFAULT ''"), @@ -714,21 +810,29 @@ def _migrate_to_v17(conn): ("kdocs_admin_notify_email", "TEXT DEFAULT ''"), ] for field, ddl in system_fields: - if field not in columns: - cursor.execute(f"ALTER TABLE system_config ADD COLUMN {field} {ddl}") - print(f" [OK] 添加 system_config.{field} 字段") - - cursor.execute("PRAGMA table_info(users)") - columns = [col[1] for col in cursor.fetchall()] + _add_column_if_missing( + cursor, + "system_config", + system_columns, + field, + ddl, + ok_message=f" [OK] 添加 system_config.{field} 字段", + ) + user_columns = _get_table_columns(cursor, "users") user_fields = [ ("kdocs_unit", "TEXT DEFAULT ''"), ("kdocs_auto_upload", "INTEGER DEFAULT 0"), ] for field, ddl in user_fields: - if field not in columns: - cursor.execute(f"ALTER TABLE users ADD COLUMN {field} {ddl}") - print(f" [OK] 添加 users.{field} 字段") + _add_column_if_missing( + cursor, + "users", + user_columns, + field, + ddl, + ok_message=f" [OK] 添加 users.{field} 字段", + ) conn.commit() @@ -737,15 +841,22 @@ def _migrate_to_v18(conn): """迁移到版本18 - 金山文档上传:有效行范围配置""" cursor = conn.cursor() - cursor.execute("PRAGMA table_info(system_config)") - columns = [col[1] for col in cursor.fetchall()] - - if "kdocs_row_start" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_start INTEGER DEFAULT 0") - print(" [OK] 添加 system_config.kdocs_row_start 字段") - - if "kdocs_row_end" not in columns: - cursor.execute("ALTER TABLE system_config ADD COLUMN kdocs_row_end INTEGER DEFAULT 0") - print(" [OK] 添加 system_config.kdocs_row_end 字段") + columns = _get_table_columns(cursor, "system_config") + _add_column_if_missing( + cursor, + "system_config", + columns, + "kdocs_row_start", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 system_config.kdocs_row_start 字段", + ) + _add_column_if_missing( + cursor, + "system_config", + columns, + "kdocs_row_end", + "INTEGER DEFAULT 0", + ok_message=" [OK] 添加 system_config.kdocs_row_end 字段", + ) conn.commit() diff --git a/db/schedules.py b/db/schedules.py index 4294169..d769804 100644 --- a/db/schedules.py +++ b/db/schedules.py @@ -2,12 +2,93 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from datetime import datetime +import json +from datetime import datetime, timedelta import db_pool from services.schedule_utils import compute_next_run_at, format_cst from services.time_utils import get_beijing_now +_SCHEDULE_DEFAULT_TIME = "08:00" +_SCHEDULE_DEFAULT_WEEKDAYS = "1,2,3,4,5" + +_ALLOWED_SCHEDULE_UPDATE_FIELDS = ( + "name", + "enabled", + "schedule_time", + "weekdays", + "browse_type", + "enable_screenshot", + "random_delay", + "account_ids", +) + +_ALLOWED_EXEC_LOG_UPDATE_FIELDS = ( + "total_accounts", + "success_accounts", + "failed_accounts", + "total_items", + "total_attachments", + "total_screenshots", + "duration_seconds", + "status", + "error_message", +) + + +def _normalize_limit(limit, default: int, *, minimum: int = 1) -> int: + try: + parsed = int(limit) + except Exception: + parsed = default + if parsed < minimum: + return minimum + return parsed + + +def _to_int(value, default: int = 0) -> int: + try: + return int(value) + except Exception: + return default + + +def _format_optional_datetime(dt: datetime | None) -> str | None: + if dt is None: + return None + return format_cst(dt) + + +def _serialize_account_ids(account_ids) -> str: + return json.dumps(account_ids) if account_ids else "[]" + + +def _compute_schedule_next_run_str( + *, + now_dt, + schedule_time, + weekdays, + random_delay, + last_run_at, +) -> str: + next_dt = compute_next_run_at( + now=now_dt, + schedule_time=str(schedule_time or _SCHEDULE_DEFAULT_TIME), + weekdays=str(weekdays or _SCHEDULE_DEFAULT_WEEKDAYS), + random_delay=_to_int(random_delay, 0), + last_run_at=str(last_run_at or "") if last_run_at else None, + ) + return format_cst(next_dt) + + +def _map_schedule_log_row(row) -> dict: + log = dict(row) + log["created_at"] = log.get("execute_time") + log["success_count"] = log.get("success_accounts", 0) + log["failed_count"] = log.get("failed_accounts", 0) + log["duration"] = log.get("duration_seconds", 0) + return log + def get_user_schedules(user_id): """获取用户的所有定时任务""" @@ -44,14 +125,10 @@ def create_user_schedule( account_ids=None, ): """创建用户定时任务""" - import json - with db_pool.get_db() as conn: cursor = conn.cursor() cst_time = format_cst(get_beijing_now()) - account_ids_str = json.dumps(account_ids) if account_ids else "[]" - cursor.execute( """ INSERT INTO user_schedules ( @@ -66,8 +143,8 @@ def create_user_schedule( weekdays, browse_type, enable_screenshot, - int(random_delay or 0), - account_ids_str, + _to_int(random_delay, 0), + _serialize_account_ids(account_ids), cst_time, cst_time, ), @@ -79,28 +156,11 @@ def create_user_schedule( def update_user_schedule(schedule_id, **kwargs): """更新用户定时任务""" - import json - with db_pool.get_db() as conn: cursor = conn.cursor() now_dt = get_beijing_now() now_str = format_cst(now_dt) - updates = [] - params = [] - - allowed_fields = [ - "name", - "enabled", - "schedule_time", - "weekdays", - "browse_type", - "enable_screenshot", - "random_delay", - "account_ids", - ] - - # 读取旧值,用于决定是否需要重算 next_run_at cursor.execute( """ SELECT enabled, schedule_time, weekdays, random_delay, last_run_at @@ -112,10 +172,11 @@ def update_user_schedule(schedule_id, **kwargs): current = cursor.fetchone() if not current: return False - current_enabled = int(current[0] or 0) + + current_enabled = _to_int(current[0], 0) current_time = current[1] current_weekdays = current[2] - current_random_delay = int(current[3] or 0) + current_random_delay = _to_int(current[3], 0) current_last_run_at = current[4] will_enabled = current_enabled @@ -123,21 +184,28 @@ def update_user_schedule(schedule_id, **kwargs): next_weekdays = current_weekdays next_random_delay = current_random_delay - for field in allowed_fields: - if field in kwargs: - value = kwargs[field] - if field == "account_ids" and isinstance(value, list): - value = json.dumps(value) - if field == "enabled": - will_enabled = 1 if value else 0 - if field == "schedule_time": - next_time = value - if field == "weekdays": - next_weekdays = value - if field == "random_delay": - next_random_delay = int(value or 0) - updates.append(f"{field} = ?") - params.append(value) + updates = [] + params = [] + + for field in _ALLOWED_SCHEDULE_UPDATE_FIELDS: + if field not in kwargs: + continue + + value = kwargs[field] + if field == "account_ids" and isinstance(value, list): + value = json.dumps(value) + + if field == "enabled": + will_enabled = 1 if value else 0 + if field == "schedule_time": + next_time = value + if field == "weekdays": + next_weekdays = value + if field == "random_delay": + next_random_delay = int(value or 0) + + updates.append(f"{field} = ?") + params.append(value) if not updates: return False @@ -145,30 +213,26 @@ def update_user_schedule(schedule_id, **kwargs): updates.append("updated_at = ?") params.append(now_str) - # 关键字段变更后重算 next_run_at,确保索引驱动不会跑偏 - # - # 需求:当用户修改“执行时间/执行日期/随机±15分钟”后,即使今天已经执行过,也允许按新配置在今天再次触发。 - # 做法:这些关键字段发生变更时,重算 next_run_at 时忽略 last_run_at 的“同日仅一次”限制。 - config_changed = any(key in kwargs for key in ["schedule_time", "weekdays", "random_delay"]) + config_changed = any(key in kwargs for key in ("schedule_time", "weekdays", "random_delay")) enabled_toggled = "enabled" in kwargs should_recompute_next = config_changed or (enabled_toggled and will_enabled == 1) + if should_recompute_next: - next_dt = compute_next_run_at( - now=now_dt, - schedule_time=str(next_time or "08:00"), - weekdays=str(next_weekdays or "1,2,3,4,5"), - random_delay=int(next_random_delay or 0), - last_run_at=None if config_changed else (str(current_last_run_at or "") if current_last_run_at else None), + next_run_at = _compute_schedule_next_run_str( + now_dt=now_dt, + schedule_time=next_time, + weekdays=next_weekdays, + random_delay=next_random_delay, + last_run_at=None if config_changed else current_last_run_at, ) updates.append("next_run_at = ?") - params.append(format_cst(next_dt)) + params.append(next_run_at) - # 若本次显式禁用任务,则 next_run_at 清空(与 toggle 行为保持一致) if enabled_toggled and will_enabled == 0: updates.append("next_run_at = ?") params.append(None) - params.append(schedule_id) + params.append(schedule_id) sql = f"UPDATE user_schedules SET {', '.join(updates)} WHERE id = ?" cursor.execute(sql, params) conn.commit() @@ -203,28 +267,19 @@ def toggle_user_schedule(schedule_id, enabled): ) row = cursor.fetchone() if row: - schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = ( - row[0], - row[1], - row[2], - row[3], - row[4], - ) - + schedule_time, weekdays, random_delay, last_run_at, existing_next_run_at = row existing_next_run_at = str(existing_next_run_at or "").strip() or None - # 若 next_run_at 已经被“修改配置”逻辑预先计算好且仍在未来,则优先沿用, - # 避免 last_run_at 的“同日仅一次”限制阻塞用户把任务调整到今天再次触发。 + if existing_next_run_at and existing_next_run_at > now_str: next_run_at = existing_next_run_at else: - next_dt = compute_next_run_at( - now=now_dt, - schedule_time=str(schedule_time or "08:00"), - weekdays=str(weekdays or "1,2,3,4,5"), - random_delay=int(random_delay or 0), - last_run_at=str(last_run_at or "") if last_run_at else None, + next_run_at = _compute_schedule_next_run_str( + now_dt=now_dt, + schedule_time=schedule_time, + weekdays=weekdays, + random_delay=random_delay, + last_run_at=last_run_at, ) - next_run_at = format_cst(next_dt) cursor.execute( """ @@ -272,16 +327,15 @@ def update_schedule_last_run(schedule_id): row = cursor.fetchone() if not row: return False - schedule_time, weekdays, random_delay = row[0], row[1], row[2] - next_dt = compute_next_run_at( - now=now_dt, - schedule_time=str(schedule_time or "08:00"), - weekdays=str(weekdays or "1,2,3,4,5"), - random_delay=int(random_delay or 0), + schedule_time, weekdays, random_delay = row + next_run_at = _compute_schedule_next_run_str( + now_dt=now_dt, + schedule_time=schedule_time, + weekdays=weekdays, + random_delay=random_delay, last_run_at=now_str, ) - next_run_at = format_cst(next_dt) cursor.execute( """ @@ -305,7 +359,11 @@ def update_schedule_next_run(schedule_id: int, next_run_at: str) -> bool: SET next_run_at = ?, updated_at = ? WHERE id = ? """, - (str(next_run_at or "").strip() or None, format_cst(get_beijing_now()), int(schedule_id)), + ( + str(next_run_at or "").strip() or None, + format_cst(get_beijing_now()), + int(schedule_id), + ), ) conn.commit() return cursor.rowcount > 0 @@ -328,15 +386,15 @@ def recompute_schedule_next_run(schedule_id: int, *, now_dt=None) -> bool: if not row: return False - schedule_time, weekdays, random_delay, last_run_at = row[0], row[1], row[2], row[3] - next_dt = compute_next_run_at( - now=now_dt, - schedule_time=str(schedule_time or "08:00"), - weekdays=str(weekdays or "1,2,3,4,5"), - random_delay=int(random_delay or 0), - last_run_at=str(last_run_at or "") if last_run_at else None, + schedule_time, weekdays, random_delay, last_run_at = row + next_run_at = _compute_schedule_next_run_str( + now_dt=now_dt, + schedule_time=schedule_time, + weekdays=weekdays, + random_delay=random_delay, + last_run_at=last_run_at, ) - return update_schedule_next_run(int(schedule_id), format_cst(next_dt)) + return update_schedule_next_run(int(schedule_id), next_run_at) def get_due_user_schedules(now_cst: str, limit: int = 50): @@ -345,6 +403,8 @@ def get_due_user_schedules(now_cst: str, limit: int = 50): if not now_cst: now_cst = format_cst(get_beijing_now()) + safe_limit = _normalize_limit(limit, 50, minimum=1) + with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( @@ -358,7 +418,7 @@ def get_due_user_schedules(now_cst: str, limit: int = 50): ORDER BY us.next_run_at ASC LIMIT ? """, - (now_cst, int(limit)), + (now_cst, safe_limit), ) return [dict(row) for row in cursor.fetchall()] @@ -370,15 +430,13 @@ def create_schedule_execution_log(schedule_id, user_id, schedule_name): """创建定时任务执行日志""" with db_pool.get_db() as conn: cursor = conn.cursor() - execute_time = format_cst(get_beijing_now()) - cursor.execute( """ INSERT INTO schedule_execution_logs ( schedule_id, user_id, schedule_name, execute_time, status ) VALUES (?, ?, ?, ?, 'running') """, - (schedule_id, user_id, schedule_name, execute_time), + (schedule_id, user_id, schedule_name, format_cst(get_beijing_now())), ) conn.commit() @@ -393,22 +451,11 @@ def update_schedule_execution_log(log_id, **kwargs): updates = [] params = [] - allowed_fields = [ - "total_accounts", - "success_accounts", - "failed_accounts", - "total_items", - "total_attachments", - "total_screenshots", - "duration_seconds", - "status", - "error_message", - ] - - for field in allowed_fields: - if field in kwargs: - updates.append(f"{field} = ?") - params.append(kwargs[field]) + for field in _ALLOWED_EXEC_LOG_UPDATE_FIELDS: + if field not in kwargs: + continue + updates.append(f"{field} = ?") + params.append(kwargs[field]) if not updates: return False @@ -424,6 +471,7 @@ def update_schedule_execution_log(log_id, **kwargs): def get_schedule_execution_logs(schedule_id, limit=10): """获取定时任务执行日志""" try: + safe_limit = _normalize_limit(limit, 10, minimum=1) with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( @@ -433,24 +481,16 @@ def get_schedule_execution_logs(schedule_id, limit=10): ORDER BY execute_time DESC LIMIT ? """, - (schedule_id, limit), + (schedule_id, safe_limit), ) logs = [] - rows = cursor.fetchall() - - for row in rows: + for row in cursor.fetchall(): try: - log = dict(row) - log["created_at"] = log.get("execute_time") - log["success_count"] = log.get("success_accounts", 0) - log["failed_count"] = log.get("failed_accounts", 0) - log["duration"] = log.get("duration_seconds", 0) - logs.append(log) + logs.append(_map_schedule_log_row(row)) except Exception as e: print(f"[数据库] 处理日志行时出错: {e}") continue - return logs except Exception as e: print(f"[数据库] 查询定时任务日志时出错: {e}") @@ -462,6 +502,7 @@ def get_schedule_execution_logs(schedule_id, limit=10): def get_user_all_schedule_logs(user_id, limit=50): """获取用户所有定时任务的执行日志""" + safe_limit = _normalize_limit(limit, 50, minimum=1) with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( @@ -471,7 +512,7 @@ def get_user_all_schedule_logs(user_id, limit=50): ORDER BY execute_time DESC LIMIT ? """, - (user_id, limit), + (user_id, safe_limit), ) return [dict(row) for row in cursor.fetchall()] @@ -493,14 +534,21 @@ def delete_schedule_logs(schedule_id, user_id): def clean_old_schedule_logs(days=30): """清理指定天数前的定时任务执行日志""" + safe_days = _to_int(days, 30) + if safe_days < 0: + safe_days = 0 + + cutoff_dt = get_beijing_now() - timedelta(days=safe_days) + cutoff_str = format_cst(cutoff_dt) + with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute( """ DELETE FROM schedule_execution_logs - WHERE execute_time < datetime('now', 'localtime', '-' || ? || ' days') + WHERE execute_time < ? """, - (days,), + (cutoff_str,), ) conn.commit() return cursor.rowcount diff --git a/db/schema.py b/db/schema.py index 59108a7..14ec136 100644 --- a/db/schema.py +++ b/db/schema.py @@ -362,6 +362,8 @@ def ensure_schema(conn) -> None: cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_status ON users(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_vip_expire ON users(vip_expire_time)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_fingerprints_user ON login_fingerprints(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_login_ips_user ON login_ips(user_id)") @@ -391,6 +393,8 @@ def ensure_schema(conn) -> None: cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_id ON task_logs(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_status ON task_logs(status)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_created_at ON task_logs(created_at)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source ON task_logs(source)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_source_created_at ON task_logs(source, created_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_task_logs_user_date ON task_logs(user_id, created_at)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_bug_feedbacks_user_id ON bug_feedbacks(user_id)") @@ -409,6 +413,9 @@ def ensure_schema(conn) -> None: cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_id ON schedule_execution_logs(schedule_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_id ON schedule_execution_logs(user_id)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_status ON schedule_execution_logs(status)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_execute_time ON schedule_execution_logs(execute_time)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_schedule_time ON schedule_execution_logs(schedule_id, execute_time)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_schedule_execution_logs_user_time ON schedule_execution_logs(user_id, execute_time)") # 初始化VIP配置(幂等) try: diff --git a/db/security.py b/db/security.py index 79ad0f3..677627f 100644 --- a/db/security.py +++ b/db/security.py @@ -3,13 +3,82 @@ from __future__ import annotations from datetime import timedelta -from typing import Any, Optional -from typing import Dict +from typing import Any, Dict, Optional import db_pool from db.utils import get_cst_now, get_cst_now_str +_THREAT_EVENT_SELECT_COLUMNS = """ + id, + threat_type, + score, + rule, + field_name, + matched, + value_preview, + ip, + user_id, + request_method, + request_path, + user_agent, + created_at +""" + + +def _normalize_page(page: int) -> int: + try: + page_i = int(page) + except Exception: + page_i = 1 + return max(1, page_i) + + +def _normalize_per_page(per_page: int, default: int = 20) -> int: + try: + value = int(per_page) + except Exception: + value = default + return max(1, min(200, value)) + + +def _normalize_limit(limit: int, default: int = 50) -> int: + try: + value = int(limit) + except Exception: + value = default + return max(1, min(200, value)) + + +def _row_value(row, key: str, index: int = 0, default=None): + if row is None: + return default + try: + return row[key] + except Exception: + try: + return row[index] + except Exception: + return default + + +def _fetch_threat_events_history(where_clause: str, params: tuple[Any, ...], limit_i: int) -> list[dict]: + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + f""" + SELECT + {_THREAT_EVENT_SELECT_COLUMNS} + FROM threat_events + WHERE {where_clause} + ORDER BY created_at DESC, id DESC + LIMIT ? + """, + tuple(params) + (limit_i,), + ) + return [dict(r) for r in cursor.fetchall()] + + def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict[str, bool]: """记录登录环境信息,返回是否新设备/新IP。""" user_id = int(user_id) @@ -36,7 +105,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict SET last_seen = ?, last_ip = ? WHERE id = ? """, - (now_str, ip_text, row["id"] if isinstance(row, dict) else row[0]), + (now_str, ip_text, _row_value(row, "id", 0)), ) else: cursor.execute( @@ -61,7 +130,7 @@ def record_login_context(user_id: int, ip_address: str, user_agent: str) -> Dict SET last_seen = ? WHERE id = ? """, - (now_str, row["id"] if isinstance(row, dict) else row[0]), + (now_str, _row_value(row, "id", 0)), ) else: cursor.execute( @@ -166,15 +235,8 @@ def _build_threat_events_where_clause(filters: Optional[dict]) -> tuple[str, lis def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = None) -> dict: """分页获取威胁事件。""" - try: - page_i = max(1, int(page)) - except Exception: - page_i = 1 - try: - per_page_i = int(per_page) - except Exception: - per_page_i = 20 - per_page_i = max(1, min(200, per_page_i)) + page_i = _normalize_page(page) + per_page_i = _normalize_per_page(per_page, default=20) where_sql, params = _build_threat_events_where_clause(filters) offset = (page_i - 1) * per_page_i @@ -188,19 +250,7 @@ def get_threat_events_list(page: int, per_page: int, filters: Optional[dict] = N cursor.execute( f""" SELECT - id, - threat_type, - score, - rule, - field_name, - matched, - value_preview, - ip, - user_id, - request_method, - request_path, - user_agent, - created_at + {_THREAT_EVENT_SELECT_COLUMNS} FROM threat_events {where_sql} ORDER BY created_at DESC, id DESC @@ -218,75 +268,20 @@ def get_ip_threat_history(ip: str, limit: int = 50) -> list[dict]: ip_text = str(ip or "").strip()[:64] if not ip_text: return [] - try: - limit_i = max(1, min(200, int(limit))) - except Exception: - limit_i = 50 - with db_pool.get_db() as conn: - cursor = conn.cursor() - cursor.execute( - """ - SELECT - id, - threat_type, - score, - rule, - field_name, - matched, - value_preview, - ip, - user_id, - request_method, - request_path, - user_agent, - created_at - FROM threat_events - WHERE ip = ? - ORDER BY created_at DESC, id DESC - LIMIT ? - """, - (ip_text, limit_i), - ) - return [dict(r) for r in cursor.fetchall()] + limit_i = _normalize_limit(limit, default=50) + return _fetch_threat_events_history("ip = ?", (ip_text,), limit_i) def get_user_threat_history(user_id: int, limit: int = 50) -> list[dict]: """获取用户的威胁历史(最近limit条)。""" if user_id is None: return [] + try: user_id_int = int(user_id) except Exception: return [] - try: - limit_i = max(1, min(200, int(limit))) - except Exception: - limit_i = 50 - with db_pool.get_db() as conn: - cursor = conn.cursor() - cursor.execute( - """ - SELECT - id, - threat_type, - score, - rule, - field_name, - matched, - value_preview, - ip, - user_id, - request_method, - request_path, - user_agent, - created_at - FROM threat_events - WHERE user_id = ? - ORDER BY created_at DESC, id DESC - LIMIT ? - """, - (user_id_int, limit_i), - ) - return [dict(r) for r in cursor.fetchall()] + limit_i = _normalize_limit(limit, default=50) + return _fetch_threat_events_history("user_id = ?", (user_id_int,), limit_i) diff --git a/db/tasks.py b/db/tasks.py index 8b761cf..f28e280 100644 --- a/db/tasks.py +++ b/db/tasks.py @@ -2,12 +2,135 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from datetime import datetime - -import pytz +from datetime import datetime, timedelta import db_pool -from db.utils import sanitize_sql_like_pattern +from db.utils import get_cst_now, get_cst_now_str, sanitize_sql_like_pattern + +_TASK_STATS_SELECT_SQL = """ + SELECT + COUNT(*) as total_tasks, + SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks, + SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks, + SUM(total_items) as total_items, + SUM(total_attachments) as total_attachments + FROM task_logs +""" + +_USER_RUN_STATS_SELECT_SQL = """ + SELECT + SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed, + SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed, + SUM(total_items) as total_items, + SUM(total_attachments) as total_attachments + FROM task_logs +""" + + +def _build_day_bounds(date_filter: str) -> tuple[str | None, str | None]: + """将 YYYY-MM-DD 转换为 [day_start, day_end) 区间。""" + try: + day_start = datetime.strptime(str(date_filter), "%Y-%m-%d") + except Exception: + return None, None + + day_end = day_start + timedelta(days=1) + return day_start.strftime("%Y-%m-%d %H:%M:%S"), day_end.strftime("%Y-%m-%d %H:%M:%S") + + +def _normalize_int(value, default: int, *, minimum: int | None = None) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + if minimum is not None and parsed < minimum: + return minimum + return parsed + + +def _stat_value(row, key: str) -> int: + try: + value = row[key] if row else 0 + except Exception: + value = 0 + return int(value or 0) + + +def _build_task_logs_where_sql( + *, + date_filter=None, + status_filter=None, + source_filter=None, + user_id_filter=None, + account_filter=None, +) -> tuple[str, list]: + where_clauses = ["1=1"] + params = [] + + if date_filter: + day_start, day_end = _build_day_bounds(date_filter) + if day_start and day_end: + where_clauses.append("tl.created_at >= ? AND tl.created_at < ?") + params.extend([day_start, day_end]) + else: + where_clauses.append("date(tl.created_at) = ?") + params.append(date_filter) + + if status_filter: + where_clauses.append("tl.status = ?") + params.append(status_filter) + + if source_filter: + source_filter = str(source_filter or "").strip() + if source_filter == "user_scheduled": + where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'") + params.append("user_scheduled:%") + elif source_filter.endswith("*"): + prefix = source_filter[:-1] + safe_prefix = sanitize_sql_like_pattern(prefix) + where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'") + params.append(f"{safe_prefix}%") + else: + where_clauses.append("tl.source = ?") + params.append(source_filter) + + if user_id_filter: + where_clauses.append("tl.user_id = ?") + params.append(user_id_filter) + + if account_filter: + safe_filter = sanitize_sql_like_pattern(account_filter) + where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'") + params.append(f"%{safe_filter}%") + + return " AND ".join(where_clauses), params + + +def _fetch_task_stats_row(cursor, *, where_clause: str = "", params: tuple | list = ()) -> dict: + sql = _TASK_STATS_SELECT_SQL + if where_clause: + sql = f"{sql}\nWHERE {where_clause}" + cursor.execute(sql, params) + row = cursor.fetchone() + return { + "total_tasks": _stat_value(row, "total_tasks"), + "success_tasks": _stat_value(row, "success_tasks"), + "failed_tasks": _stat_value(row, "failed_tasks"), + "total_items": _stat_value(row, "total_items"), + "total_attachments": _stat_value(row, "total_attachments"), + } + + +def _fetch_user_run_stats_row(cursor, *, where_clause: str, params: tuple | list) -> dict: + sql = f"{_USER_RUN_STATS_SELECT_SQL}\nWHERE {where_clause}" + cursor.execute(sql, params) + row = cursor.fetchone() + return { + "completed": _stat_value(row, "completed"), + "failed": _stat_value(row, "failed"), + "total_items": _stat_value(row, "total_items"), + "total_attachments": _stat_value(row, "total_attachments"), + } def create_task_log( @@ -25,8 +148,6 @@ def create_task_log( """创建任务日志记录""" with db_pool.get_db() as conn: cursor = conn.cursor() - cst_tz = pytz.timezone("Asia/Shanghai") - cst_time = datetime.now(cst_tz).strftime("%Y-%m-%d %H:%M:%S") cursor.execute( """ @@ -45,7 +166,7 @@ def create_task_log( total_attachments, error_message, duration, - cst_time, + get_cst_now_str(), source, ), ) @@ -64,54 +185,27 @@ def get_task_logs( account_filter=None, ): """获取任务日志列表(支持分页和多种筛选)""" + limit = _normalize_int(limit, 100, minimum=1) + offset = _normalize_int(offset, 0, minimum=0) + with db_pool.get_db() as conn: cursor = conn.cursor() - where_clauses = ["1=1"] - params = [] - - if date_filter: - where_clauses.append("date(tl.created_at) = ?") - params.append(date_filter) - - if status_filter: - where_clauses.append("tl.status = ?") - params.append(status_filter) - - if source_filter: - source_filter = str(source_filter or "").strip() - # 兼容“虚拟来源”:用于筛选 user_scheduled:batch_xxx 这类动态值 - if source_filter == "user_scheduled": - where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'") - params.append("user_scheduled:%") - elif source_filter.endswith("*"): - prefix = source_filter[:-1] - safe_prefix = sanitize_sql_like_pattern(prefix) - where_clauses.append("tl.source LIKE ? ESCAPE '\\\\'") - params.append(f"{safe_prefix}%") - else: - where_clauses.append("tl.source = ?") - params.append(source_filter) - - if user_id_filter: - where_clauses.append("tl.user_id = ?") - params.append(user_id_filter) - - if account_filter: - safe_filter = sanitize_sql_like_pattern(account_filter) - where_clauses.append("tl.username LIKE ? ESCAPE '\\\\'") - params.append(f"%{safe_filter}%") - - where_sql = " AND ".join(where_clauses) + where_sql, params = _build_task_logs_where_sql( + date_filter=date_filter, + status_filter=status_filter, + source_filter=source_filter, + user_id_filter=user_id_filter, + account_filter=account_filter, + ) count_sql = f""" SELECT COUNT(*) as total FROM task_logs tl - LEFT JOIN users u ON tl.user_id = u.id WHERE {where_sql} """ cursor.execute(count_sql, params) - total = cursor.fetchone()["total"] + total = _stat_value(cursor.fetchone(), "total") data_sql = f""" SELECT @@ -123,9 +217,10 @@ def get_task_logs( ORDER BY tl.created_at DESC LIMIT ? OFFSET ? """ - params.extend([limit, offset]) + data_params = list(params) + data_params.extend([limit, offset]) - cursor.execute(data_sql, params) + cursor.execute(data_sql, data_params) logs = [dict(row) for row in cursor.fetchall()] return {"logs": logs, "total": total} @@ -133,61 +228,39 @@ def get_task_logs( def get_task_stats(date_filter=None): """获取任务统计信息""" + if date_filter is None: + date_filter = get_cst_now().strftime("%Y-%m-%d") + + day_start, day_end = _build_day_bounds(date_filter) + with db_pool.get_db() as conn: cursor = conn.cursor() - cst_tz = pytz.timezone("Asia/Shanghai") - if date_filter is None: - date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d") + if day_start and day_end: + today_stats = _fetch_task_stats_row( + cursor, + where_clause="created_at >= ? AND created_at < ?", + params=(day_start, day_end), + ) + else: + today_stats = _fetch_task_stats_row( + cursor, + where_clause="date(created_at) = ?", + params=(date_filter,), + ) - cursor.execute( - """ - SELECT - COUNT(*) as total_tasks, - SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks, - SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks, - SUM(total_items) as total_items, - SUM(total_attachments) as total_attachments - FROM task_logs - WHERE date(created_at) = ? - """, - (date_filter,), - ) - today_stats = cursor.fetchone() + total_stats = _fetch_task_stats_row(cursor) - cursor.execute( - """ - SELECT - COUNT(*) as total_tasks, - SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as success_tasks, - SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed_tasks, - SUM(total_items) as total_items, - SUM(total_attachments) as total_attachments - FROM task_logs - """ - ) - total_stats = cursor.fetchone() - - return { - "today": { - "total_tasks": today_stats["total_tasks"] or 0, - "success_tasks": today_stats["success_tasks"] or 0, - "failed_tasks": today_stats["failed_tasks"] or 0, - "total_items": today_stats["total_items"] or 0, - "total_attachments": today_stats["total_attachments"] or 0, - }, - "total": { - "total_tasks": total_stats["total_tasks"] or 0, - "success_tasks": total_stats["success_tasks"] or 0, - "failed_tasks": total_stats["failed_tasks"] or 0, - "total_items": total_stats["total_items"] or 0, - "total_attachments": total_stats["total_attachments"] or 0, - }, - } + return {"today": today_stats, "total": total_stats} def delete_old_task_logs(days=30, batch_size=1000): """删除N天前的任务日志(分批删除,避免长时间锁表)""" + days = _normalize_int(days, 30, minimum=0) + batch_size = _normalize_int(batch_size, 1000, minimum=1) + + cutoff = (get_cst_now() - timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S") + total_deleted = 0 while True: with db_pool.get_db() as conn: @@ -197,16 +270,16 @@ def delete_old_task_logs(days=30, batch_size=1000): DELETE FROM task_logs WHERE rowid IN ( SELECT rowid FROM task_logs - WHERE created_at < datetime('now', 'localtime', '-' || ? || ' days') + WHERE created_at < ? LIMIT ? ) """, - (days, batch_size), + (cutoff, batch_size), ) deleted = cursor.rowcount conn.commit() - if deleted == 0: + if deleted <= 0: break total_deleted += deleted @@ -215,31 +288,23 @@ def delete_old_task_logs(days=30, batch_size=1000): def get_user_run_stats(user_id, date_filter=None): """获取用户的运行统计信息""" + if date_filter is None: + date_filter = get_cst_now().strftime("%Y-%m-%d") + + day_start, day_end = _build_day_bounds(date_filter) + with db_pool.get_db() as conn: - cst_tz = pytz.timezone("Asia/Shanghai") cursor = conn.cursor() - if date_filter is None: - date_filter = datetime.now(cst_tz).strftime("%Y-%m-%d") + if day_start and day_end: + return _fetch_user_run_stats_row( + cursor, + where_clause="user_id = ? AND created_at >= ? AND created_at < ?", + params=(user_id, day_start, day_end), + ) - cursor.execute( - """ - SELECT - SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) as completed, - SUM(CASE WHEN status = 'failed' THEN 1 ELSE 0 END) as failed, - SUM(total_items) as total_items, - SUM(total_attachments) as total_attachments - FROM task_logs - WHERE user_id = ? AND date(created_at) = ? - """, - (user_id, date_filter), + return _fetch_user_run_stats_row( + cursor, + where_clause="user_id = ? AND date(created_at) = ?", + params=(user_id, date_filter), ) - - stats = cursor.fetchone() - - return { - "completed": stats["completed"] or 0, - "failed": stats["failed"] or 0, - "total_items": stats["total_items"] or 0, - "total_attachments": stats["total_attachments"] or 0, - } diff --git a/db/users.py b/db/users.py index 42423a5..c5cf5d1 100644 --- a/db/users.py +++ b/db/users.py @@ -16,8 +16,41 @@ from password_utils import ( verify_password_bcrypt, verify_password_sha256, ) + logger = get_logger(__name__) +_CST_TZ = pytz.timezone("Asia/Shanghai") +_PERMANENT_VIP_EXPIRE = "2099-12-31 23:59:59" + + +def _row_to_dict(row): + return dict(row) if row else None + + +def _get_user_by_field(field_name: str, field_value): + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute(f"SELECT * FROM users WHERE {field_name} = ?", (field_value,)) + return _row_to_dict(cursor.fetchone()) + + +def _parse_cst_datetime(datetime_str: str | None): + if not datetime_str: + return None + try: + naive_dt = datetime.strptime(str(datetime_str), "%Y-%m-%d %H:%M:%S") + return _CST_TZ.localize(naive_dt) + except Exception: + return None + + +def _format_vip_expire(days: int, *, base_dt: datetime | None = None) -> str: + if int(days) == 999999: + return _PERMANENT_VIP_EXPIRE + if base_dt is None: + base_dt = datetime.now(_CST_TZ) + return (base_dt + timedelta(days=int(days))).strftime("%Y-%m-%d %H:%M:%S") + def get_vip_config(): """获取VIP配置""" @@ -32,13 +65,12 @@ def set_default_vip_days(days): """设置默认VIP天数""" with db_pool.get_db() as conn: cursor = conn.cursor() - cst_time = get_cst_now_str() cursor.execute( """ INSERT OR REPLACE INTO vip_config (id, default_vip_days, updated_at) VALUES (1, ?, ?) """, - (days, cst_time), + (days, get_cst_now_str()), ) conn.commit() return True @@ -47,14 +79,8 @@ def set_default_vip_days(days): def set_user_vip(user_id, days): """设置用户VIP - days: 7=一周, 30=一个月, 365=一年, 999999=永久""" with db_pool.get_db() as conn: - cst_tz = pytz.timezone("Asia/Shanghai") cursor = conn.cursor() - - if days == 999999: - expire_time = "2099-12-31 23:59:59" - else: - expire_time = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S") - + expire_time = _format_vip_expire(days) cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (expire_time, user_id)) conn.commit() return cursor.rowcount > 0 @@ -63,29 +89,26 @@ def set_user_vip(user_id, days): def extend_user_vip(user_id, days): """延长用户VIP时间""" user = get_user_by_id(user_id) - cst_tz = pytz.timezone("Asia/Shanghai") - if not user: return False + current_expire = user.get("vip_expire_time") + now_dt = datetime.now(_CST_TZ) + + if current_expire and current_expire != _PERMANENT_VIP_EXPIRE: + expire_time = _parse_cst_datetime(current_expire) + if expire_time is not None: + if expire_time < now_dt: + expire_time = now_dt + new_expire = _format_vip_expire(days, base_dt=expire_time) + else: + logger.warning("解析VIP过期时间失败,使用当前时间") + new_expire = _format_vip_expire(days, base_dt=now_dt) + else: + new_expire = _format_vip_expire(days, base_dt=now_dt) + with db_pool.get_db() as conn: cursor = conn.cursor() - current_expire = user.get("vip_expire_time") - - if current_expire and current_expire != "2099-12-31 23:59:59": - try: - expire_time_naive = datetime.strptime(current_expire, "%Y-%m-%d %H:%M:%S") - expire_time = cst_tz.localize(expire_time_naive) - now = datetime.now(cst_tz) - if expire_time < now: - expire_time = now - new_expire = (expire_time + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S") - except (ValueError, AttributeError) as e: - logger.warning(f"解析VIP过期时间失败: {e}, 使用当前时间") - new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S") - else: - new_expire = (datetime.now(cst_tz) + timedelta(days=days)).strftime("%Y-%m-%d %H:%M:%S") - cursor.execute("UPDATE users SET vip_expire_time = ? WHERE id = ?", (new_expire, user_id)) conn.commit() return cursor.rowcount > 0 @@ -105,45 +128,49 @@ def is_user_vip(user_id): 注意:数据库中存储的时间统一使用CST(Asia/Shanghai)时区 """ - cst_tz = pytz.timezone("Asia/Shanghai") user = get_user_by_id(user_id) - - if not user or not user.get("vip_expire_time"): + if not user: return False - try: - expire_time_naive = datetime.strptime(user["vip_expire_time"], "%Y-%m-%d %H:%M:%S") - expire_time = cst_tz.localize(expire_time_naive) - now = datetime.now(cst_tz) - return now < expire_time - except (ValueError, AttributeError) as e: - logger.warning(f"检查VIP状态失败 (user_id={user_id}): {e}") + vip_expire_time = user.get("vip_expire_time") + if not vip_expire_time: return False + expire_time = _parse_cst_datetime(vip_expire_time) + if expire_time is None: + logger.warning(f"检查VIP状态失败 (user_id={user_id}): 无法解析时间") + return False + + return datetime.now(_CST_TZ) < expire_time + def get_user_vip_info(user_id): """获取用户VIP信息""" - cst_tz = pytz.timezone("Asia/Shanghai") user = get_user_by_id(user_id) - if not user: return {"is_vip": False, "expire_time": None, "days_left": 0, "username": ""} vip_expire_time = user.get("vip_expire_time") + username = user.get("username", "") + if not vip_expire_time: - return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")} + return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username} - try: - expire_time_naive = datetime.strptime(vip_expire_time, "%Y-%m-%d %H:%M:%S") - expire_time = cst_tz.localize(expire_time_naive) - now = datetime.now(cst_tz) - is_vip = now < expire_time - days_left = (expire_time - now).days if is_vip else 0 + expire_time = _parse_cst_datetime(vip_expire_time) + if expire_time is None: + logger.warning("VIP信息获取错误: 无法解析过期时间") + return {"is_vip": False, "expire_time": None, "days_left": 0, "username": username} - return {"username": user.get("username", ""), "is_vip": is_vip, "expire_time": vip_expire_time, "days_left": max(0, days_left)} - except Exception as e: - logger.warning(f"VIP信息获取错误: {e}") - return {"is_vip": False, "expire_time": None, "days_left": 0, "username": user.get("username", "")} + now_dt = datetime.now(_CST_TZ) + is_vip = now_dt < expire_time + days_left = (expire_time - now_dt).days if is_vip else 0 + + return { + "username": username, + "is_vip": is_vip, + "expire_time": vip_expire_time, + "days_left": max(0, days_left), + } # ==================== 用户相关 ==================== @@ -151,8 +178,6 @@ def get_user_vip_info(user_id): def create_user(username, password, email=""): """创建新用户(默认直接通过,赠送默认VIP)""" - cst_tz = pytz.timezone("Asia/Shanghai") - with db_pool.get_db() as conn: cursor = conn.cursor() password_hash = hash_password_bcrypt(password) @@ -160,12 +185,8 @@ def create_user(username, password, email=""): default_vip_days = get_vip_config()["default_vip_days"] vip_expire_time = None - - if default_vip_days > 0: - if default_vip_days == 999999: - vip_expire_time = "2099-12-31 23:59:59" - else: - vip_expire_time = (datetime.now(cst_tz) + timedelta(days=default_vip_days)).strftime("%Y-%m-%d %H:%M:%S") + if int(default_vip_days or 0) > 0: + vip_expire_time = _format_vip_expire(int(default_vip_days)) try: cursor.execute( @@ -210,28 +231,28 @@ def verify_user(username, password): def get_user_by_id(user_id): """根据ID获取用户""" - with db_pool.get_db() as conn: - cursor = conn.cursor() - cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) - user = cursor.fetchone() - return dict(user) if user else None + return _get_user_by_field("id", user_id) def get_user_kdocs_settings(user_id): """获取用户的金山文档配置""" - user = get_user_by_id(user_id) - if not user: - return None - return { - "kdocs_unit": user.get("kdocs_unit") or "", - "kdocs_auto_upload": 1 if user.get("kdocs_auto_upload") else 0, - } + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute("SELECT kdocs_unit, kdocs_auto_upload FROM users WHERE id = ?", (user_id,)) + row = cursor.fetchone() + if not row: + return None + return { + "kdocs_unit": (row["kdocs_unit"] or "") if isinstance(row, dict) else (row[0] or ""), + "kdocs_auto_upload": 1 if ((row["kdocs_auto_upload"] if isinstance(row, dict) else row[1]) or 0) else 0, + } def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=None) -> bool: """更新用户的金山文档配置""" updates = [] params = [] + if kdocs_unit is not None: updates.append("kdocs_unit = ?") params.append(kdocs_unit) @@ -252,11 +273,7 @@ def update_user_kdocs_settings(user_id, *, kdocs_unit=None, kdocs_auto_upload=No def get_user_by_username(username): """根据用户名获取用户""" - with db_pool.get_db() as conn: - cursor = conn.cursor() - cursor.execute("SELECT * FROM users WHERE username = ?", (username,)) - user = cursor.fetchone() - return dict(user) if user else None + return _get_user_by_field("username", username) def get_all_users(): @@ -279,14 +296,13 @@ def approve_user(user_id): """审核通过用户""" with db_pool.get_db() as conn: cursor = conn.cursor() - cst_time = get_cst_now_str() cursor.execute( """ UPDATE users SET status = 'approved', approved_at = ? WHERE id = ? """, - (cst_time, user_id), + (get_cst_now_str(), user_id), ) conn.commit() return cursor.rowcount > 0 @@ -315,5 +331,5 @@ def get_user_stats(user_id): with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT COUNT(*) as count FROM accounts WHERE user_id = ?", (user_id,)) - account_count = cursor.fetchone()["count"] - return {"account_count": account_count} + row = cursor.fetchone() + return {"account_count": int((row["count"] if row else 0) or 0)} diff --git a/db_pool.py b/db_pool.py index 3ce43fb..b5f4fe9 100755 --- a/db_pool.py +++ b/db_pool.py @@ -7,8 +7,12 @@ import sqlite3 import threading -from queue import Queue, Empty -import time +from queue import Empty, Full, Queue + +from app_logger import get_logger + + +logger = get_logger("database") class ConnectionPool: @@ -44,12 +48,55 @@ class ConnectionPool: """创建新的数据库连接""" conn = sqlite3.connect(self.database, check_same_thread=False) conn.row_factory = sqlite3.Row + # 启用外键约束,确保 ON DELETE CASCADE 等约束生效 + conn.execute("PRAGMA foreign_keys=ON") # 设置WAL模式提高并发性能 conn.execute("PRAGMA journal_mode=WAL") + # 在WAL模式下使用NORMAL同步,兼顾性能与可靠性 + conn.execute("PRAGMA synchronous=NORMAL") # 设置合理的超时时间 conn.execute("PRAGMA busy_timeout=5000") return conn + def _close_connection(self, conn) -> None: + if conn is None: + return + try: + conn.close() + except Exception as e: + logger.warning(f"关闭连接失败: {e}") + + def _is_connection_healthy(self, conn) -> bool: + if conn is None: + return False + try: + conn.rollback() + conn.execute("SELECT 1") + return True + except sqlite3.Error as e: + logger.warning(f"连接健康检查失败(数据库错误): {e}") + except Exception as e: + logger.warning(f"连接健康检查失败(未知错误): {e}") + return False + + def _replenish_pool_if_needed(self) -> None: + with self._lock: + if self._pool.qsize() >= self.pool_size: + return + + new_conn = None + try: + new_conn = self._create_connection() + self._pool.put(new_conn, block=False) + self._created_connections += 1 + except Full: + if new_conn: + self._close_connection(new_conn) + except Exception as e: + if new_conn: + self._close_connection(new_conn) + logger.warning(f"重建连接失败: {e}") + def get_connection(self): """ 从连接池获取连接 @@ -70,66 +117,20 @@ class ConnectionPool: Args: conn: 要归还的连接 """ - import sqlite3 - from queue import Full - if conn is None: return - connection_healthy = False - try: - # 回滚任何未提交的事务 - conn.rollback() - # 安全修复:验证连接是否健康,防止损坏的连接污染连接池 - conn.execute("SELECT 1") - connection_healthy = True - except sqlite3.Error as e: - # 数据库相关错误,连接可能损坏 - print(f"连接健康检查失败(数据库错误): {e}") - except Exception as e: - print(f"连接健康检查失败(未知错误): {e}") - - if connection_healthy: + if self._is_connection_healthy(conn): try: self._pool.put(conn, block=False) - return # 成功归还 + return except Full: - # 队列已满(不应该发生,但处理它) - print(f"警告: 连接池已满,关闭多余连接") - connection_healthy = False # 标记为需要关闭 + logger.warning("连接池已满,关闭多余连接") + self._close_connection(conn) + return - # 连接不健康或队列已满,关闭它 - try: - conn.close() - except Exception as close_error: - print(f"关闭连接失败: {close_error}") - - # 如果连接不健康,尝试创建新连接补充池 - if not connection_healthy: - with self._lock: - # 双重检查:确保池确实需要补充 - if self._pool.qsize() < self.pool_size: - new_conn = None - try: - new_conn = self._create_connection() - self._pool.put(new_conn, block=False) - # 只有成功放入池后才增加计数 - self._created_connections += 1 - except Full: - # 在获取锁期间池被填满了,关闭新建的连接 - if new_conn: - try: - new_conn.close() - except Exception: - pass - except Exception as create_error: - # 创建连接失败,确保关闭已创建的连接 - if new_conn: - try: - new_conn.close() - except Exception: - pass - print(f"重建连接失败: {create_error}") + self._close_connection(conn) + self._replenish_pool_if_needed() def close_all(self): """关闭所有连接""" @@ -138,7 +139,7 @@ class ConnectionPool: conn = self._pool.get(block=False) conn.close() except Exception as e: - print(f"关闭连接失败: {e}") + logger.warning(f"关闭连接失败: {e}") def get_stats(self): """获取连接池统计信息""" @@ -175,14 +176,14 @@ class PooledConnection: if exc_type is not None: # 发生异常,回滚事务 self._conn.rollback() - print(f"数据库事务已回滚: {exc_type.__name__}") + logger.warning(f"数据库事务已回滚: {exc_type.__name__}") # 注意: 不自动commit,要求用户显式调用conn.commit() if self._cursor: self._cursor.close() self._cursor = None except Exception as e: - print(f"关闭游标失败: {e}") + logger.warning(f"关闭游标失败: {e}") finally: # 归还连接 self._pool.return_connection(self._conn) @@ -254,7 +255,7 @@ def init_pool(database, pool_size=5): with _pool_lock: if _pool is None: _pool = ConnectionPool(database, pool_size) - print(f"[OK] 数据库连接池已初始化 (大小: {pool_size})") + logger.info(f"[OK] 数据库连接池已初始化 (大小: {pool_size})") def get_db(): diff --git a/email_service.py b/email_service.py index c02629c..32fcdb8 100644 --- a/email_service.py +++ b/email_service.py @@ -34,15 +34,14 @@ def get_beijing_now_str(): from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart from email.mime.base import MIMEBase -from email.mime.image import MIMEImage from email import encoders from email.header import Header from email.utils import formataddr -from typing import Optional, List, Dict, Any, Callable +from typing import Optional, List, Dict, Any, Callable, Tuple from io import BytesIO import db_pool -from crypto_utils import encrypt_password, decrypt_password, is_encrypted +from crypto_utils import encrypt_password, decrypt_password from app_logger import get_logger logger = get_logger("email_service") @@ -102,6 +101,333 @@ QUEUE_MAX_SIZE = int(os.environ.get('EMAIL_QUEUE_MAX_SIZE', '100')) MAX_ATTACHMENT_SIZE = int(os.environ.get('EMAIL_MAX_ATTACHMENT_SIZE', str(10 * 1024 * 1024))) # 10MB +def _resolve_base_url(base_url: Optional[str] = None) -> str: + """解析系统基础URL,按参数 -> 邮件设置 -> 配置文件 -> 默认值顺序。""" + if base_url: + return str(base_url).strip() or 'http://localhost:51233' + + try: + settings = get_email_settings() + configured_url = (settings or {}).get('base_url', '') + if configured_url: + return configured_url + except Exception: + pass + + try: + from app_config import Config + configured_url = getattr(Config, 'BASE_URL', '') + if configured_url: + return configured_url + except Exception: + pass + + return 'http://localhost:51233' + + +def _load_email_template(template_filename: str, fallback_html: str) -> str: + """读取邮件模板,读取失败时回退到内置模板。""" + template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', template_filename) + try: + with open(template_path, 'r', encoding='utf-8') as f: + return f.read() + except FileNotFoundError: + return fallback_html + + +def _render_template(template: str, values: Dict[str, Any]) -> str: + """替换模板变量,兼容 {{ key }} 和 {{key}} 两种写法。""" + rendered = template + for key, value in values.items(): + value_text = '' if value is None else str(value) + rendered = rendered.replace(f'{{{{ {key} }}}}', value_text) + rendered = rendered.replace(f'{{{{{key}}}}}', value_text) + return rendered + + +def _mark_unused_tokens_as_used(user_id: int, token_type: str) -> None: + """使指定类型的旧 token 失效。""" + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE email_tokens SET used = 1 + WHERE user_id = ? AND token_type = ? AND used = 0 + """, + (user_id, token_type), + ) + conn.commit() + + +def _send_token_email( + *, + email: str, + username: str, + user_id: int, + base_url: Optional[str], + token_type: str, + rate_limit_error: str, + subject: str, + template_filename: str, + fallback_html: str, + url_path: str, + url_template_key: str, + text_template: str, + invalidate_existing_tokens: bool = False, +) -> Dict[str, Any]: + """通用 token 邮件发送流程(注册/重置/绑定)。""" + if not check_rate_limit(email, token_type): + return {'success': False, 'error': rate_limit_error, 'token': None} + + if invalidate_existing_tokens: + _mark_unused_tokens_as_used(user_id, token_type) + + token = generate_email_token(email, token_type, user_id) + resolved_base_url = _resolve_base_url(base_url) + normalized_path = url_path if url_path.startswith('/') else f'/{url_path}' + action_url = f"{resolved_base_url.rstrip('/')}{normalized_path.format(token=token)}" + + html_template = _load_email_template(template_filename, fallback_html) + html_body = _render_template( + html_template, + { + 'username': username, + url_template_key: action_url, + }, + ) + + text_body = text_template.format(username=username, action_url=action_url) + + result = send_email( + to_email=email, + subject=subject, + body=text_body, + html_body=html_body, + email_type=token_type, + user_id=user_id, + ) + if result.get('success'): + return {'success': True, 'error': '', 'token': token} + return {'success': False, 'error': result.get('error', '发送失败'), 'token': None} + + +def _task_notify_precheck(email: str, *, require_screenshots: bool = False, screenshots: Optional[List] = None) -> str: + """任务通知发送前置检查,返回空字符串表示通过。""" + settings = get_email_settings() + if not settings.get('enabled', False): + return '邮件功能未启用' + if not settings.get('task_notify_enabled', False): + return '任务通知功能未启用' + if not email: + return '用户未设置邮箱' + if require_screenshots and not screenshots: + return '没有截图需要发送' + return '' + + +def _load_screenshot_data(screenshot_path: Optional[str], log_callback: Optional[Callable] = None) -> Tuple[Optional[bytes], Optional[str]]: + """读取截图数据,失败时仅记录日志并继续。""" + if not screenshot_path or not os.path.exists(screenshot_path): + return None, None + try: + with open(screenshot_path, 'rb') as f: + return f.read(), os.path.basename(screenshot_path) + except Exception as e: + if log_callback: + log_callback(f"[邮件] 读取截图文件失败: {e}") + return None, None + + +def _render_task_complete_bodies( + *, + html_template: str, + username: str, + account_name: str, + browse_type: str, + total_items: int, + total_attachments: int, + complete_time: str, + batch_info: str, + screenshot_text: str, +) -> Tuple[str, str]: + """渲染任务完成通知邮件内容。""" + html_body = _render_template( + html_template, + { + 'username': username, + 'account_name': account_name, + 'browse_type': browse_type, + 'total_items': total_items, + 'total_attachments': total_attachments, + 'complete_time': complete_time, + 'batch_info': batch_info, + }, + ) + + text_body = f""" +您好,{username}! + +您的浏览任务已完成。 + +账号:{account_name} +浏览类型:{browse_type} +浏览条目:{total_items} 条 +附件数量:{total_attachments} 个 +完成时间:{complete_time} + +{screenshot_text} +""" + return html_body, text_body + + +def _build_batch_accounts_html_rows(screenshots: List[Dict[str, Any]]) -> str: + """构建批次任务邮件中的账号详情表格。""" + rows = [] + for item in screenshots: + rows.append( + f""" + + {item.get('account_name', '未知')} + {item.get('items', 0)} + {item.get('attachments', 0)} + + """ + ) + return ''.join(rows) + + +def _collect_existing_screenshot_paths(screenshots: List[Dict[str, Any]]) -> List[Tuple[str, str]]: + """收集存在的截图路径与压缩包内文件名。""" + screenshot_paths: List[Tuple[str, str]] = [] + for item in screenshots: + path = item.get('path') + if path and os.path.exists(path): + arcname = f"{item.get('account_name', 'screenshot')}_{os.path.basename(path)}" + screenshot_paths.append((path, arcname)) + return screenshot_paths + + +def _build_zip_attachment_from_paths(screenshot_paths: List[Tuple[str, str]]) -> Tuple[Optional[bytes], Optional[str], str]: + """从路径列表构建 ZIP 附件,返回 (zip_data, zip_filename, note)。""" + if not screenshot_paths: + return None, None, '本次无可用截图文件(可能截图失败或未启用截图)。' + + import tempfile + + zip_path = None + try: + with tempfile.NamedTemporaryFile(prefix='screenshots_', suffix='.zip', delete=False) as tmp: + zip_path = tmp.name + + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: + for file_path, arcname in screenshot_paths: + try: + zf.write(file_path, arcname=arcname) + except Exception as e: + logger.warning(f"[邮件] 写入ZIP失败: {e}") + + zip_size = os.path.getsize(zip_path) if zip_path and os.path.exists(zip_path) else 0 + if zip_size <= 0: + return None, None, '本次无可用截图文件(可能截图失败或文件不存在)。' + if zip_size > MAX_ATTACHMENT_SIZE: + return None, None, f'截图打包文件过大({zip_size} bytes),本次不附加附件。' + + with open(zip_path, 'rb') as f: + zip_data = f.read() + zip_filename = f"screenshots_{datetime.now(BEIJING_TZ).strftime('%Y%m%d_%H%M%S')}.zip" + return zip_data, zip_filename, '截图已打包为ZIP附件,请查收。' + except Exception as e: + logger.warning(f"[邮件] 打包截图失败: {e}") + return None, None, '截图打包失败,本次不附加附件。' + finally: + if zip_path and os.path.exists(zip_path): + try: + os.remove(zip_path) + except Exception: + pass + + +def _reset_smtp_daily_quota_if_needed(cursor, today: Optional[str] = None) -> int: + """在日期切换后重置 SMTP 每日计数,返回更新记录数。""" + reset_day = today or get_beijing_today() + cursor.execute( + """ + UPDATE smtp_configs + SET daily_sent = 0, daily_reset_date = ? + WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = '' + """, + (reset_day, reset_day), + ) + return int(cursor.rowcount or 0) + + +def _build_smtp_admin_config(row, include_password: bool = False) -> Dict[str, Any]: + """将 smtp_configs 全字段查询行转换为管理接口配置字典。""" + encrypted_password = row[8] + password_text = '' + if encrypted_password: + if include_password: + password_text = decrypt_password(encrypted_password) + else: + password_text = '******' + + config = { + 'id': row[0], + 'name': row[1], + 'enabled': bool(row[2]), + 'is_primary': bool(row[3]), + 'priority': row[4], + 'host': row[5], + 'port': row[6], + 'username': row[7], + 'password': password_text, + 'has_password': bool(encrypted_password), + 'use_ssl': bool(row[9]), + 'use_tls': bool(row[10]), + 'sender_name': row[11], + 'sender_email': row[12], + 'daily_limit': row[13], + 'daily_sent': row[14], + 'daily_reset_date': row[15], + 'last_success_at': row[16], + 'last_error': row[17], + 'success_count': row[18], + 'fail_count': row[19], + 'created_at': row[20], + 'updated_at': row[21], + } + total = (config['success_count'] or 0) + (config['fail_count'] or 0) + config['success_rate'] = round((config['success_count'] or 0) / total * 100, 1) if total > 0 else 0 + return config + + +def _build_smtp_runtime_config(row) -> Dict[str, Any]: + """将 smtp_configs 运行时查询行转换为发送字典(含每日限额字段)。""" + return { + 'id': row[0], + 'name': row[1], + 'host': row[2], + 'port': row[3], + 'username': row[4], + 'password': decrypt_password(row[5]) if row[5] else '', + 'use_ssl': bool(row[6]), + 'use_tls': bool(row[7]), + 'sender_name': row[8], + 'sender_email': row[9], + 'daily_limit': row[10] or 0, + 'daily_sent': row[11] or 0, + 'is_primary': bool(row[12]), + } + + +def _strip_smtp_quota_fields(runtime_config: Dict[str, Any]) -> Dict[str, Any]: + """移除内部限额字段,返回 send_email 使用的 SMTP 配置字典。""" + config = dict(runtime_config) + config.pop('daily_limit', None) + config.pop('daily_sent', None) + return config + + # ============ 数据库操作 ============ def init_email_tables(): @@ -317,12 +643,7 @@ def get_smtp_configs(include_password: bool = False) -> List[Dict[str, Any]]: with db_pool.get_db() as conn: cursor = conn.cursor() # 确保每天的配额在日期切换后能及时重置(即使当天没有触发邮件发送) - today = get_beijing_today() - cursor.execute(""" - UPDATE smtp_configs - SET daily_sent = 0, daily_reset_date = ? - WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = '' - """, (today, today)) + _reset_smtp_daily_quota_if_needed(cursor) conn.commit() cursor.execute(""" @@ -335,53 +656,14 @@ def get_smtp_configs(include_password: bool = False) -> List[Dict[str, Any]]: ORDER BY is_primary DESC, priority ASC, id ASC """) - configs = [] - for row in cursor.fetchall(): - config = { - 'id': row[0], - 'name': row[1], - 'enabled': bool(row[2]), - 'is_primary': bool(row[3]), - 'priority': row[4], - 'host': row[5], - 'port': row[6], - 'username': row[7], - 'password': '******' if row[8] and not include_password else (decrypt_password(row[8]) if include_password and row[8] else ''), - 'has_password': bool(row[8]), - 'use_ssl': bool(row[9]), - 'use_tls': bool(row[10]), - 'sender_name': row[11], - 'sender_email': row[12], - 'daily_limit': row[13], - 'daily_sent': row[14], - 'daily_reset_date': row[15], - 'last_success_at': row[16], - 'last_error': row[17], - 'success_count': row[18], - 'fail_count': row[19], - 'created_at': row[20], - 'updated_at': row[21] - } - - # 计算成功率 - total = config['success_count'] + config['fail_count'] - config['success_rate'] = round(config['success_count'] / total * 100, 1) if total > 0 else 0 - - configs.append(config) - - return configs + return [_build_smtp_admin_config(row, include_password=include_password) for row in cursor.fetchall()] def get_smtp_config(config_id: int, include_password: bool = False) -> Optional[Dict[str, Any]]: """获取单个SMTP配置""" with db_pool.get_db() as conn: cursor = conn.cursor() - today = get_beijing_today() - cursor.execute(""" - UPDATE smtp_configs - SET daily_sent = 0, daily_reset_date = ? - WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = '' - """, (today, today)) + _reset_smtp_daily_quota_if_needed(cursor) conn.commit() cursor.execute(""" @@ -397,31 +679,7 @@ def get_smtp_config(config_id: int, include_password: bool = False) -> Optional[ if not row: return None - return { - 'id': row[0], - 'name': row[1], - 'enabled': bool(row[2]), - 'is_primary': bool(row[3]), - 'priority': row[4], - 'host': row[5], - 'port': row[6], - 'username': row[7], - 'password': '******' if row[8] and not include_password else (decrypt_password(row[8]) if include_password and row[8] else ''), - 'has_password': bool(row[8]), - 'use_ssl': bool(row[9]), - 'use_tls': bool(row[10]), - 'sender_name': row[11], - 'sender_email': row[12], - 'daily_limit': row[13], - 'daily_sent': row[14], - 'daily_reset_date': row[15], - 'last_success_at': row[16], - 'last_error': row[17], - 'success_count': row[18], - 'fail_count': row[19], - 'created_at': row[20], - 'updated_at': row[21] - } + return _build_smtp_admin_config(row, include_password=include_password) def create_smtp_config(data: Dict[str, Any]) -> int: @@ -556,61 +814,50 @@ def clear_primary_smtp_config() -> bool: return True -def _get_available_smtp_config(failover: bool = True) -> Optional[Dict[str, Any]]: - """ - 获取可用的SMTP配置 - 优先级: 主配置 > 按priority排序的启用配置 - """ - today = get_beijing_today() +def _fetch_candidate_smtp_configs(cursor, *, exclude_ids: Optional[List[int]] = None) -> List[Dict[str, Any]]: + """获取可用 SMTP 候选配置(已过滤不可用/超额配置)。""" + excluded = [int(i) for i in (exclude_ids or []) if i is not None] + sql = """ + SELECT id, name, host, port, username, password, use_ssl, use_tls, + sender_name, sender_email, daily_limit, daily_sent, is_primary + FROM smtp_configs + WHERE enabled = 1 + """ + params: List[Any] = [] + if excluded: + placeholders = ",".join(["?" for _ in excluded]) + sql += f" AND id NOT IN ({placeholders})" + params.extend(excluded) + + sql += " ORDER BY is_primary DESC, priority ASC, id ASC" + cursor.execute(sql, params) + + candidates: List[Dict[str, Any]] = [] + for row in cursor.fetchall() or []: + runtime_config = _build_smtp_runtime_config(row) + daily_limit = int(runtime_config.get('daily_limit') or 0) + daily_sent = int(runtime_config.get('daily_sent') or 0) + if daily_limit > 0 and daily_sent >= daily_limit: + continue + candidates.append(_strip_smtp_quota_fields(runtime_config)) + + return candidates + + +def _get_smtp_candidates(*, exclude_ids: Optional[List[int]] = None) -> List[Dict[str, Any]]: + """读取并返回当前可发送的 SMTP 候选配置列表。""" with db_pool.get_db() as conn: cursor = conn.cursor() - - # 先重置过期的每日计数 - cursor.execute(""" - UPDATE smtp_configs - SET daily_sent = 0, daily_reset_date = ? - WHERE daily_reset_date != ? OR daily_reset_date IS NULL OR daily_reset_date = '' - """, (today, today)) + _reset_smtp_daily_quota_if_needed(cursor) conn.commit() + return _fetch_candidate_smtp_configs(cursor, exclude_ids=exclude_ids) - # 获取所有启用的配置,按优先级排序 - cursor.execute(""" - SELECT id, name, host, port, username, password, use_ssl, use_tls, - sender_name, sender_email, daily_limit, daily_sent, is_primary - FROM smtp_configs - WHERE enabled = 1 - ORDER BY is_primary DESC, priority ASC, id ASC - """) - configs = cursor.fetchall() - - for row in configs: - config_id, name, host, port, username, password, use_ssl, use_tls, \ - sender_name, sender_email, daily_limit, daily_sent, is_primary = row - - # 检查每日限额 - if daily_limit > 0 and daily_sent >= daily_limit: - continue # 超过限额,跳过此配置 - - # 解密密码 - decrypted_password = decrypt_password(password) if password else '' - - return { - 'id': config_id, - 'name': name, - 'host': host, - 'port': port, - 'username': username, - 'password': decrypted_password, - 'use_ssl': bool(use_ssl), - 'use_tls': bool(use_tls), - 'sender_name': sender_name, - 'sender_email': sender_email, - 'is_primary': bool(is_primary) - } - - return None +def _get_available_smtp_config(failover: bool = True) -> Optional[Dict[str, Any]]: + """获取首个可用 SMTP 配置。""" + candidates = _get_smtp_candidates() + return candidates[0] if candidates else None def _update_smtp_stats(config_id: int, success: bool, error: str = ''): @@ -751,6 +998,44 @@ class EmailSender: raise +def _create_email_sender_from_config(config: Dict[str, Any]) -> EmailSender: + """从 SMTP 配置创建发送器。""" + return EmailSender( + { + 'name': config.get('name', ''), + 'host': config.get('host', ''), + 'port': config.get('port', 465), + 'username': config.get('username', ''), + 'password': config.get('password', ''), + 'use_ssl': bool(config.get('use_ssl', True)), + 'use_tls': bool(config.get('use_tls', False)), + 'sender_name': config.get('sender_name', '自动化学习'), + 'sender_email': config.get('sender_email', ''), + } + ) + + +def _send_with_smtp_config( + *, + config: Dict[str, Any], + to_email: str, + subject: str, + body: str, + html_body: str = None, + attachments: Optional[List[Dict[str, Any]]] = None, +) -> Tuple[bool, str]: + """使用指定 SMTP 配置发送一封邮件。""" + sender = _create_email_sender_from_config(config) + try: + sender.connect() + sender.send(to_email, subject, body, html_body, attachments) + sender.disconnect() + return True, '' + except Exception as e: + sender.disconnect() + return False, str(e) + + def send_email( to_email: str, subject: str, @@ -767,103 +1052,59 @@ def send_email( Returns: {'success': bool, 'error': str, 'config_id': int} """ - # 检查全局开关 settings = get_email_settings() if not settings['enabled']: return {'success': False, 'error': '邮件功能未启用', 'config_id': None} - # 获取可用配置 - config = _get_available_smtp_config(settings['failover_enabled']) - if not config: + candidates = _get_smtp_candidates() + if not settings['failover_enabled']: + candidates = candidates[:1] + if not candidates: return {'success': False, 'error': '没有可用的SMTP配置', 'config_id': None} - tried_configs = [] + tried_configs: List[int] = [] last_error = '' - while config: - tried_configs.append(config['id']) - sender = EmailSender(config) + for config in candidates: + config_id = int(config.get('id') or 0) + tried_configs.append(config_id) - try: - sender.connect() - sender.send(to_email, subject, body, html_body, attachments) - sender.disconnect() + ok, error = _send_with_smtp_config( + config=config, + to_email=to_email, + subject=subject, + body=body, + html_body=html_body, + attachments=attachments, + ) - # 更新统计 - _update_smtp_stats(config['id'], True) - _log_email_send(user_id, config['id'], to_email, email_type, subject, 'success', '', attachments) + if ok: + _update_smtp_stats(config_id, True) + _log_email_send(user_id, config_id, to_email, email_type, subject, 'success', '', attachments) _update_email_stats(email_type, True) if log_callback: - log_callback(f"[邮件服务] 发送成功: {to_email} (使用: {config['name']})") + log_callback(f"[邮件服务] 发送成功: {to_email} (使用: {config.get('name', '')})") - return {'success': True, 'error': '', 'config_id': config['id']} + return {'success': True, 'error': '', 'config_id': config_id} - except Exception as e: - last_error = str(e) - sender.disconnect() + last_error = error + _update_smtp_stats(config_id, False, last_error) - # 更新失败统计 - _update_smtp_stats(config['id'], False, last_error) + if log_callback: + log_callback(f"[邮件服务] 发送失败 [{config.get('name', '')}]: {last_error}") - if log_callback: - log_callback(f"[邮件服务] 发送失败 [{config['name']}]: {e}") - - # 故障转移:尝试下一个配置 - if settings['failover_enabled']: - config = _get_next_available_smtp_config(tried_configs) - else: - config = None - - # 所有配置都失败 - _log_email_send(user_id, tried_configs[0] if tried_configs else None, - to_email, email_type, subject, 'failed', last_error, attachments) + first_config_id = tried_configs[0] if tried_configs else None + _log_email_send(user_id, first_config_id, to_email, email_type, subject, 'failed', last_error, attachments) _update_email_stats(email_type, False) - return {'success': False, 'error': last_error, 'config_id': tried_configs[0] if tried_configs else None} + return {'success': False, 'error': last_error, 'config_id': first_config_id} def _get_next_available_smtp_config(exclude_ids: List[int]) -> Optional[Dict[str, Any]]: - """获取下一个可用的SMTP配置(排除已尝试的)""" - today = get_beijing_today() - - with db_pool.get_db() as conn: - cursor = conn.cursor() - - placeholders = ','.join(['?' for _ in exclude_ids]) - cursor.execute(f""" - SELECT id, name, host, port, username, password, use_ssl, use_tls, - sender_name, sender_email, daily_limit, daily_sent, is_primary - FROM smtp_configs - WHERE enabled = 1 AND id NOT IN ({placeholders}) - ORDER BY is_primary DESC, priority ASC, id ASC - LIMIT 1 - """, exclude_ids) - - row = cursor.fetchone() - if not row: - return None - - config_id, name, host, port, username, password, use_ssl, use_tls, \ - sender_name, sender_email, daily_limit, daily_sent, is_primary = row - - # 检查每日限额 - if daily_limit > 0 and daily_sent >= daily_limit: - return _get_next_available_smtp_config(exclude_ids + [config_id]) - - return { - 'id': config_id, - 'name': name, - 'host': host, - 'port': port, - 'username': username, - 'password': decrypt_password(password) if password else '', - 'use_ssl': bool(use_ssl), - 'use_tls': bool(use_tls), - 'sender_name': sender_name, - 'sender_email': sender_email, - 'is_primary': bool(is_primary) - } + """获取下一个可用的SMTP配置(排除已尝试的)。""" + candidates = _get_smtp_candidates(exclude_ids=exclude_ids) + return candidates[0] if candidates else None def test_smtp_config(config_id: int, test_email: str) -> Dict[str, Any]: @@ -872,28 +1113,14 @@ def test_smtp_config(config_id: int, test_email: str) -> Dict[str, Any]: if not config: return {'success': False, 'error': '配置不存在'} - sender = EmailSender({ - 'name': config['name'], - 'host': config['host'], - 'port': config['port'], - 'username': config['username'], - 'password': config['password'], - 'use_ssl': config['use_ssl'], - 'use_tls': config['use_tls'], - 'sender_name': config['sender_name'], - 'sender_email': config['sender_email'] - }) + ok, error = _send_with_smtp_config( + config=config, + to_email=test_email, + subject='自动化学习 - SMTP配置测试', + body=f'这是一封测试邮件。\n\n配置名称: {config["name"]}\nSMTP服务器: {config["host"]}:{config["port"]}\n\n如果您收到此邮件,说明SMTP配置正确。', + ) - try: - sender.connect() - sender.send( - test_email, - '自动化学习 - SMTP配置测试', - f'这是一封测试邮件。\n\n配置名称: {config["name"]}\nSMTP服务器: {config["host"]}:{config["port"]}\n\n如果您收到此邮件,说明SMTP配置正确。', - None, - None - ) - sender.disconnect() + if ok: # 更新最后成功时间 with db_pool.get_db() as conn: @@ -908,16 +1135,18 @@ def test_smtp_config(config_id: int, test_email: str) -> Dict[str, Any]: return {'success': True, 'error': ''} - except Exception as e: - # 记录错误 - with db_pool.get_db() as conn: - cursor = conn.cursor() - cursor.execute(""" - UPDATE smtp_configs SET last_error = ? WHERE id = ? - """, (str(e)[:500], config_id)) - conn.commit() + # 记录错误 + with db_pool.get_db() as conn: + cursor = conn.cursor() + cursor.execute( + """ + UPDATE smtp_configs SET last_error = ? WHERE id = ? + """, + (error[:500], config_id), + ) + conn.commit() - return {'success': False, 'error': str(e)} + return {'success': False, 'error': error} # ============ 邮件日志 ============ @@ -1248,84 +1477,43 @@ def send_register_verification_email( Returns: {'success': bool, 'error': str, 'token': str} """ - # 检查发送频率限制 - if not check_rate_limit(email, EMAIL_TYPE_REGISTER): - return { - 'success': False, - 'error': '发送太频繁,请稍后再试', - 'token': None - } - - # 生成验证Token - token = generate_email_token(email, EMAIL_TYPE_REGISTER, user_id) - - # 获取base_url - if not base_url: - settings = get_email_settings() - base_url = settings.get('base_url', '') - - if not base_url: - # 尝试从配置获取 - try: - from app_config import Config - base_url = Config.BASE_URL - except: - base_url = 'http://localhost:51233' - - # 生成验证链接 - verify_url = f"{base_url.rstrip('/')}/api/verify-email/{token}" - - # 读取邮件模板 - template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', 'register.html') - try: - with open(template_path, 'r', encoding='utf-8') as f: - html_template = f.read() - except FileNotFoundError: - # 使用简单的HTML模板 - html_template = """ - - -

邮箱验证

-

您好,{{ username }}!

-

请点击下面的链接验证您的邮箱地址:

-

{{ verify_url }}

-

此链接24小时内有效。

- - - """ - - # 替换模板变量 - html_body = html_template.replace('{{ username }}', username) - html_body = html_body.replace('{{ verify_url }}', verify_url) - - # 纯文本版本 - text_body = f""" + fallback_html = """ + + +

邮箱验证

+

您好,{{ username }}!

+

请点击下面的链接验证您的邮箱地址:

+

{{ verify_url }}

+

此链接24小时内有效。

+ + + """ + text_template = """ 您好,{username}! 感谢您注册自动化学习。请点击下面的链接验证您的邮箱地址: -{verify_url} +{action_url} 此链接24小时内有效。 如果您没有注册过账号,请忽略此邮件。 """ - - # 发送邮件 - result = send_email( - to_email=email, + return _send_token_email( + email=email, + username=username, + user_id=user_id, + base_url=base_url, + token_type=EMAIL_TYPE_REGISTER, + rate_limit_error='发送太频繁,请稍后再试', subject='【自动化学习】邮箱验证', - body=text_body, - html_body=html_body, - email_type=EMAIL_TYPE_REGISTER, - user_id=user_id + template_filename='register.html', + fallback_html=fallback_html, + url_path='/api/verify-email/{token}', + url_template_key='verify_url', + text_template=text_template, ) - if result['success']: - return {'success': True, 'error': '', 'token': token} - else: - return {'success': False, 'error': result['error'], 'token': None} - def resend_register_verification_email(user_id: int, email: str, username: str) -> Dict[str, Any]: """ @@ -1340,14 +1528,7 @@ def resend_register_verification_email(user_id: int, email: str, username: str) {'success': bool, 'error': str} """ # 检查是否有未过期的token - with db_pool.get_db() as conn: - cursor = conn.cursor() - # 先使旧token失效 - cursor.execute(""" - UPDATE email_tokens SET used = 1 - WHERE user_id = ? AND token_type = ? AND used = 0 - """, (user_id, EMAIL_TYPE_REGISTER)) - conn.commit() + _mark_unused_tokens_as_used(user_id, EMAIL_TYPE_REGISTER) # 发送新的验证邮件 return send_register_verification_email(email, username, user_id) @@ -1373,91 +1554,44 @@ def send_password_reset_email( Returns: {'success': bool, 'error': str, 'token': str} """ - # 检查发送频率限制(密码重置限制5分钟) - if not check_rate_limit(email, EMAIL_TYPE_RESET): - return { - 'success': False, - 'error': '发送太频繁,请5分钟后再试', - 'token': None - } - - # 使旧的重置token失效 - with db_pool.get_db() as conn: - cursor = conn.cursor() - cursor.execute(""" - UPDATE email_tokens SET used = 1 - WHERE user_id = ? AND token_type = ? AND used = 0 - """, (user_id, EMAIL_TYPE_RESET)) - conn.commit() - - # 生成新的验证Token - token = generate_email_token(email, EMAIL_TYPE_RESET, user_id) - - # 获取base_url - if not base_url: - settings = get_email_settings() - base_url = settings.get('base_url', '') - - if not base_url: - try: - from app_config import Config - base_url = Config.BASE_URL - except: - base_url = 'http://localhost:51233' - - # 生成重置链接 - reset_url = f"{base_url.rstrip('/')}/reset-password/{token}" - - # 读取邮件模板 - template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', 'reset_password.html') - try: - with open(template_path, 'r', encoding='utf-8') as f: - html_template = f.read() - except FileNotFoundError: - html_template = """ - - -

密码重置

-

您好,{{ username }}!

-

请点击下面的链接重置您的密码:

-

{{ reset_url }}

-

此链接30分钟内有效。

- - - """ - - # 替换模板变量 - html_body = html_template.replace('{{ username }}', username) - html_body = html_body.replace('{{ reset_url }}', reset_url) - - # 纯文本版本 - text_body = f""" + fallback_html = """ + + +

密码重置

+

您好,{{ username }}!

+

请点击下面的链接重置您的密码:

+

{{ reset_url }}

+

此链接30分钟内有效。

+ + + """ + text_template = """ 您好,{username}! 我们收到了您的密码重置请求。请点击下面的链接重置您的密码: -{reset_url} +{action_url} 此链接30分钟内有效。 如果您没有申请过密码重置,请忽略此邮件。 """ - - # 发送邮件 - result = send_email( - to_email=email, + return _send_token_email( + email=email, + username=username, + user_id=user_id, + base_url=base_url, + token_type=EMAIL_TYPE_RESET, + rate_limit_error='发送太频繁,请5分钟后再试', subject='【自动化学习】密码重置', - body=text_body, - html_body=html_body, - email_type=EMAIL_TYPE_RESET, - user_id=user_id + template_filename='reset_password.html', + fallback_html=fallback_html, + url_path='/reset-password/{token}', + url_template_key='reset_url', + text_template=text_template, + invalidate_existing_tokens=True, ) - if result['success']: - return {'success': True, 'error': '', 'token': token} - else: - return {'success': False, 'error': result['error'], 'token': None} - def verify_password_reset_token(token: str) -> Optional[Dict[str, Any]]: """ @@ -1533,91 +1667,44 @@ def send_bind_email_verification( Returns: {'success': bool, 'error': str, 'token': str} """ - # 检查发送频率限制(绑定邮件限制1分钟) - if not check_rate_limit(email, EMAIL_TYPE_BIND): - return { - 'success': False, - 'error': '发送太频繁,请1分钟后再试', - 'token': None - } - - # 使旧的绑定token失效 - with db_pool.get_db() as conn: - cursor = conn.cursor() - cursor.execute(""" - UPDATE email_tokens SET used = 1 - WHERE user_id = ? AND token_type = ? AND used = 0 - """, (user_id, EMAIL_TYPE_BIND)) - conn.commit() - - # 生成新的验证Token - token = generate_email_token(email, EMAIL_TYPE_BIND, user_id) - - # 获取base_url - if not base_url: - settings = get_email_settings() - base_url = settings.get('base_url', '') - - if not base_url: - try: - from app_config import Config - base_url = Config.BASE_URL - except: - base_url = 'http://localhost:51233' - - # 生成验证链接 - verify_url = f"{base_url.rstrip('/')}/api/verify-bind-email/{token}" - - # 读取邮件模板 - template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', 'bind_email.html') - try: - with open(template_path, 'r', encoding='utf-8') as f: - html_template = f.read() - except FileNotFoundError: - html_template = """ - - -

邮箱绑定验证

-

您好,{{ username }}!

-

请点击下面的链接完成邮箱绑定:

-

{{ verify_url }}

-

此链接1小时内有效。

- - - """ - - # 替换模板变量 - html_body = html_template.replace('{{ username }}', username) - html_body = html_body.replace('{{ verify_url }}', verify_url) - - # 纯文本版本 - text_body = f""" + fallback_html = """ + + +

邮箱绑定验证

+

您好,{{ username }}!

+

请点击下面的链接完成邮箱绑定:

+

{{ verify_url }}

+

此链接1小时内有效。

+ + + """ + text_template = """ 您好,{username}! 您正在绑定此邮箱到您的账号。请点击下面的链接完成验证: -{verify_url} +{action_url} 此链接1小时内有效。 如果这不是您的操作,请忽略此邮件。 """ - - # 发送邮件 - result = send_email( - to_email=email, + return _send_token_email( + email=email, + username=username, + user_id=user_id, + base_url=base_url, + token_type=EMAIL_TYPE_BIND, + rate_limit_error='发送太频繁,请1分钟后再试', subject='【自动化学习】邮箱绑定验证', - body=text_body, - html_body=html_body, - email_type=EMAIL_TYPE_BIND, - user_id=user_id + template_filename='bind_email.html', + fallback_html=fallback_html, + url_path='/api/verify-bind-email/{token}', + url_template_key='verify_url', + text_template=text_template, + invalidate_existing_tokens=True, ) - if result['success']: - return {'success': True, 'error': '', 'token': token} - else: - return {'success': False, 'error': result['error'], 'token': None} - def verify_bind_email_token(token: str) -> Optional[Dict[str, Any]]: """ @@ -1897,6 +1984,37 @@ def create_zip_attachment(files: List[Dict[str, Any]], zip_filename: str = 'scre # ============ 任务完成通知邮件 ============ +def _send_email_and_collect( + *, + to_email: str, + subject: str, + body: str, + html_body: Optional[str] = None, + attachments: Optional[List[Dict[str, Any]]] = None, + email_type: str, + user_id: Optional[int], + log_callback: Optional[Callable] = None, + success_log: Optional[str] = None, +) -> Tuple[bool, str]: + result = send_email( + to_email=to_email, + subject=subject, + body=body, + html_body=html_body, + attachments=attachments, + email_type=email_type, + user_id=user_id, + log_callback=log_callback, + ) + + if result.get('success'): + if success_log and log_callback: + log_callback(success_log) + return True, '' + + return False, str(result.get('error') or '发送失败') + + def send_task_complete_email( user_id: int, email: str, @@ -1925,39 +2043,16 @@ def send_task_complete_email( Returns: {'success': bool, 'error': str, 'emails_sent': int} """ - # 检查邮件功能是否启用 - settings = get_email_settings() - if not settings.get('enabled', False): - return {'success': False, 'error': '邮件功能未启用', 'emails_sent': 0} + precheck_error = _task_notify_precheck(email) + if precheck_error: + return {'success': False, 'error': precheck_error, 'emails_sent': 0} - if not settings.get('task_notify_enabled', False): - return {'success': False, 'error': '任务通知功能未启用', 'emails_sent': 0} - - if not email: - return {'success': False, 'error': '用户未设置邮箱', 'emails_sent': 0} - - # 获取完成时间 complete_time = get_beijing_now_str() + screenshot_data, screenshot_filename = _load_screenshot_data(screenshot_path, log_callback) - # 读取截图文件 - screenshot_data = None - screenshot_filename = None - if screenshot_path and os.path.exists(screenshot_path): - try: - with open(screenshot_path, 'rb') as f: - screenshot_data = f.read() - screenshot_filename = os.path.basename(screenshot_path) - except Exception as e: - if log_callback: - log_callback(f"[邮件] 读取截图文件失败: {e}") - - # 读取邮件模板 - template_path = os.path.join(os.path.dirname(__file__), 'templates', 'email', 'task_complete.html') - try: - with open(template_path, 'r', encoding='utf-8') as f: - html_template = f.read() - except FileNotFoundError: - html_template = """ + html_template = _load_email_template( + 'task_complete.html', + """

任务完成通知

@@ -1970,108 +2065,72 @@ def send_task_complete_email(

完成时间:{{ complete_time }}

- """ + """, + ) - # 准备发送 emails_sent = 0 last_error = '' + is_oversized_attachment = bool(screenshot_data and len(screenshot_data) > MAX_ATTACHMENT_SIZE) - # 检查附件大小,决定是否需要分批发送 - if screenshot_data and len(screenshot_data) > MAX_ATTACHMENT_SIZE: - # 附件超过限制,不压缩直接发送(大图片压缩效果不大) - # 这种情况很少见,但作为容错处理 - batch_info = '(由于文件较大,截图将单独发送)' + if is_oversized_attachment: + html_body, text_body = _render_task_complete_bodies( + html_template=html_template, + username=username, + account_name=account_name, + browse_type=browse_type, + total_items=total_items, + total_attachments=total_attachments, + complete_time=complete_time, + batch_info='(由于文件较大,截图将单独发送)', + screenshot_text='截图将在下一封邮件中发送。', + ) - # 先发送不带附件的通知邮件 - html_body = html_template.replace('{{ username }}', username) - html_body = html_body.replace('{{ account_name }}', account_name) - html_body = html_body.replace('{{ browse_type }}', browse_type) - html_body = html_body.replace('{{ total_items }}', str(total_items)) - html_body = html_body.replace('{{ total_attachments }}', str(total_attachments)) - html_body = html_body.replace('{{ complete_time }}', complete_time) - html_body = html_body.replace('{{ batch_info }}', batch_info) - - text_body = f""" -您好,{username}! - -您的浏览任务已完成。 - -账号:{account_name} -浏览类型:{browse_type} -浏览条目:{total_items} 条 -附件数量:{total_attachments} 个 -完成时间:{complete_time} - -截图将在下一封邮件中发送。 -""" - - result = send_email( + ok, error = _send_email_and_collect( to_email=email, subject=f'【自动化学习】任务完成 - {account_name}', body=text_body, html_body=html_body, email_type=EMAIL_TYPE_TASK_COMPLETE, user_id=user_id, - log_callback=log_callback + log_callback=log_callback, + success_log='[邮件] 任务通知已发送', ) - - if result['success']: + if ok: emails_sent += 1 - if log_callback: - log_callback(f"[邮件] 任务通知已发送") else: - last_error = result['error'] + last_error = error - # 单独发送截图附件 attachment = [{'filename': screenshot_filename, 'data': screenshot_data}] - result2 = send_email( + ok, error = _send_email_and_collect( to_email=email, subject=f'【自动化学习】任务截图 - {account_name}', body=f'这是 {account_name} 的任务截图。', attachments=attachment, email_type=EMAIL_TYPE_TASK_COMPLETE, user_id=user_id, - log_callback=log_callback + log_callback=log_callback, + success_log='[邮件] 截图附件已发送', ) - - if result2['success']: + if ok: emails_sent += 1 - if log_callback: - log_callback(f"[邮件] 截图附件已发送") else: - last_error = result2['error'] + last_error = error else: - # 正常情况:附件大小在限制内,一次性发送 - batch_info = '' - attachments = None + attachments = [{'filename': screenshot_filename, 'data': screenshot_data}] if screenshot_data else None + html_body, text_body = _render_task_complete_bodies( + html_template=html_template, + username=username, + account_name=account_name, + browse_type=browse_type, + total_items=total_items, + total_attachments=total_attachments, + complete_time=complete_time, + batch_info='', + screenshot_text='截图已附在邮件中。' if screenshot_data else '', + ) - if screenshot_data: - attachments = [{'filename': screenshot_filename, 'data': screenshot_data}] - - html_body = html_template.replace('{{ username }}', username) - html_body = html_body.replace('{{ account_name }}', account_name) - html_body = html_body.replace('{{ browse_type }}', browse_type) - html_body = html_body.replace('{{ total_items }}', str(total_items)) - html_body = html_body.replace('{{ total_attachments }}', str(total_attachments)) - html_body = html_body.replace('{{ complete_time }}', complete_time) - html_body = html_body.replace('{{ batch_info }}', batch_info) - - text_body = f""" -您好,{username}! - -您的浏览任务已完成。 - -账号:{account_name} -浏览类型:{browse_type} -浏览条目:{total_items} 条 -附件数量:{total_attachments} 个 -完成时间:{complete_time} - -{'截图已附在邮件中。' if screenshot_data else ''} -""" - - result = send_email( + ok, error = _send_email_and_collect( to_email=email, subject=f'【自动化学习】任务完成 - {account_name}', body=text_body, @@ -2079,15 +2138,13 @@ def send_task_complete_email( attachments=attachments, email_type=EMAIL_TYPE_TASK_COMPLETE, user_id=user_id, - log_callback=log_callback + log_callback=log_callback, + success_log='[邮件] 任务通知已发送', ) - - if result['success']: + if ok: emails_sent += 1 - if log_callback: - log_callback(f"[邮件] 任务通知已发送") else: - last_error = result['error'] + last_error = error return { 'success': emails_sent > 0, @@ -2118,63 +2175,25 @@ def send_task_complete_email_async( log_callback("[邮件] 邮件队列已满,任务通知未发送") -def send_batch_task_complete_email( - user_id: int, - email: str, +def _summarize_batch_screenshots(screenshots: List[Dict[str, Any]]) -> Tuple[int, int, int]: + total_items_sum = sum(int(s.get('items', 0) or 0) for s in screenshots) + total_attachments_sum = sum(int(s.get('attachments', 0) or 0) for s in screenshots) + account_count = len(screenshots) + return total_items_sum, total_attachments_sum, account_count + + +def _render_batch_task_complete_html( + *, username: str, schedule_name: str, browse_type: str, - screenshots: List[Dict[str, Any]] -) -> Dict[str, Any]: - """ - 发送批次任务完成通知邮件(多账号截图打包) - - Args: - user_id: 用户ID - email: 收件人邮箱 - username: 用户名 - schedule_name: 定时任务名称 - browse_type: 浏览类型 - screenshots: 截图列表 [{'account_name': x, 'path': y, 'items': n, 'attachments': m}, ...] - - Returns: - {'success': bool, 'error': str} - """ - # 检查邮件功能是否启用 - settings = get_email_settings() - if not settings.get('enabled', False): - return {'success': False, 'error': '邮件功能未启用'} - - if not settings.get('task_notify_enabled', False): - return {'success': False, 'error': '任务通知功能未启用'} - - if not email: - return {'success': False, 'error': '用户未设置邮箱'} - - if not screenshots: - return {'success': False, 'error': '没有截图需要发送'} - - # 获取完成时间 - complete_time = get_beijing_now_str() - - # 统计信息 - total_items_sum = sum(s.get('items', 0) for s in screenshots) - total_attachments_sum = sum(s.get('attachments', 0) for s in screenshots) - account_count = len(screenshots) - - # 构建账号详情HTML - accounts_html = "" - for s in screenshots: - accounts_html += f""" - - {s.get('account_name', '未知')} - {s.get('items', 0)} - {s.get('attachments', 0)} - - """ - - # 构建HTML邮件内容 - html_content = f""" + complete_time: str, + account_count: int, + total_items_sum: int, + total_attachments_sum: int, + accounts_html: str, +) -> str: + return f"""
@@ -2208,52 +2227,50 @@ def send_batch_task_complete_email( """ - # 收集可用截图文件路径(避免把所有图片一次性读入内存) - screenshot_paths = [] - for s in screenshots: - path = s.get('path') - if path and os.path.exists(path): - arcname = f"{s.get('account_name', 'screenshot')}_{os.path.basename(path)}" - screenshot_paths.append((path, arcname)) - # 如果有截图,优先落盘打包ZIP,再按大小决定是否附加(降低内存峰值) - zip_data = None - zip_filename = None - attachment_note = "" - if screenshot_paths: - import tempfile - zip_path = None - try: - with tempfile.NamedTemporaryFile(prefix="screenshots_", suffix=".zip", delete=False) as tmp: - zip_path = tmp.name - with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: - for file_path, arcname in screenshot_paths: - try: - zf.write(file_path, arcname=arcname) - except Exception as e: - logger.warning(f"[邮件] 写入ZIP失败: {e}") +def send_batch_task_complete_email( + user_id: int, + email: str, + username: str, + schedule_name: str, + browse_type: str, + screenshots: List[Dict[str, Any]] +) -> Dict[str, Any]: + """ + 发送批次任务完成通知邮件(多账号截图打包) - zip_size = os.path.getsize(zip_path) if zip_path and os.path.exists(zip_path) else 0 - if zip_size <= 0: - attachment_note = "本次无可用截图文件(可能截图失败或文件不存在)。" - elif zip_size > MAX_ATTACHMENT_SIZE: - attachment_note = f"截图打包文件过大({zip_size} bytes),本次不附加附件。" - else: - with open(zip_path, 'rb') as f: - zip_data = f.read() - zip_filename = f"screenshots_{datetime.now(BEIJING_TZ).strftime('%Y%m%d_%H%M%S')}.zip" - attachment_note = "截图已打包为ZIP附件,请查收。" - except Exception as e: - logger.warning(f"[邮件] 打包截图失败: {e}") - attachment_note = "截图打包失败,本次不附加附件。" - finally: - if zip_path and os.path.exists(zip_path): - try: - os.remove(zip_path) - except Exception: - pass - else: - attachment_note = "本次无可用截图文件(可能截图失败或未启用截图)。" + Args: + user_id: 用户ID + email: 收件人邮箱 + username: 用户名 + schedule_name: 定时任务名称 + browse_type: 浏览类型 + screenshots: 截图列表 [{'account_name': x, 'path': y, 'items': n, 'attachments': m}, ...] + + Returns: + {'success': bool, 'error': str} + """ + precheck_error = _task_notify_precheck(email, require_screenshots=True, screenshots=screenshots) + if precheck_error: + return {'success': False, 'error': precheck_error} + + complete_time = get_beijing_now_str() + total_items_sum, total_attachments_sum, account_count = _summarize_batch_screenshots(screenshots) + accounts_html = _build_batch_accounts_html_rows(screenshots) + + html_content = _render_batch_task_complete_html( + username=username, + schedule_name=schedule_name, + browse_type=browse_type, + complete_time=complete_time, + account_count=account_count, + total_items_sum=total_items_sum, + total_attachments_sum=total_attachments_sum, + accounts_html=accounts_html, + ) + + screenshot_paths = _collect_existing_screenshot_paths(screenshots) + zip_data, zip_filename, attachment_note = _build_zip_attachment_from_paths(screenshot_paths) # 将附件说明写入邮件内容 html_content = html_content.replace("截图已打包为ZIP附件,请查收。", attachment_note) @@ -2267,7 +2284,7 @@ def send_batch_task_complete_email( 'mime_type': 'application/zip' }) - result = send_email( + ok, error = _send_email_and_collect( to_email=email, subject=f'【自动化学习】定时任务完成 - {schedule_name}', body='', @@ -2277,10 +2294,9 @@ def send_batch_task_complete_email( user_id=user_id, ) - if result['success']: + if ok: return {'success': True} - else: - return {'success': False, 'error': result.get('error', '发送失败')} + return {'success': False, 'error': error or '发送失败'} def send_batch_task_complete_email_async( diff --git a/realtime/status_push.py b/realtime/status_push.py index 665dfdf..481d672 100644 --- a/realtime/status_push.py +++ b/realtime/status_push.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import json import os import time @@ -9,8 +10,40 @@ from services.runtime import get_logger, get_socketio from services.state import safe_get_account, safe_iter_task_status_items +def _to_int(value, default: int = 0) -> int: + try: + return int(value) + except Exception: + return int(default) + + +def _payload_signature(payload: dict) -> str: + try: + return json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":"), default=str) + except Exception: + return repr(payload) + + +def _should_emit( + *, + last_sig: str | None, + last_ts: float, + new_sig: str, + now_ts: float, + min_interval: float, + force_interval: float, +) -> bool: + if last_sig is None: + return True + if (now_ts - last_ts) >= force_interval: + return True + if new_sig != last_sig and (now_ts - last_ts) >= min_interval: + return True + return False + + def status_push_worker() -> None: - """后台线程:按间隔推送排队/运行中任务的状态更新(可节流)。""" + """后台线程:按间隔推送排队/运行中任务状态(变更驱动+心跳兜底)。""" logger = get_logger() try: push_interval = float(os.environ.get("STATUS_PUSH_INTERVAL_SECONDS", "1")) @@ -18,18 +51,41 @@ def status_push_worker() -> None: push_interval = 1.0 push_interval = max(0.5, push_interval) + try: + queue_min_interval = float(os.environ.get("STATUS_PUSH_MIN_QUEUE_INTERVAL_SECONDS", str(push_interval))) + except Exception: + queue_min_interval = push_interval + queue_min_interval = max(push_interval, queue_min_interval) + + try: + progress_min_interval = float( + os.environ.get("STATUS_PUSH_MIN_PROGRESS_INTERVAL_SECONDS", str(push_interval)) + ) + except Exception: + progress_min_interval = push_interval + progress_min_interval = max(push_interval, progress_min_interval) + + try: + force_interval = float(os.environ.get("STATUS_PUSH_FORCE_INTERVAL_SECONDS", "10")) + except Exception: + force_interval = 10.0 + force_interval = max(push_interval, force_interval) + socketio = get_socketio() from services.tasks import get_task_scheduler scheduler = get_task_scheduler() + emitted_state: dict[str, dict] = {} while True: try: + now_ts = time.time() queue_snapshot = scheduler.get_queue_state_snapshot() pending_total = int(queue_snapshot.get("pending_total", 0) or 0) running_total = int(queue_snapshot.get("running_total", 0) or 0) running_by_user = queue_snapshot.get("running_by_user") or {} positions = queue_snapshot.get("positions") or {} + active_account_ids = set() status_items = safe_iter_task_status_items() for account_id, status_info in status_items: @@ -39,11 +95,15 @@ def status_push_worker() -> None: user_id = status_info.get("user_id") if not user_id: continue + + active_account_ids.add(str(account_id)) account = safe_get_account(user_id, account_id) if not account: continue + + user_id_int = _to_int(user_id) account_data = account.to_dict() - pos = positions.get(account_id) or {} + pos = positions.get(account_id) or positions.get(str(account_id)) or {} account_data.update( { "queue_pending_total": pending_total, @@ -51,10 +111,23 @@ def status_push_worker() -> None: "queue_ahead": pos.get("queue_ahead"), "queue_position": pos.get("queue_position"), "queue_is_vip": pos.get("is_vip"), - "queue_running_user": int(running_by_user.get(int(user_id), 0) or 0), + "queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))), } ) - socketio.emit("account_update", account_data, room=f"user_{user_id}") + + cache_entry = emitted_state.setdefault(str(account_id), {}) + account_sig = _payload_signature(account_data) + if _should_emit( + last_sig=cache_entry.get("account_sig"), + last_ts=float(cache_entry.get("account_ts", 0) or 0), + new_sig=account_sig, + now_ts=now_ts, + min_interval=queue_min_interval, + force_interval=force_interval, + ): + socketio.emit("account_update", account_data, room=f"user_{user_id}") + cache_entry["account_sig"] = account_sig + cache_entry["account_ts"] = now_ts if status != "运行中": continue @@ -74,9 +147,26 @@ def status_push_worker() -> None: "queue_running_total": running_total, "queue_ahead": pos.get("queue_ahead"), "queue_position": pos.get("queue_position"), - "queue_running_user": int(running_by_user.get(int(user_id), 0) or 0), + "queue_running_user": _to_int(running_by_user.get(user_id_int, running_by_user.get(str(user_id_int), 0))), } - socketio.emit("task_progress", progress_data, room=f"user_{user_id}") + + progress_sig = _payload_signature(progress_data) + if _should_emit( + last_sig=cache_entry.get("progress_sig"), + last_ts=float(cache_entry.get("progress_ts", 0) or 0), + new_sig=progress_sig, + now_ts=now_ts, + min_interval=progress_min_interval, + force_interval=force_interval, + ): + socketio.emit("task_progress", progress_data, room=f"user_{user_id}") + cache_entry["progress_sig"] = progress_sig + cache_entry["progress_ts"] = now_ts + + if emitted_state: + stale_ids = [account_id for account_id in emitted_state.keys() if account_id not in active_account_ids] + for account_id in stale_ids: + emitted_state.pop(account_id, None) time.sleep(push_interval) except Exception as e: diff --git a/routes/admin_api/__init__.py b/routes/admin_api/__init__.py index ea19e72..32cd44b 100644 --- a/routes/admin_api/__init__.py +++ b/routes/admin_api/__init__.py @@ -8,6 +8,15 @@ admin_api_bp = Blueprint("admin_api", __name__, url_prefix="/yuyx/api") # Import side effects: register routes on blueprint from routes.admin_api import core as _core # noqa: F401 +from routes.admin_api import system_config_api as _system_config_api # noqa: F401 +from routes.admin_api import operations_api as _operations_api # noqa: F401 +from routes.admin_api import announcements_api as _announcements_api # noqa: F401 +from routes.admin_api import users_api as _users_api # noqa: F401 +from routes.admin_api import account_api as _account_api # noqa: F401 +from routes.admin_api import feedback_api as _feedback_api # noqa: F401 +from routes.admin_api import infra_api as _infra_api # noqa: F401 +from routes.admin_api import tasks_api as _tasks_api # noqa: F401 +from routes.admin_api import email_api as _email_api # noqa: F401 # Export security blueprint for app registration from routes.admin_api.security import security_bp # noqa: F401 diff --git a/routes/admin_api/account_api.py b/routes/admin_api/account_api.py new file mode 100644 index 0000000..6a7d3ba --- /dev/null +++ b/routes/admin_api/account_api.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import database +from app_security import validate_password +from flask import jsonify, request, session +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required + +# ==================== 密码重置 / 反馈(管理员) ==================== + + +@admin_api_bp.route("/admin/password", methods=["PUT"]) +@admin_required +def update_admin_password(): + """修改管理员密码""" + data = request.json or {} + new_password = (data.get("new_password") or "").strip() + + if not new_password: + return jsonify({"error": "密码不能为空"}), 400 + + username = session.get("admin_username") + if database.update_admin_password(username, new_password): + return jsonify({"success": True}) + return jsonify({"error": "修改失败"}), 400 + + +@admin_api_bp.route("/admin/username", methods=["PUT"]) +@admin_required +def update_admin_username(): + """修改管理员用户名""" + data = request.json or {} + new_username = (data.get("new_username") or "").strip() + + if not new_username: + return jsonify({"error": "用户名不能为空"}), 400 + + old_username = session.get("admin_username") + if database.update_admin_username(old_username, new_username): + session["admin_username"] = new_username + return jsonify({"success": True}) + return jsonify({"error": "修改失败,用户名可能已存在"}), 400 + + +@admin_api_bp.route("/users//reset_password", methods=["POST"]) +@admin_required +def admin_reset_password_route(user_id): + """管理员直接重置用户密码(无需审核)""" + data = request.json or {} + new_password = (data.get("new_password") or "").strip() + + if not new_password: + return jsonify({"error": "新密码不能为空"}), 400 + + is_valid, error_msg = validate_password(new_password) + if not is_valid: + return jsonify({"error": error_msg}), 400 + + if database.admin_reset_user_password(user_id, new_password): + return jsonify({"message": "密码重置成功"}) + return jsonify({"error": "重置失败,用户不存在"}), 400 diff --git a/routes/admin_api/announcements_api.py b/routes/admin_api/announcements_api.py new file mode 100644 index 0000000..70947f4 --- /dev/null +++ b/routes/admin_api/announcements_api.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import os +import posixpath +import secrets +import time + +import database +from app_config import get_config +from app_logger import get_logger +from app_security import is_safe_path, sanitize_filename +from flask import current_app, jsonify, request, url_for +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required + +logger = get_logger("app") +config = get_config() + + +def _get_upload_dir(): + rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements") + if not is_safe_path(current_app.root_path, rel_dir): + rel_dir = "static/announcements" + abs_dir = os.path.join(current_app.root_path, rel_dir) + os.makedirs(abs_dir, exist_ok=True) + return abs_dir, rel_dir + + +def _get_file_size(file_storage): + try: + file_storage.stream.seek(0, os.SEEK_END) + size = file_storage.stream.tell() + file_storage.stream.seek(0) + return size + except Exception: + return None + + +# ==================== 公告管理API(管理员) ==================== + + +@admin_api_bp.route("/announcements/upload_image", methods=["POST"]) +@admin_required +def admin_upload_announcement_image(): + """上传公告图片(返回可访问URL)""" + file = request.files.get("file") + if not file or not file.filename: + return jsonify({"error": "请选择图片"}), 400 + + filename = sanitize_filename(file.filename) + ext = os.path.splitext(filename)[1].lower() + allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"}) + if not ext or ext not in allowed_exts: + return jsonify({"error": "不支持的图片格式"}), 400 + if file.mimetype and not str(file.mimetype).startswith("image/"): + return jsonify({"error": "文件类型无效"}), 400 + + size = _get_file_size(file) + max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024)) + if size is not None and size > max_size: + max_mb = max_size // 1024 // 1024 + return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400 + + abs_dir, rel_dir = _get_upload_dir() + token = secrets.token_hex(6) + name = f"announcement_{int(time.time())}_{token}{ext}" + save_path = os.path.join(abs_dir, name) + file.save(save_path) + + static_root = os.path.join(current_app.root_path, "static") + rel_to_static = os.path.relpath(abs_dir, static_root) + if rel_to_static.startswith(".."): + rel_to_static = "announcements" + url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name) + return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)}) + + +@admin_api_bp.route("/announcements", methods=["GET"]) +@admin_required +def admin_get_announcements(): + """获取公告列表""" + try: + limit = int(request.args.get("limit", 50)) + offset = int(request.args.get("offset", 0)) + except (TypeError, ValueError): + limit, offset = 50, 0 + + limit = max(1, min(200, limit)) + offset = max(0, offset) + return jsonify(database.get_announcements(limit=limit, offset=offset)) + + +@admin_api_bp.route("/announcements", methods=["POST"]) +@admin_required +def admin_create_announcement(): + """创建公告(默认启用并替换旧公告)""" + data = request.json or {} + title = (data.get("title") or "").strip() + content = (data.get("content") or "").strip() + image_url = (data.get("image_url") or "").strip() + is_active = bool(data.get("is_active", True)) + + if image_url and len(image_url) > 1000: + return jsonify({"error": "图片地址过长"}), 400 + + announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active) + if not announcement_id: + return jsonify({"error": "标题和内容不能为空"}), 400 + + return jsonify({"success": True, "id": announcement_id}) + + +@admin_api_bp.route("/announcements//activate", methods=["POST"]) +@admin_required +def admin_activate_announcement(announcement_id): + """启用公告(会自动停用其他公告)""" + if not database.get_announcement_by_id(announcement_id): + return jsonify({"error": "公告不存在"}), 404 + ok = database.set_announcement_active(announcement_id, True) + return jsonify({"success": ok}) + + +@admin_api_bp.route("/announcements//deactivate", methods=["POST"]) +@admin_required +def admin_deactivate_announcement(announcement_id): + """停用公告""" + if not database.get_announcement_by_id(announcement_id): + return jsonify({"error": "公告不存在"}), 404 + ok = database.set_announcement_active(announcement_id, False) + return jsonify({"success": ok}) + + +@admin_api_bp.route("/announcements/", methods=["DELETE"]) +@admin_required +def admin_delete_announcement(announcement_id): + """删除公告""" + if not database.get_announcement_by_id(announcement_id): + return jsonify({"error": "公告不存在"}), 404 + ok = database.delete_announcement(announcement_id) + return jsonify({"success": ok}) + + diff --git a/routes/admin_api/core.py b/routes/admin_api/core.py index a744c4d..3449169 100644 --- a/routes/admin_api/core.py +++ b/routes/admin_api/core.py @@ -3,39 +3,19 @@ from __future__ import annotations import os -import posixpath -import secrets -import threading import time -from datetime import datetime import database -import email_service -import requests from app_config import get_config from app_logger import get_logger -from app_security import ( - get_rate_limit_ip, - is_safe_outbound_url, - is_safe_path, - require_ip_not_locked, - sanitize_filename, - validate_email, - validate_password, -) +from app_security import get_rate_limit_ip, require_ip_not_locked from flask import current_app, jsonify, redirect, request, session, url_for from routes.admin_api import admin_api_bp from routes.decorators import admin_required from services.accounts_service import load_user_accounts -from services.browse_types import BROWSE_TYPE_SHOULD_READ, validate_browse_type from services.checkpoints import get_checkpoint_mgr -from services.scheduler import run_scheduled_task from services.state import ( - safe_clear_user_logs, - safe_get_account, safe_get_user_accounts_snapshot, - safe_iter_task_status_items, - safe_remove_user_accounts, safe_verify_and_consume_captcha, check_login_ip_user_locked, check_login_rate_limits, @@ -46,42 +26,11 @@ from services.state import ( clear_login_failures, record_login_failure, ) -from services.tasks import get_task_scheduler, submit_account_task -from services.time_utils import BEIJING_TZ, get_beijing_now +from services.tasks import submit_account_task logger = get_logger("app") config = get_config() -_server_cpu_percent_lock = threading.Lock() -_server_cpu_percent_last: float | None = None -_server_cpu_percent_last_ts = 0.0 - - -def _get_server_cpu_percent() -> float: - import psutil - - global _server_cpu_percent_last, _server_cpu_percent_last_ts - - now = time.time() - with _server_cpu_percent_lock: - if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5: - return _server_cpu_percent_last - - try: - if _server_cpu_percent_last is None: - cpu_percent = float(psutil.cpu_percent(interval=0.1)) - else: - cpu_percent = float(psutil.cpu_percent(interval=None)) - except Exception: - cpu_percent = float(_server_cpu_percent_last or 0.0) - - if cpu_percent < 0: - cpu_percent = 0.0 - - _server_cpu_percent_last = cpu_percent - _server_cpu_percent_last_ts = now - return cpu_percent - def _admin_reauth_required() -> bool: try: @@ -95,25 +44,6 @@ def _require_admin_reauth(): return jsonify({"error": "需要二次确认", "code": "reauth_required"}), 401 return None -def _get_upload_dir(): - rel_dir = getattr(config, "ANNOUNCEMENT_IMAGE_DIR", "static/announcements") - if not is_safe_path(current_app.root_path, rel_dir): - rel_dir = "static/announcements" - abs_dir = os.path.join(current_app.root_path, rel_dir) - os.makedirs(abs_dir, exist_ok=True) - return abs_dir, rel_dir - - -def _get_file_size(file_storage): - try: - file_storage.stream.seek(0, os.SEEK_END) - size = file_storage.stream.tell() - file_storage.stream.seek(0) - return size - except Exception: - return None - - @admin_api_bp.route("/debug-config", methods=["GET"]) @admin_required def debug_config(): @@ -247,948 +177,6 @@ def admin_reauth(): session.modified = True return jsonify({"success": True, "expires_in": int(config.ADMIN_REAUTH_WINDOW_SECONDS)}) - -# ==================== 公告管理API(管理员) ==================== - - -@admin_api_bp.route("/announcements/upload_image", methods=["POST"]) -@admin_required -def admin_upload_announcement_image(): - """上传公告图片(返回可访问URL)""" - file = request.files.get("file") - if not file or not file.filename: - return jsonify({"error": "请选择图片"}), 400 - - filename = sanitize_filename(file.filename) - ext = os.path.splitext(filename)[1].lower() - allowed_exts = getattr(config, "ALLOWED_ANNOUNCEMENT_IMAGE_EXTENSIONS", {".png", ".jpg", ".jpeg"}) - if not ext or ext not in allowed_exts: - return jsonify({"error": "不支持的图片格式"}), 400 - if file.mimetype and not str(file.mimetype).startswith("image/"): - return jsonify({"error": "文件类型无效"}), 400 - - size = _get_file_size(file) - max_size = int(getattr(config, "MAX_ANNOUNCEMENT_IMAGE_SIZE", 5 * 1024 * 1024)) - if size is not None and size > max_size: - max_mb = max_size // 1024 // 1024 - return jsonify({"error": f"图片大小不能超过{max_mb}MB"}), 400 - - abs_dir, rel_dir = _get_upload_dir() - token = secrets.token_hex(6) - name = f"announcement_{int(time.time())}_{token}{ext}" - save_path = os.path.join(abs_dir, name) - file.save(save_path) - - static_root = os.path.join(current_app.root_path, "static") - rel_to_static = os.path.relpath(abs_dir, static_root) - if rel_to_static.startswith(".."): - rel_to_static = "announcements" - url_path = posixpath.join(rel_to_static.replace(os.sep, "/"), name) - return jsonify({"success": True, "url": url_for("serve_static", filename=url_path)}) - - -@admin_api_bp.route("/announcements", methods=["GET"]) -@admin_required -def admin_get_announcements(): - """获取公告列表""" - try: - limit = int(request.args.get("limit", 50)) - offset = int(request.args.get("offset", 0)) - except (TypeError, ValueError): - limit, offset = 50, 0 - - limit = max(1, min(200, limit)) - offset = max(0, offset) - return jsonify(database.get_announcements(limit=limit, offset=offset)) - - -@admin_api_bp.route("/announcements", methods=["POST"]) -@admin_required -def admin_create_announcement(): - """创建公告(默认启用并替换旧公告)""" - data = request.json or {} - title = (data.get("title") or "").strip() - content = (data.get("content") or "").strip() - image_url = (data.get("image_url") or "").strip() - is_active = bool(data.get("is_active", True)) - - if image_url and len(image_url) > 1000: - return jsonify({"error": "图片地址过长"}), 400 - - announcement_id = database.create_announcement(title, content, image_url=image_url, is_active=is_active) - if not announcement_id: - return jsonify({"error": "标题和内容不能为空"}), 400 - - return jsonify({"success": True, "id": announcement_id}) - - -@admin_api_bp.route("/announcements//activate", methods=["POST"]) -@admin_required -def admin_activate_announcement(announcement_id): - """启用公告(会自动停用其他公告)""" - if not database.get_announcement_by_id(announcement_id): - return jsonify({"error": "公告不存在"}), 404 - ok = database.set_announcement_active(announcement_id, True) - return jsonify({"success": ok}) - - -@admin_api_bp.route("/announcements//deactivate", methods=["POST"]) -@admin_required -def admin_deactivate_announcement(announcement_id): - """停用公告""" - if not database.get_announcement_by_id(announcement_id): - return jsonify({"error": "公告不存在"}), 404 - ok = database.set_announcement_active(announcement_id, False) - return jsonify({"success": ok}) - - -@admin_api_bp.route("/announcements/", methods=["DELETE"]) -@admin_required -def admin_delete_announcement(announcement_id): - """删除公告""" - if not database.get_announcement_by_id(announcement_id): - return jsonify({"error": "公告不存在"}), 404 - ok = database.delete_announcement(announcement_id) - return jsonify({"success": ok}) - - -# ==================== 用户管理/统计(管理员) ==================== - - -@admin_api_bp.route("/users", methods=["GET"]) -@admin_required -def get_all_users(): - """获取所有用户""" - users = database.get_all_users() - return jsonify(users) - - -@admin_api_bp.route("/users/pending", methods=["GET"]) -@admin_required -def get_pending_users(): - """获取待审核用户""" - users = database.get_pending_users() - return jsonify(users) - - -@admin_api_bp.route("/users//approve", methods=["POST"]) -@admin_required -def approve_user_route(user_id): - """审核通过用户""" - if database.approve_user(user_id): - return jsonify({"success": True}) - return jsonify({"error": "审核失败"}), 400 - - -@admin_api_bp.route("/users//reject", methods=["POST"]) -@admin_required -def reject_user_route(user_id): - """拒绝用户""" - if database.reject_user(user_id): - return jsonify({"success": True}) - return jsonify({"error": "拒绝失败"}), 400 - - -@admin_api_bp.route("/users/", methods=["DELETE"]) -@admin_required -def delete_user_route(user_id): - """删除用户""" - if database.delete_user(user_id): - safe_remove_user_accounts(user_id) - safe_clear_user_logs(user_id) - return jsonify({"success": True}) - return jsonify({"error": "删除失败"}), 400 - - -@admin_api_bp.route("/stats", methods=["GET"]) -@admin_required -def get_system_stats(): - """获取系统统计""" - stats = database.get_system_stats() - stats["admin_username"] = session.get("admin_username", "admin") - return jsonify(stats) - - -@admin_api_bp.route("/browser_pool/stats", methods=["GET"]) -@admin_required -def get_browser_pool_stats(): - """获取截图线程池状态""" - try: - from browser_pool_worker import get_browser_worker_pool - - pool = get_browser_worker_pool() - stats = pool.get_stats() or {} - - worker_details = [] - for w in stats.get("workers") or []: - last_ts = float(w.get("last_active_ts") or 0) - last_active_at = None - if last_ts > 0: - try: - last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S") - except Exception: - last_active_at = None - - created_ts = w.get("browser_created_at") - created_at = None - if created_ts: - try: - created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S") - except Exception: - created_at = None - - worker_details.append( - { - "worker_id": w.get("worker_id"), - "idle": bool(w.get("idle")), - "has_browser": bool(w.get("has_browser")), - "total_tasks": int(w.get("total_tasks") or 0), - "failed_tasks": int(w.get("failed_tasks") or 0), - "browser_use_count": int(w.get("browser_use_count") or 0), - "browser_created_at": created_at, - "browser_created_ts": created_ts, - "last_active_at": last_active_at, - "last_active_ts": last_ts, - "thread_alive": bool(w.get("thread_alive")), - } - ) - - total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0) - return jsonify( - { - "total_workers": total_workers, - "active_workers": int(stats.get("busy_workers") or 0), - "idle_workers": int(stats.get("idle_workers") or 0), - "queue_size": int(stats.get("queue_size") or 0), - "workers": worker_details, - "summary": { - "total_tasks": int(stats.get("total_tasks") or 0), - "failed_tasks": int(stats.get("failed_tasks") or 0), - "success_rate": stats.get("success_rate"), - }, - "server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"), - } - ) - except Exception as e: - logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}") - return jsonify({"error": "获取截图线程池状态失败"}), 500 - - -@admin_api_bp.route("/docker_stats", methods=["GET"]) -@admin_required -def get_docker_stats(): - """获取Docker容器运行状态""" - import subprocess - - docker_status = { - "running": False, - "container_name": "N/A", - "uptime": "N/A", - "memory_usage": "N/A", - "memory_limit": "N/A", - "memory_percent": "N/A", - "cpu_percent": "N/A", - "status": "Unknown", - } - - try: - if os.path.exists("/.dockerenv"): - docker_status["running"] = True - - try: - with open("/etc/hostname", "r") as f: - docker_status["container_name"] = f.read().strip() - except Exception as e: - logger.debug(f"读取容器名称失败: {e}") - - try: - if os.path.exists("/sys/fs/cgroup/memory.current"): - with open("/sys/fs/cgroup/memory.current", "r") as f: - mem_total = int(f.read().strip()) - - cache = 0 - if os.path.exists("/sys/fs/cgroup/memory.stat"): - with open("/sys/fs/cgroup/memory.stat", "r") as f: - for line in f: - if line.startswith("inactive_file "): - cache = int(line.split()[1]) - break - - mem_bytes = mem_total - cache - docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024) - - if os.path.exists("/sys/fs/cgroup/memory.max"): - with open("/sys/fs/cgroup/memory.max", "r") as f: - limit_str = f.read().strip() - if limit_str != "max": - limit_bytes = int(limit_str) - docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024) - docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100) - elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"): - with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f: - mem_bytes = int(f.read().strip()) - docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024) - - with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f: - limit_bytes = int(f.read().strip()) - if limit_bytes < 1e18: - docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024) - docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100) - except Exception as e: - logger.debug(f"读取内存信息失败: {e}") - - try: - if os.path.exists("/sys/fs/cgroup/cpu.stat"): - cpu_usage = 0 - with open("/sys/fs/cgroup/cpu.stat", "r") as f: - for line in f: - if line.startswith("usage_usec"): - cpu_usage = int(line.split()[1]) - break - - time.sleep(0.1) - cpu_usage2 = 0 - with open("/sys/fs/cgroup/cpu.stat", "r") as f: - for line in f: - if line.startswith("usage_usec"): - cpu_usage2 = int(line.split()[1]) - break - - cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e6 * 100 - docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent) - elif os.path.exists("/sys/fs/cgroup/cpu/cpuacct.usage"): - with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f: - cpu_usage = int(f.read().strip()) - - time.sleep(0.1) - with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f: - cpu_usage2 = int(f.read().strip()) - - cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e9 * 100 - docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent) - except Exception as e: - logger.debug(f"读取CPU信息失败: {e}") - - try: - # 读取系统运行时间 - with open('/proc/uptime', 'r') as f: - system_uptime = float(f.read().split()[0]) - - # 读取 PID 1 的启动时间 (jiffies) - with open('/proc/1/stat', 'r') as f: - stat = f.read().split() - starttime_jiffies = int(stat[21]) - - # 获取 CLK_TCK (通常是 100) - clk_tck = os.sysconf(os.sysconf_names['SC_CLK_TCK']) - - # 计算容器运行时长(秒) - container_uptime_seconds = system_uptime - (starttime_jiffies / clk_tck) - - # 格式化为可读字符串 - days = int(container_uptime_seconds // 86400) - hours = int((container_uptime_seconds % 86400) // 3600) - minutes = int((container_uptime_seconds % 3600) // 60) - - if days > 0: - docker_status["uptime"] = f"{days}天{hours}小时{minutes}分钟" - elif hours > 0: - docker_status["uptime"] = f"{hours}小时{minutes}分钟" - else: - docker_status["uptime"] = f"{minutes}分钟" - except Exception as e: - logger.debug(f"获取容器运行时间失败: {e}") - - docker_status["status"] = "Running" - - else: - docker_status["status"] = "Not in Docker" - except Exception as e: - docker_status["status"] = f"Error: {str(e)}" - - return jsonify(docker_status) - - -# ==================== VIP 管理(管理员) ==================== - - -@admin_api_bp.route("/vip/config", methods=["GET"]) -@admin_required -def get_vip_config_api(): - """获取VIP配置""" - config = database.get_vip_config() - return jsonify(config) - - -@admin_api_bp.route("/vip/config", methods=["POST"]) -@admin_required -def set_vip_config_api(): - """设置默认VIP天数""" - data = request.json or {} - days = data.get("default_vip_days", 0) - - if not isinstance(days, int) or days < 0: - return jsonify({"error": "VIP天数必须是非负整数"}), 400 - - database.set_default_vip_days(days) - return jsonify({"message": "VIP配置已更新", "default_vip_days": days}) - - -@admin_api_bp.route("/users//vip", methods=["POST"]) -@admin_required -def set_user_vip_api(user_id): - """设置用户VIP""" - data = request.json or {} - days = data.get("days", 30) - - valid_days = [7, 30, 365, 999999] - if days not in valid_days: - return jsonify({"error": "VIP天数必须是 7/30/365/999999 之一"}), 400 - - if database.set_user_vip(user_id, days): - vip_type = {7: "一周", 30: "一个月", 365: "一年", 999999: "永久"}[days] - return jsonify({"message": f"VIP设置成功: {vip_type}"}) - return jsonify({"error": "设置失败,用户不存在"}), 400 - - -@admin_api_bp.route("/users//vip", methods=["DELETE"]) -@admin_required -def remove_user_vip_api(user_id): - """移除用户VIP""" - if database.remove_user_vip(user_id): - return jsonify({"message": "VIP已移除"}) - return jsonify({"error": "移除失败"}), 400 - - -@admin_api_bp.route("/users//vip", methods=["GET"]) -@admin_required -def get_user_vip_info_api(user_id): - """获取用户VIP信息(管理员)""" - vip_info = database.get_user_vip_info(user_id) - return jsonify(vip_info) - - -# ==================== 系统配置 / 定时 / 代理(管理员) ==================== - - -@admin_api_bp.route("/system/config", methods=["GET"]) -@admin_required -def get_system_config_api(): - """获取系统配置""" - return jsonify(database.get_system_config()) - - -@admin_api_bp.route("/system/config", methods=["POST"]) -@admin_required -def update_system_config_api(): - """更新系统配置""" - data = request.json or {} - - max_concurrent = data.get("max_concurrent_global") - schedule_enabled = data.get("schedule_enabled") - schedule_time = data.get("schedule_time") - schedule_browse_type = data.get("schedule_browse_type") - schedule_weekdays = data.get("schedule_weekdays") - new_max_concurrent_per_account = data.get("max_concurrent_per_account") - new_max_screenshot_concurrent = data.get("max_screenshot_concurrent") - enable_screenshot = data.get("enable_screenshot") - auto_approve_enabled = data.get("auto_approve_enabled") - auto_approve_hourly_limit = data.get("auto_approve_hourly_limit") - auto_approve_vip_days = data.get("auto_approve_vip_days") - kdocs_enabled = data.get("kdocs_enabled") - kdocs_doc_url = data.get("kdocs_doc_url") - kdocs_default_unit = data.get("kdocs_default_unit") - kdocs_sheet_name = data.get("kdocs_sheet_name") - kdocs_sheet_index = data.get("kdocs_sheet_index") - kdocs_unit_column = data.get("kdocs_unit_column") - kdocs_image_column = data.get("kdocs_image_column") - kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled") - kdocs_admin_notify_email = data.get("kdocs_admin_notify_email") - kdocs_row_start = data.get("kdocs_row_start") - kdocs_row_end = data.get("kdocs_row_end") - - if max_concurrent is not None: - if not isinstance(max_concurrent, int) or max_concurrent < 1: - return jsonify({"error": "全局并发数必须大于0(建议:小型服务器2-5,中型5-10,大型10-20)"}), 400 - - if new_max_concurrent_per_account is not None: - if not isinstance(new_max_concurrent_per_account, int) or new_max_concurrent_per_account < 1: - return jsonify({"error": "单账号并发数必须大于0(建议设为1,避免同一用户任务相互影响)"}), 400 - - if new_max_screenshot_concurrent is not None: - if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1: - return jsonify({"error": "截图并发数必须大于0(建议根据服务器配置设置,wkhtmltoimage 资源占用较低)"}), 400 - - if enable_screenshot is not None: - if isinstance(enable_screenshot, bool): - enable_screenshot = 1 if enable_screenshot else 0 - if enable_screenshot not in (0, 1): - return jsonify({"error": "截图开关必须是0或1"}), 400 - - if schedule_time is not None: - import re - - if not re.match(r"^([01]\\d|2[0-3]):([0-5]\\d)$", schedule_time): - return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400 - - if schedule_browse_type is not None: - normalized = validate_browse_type(schedule_browse_type, default=BROWSE_TYPE_SHOULD_READ) - if not normalized: - return jsonify({"error": "浏览类型无效"}), 400 - schedule_browse_type = normalized - - if schedule_weekdays is not None: - try: - days = [int(d.strip()) for d in schedule_weekdays.split(",") if d.strip()] - if not all(1 <= d <= 7 for d in days): - return jsonify({"error": "星期数字必须在1-7之间"}), 400 - except (ValueError, AttributeError): - return jsonify({"error": "星期格式错误"}), 400 - - if auto_approve_hourly_limit is not None: - if not isinstance(auto_approve_hourly_limit, int) or auto_approve_hourly_limit < 1: - return jsonify({"error": "每小时注册限制必须大于0"}), 400 - - if auto_approve_vip_days is not None: - if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0: - return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400 - - if kdocs_enabled is not None: - if isinstance(kdocs_enabled, bool): - kdocs_enabled = 1 if kdocs_enabled else 0 - if kdocs_enabled not in (0, 1): - return jsonify({"error": "表格上传开关必须是0或1"}), 400 - - if kdocs_doc_url is not None: - kdocs_doc_url = str(kdocs_doc_url or "").strip() - if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url): - return jsonify({"error": "文档链接格式不正确"}), 400 - - if kdocs_default_unit is not None: - kdocs_default_unit = str(kdocs_default_unit or "").strip() - if len(kdocs_default_unit) > 50: - return jsonify({"error": "默认县区长度不能超过50"}), 400 - - if kdocs_sheet_name is not None: - kdocs_sheet_name = str(kdocs_sheet_name or "").strip() - if len(kdocs_sheet_name) > 50: - return jsonify({"error": "Sheet名称长度不能超过50"}), 400 - - if kdocs_sheet_index is not None: - try: - kdocs_sheet_index = int(kdocs_sheet_index) - except Exception: - return jsonify({"error": "Sheet序号必须是数字"}), 400 - if kdocs_sheet_index < 0: - return jsonify({"error": "Sheet序号不能为负数"}), 400 - - if kdocs_unit_column is not None: - kdocs_unit_column = str(kdocs_unit_column or "").strip().upper() - if not kdocs_unit_column: - return jsonify({"error": "县区列不能为空"}), 400 - import re - - if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column): - return jsonify({"error": "县区列格式错误"}), 400 - - if kdocs_image_column is not None: - kdocs_image_column = str(kdocs_image_column or "").strip().upper() - if not kdocs_image_column: - return jsonify({"error": "图片列不能为空"}), 400 - import re - - if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column): - return jsonify({"error": "图片列格式错误"}), 400 - - if kdocs_admin_notify_enabled is not None: - if isinstance(kdocs_admin_notify_enabled, bool): - kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0 - if kdocs_admin_notify_enabled not in (0, 1): - return jsonify({"error": "管理员通知开关必须是0或1"}), 400 - - if kdocs_admin_notify_email is not None: - kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip() - if kdocs_admin_notify_email: - is_valid, error_msg = validate_email(kdocs_admin_notify_email) - if not is_valid: - return jsonify({"error": error_msg}), 400 - - if kdocs_row_start is not None: - try: - kdocs_row_start = int(kdocs_row_start) - except (ValueError, TypeError): - return jsonify({"error": "起始行必须是数字"}), 400 - if kdocs_row_start < 0: - return jsonify({"error": "起始行不能为负数"}), 400 - - if kdocs_row_end is not None: - try: - kdocs_row_end = int(kdocs_row_end) - except (ValueError, TypeError): - return jsonify({"error": "结束行必须是数字"}), 400 - if kdocs_row_end < 0: - return jsonify({"error": "结束行不能为负数"}), 400 - - old_config = database.get_system_config() or {} - - if not database.update_system_config( - max_concurrent=max_concurrent, - schedule_enabled=schedule_enabled, - schedule_time=schedule_time, - schedule_browse_type=schedule_browse_type, - schedule_weekdays=schedule_weekdays, - max_concurrent_per_account=new_max_concurrent_per_account, - max_screenshot_concurrent=new_max_screenshot_concurrent, - enable_screenshot=enable_screenshot, - auto_approve_enabled=auto_approve_enabled, - auto_approve_hourly_limit=auto_approve_hourly_limit, - auto_approve_vip_days=auto_approve_vip_days, - kdocs_enabled=kdocs_enabled, - kdocs_doc_url=kdocs_doc_url, - kdocs_default_unit=kdocs_default_unit, - kdocs_sheet_name=kdocs_sheet_name, - kdocs_sheet_index=kdocs_sheet_index, - kdocs_unit_column=kdocs_unit_column, - kdocs_image_column=kdocs_image_column, - kdocs_admin_notify_enabled=kdocs_admin_notify_enabled, - kdocs_admin_notify_email=kdocs_admin_notify_email, - kdocs_row_start=kdocs_row_start, - kdocs_row_end=kdocs_row_end, - ): - return jsonify({"error": "更新失败"}), 400 - - try: - new_config = database.get_system_config() or {} - scheduler = get_task_scheduler() - scheduler.update_limits( - max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))), - max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))), - ) - if new_max_screenshot_concurrent is not None: - try: - from browser_pool_worker import resize_browser_worker_pool - - if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))): - logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}") - except Exception as pool_error: - logger.warning(f"截图线程池并发更新失败: {pool_error}") - except Exception: - pass - - if max_concurrent is not None and max_concurrent != old_config.get("max_concurrent_global"): - logger.info(f"全局并发数已更新为: {max_concurrent}") - if new_max_concurrent_per_account is not None and new_max_concurrent_per_account != old_config.get("max_concurrent_per_account"): - logger.info(f"单用户并发数已更新为: {new_max_concurrent_per_account}") - if new_max_screenshot_concurrent is not None: - logger.info(f"截图并发数已更新为: {new_max_screenshot_concurrent}") - - return jsonify({"message": "系统配置已更新"}) - - -@admin_api_bp.route("/kdocs/status", methods=["GET"]) -@admin_required -def get_kdocs_status_api(): - """获取金山文档上传状态""" - try: - from services.kdocs_uploader import get_kdocs_uploader - - uploader = get_kdocs_uploader() - status = uploader.get_status() - live = str(request.args.get("live", "")).lower() in ("1", "true", "yes") - if live: - live_status = uploader.refresh_login_status() - if live_status.get("success"): - logged_in = bool(live_status.get("logged_in")) - status["logged_in"] = logged_in - status["last_login_ok"] = logged_in - status["login_required"] = not logged_in - if live_status.get("error"): - status["last_error"] = live_status.get("error") - else: - status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None - if status.get("last_login_ok") is True and status.get("last_error") == "操作超时": - status["last_error"] = None - return jsonify(status) - except Exception as e: - return jsonify({"error": f"获取状态失败: {e}"}), 500 - - -@admin_api_bp.route("/kdocs/qr", methods=["POST"]) -@admin_required -def get_kdocs_qr_api(): - """获取金山文档登录二维码""" - try: - from services.kdocs_uploader import get_kdocs_uploader - - uploader = get_kdocs_uploader() - data = request.get_json(silent=True) or {} - force = bool(data.get("force")) - if not force: - force = str(request.args.get("force", "")).lower() in ("1", "true", "yes") - result = uploader.request_qr(force=force) - if not result.get("success"): - return jsonify({"error": result.get("error", "获取二维码失败")}), 400 - return jsonify(result) - except Exception as e: - return jsonify({"error": f"获取二维码失败: {e}"}), 500 - - -@admin_api_bp.route("/kdocs/clear-login", methods=["POST"]) -@admin_required -def clear_kdocs_login_api(): - """清除金山文档登录态""" - try: - from services.kdocs_uploader import get_kdocs_uploader - - uploader = get_kdocs_uploader() - result = uploader.clear_login() - if not result.get("success"): - return jsonify({"error": result.get("error", "清除失败")}), 400 - return jsonify({"success": True}) - except Exception as e: - return jsonify({"error": f"清除失败: {e}"}), 500 - - -@admin_api_bp.route("/schedule/execute", methods=["POST"]) -@admin_required -def execute_schedule_now(): - """立即执行定时任务(无视定时时间和星期限制)""" - try: - threading.Thread(target=run_scheduled_task, args=(True,), daemon=True).start() - logger.info("[立即执行定时任务] 管理员手动触发定时任务执行(跳过星期检查)") - return jsonify({"message": "定时任务已开始执行,请查看任务列表获取进度"}) - except Exception as e: - logger.error(f"[立即执行定时任务] 启动失败: {str(e)}") - return jsonify({"error": f"启动失败: {str(e)}"}), 500 - - -@admin_api_bp.route("/proxy/config", methods=["GET"]) -@admin_required -def get_proxy_config_api(): - """获取代理配置""" - config_data = database.get_system_config() - return jsonify( - { - "proxy_enabled": config_data.get("proxy_enabled", 0), - "proxy_api_url": config_data.get("proxy_api_url", ""), - "proxy_expire_minutes": config_data.get("proxy_expire_minutes", 3), - } - ) - - -@admin_api_bp.route("/proxy/config", methods=["POST"]) -@admin_required -def update_proxy_config_api(): - """更新代理配置""" - data = request.json or {} - proxy_enabled = data.get("proxy_enabled") - proxy_api_url = (data.get("proxy_api_url", "") or "").strip() - proxy_expire_minutes = data.get("proxy_expire_minutes") - - if proxy_enabled is not None and proxy_enabled not in [0, 1]: - return jsonify({"error": "proxy_enabled必须是0或1"}), 400 - - if proxy_expire_minutes is not None: - if not isinstance(proxy_expire_minutes, int) or proxy_expire_minutes < 1: - return jsonify({"error": "代理有效期必须是大于0的整数"}), 400 - - if database.update_system_config( - proxy_enabled=proxy_enabled, - proxy_api_url=proxy_api_url, - proxy_expire_minutes=proxy_expire_minutes, - ): - return jsonify({"message": "代理配置已更新"}) - return jsonify({"error": "更新失败"}), 400 - - -@admin_api_bp.route("/proxy/test", methods=["POST"]) -@admin_required -def test_proxy_api(): - """测试代理连接""" - data = request.json or {} - api_url = (data.get("api_url") or "").strip() - - if not api_url: - return jsonify({"error": "请提供API地址"}), 400 - - if not is_safe_outbound_url(api_url): - return jsonify({"error": "API地址不可用或不安全"}), 400 - - try: - response = requests.get(api_url, timeout=10) - if response.status_code == 200: - ip_port = response.text.strip() - if ip_port and ":" in ip_port: - return jsonify({"success": True, "proxy": ip_port, "message": f"代理获取成功: {ip_port}"}) - return jsonify({"success": False, "message": f"代理格式错误: {ip_port}"}), 400 - return jsonify({"success": False, "message": f"HTTP错误: {response.status_code}"}), 400 - except Exception as e: - return jsonify({"success": False, "message": f"连接失败: {str(e)}"}), 500 - - -@admin_api_bp.route("/server/info", methods=["GET"]) -@admin_required -def get_server_info_api(): - """获取服务器信息""" - import psutil - - cpu_percent = _get_server_cpu_percent() - - memory = psutil.virtual_memory() - memory_total = f"{memory.total / (1024**3):.1f}GB" - memory_used = f"{memory.used / (1024**3):.1f}GB" - memory_percent = memory.percent - - disk = psutil.disk_usage("/") - disk_total = f"{disk.total / (1024**3):.1f}GB" - disk_used = f"{disk.used / (1024**3):.1f}GB" - disk_percent = disk.percent - - boot_time = datetime.fromtimestamp(psutil.boot_time(), tz=BEIJING_TZ) - uptime_delta = get_beijing_now() - boot_time - days = uptime_delta.days - hours = uptime_delta.seconds // 3600 - uptime = f"{days}天{hours}小时" - - return jsonify( - { - "cpu_percent": cpu_percent, - "memory_total": memory_total, - "memory_used": memory_used, - "memory_percent": memory_percent, - "disk_total": disk_total, - "disk_used": disk_used, - "disk_percent": disk_percent, - "uptime": uptime, - } - ) - - -# ==================== 任务统计与日志(管理员) ==================== - - -@admin_api_bp.route("/task/stats", methods=["GET"]) -@admin_required -def get_task_stats_api(): - """获取任务统计数据""" - date_filter = request.args.get("date") - stats = database.get_task_stats(date_filter) - return jsonify(stats) - - -@admin_api_bp.route("/task/running", methods=["GET"]) -@admin_required -def get_running_tasks_api(): - """获取当前运行中和排队中的任务""" - import time as time_mod - - current_time = time_mod.time() - running = [] - queuing = [] - - for account_id, info in safe_iter_task_status_items(): - elapsed = int(current_time - info.get("start_time", current_time)) - - user = database.get_user_by_id(info.get("user_id")) - user_username = user["username"] if user else "N/A" - - progress = info.get("progress", {"items": 0, "attachments": 0}) - task_info = { - "account_id": account_id, - "user_id": info.get("user_id"), - "user_username": user_username, - "username": info.get("username"), - "browse_type": info.get("browse_type"), - "source": info.get("source", "manual"), - "detail_status": info.get("detail_status", "未知"), - "progress_items": progress.get("items", 0), - "progress_attachments": progress.get("attachments", 0), - "elapsed_seconds": elapsed, - "elapsed_display": f"{elapsed // 60}分{elapsed % 60}秒" if elapsed >= 60 else f"{elapsed}秒", - } - - if info.get("status") == "运行中": - running.append(task_info) - else: - queuing.append(task_info) - - running.sort(key=lambda x: x["elapsed_seconds"], reverse=True) - queuing.sort(key=lambda x: x["elapsed_seconds"], reverse=True) - - try: - max_concurrent = int(get_task_scheduler().max_global) - except Exception: - max_concurrent = int((database.get_system_config() or {}).get("max_concurrent_global", 2)) - - return jsonify( - { - "running": running, - "queuing": queuing, - "running_count": len(running), - "queuing_count": len(queuing), - "max_concurrent": max_concurrent, - } - ) - - -@admin_api_bp.route("/task/logs", methods=["GET"]) -@admin_required -def get_task_logs_api(): - """获取任务日志列表(支持分页和多种筛选)""" - try: - limit = int(request.args.get("limit", 20)) - limit = max(1, min(limit, 200)) # 限制 1-200 条 - except (ValueError, TypeError): - limit = 20 - - try: - offset = int(request.args.get("offset", 0)) - offset = max(0, offset) - except (ValueError, TypeError): - offset = 0 - - date_filter = request.args.get("date") - status_filter = request.args.get("status") - source_filter = request.args.get("source") - user_id_filter = request.args.get("user_id") - account_filter = (request.args.get("account") or "").strip() - - if user_id_filter: - try: - user_id_filter = int(user_id_filter) - except (ValueError, TypeError): - user_id_filter = None - - try: - result = database.get_task_logs( - limit=limit, - offset=offset, - date_filter=date_filter, - status_filter=status_filter, - source_filter=source_filter, - user_id_filter=user_id_filter, - account_filter=account_filter if account_filter else None, - ) - return jsonify(result) - except Exception as e: - logger.error(f"获取任务日志失败: {e}") - return jsonify({"logs": [], "total": 0, "error": "查询失败"}) - - -@admin_api_bp.route("/task/logs/clear", methods=["POST"]) -@admin_required -def clear_old_task_logs_api(): - """清理旧的任务日志""" - data = request.json or {} - days = data.get("days", 30) - - if not isinstance(days, int) or days < 1: - return jsonify({"error": "天数必须是大于0的整数"}), 400 - - deleted_count = database.delete_old_task_logs(days) - return jsonify({"message": f"已删除{days}天前的{deleted_count}条日志"}) - - @admin_api_bp.route("/docker/restart", methods=["POST"]) @admin_required def restart_docker_container(): @@ -1216,7 +204,7 @@ os._exit(0) f.write(restart_script) subprocess.Popen( - ["python", "/tmp/restart_container.py"], + ["python3", "/tmp/restart_container.py"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True, @@ -1228,112 +216,6 @@ os._exit(0) return jsonify({"error": f"重启失败: {str(e)}"}), 500 -# ==================== 密码重置 / 反馈(管理员) ==================== - - -@admin_api_bp.route("/admin/password", methods=["PUT"]) -@admin_required -def update_admin_password(): - """修改管理员密码""" - data = request.json or {} - new_password = (data.get("new_password") or "").strip() - - if not new_password: - return jsonify({"error": "密码不能为空"}), 400 - - username = session.get("admin_username") - if database.update_admin_password(username, new_password): - return jsonify({"success": True}) - return jsonify({"error": "修改失败"}), 400 - - -@admin_api_bp.route("/admin/username", methods=["PUT"]) -@admin_required -def update_admin_username(): - """修改管理员用户名""" - data = request.json or {} - new_username = (data.get("new_username") or "").strip() - - if not new_username: - return jsonify({"error": "用户名不能为空"}), 400 - - old_username = session.get("admin_username") - if database.update_admin_username(old_username, new_username): - session["admin_username"] = new_username - return jsonify({"success": True}) - return jsonify({"error": "修改失败,用户名可能已存在"}), 400 - - -@admin_api_bp.route("/users//reset_password", methods=["POST"]) -@admin_required -def admin_reset_password_route(user_id): - """管理员直接重置用户密码(无需审核)""" - data = request.json or {} - new_password = (data.get("new_password") or "").strip() - - if not new_password: - return jsonify({"error": "新密码不能为空"}), 400 - - is_valid, error_msg = validate_password(new_password) - if not is_valid: - return jsonify({"error": error_msg}), 400 - - if database.admin_reset_user_password(user_id, new_password): - return jsonify({"message": "密码重置成功"}) - return jsonify({"error": "重置失败,用户不存在"}), 400 - - -@admin_api_bp.route("/feedbacks", methods=["GET"]) -@admin_required -def get_all_feedbacks(): - """管理员获取所有反馈""" - status = request.args.get("status") - try: - limit = int(request.args.get("limit", 100)) - offset = int(request.args.get("offset", 0)) - limit = min(max(1, limit), 1000) - offset = max(0, offset) - except (ValueError, TypeError): - return jsonify({"error": "无效的分页参数"}), 400 - - feedbacks = database.get_bug_feedbacks(limit=limit, offset=offset, status_filter=status) - stats = database.get_feedback_stats() - return jsonify({"feedbacks": feedbacks, "stats": stats}) - - -@admin_api_bp.route("/feedbacks//reply", methods=["POST"]) -@admin_required -def reply_to_feedback(feedback_id): - """管理员回复反馈""" - data = request.get_json() or {} - reply = (data.get("reply") or "").strip() - - if not reply: - return jsonify({"error": "回复内容不能为空"}), 400 - - if database.reply_feedback(feedback_id, reply): - return jsonify({"message": "回复成功"}) - return jsonify({"error": "反馈不存在"}), 404 - - -@admin_api_bp.route("/feedbacks//close", methods=["POST"]) -@admin_required -def close_feedback_api(feedback_id): - """管理员关闭反馈""" - if database.close_feedback(feedback_id): - return jsonify({"message": "已关闭"}) - return jsonify({"error": "反馈不存在"}), 404 - - -@admin_api_bp.route("/feedbacks/", methods=["DELETE"]) -@admin_required -def delete_feedback_api(feedback_id): - """管理员删除反馈""" - if database.delete_feedback(feedback_id): - return jsonify({"message": "已删除"}) - return jsonify({"error": "反馈不存在"}), 404 - - # ==================== 断点续传(管理员) ==================== @@ -1388,208 +270,3 @@ def checkpoint_abandon(task_id): return jsonify({"success": False}), 404 except Exception as e: return jsonify({"success": False, "message": str(e)}), 500 - - -# ==================== 邮件服务(管理员) ==================== - - -@admin_api_bp.route("/email/settings", methods=["GET"]) -@admin_required -def get_email_settings_api(): - """获取全局邮件设置""" - try: - settings = email_service.get_email_settings() - return jsonify(settings) - except Exception as e: - logger.error(f"获取邮件设置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/email/settings", methods=["POST"]) -@admin_required -def update_email_settings_api(): - """更新全局邮件设置""" - try: - data = request.json or {} - enabled = data.get("enabled", False) - failover_enabled = data.get("failover_enabled", True) - register_verify_enabled = data.get("register_verify_enabled") - login_alert_enabled = data.get("login_alert_enabled") - base_url = data.get("base_url") - task_notify_enabled = data.get("task_notify_enabled") - - email_service.update_email_settings( - enabled=enabled, - failover_enabled=failover_enabled, - register_verify_enabled=register_verify_enabled, - login_alert_enabled=login_alert_enabled, - base_url=base_url, - task_notify_enabled=task_notify_enabled, - ) - return jsonify({"success": True}) - except Exception as e: - logger.error(f"更新邮件设置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/smtp/configs", methods=["GET"]) -@admin_required -def get_smtp_configs_api(): - """获取所有SMTP配置列表""" - try: - configs = email_service.get_smtp_configs(include_password=False) - return jsonify(configs) - except Exception as e: - logger.error(f"获取SMTP配置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/smtp/configs", methods=["POST"]) -@admin_required -def create_smtp_config_api(): - """创建SMTP配置""" - try: - data = request.json or {} - if not data.get("host"): - return jsonify({"error": "SMTP服务器地址不能为空"}), 400 - if not data.get("username"): - return jsonify({"error": "SMTP用户名不能为空"}), 400 - - config_id = email_service.create_smtp_config(data) - return jsonify({"success": True, "id": config_id}) - except Exception as e: - logger.error(f"创建SMTP配置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/smtp/configs/", methods=["GET"]) -@admin_required -def get_smtp_config_api(config_id): - """获取单个SMTP配置详情""" - try: - config_data = email_service.get_smtp_config(config_id, include_password=False) - if not config_data: - return jsonify({"error": "配置不存在"}), 404 - return jsonify(config_data) - except Exception as e: - logger.error(f"获取SMTP配置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/smtp/configs/", methods=["PUT"]) -@admin_required -def update_smtp_config_api(config_id): - """更新SMTP配置""" - try: - data = request.json or {} - if email_service.update_smtp_config(config_id, data): - return jsonify({"success": True}) - return jsonify({"error": "更新失败"}), 400 - except Exception as e: - logger.error(f"更新SMTP配置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/smtp/configs/", methods=["DELETE"]) -@admin_required -def delete_smtp_config_api(config_id): - """删除SMTP配置""" - try: - if email_service.delete_smtp_config(config_id): - return jsonify({"success": True}) - return jsonify({"error": "删除失败"}), 400 - except Exception as e: - logger.error(f"删除SMTP配置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/smtp/configs//test", methods=["POST"]) -@admin_required -def test_smtp_config_api(config_id): - """测试SMTP配置""" - try: - data = request.json or {} - test_email = str(data.get("email", "") or "").strip() - if not test_email: - return jsonify({"error": "请提供测试邮箱"}), 400 - - is_valid, error_msg = validate_email(test_email) - if not is_valid: - return jsonify({"error": error_msg}), 400 - - result = email_service.test_smtp_config(config_id, test_email) - return jsonify(result) - except Exception as e: - logger.error(f"测试SMTP配置失败: {e}") - return jsonify({"success": False, "error": str(e)}), 500 - - -@admin_api_bp.route("/smtp/configs//primary", methods=["POST"]) -@admin_required -def set_primary_smtp_config_api(config_id): - """设置主SMTP配置""" - try: - if email_service.set_primary_smtp_config(config_id): - return jsonify({"success": True}) - return jsonify({"error": "设置失败"}), 400 - except Exception as e: - logger.error(f"设置主SMTP配置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"]) -@admin_required -def clear_primary_smtp_config_api(): - """取消主SMTP配置""" - try: - email_service.clear_primary_smtp_config() - return jsonify({"success": True}) - except Exception as e: - logger.error(f"取消主SMTP配置失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/email/stats", methods=["GET"]) -@admin_required -def get_email_stats_api(): - """获取邮件发送统计""" - try: - stats = email_service.get_email_stats() - return jsonify(stats) - except Exception as e: - logger.error(f"获取邮件统计失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/email/logs", methods=["GET"]) -@admin_required -def get_email_logs_api(): - """获取邮件发送日志""" - try: - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) - email_type = request.args.get("type", None) - status = request.args.get("status", None) - - page_size = min(max(page_size, 10), 100) - result = email_service.get_email_logs(page, page_size, email_type, status) - return jsonify(result) - except Exception as e: - logger.error(f"获取邮件日志失败: {e}") - return jsonify({"error": str(e)}), 500 - - -@admin_api_bp.route("/email/logs/cleanup", methods=["POST"]) -@admin_required -def cleanup_email_logs_api(): - """清理过期邮件日志""" - try: - data = request.json or {} - days = data.get("days", 30) - days = min(max(days, 7), 365) - - deleted = email_service.cleanup_email_logs(days) - return jsonify({"success": True, "deleted": deleted}) - except Exception as e: - logger.error(f"清理邮件日志失败: {e}") - return jsonify({"error": str(e)}), 500 diff --git a/routes/admin_api/email_api.py b/routes/admin_api/email_api.py new file mode 100644 index 0000000..d4aaa5f --- /dev/null +++ b/routes/admin_api/email_api.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import email_service +from app_logger import get_logger +from app_security import validate_email +from flask import jsonify, request +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required + +logger = get_logger("app") + + +@admin_api_bp.route("/email/settings", methods=["GET"]) +@admin_required +def get_email_settings_api(): + """获取全局邮件设置""" + try: + settings = email_service.get_email_settings() + return jsonify(settings) + except Exception as e: + logger.error(f"获取邮件设置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/email/settings", methods=["POST"]) +@admin_required +def update_email_settings_api(): + """更新全局邮件设置""" + try: + data = request.json or {} + enabled = data.get("enabled", False) + failover_enabled = data.get("failover_enabled", True) + register_verify_enabled = data.get("register_verify_enabled") + login_alert_enabled = data.get("login_alert_enabled") + base_url = data.get("base_url") + task_notify_enabled = data.get("task_notify_enabled") + + email_service.update_email_settings( + enabled=enabled, + failover_enabled=failover_enabled, + register_verify_enabled=register_verify_enabled, + login_alert_enabled=login_alert_enabled, + base_url=base_url, + task_notify_enabled=task_notify_enabled, + ) + return jsonify({"success": True}) + except Exception as e: + logger.error(f"更新邮件设置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/smtp/configs", methods=["GET"]) +@admin_required +def get_smtp_configs_api(): + """获取所有SMTP配置列表""" + try: + configs = email_service.get_smtp_configs(include_password=False) + return jsonify(configs) + except Exception as e: + logger.error(f"获取SMTP配置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/smtp/configs", methods=["POST"]) +@admin_required +def create_smtp_config_api(): + """创建SMTP配置""" + try: + data = request.json or {} + if not data.get("host"): + return jsonify({"error": "SMTP服务器地址不能为空"}), 400 + if not data.get("username"): + return jsonify({"error": "SMTP用户名不能为空"}), 400 + + config_id = email_service.create_smtp_config(data) + return jsonify({"success": True, "id": config_id}) + except Exception as e: + logger.error(f"创建SMTP配置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/smtp/configs/", methods=["GET"]) +@admin_required +def get_smtp_config_api(config_id): + """获取单个SMTP配置详情""" + try: + config_data = email_service.get_smtp_config(config_id, include_password=False) + if not config_data: + return jsonify({"error": "配置不存在"}), 404 + return jsonify(config_data) + except Exception as e: + logger.error(f"获取SMTP配置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/smtp/configs/", methods=["PUT"]) +@admin_required +def update_smtp_config_api(config_id): + """更新SMTP配置""" + try: + data = request.json or {} + if email_service.update_smtp_config(config_id, data): + return jsonify({"success": True}) + return jsonify({"error": "更新失败"}), 400 + except Exception as e: + logger.error(f"更新SMTP配置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/smtp/configs/", methods=["DELETE"]) +@admin_required +def delete_smtp_config_api(config_id): + """删除SMTP配置""" + try: + if email_service.delete_smtp_config(config_id): + return jsonify({"success": True}) + return jsonify({"error": "删除失败"}), 400 + except Exception as e: + logger.error(f"删除SMTP配置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/smtp/configs//test", methods=["POST"]) +@admin_required +def test_smtp_config_api(config_id): + """测试SMTP配置""" + try: + data = request.json or {} + test_email = str(data.get("email", "") or "").strip() + if not test_email: + return jsonify({"error": "请提供测试邮箱"}), 400 + + is_valid, error_msg = validate_email(test_email) + if not is_valid: + return jsonify({"error": error_msg}), 400 + + result = email_service.test_smtp_config(config_id, test_email) + return jsonify(result) + except Exception as e: + logger.error(f"测试SMTP配置失败: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +@admin_api_bp.route("/smtp/configs//primary", methods=["POST"]) +@admin_required +def set_primary_smtp_config_api(config_id): + """设置主SMTP配置""" + try: + if email_service.set_primary_smtp_config(config_id): + return jsonify({"success": True}) + return jsonify({"error": "设置失败"}), 400 + except Exception as e: + logger.error(f"设置主SMTP配置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/smtp/configs/primary/clear", methods=["POST"]) +@admin_required +def clear_primary_smtp_config_api(): + """取消主SMTP配置""" + try: + email_service.clear_primary_smtp_config() + return jsonify({"success": True}) + except Exception as e: + logger.error(f"取消主SMTP配置失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/email/stats", methods=["GET"]) +@admin_required +def get_email_stats_api(): + """获取邮件发送统计""" + try: + stats = email_service.get_email_stats() + return jsonify(stats) + except Exception as e: + logger.error(f"获取邮件统计失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/email/logs", methods=["GET"]) +@admin_required +def get_email_logs_api(): + """获取邮件发送日志""" + try: + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 20, type=int) + email_type = request.args.get("type", None) + status = request.args.get("status", None) + + page_size = min(max(page_size, 10), 100) + result = email_service.get_email_logs(page, page_size, email_type, status) + return jsonify(result) + except Exception as e: + logger.error(f"获取邮件日志失败: {e}") + return jsonify({"error": str(e)}), 500 + + +@admin_api_bp.route("/email/logs/cleanup", methods=["POST"]) +@admin_required +def cleanup_email_logs_api(): + """清理过期邮件日志""" + try: + data = request.json or {} + days = data.get("days", 30) + days = min(max(days, 7), 365) + + deleted = email_service.cleanup_email_logs(days) + return jsonify({"success": True, "deleted": deleted}) + except Exception as e: + logger.error(f"清理邮件日志失败: {e}") + return jsonify({"error": str(e)}), 500 diff --git a/routes/admin_api/feedback_api.py b/routes/admin_api/feedback_api.py new file mode 100644 index 0000000..e921563 --- /dev/null +++ b/routes/admin_api/feedback_api.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import database +from flask import jsonify, request +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required + +@admin_api_bp.route("/feedbacks", methods=["GET"]) +@admin_required +def get_all_feedbacks(): + """管理员获取所有反馈""" + status = request.args.get("status") + try: + limit = int(request.args.get("limit", 100)) + offset = int(request.args.get("offset", 0)) + limit = min(max(1, limit), 1000) + offset = max(0, offset) + except (ValueError, TypeError): + return jsonify({"error": "无效的分页参数"}), 400 + + feedbacks = database.get_bug_feedbacks(limit=limit, offset=offset, status_filter=status) + stats = database.get_feedback_stats() + return jsonify({"feedbacks": feedbacks, "stats": stats}) + + +@admin_api_bp.route("/feedbacks//reply", methods=["POST"]) +@admin_required +def reply_to_feedback(feedback_id): + """管理员回复反馈""" + data = request.get_json() or {} + reply = (data.get("reply") or "").strip() + + if not reply: + return jsonify({"error": "回复内容不能为空"}), 400 + + if database.reply_feedback(feedback_id, reply): + return jsonify({"message": "回复成功"}) + return jsonify({"error": "反馈不存在"}), 404 + + +@admin_api_bp.route("/feedbacks//close", methods=["POST"]) +@admin_required +def close_feedback_api(feedback_id): + """管理员关闭反馈""" + if database.close_feedback(feedback_id): + return jsonify({"message": "已关闭"}) + return jsonify({"error": "反馈不存在"}), 404 + + +@admin_api_bp.route("/feedbacks/", methods=["DELETE"]) +@admin_required +def delete_feedback_api(feedback_id): + """管理员删除反馈""" + if database.delete_feedback(feedback_id): + return jsonify({"message": "已删除"}) + return jsonify({"error": "反馈不存在"}), 404 diff --git a/routes/admin_api/infra_api.py b/routes/admin_api/infra_api.py new file mode 100644 index 0000000..fd8371b --- /dev/null +++ b/routes/admin_api/infra_api.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import os +import time +from datetime import datetime + +import database +from app_logger import get_logger +from flask import jsonify, session +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required +from services.time_utils import BEIJING_TZ, get_beijing_now + +logger = get_logger("app") + + +@admin_api_bp.route("/stats", methods=["GET"]) +@admin_required +def get_system_stats(): + """获取系统统计""" + stats = database.get_system_stats() + stats["admin_username"] = session.get("admin_username", "admin") + return jsonify(stats) + + +@admin_api_bp.route("/browser_pool/stats", methods=["GET"]) +@admin_required +def get_browser_pool_stats(): + """获取截图线程池状态""" + try: + from browser_pool_worker import get_browser_worker_pool + + pool = get_browser_worker_pool() + stats = pool.get_stats() or {} + + worker_details = [] + for w in stats.get("workers") or []: + last_ts = float(w.get("last_active_ts") or 0) + last_active_at = None + if last_ts > 0: + try: + last_active_at = datetime.fromtimestamp(last_ts, tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S") + except Exception: + last_active_at = None + + created_ts = w.get("browser_created_at") + created_at = None + if created_ts: + try: + created_at = datetime.fromtimestamp(float(created_ts), tz=BEIJING_TZ).strftime("%Y-%m-%d %H:%M:%S") + except Exception: + created_at = None + + worker_details.append( + { + "worker_id": w.get("worker_id"), + "idle": bool(w.get("idle")), + "has_browser": bool(w.get("has_browser")), + "total_tasks": int(w.get("total_tasks") or 0), + "failed_tasks": int(w.get("failed_tasks") or 0), + "browser_use_count": int(w.get("browser_use_count") or 0), + "browser_created_at": created_at, + "browser_created_ts": created_ts, + "last_active_at": last_active_at, + "last_active_ts": last_ts, + "thread_alive": bool(w.get("thread_alive")), + } + ) + + total_workers = len(worker_details) if worker_details else int(stats.get("pool_size") or 0) + return jsonify( + { + "total_workers": total_workers, + "active_workers": int(stats.get("busy_workers") or 0), + "idle_workers": int(stats.get("idle_workers") or 0), + "queue_size": int(stats.get("queue_size") or 0), + "workers": worker_details, + "summary": { + "total_tasks": int(stats.get("total_tasks") or 0), + "failed_tasks": int(stats.get("failed_tasks") or 0), + "success_rate": stats.get("success_rate"), + }, + "server_time_cst": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"), + } + ) + except Exception as e: + logger.exception(f"[AdminAPI] 获取截图线程池状态失败: {e}") + return jsonify({"error": "获取截图线程池状态失败"}), 500 + + +@admin_api_bp.route("/docker_stats", methods=["GET"]) +@admin_required +def get_docker_stats(): + """获取Docker容器运行状态""" + import subprocess + + docker_status = { + "running": False, + "container_name": "N/A", + "uptime": "N/A", + "memory_usage": "N/A", + "memory_limit": "N/A", + "memory_percent": "N/A", + "cpu_percent": "N/A", + "status": "Unknown", + } + + try: + if os.path.exists("/.dockerenv"): + docker_status["running"] = True + + try: + with open("/etc/hostname", "r") as f: + docker_status["container_name"] = f.read().strip() + except Exception as e: + logger.debug(f"读取容器名称失败: {e}") + + try: + if os.path.exists("/sys/fs/cgroup/memory.current"): + with open("/sys/fs/cgroup/memory.current", "r") as f: + mem_total = int(f.read().strip()) + + cache = 0 + if os.path.exists("/sys/fs/cgroup/memory.stat"): + with open("/sys/fs/cgroup/memory.stat", "r") as f: + for line in f: + if line.startswith("inactive_file "): + cache = int(line.split()[1]) + break + + mem_bytes = mem_total - cache + docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024) + + if os.path.exists("/sys/fs/cgroup/memory.max"): + with open("/sys/fs/cgroup/memory.max", "r") as f: + limit_str = f.read().strip() + if limit_str != "max": + limit_bytes = int(limit_str) + docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024) + docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100) + elif os.path.exists("/sys/fs/cgroup/memory/memory.usage_in_bytes"): + with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f: + mem_bytes = int(f.read().strip()) + docker_status["memory_usage"] = "{:.2f} MB".format(mem_bytes / 1024 / 1024) + + with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f: + limit_bytes = int(f.read().strip()) + if limit_bytes < 1e18: + docker_status["memory_limit"] = "{:.2f} GB".format(limit_bytes / 1024 / 1024 / 1024) + docker_status["memory_percent"] = "{:.2f}%".format(mem_bytes / limit_bytes * 100) + except Exception as e: + logger.debug(f"读取内存信息失败: {e}") + + try: + if os.path.exists("/sys/fs/cgroup/cpu.stat"): + cpu_usage = 0 + with open("/sys/fs/cgroup/cpu.stat", "r") as f: + for line in f: + if line.startswith("usage_usec"): + cpu_usage = int(line.split()[1]) + break + + time.sleep(0.1) + cpu_usage2 = 0 + with open("/sys/fs/cgroup/cpu.stat", "r") as f: + for line in f: + if line.startswith("usage_usec"): + cpu_usage2 = int(line.split()[1]) + break + + cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e6 * 100 + docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent) + elif os.path.exists("/sys/fs/cgroup/cpu/cpuacct.usage"): + with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f: + cpu_usage = int(f.read().strip()) + + time.sleep(0.1) + with open("/sys/fs/cgroup/cpu/cpuacct.usage", "r") as f: + cpu_usage2 = int(f.read().strip()) + + cpu_percent = (cpu_usage2 - cpu_usage) / 0.1 / 1e9 * 100 + docker_status["cpu_percent"] = "{:.2f}%".format(cpu_percent) + except Exception as e: + logger.debug(f"读取CPU信息失败: {e}") + + try: + # 读取系统运行时间 + with open('/proc/uptime', 'r') as f: + system_uptime = float(f.read().split()[0]) + + # 读取 PID 1 的启动时间 (jiffies) + with open('/proc/1/stat', 'r') as f: + stat = f.read().split() + starttime_jiffies = int(stat[21]) + + # 获取 CLK_TCK (通常是 100) + clk_tck = os.sysconf(os.sysconf_names['SC_CLK_TCK']) + + # 计算容器运行时长(秒) + container_uptime_seconds = system_uptime - (starttime_jiffies / clk_tck) + + # 格式化为可读字符串 + days = int(container_uptime_seconds // 86400) + hours = int((container_uptime_seconds % 86400) // 3600) + minutes = int((container_uptime_seconds % 3600) // 60) + + if days > 0: + docker_status["uptime"] = f"{days}天{hours}小时{minutes}分钟" + elif hours > 0: + docker_status["uptime"] = f"{hours}小时{minutes}分钟" + else: + docker_status["uptime"] = f"{minutes}分钟" + except Exception as e: + logger.debug(f"获取容器运行时间失败: {e}") + + docker_status["status"] = "Running" + + else: + docker_status["status"] = "Not in Docker" + except Exception as e: + docker_status["status"] = f"Error: {str(e)}" + + return jsonify(docker_status) + diff --git a/routes/admin_api/operations_api.py b/routes/admin_api/operations_api.py new file mode 100644 index 0000000..7f35d8c --- /dev/null +++ b/routes/admin_api/operations_api.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import threading +import time +from datetime import datetime + +import database +import requests +from app_logger import get_logger +from app_security import is_safe_outbound_url +from flask import jsonify, request +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required +from services.scheduler import run_scheduled_task +from services.time_utils import BEIJING_TZ, get_beijing_now + +logger = get_logger("app") + +_server_cpu_percent_lock = threading.Lock() +_server_cpu_percent_last: float | None = None +_server_cpu_percent_last_ts = 0.0 + + +def _get_server_cpu_percent() -> float: + import psutil + + global _server_cpu_percent_last, _server_cpu_percent_last_ts + + now = time.time() + with _server_cpu_percent_lock: + if _server_cpu_percent_last is not None and (now - _server_cpu_percent_last_ts) < 0.5: + return _server_cpu_percent_last + + try: + if _server_cpu_percent_last is None: + cpu_percent = float(psutil.cpu_percent(interval=0.1)) + else: + cpu_percent = float(psutil.cpu_percent(interval=None)) + except Exception: + cpu_percent = float(_server_cpu_percent_last or 0.0) + + if cpu_percent < 0: + cpu_percent = 0.0 + + _server_cpu_percent_last = cpu_percent + _server_cpu_percent_last_ts = now + return cpu_percent + + +@admin_api_bp.route("/kdocs/status", methods=["GET"]) +@admin_required +def get_kdocs_status_api(): + """获取金山文档上传状态""" + try: + from services.kdocs_uploader import get_kdocs_uploader + + uploader = get_kdocs_uploader() + status = uploader.get_status() + live = str(request.args.get("live", "")).lower() in ("1", "true", "yes") + if live: + live_status = uploader.refresh_login_status() + if live_status.get("success"): + logged_in = bool(live_status.get("logged_in")) + status["logged_in"] = logged_in + status["last_login_ok"] = logged_in + status["login_required"] = not logged_in + if live_status.get("error"): + status["last_error"] = live_status.get("error") + else: + status["logged_in"] = True if status.get("last_login_ok") else False if status.get("last_login_ok") is False else None + if status.get("last_login_ok") is True and status.get("last_error") == "操作超时": + status["last_error"] = None + return jsonify(status) + except Exception as e: + return jsonify({"error": f"获取状态失败: {e}"}), 500 + + +@admin_api_bp.route("/kdocs/qr", methods=["POST"]) +@admin_required +def get_kdocs_qr_api(): + """获取金山文档登录二维码""" + try: + from services.kdocs_uploader import get_kdocs_uploader + + uploader = get_kdocs_uploader() + data = request.get_json(silent=True) or {} + force = bool(data.get("force")) + if not force: + force = str(request.args.get("force", "")).lower() in ("1", "true", "yes") + result = uploader.request_qr(force=force) + if not result.get("success"): + return jsonify({"error": result.get("error", "获取二维码失败")}), 400 + return jsonify(result) + except Exception as e: + return jsonify({"error": f"获取二维码失败: {e}"}), 500 + + +@admin_api_bp.route("/kdocs/clear-login", methods=["POST"]) +@admin_required +def clear_kdocs_login_api(): + """清除金山文档登录态""" + try: + from services.kdocs_uploader import get_kdocs_uploader + + uploader = get_kdocs_uploader() + result = uploader.clear_login() + if not result.get("success"): + return jsonify({"error": result.get("error", "清除失败")}), 400 + return jsonify({"success": True}) + except Exception as e: + return jsonify({"error": f"清除失败: {e}"}), 500 + + +@admin_api_bp.route("/schedule/execute", methods=["POST"]) +@admin_required +def execute_schedule_now(): + """立即执行定时任务(无视定时时间和星期限制)""" + try: + threading.Thread(target=run_scheduled_task, args=(True,), daemon=True).start() + logger.info("[立即执行定时任务] 管理员手动触发定时任务执行(跳过星期检查)") + return jsonify({"message": "定时任务已开始执行,请查看任务列表获取进度"}) + except Exception as e: + logger.error(f"[立即执行定时任务] 启动失败: {str(e)}") + return jsonify({"error": f"启动失败: {str(e)}"}), 500 + + +@admin_api_bp.route("/proxy/config", methods=["GET"]) +@admin_required +def get_proxy_config_api(): + """获取代理配置""" + config_data = database.get_system_config() + return jsonify( + { + "proxy_enabled": config_data.get("proxy_enabled", 0), + "proxy_api_url": config_data.get("proxy_api_url", ""), + "proxy_expire_minutes": config_data.get("proxy_expire_minutes", 3), + } + ) + + +@admin_api_bp.route("/proxy/config", methods=["POST"]) +@admin_required +def update_proxy_config_api(): + """更新代理配置""" + data = request.json or {} + proxy_enabled = data.get("proxy_enabled") + proxy_api_url = (data.get("proxy_api_url", "") or "").strip() + proxy_expire_minutes = data.get("proxy_expire_minutes") + + if proxy_enabled is not None and proxy_enabled not in [0, 1]: + return jsonify({"error": "proxy_enabled必须是0或1"}), 400 + + if proxy_expire_minutes is not None: + if not isinstance(proxy_expire_minutes, int) or proxy_expire_minutes < 1: + return jsonify({"error": "代理有效期必须是大于0的整数"}), 400 + + if database.update_system_config( + proxy_enabled=proxy_enabled, + proxy_api_url=proxy_api_url, + proxy_expire_minutes=proxy_expire_minutes, + ): + return jsonify({"message": "代理配置已更新"}) + return jsonify({"error": "更新失败"}), 400 + + +@admin_api_bp.route("/proxy/test", methods=["POST"]) +@admin_required +def test_proxy_api(): + """测试代理连接""" + data = request.json or {} + api_url = (data.get("api_url") or "").strip() + + if not api_url: + return jsonify({"error": "请提供API地址"}), 400 + + if not is_safe_outbound_url(api_url): + return jsonify({"error": "API地址不可用或不安全"}), 400 + + try: + response = requests.get(api_url, timeout=10) + if response.status_code == 200: + ip_port = response.text.strip() + if ip_port and ":" in ip_port: + return jsonify({"success": True, "proxy": ip_port, "message": f"代理获取成功: {ip_port}"}) + return jsonify({"success": False, "message": f"代理格式错误: {ip_port}"}), 400 + return jsonify({"success": False, "message": f"HTTP错误: {response.status_code}"}), 400 + except Exception as e: + return jsonify({"success": False, "message": f"连接失败: {str(e)}"}), 500 + + +@admin_api_bp.route("/server/info", methods=["GET"]) +@admin_required +def get_server_info_api(): + """获取服务器信息""" + import psutil + + cpu_percent = _get_server_cpu_percent() + + memory = psutil.virtual_memory() + memory_total = f"{memory.total / (1024**3):.1f}GB" + memory_used = f"{memory.used / (1024**3):.1f}GB" + memory_percent = memory.percent + + disk = psutil.disk_usage("/") + disk_total = f"{disk.total / (1024**3):.1f}GB" + disk_used = f"{disk.used / (1024**3):.1f}GB" + disk_percent = disk.percent + + boot_time = datetime.fromtimestamp(psutil.boot_time(), tz=BEIJING_TZ) + uptime_delta = get_beijing_now() - boot_time + days = uptime_delta.days + hours = uptime_delta.seconds // 3600 + uptime = f"{days}天{hours}小时" + + return jsonify( + { + "cpu_percent": cpu_percent, + "memory_total": memory_total, + "memory_used": memory_used, + "memory_percent": memory_percent, + "disk_total": disk_total, + "disk_used": disk_used, + "disk_percent": disk_percent, + "uptime": uptime, + } + ) diff --git a/routes/admin_api/security.py b/routes/admin_api/security.py index f6a4caa..9399f7c 100644 --- a/routes/admin_api/security.py +++ b/routes/admin_api/security.py @@ -62,6 +62,19 @@ def _parse_bool(value: Any) -> bool: return text in {"1", "true", "yes", "y", "on"} +def _parse_int(value: Any, *, default: int | None = None, min_value: int | None = None) -> int | None: + try: + parsed = int(value) + except Exception: + parsed = default + + if parsed is None: + return None + if min_value is not None: + parsed = max(int(min_value), parsed) + return parsed + + def _sanitize_threat_event(event: dict) -> dict: return { "id": event.get("id"), @@ -199,10 +212,7 @@ def ban_ip(): if not reason: return jsonify({"error": "reason不能为空"}), 400 - try: - duration_hours = max(1, int(duration_hours_raw)) - except Exception: - duration_hours = 24 + duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24 ok = blacklist.ban_ip(ip, reason, duration_hours=duration_hours, permanent=permanent) if not ok: @@ -235,20 +245,14 @@ def ban_user(): duration_hours_raw = data.get("duration_hours", 24) permanent = _parse_bool(data.get("permanent", False)) - try: - user_id = int(user_id_raw) - except Exception: - user_id = None + user_id = _parse_int(user_id_raw) if user_id is None: return jsonify({"error": "user_id不能为空"}), 400 if not reason: return jsonify({"error": "reason不能为空"}), 400 - try: - duration_hours = max(1, int(duration_hours_raw)) - except Exception: - duration_hours = 24 + duration_hours = _parse_int(duration_hours_raw, default=24, min_value=1) or 24 ok = blacklist._ban_user_internal(user_id, reason=reason, duration_hours=duration_hours, permanent=permanent) if not ok: @@ -262,10 +266,7 @@ def unban_user(): """解除用户封禁""" data = _parse_json() user_id_raw = data.get("user_id") - try: - user_id = int(user_id_raw) - except Exception: - user_id = None + user_id = _parse_int(user_id_raw) if user_id is None: return jsonify({"error": "user_id不能为空"}), 400 diff --git a/routes/admin_api/system_config_api.py b/routes/admin_api/system_config_api.py new file mode 100644 index 0000000..478b950 --- /dev/null +++ b/routes/admin_api/system_config_api.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import database +from app_logger import get_logger +from app_security import is_safe_outbound_url, validate_email +from flask import jsonify, request +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required +from services.browse_types import BROWSE_TYPE_SHOULD_READ, validate_browse_type +from services.tasks import get_task_scheduler + +logger = get_logger("app") + + +@admin_api_bp.route("/system/config", methods=["GET"]) +@admin_required +def get_system_config_api(): + """获取系统配置""" + return jsonify(database.get_system_config()) + + +@admin_api_bp.route("/system/config", methods=["POST"]) +@admin_required +def update_system_config_api(): + """更新系统配置""" + data = request.json or {} + + max_concurrent = data.get("max_concurrent_global") + schedule_enabled = data.get("schedule_enabled") + schedule_time = data.get("schedule_time") + schedule_browse_type = data.get("schedule_browse_type") + schedule_weekdays = data.get("schedule_weekdays") + new_max_concurrent_per_account = data.get("max_concurrent_per_account") + new_max_screenshot_concurrent = data.get("max_screenshot_concurrent") + enable_screenshot = data.get("enable_screenshot") + auto_approve_enabled = data.get("auto_approve_enabled") + auto_approve_hourly_limit = data.get("auto_approve_hourly_limit") + auto_approve_vip_days = data.get("auto_approve_vip_days") + kdocs_enabled = data.get("kdocs_enabled") + kdocs_doc_url = data.get("kdocs_doc_url") + kdocs_default_unit = data.get("kdocs_default_unit") + kdocs_sheet_name = data.get("kdocs_sheet_name") + kdocs_sheet_index = data.get("kdocs_sheet_index") + kdocs_unit_column = data.get("kdocs_unit_column") + kdocs_image_column = data.get("kdocs_image_column") + kdocs_admin_notify_enabled = data.get("kdocs_admin_notify_enabled") + kdocs_admin_notify_email = data.get("kdocs_admin_notify_email") + kdocs_row_start = data.get("kdocs_row_start") + kdocs_row_end = data.get("kdocs_row_end") + + if max_concurrent is not None: + if not isinstance(max_concurrent, int) or max_concurrent < 1: + return jsonify({"error": "全局并发数必须大于0(建议:小型服务器2-5,中型5-10,大型10-20)"}), 400 + + if new_max_concurrent_per_account is not None: + if not isinstance(new_max_concurrent_per_account, int) or new_max_concurrent_per_account < 1: + return jsonify({"error": "单账号并发数必须大于0(建议设为1,避免同一用户任务相互影响)"}), 400 + + if new_max_screenshot_concurrent is not None: + if not isinstance(new_max_screenshot_concurrent, int) or new_max_screenshot_concurrent < 1: + return jsonify({"error": "截图并发数必须大于0(建议根据服务器配置设置,wkhtmltoimage 资源占用较低)"}), 400 + + if enable_screenshot is not None: + if isinstance(enable_screenshot, bool): + enable_screenshot = 1 if enable_screenshot else 0 + if enable_screenshot not in (0, 1): + return jsonify({"error": "截图开关必须是0或1"}), 400 + + if schedule_time is not None: + import re + + if not re.match(r"^([01]\\d|2[0-3]):([0-5]\\d)$", schedule_time): + return jsonify({"error": "时间格式错误,应为 HH:MM"}), 400 + + if schedule_browse_type is not None: + normalized = validate_browse_type(schedule_browse_type, default=BROWSE_TYPE_SHOULD_READ) + if not normalized: + return jsonify({"error": "浏览类型无效"}), 400 + schedule_browse_type = normalized + + if schedule_weekdays is not None: + try: + days = [int(d.strip()) for d in schedule_weekdays.split(",") if d.strip()] + if not all(1 <= d <= 7 for d in days): + return jsonify({"error": "星期数字必须在1-7之间"}), 400 + except (ValueError, AttributeError): + return jsonify({"error": "星期格式错误"}), 400 + + if auto_approve_hourly_limit is not None: + if not isinstance(auto_approve_hourly_limit, int) or auto_approve_hourly_limit < 1: + return jsonify({"error": "每小时注册限制必须大于0"}), 400 + + if auto_approve_vip_days is not None: + if not isinstance(auto_approve_vip_days, int) or auto_approve_vip_days < 0: + return jsonify({"error": "注册赠送VIP天数不能为负数"}), 400 + + if kdocs_enabled is not None: + if isinstance(kdocs_enabled, bool): + kdocs_enabled = 1 if kdocs_enabled else 0 + if kdocs_enabled not in (0, 1): + return jsonify({"error": "表格上传开关必须是0或1"}), 400 + + if kdocs_doc_url is not None: + kdocs_doc_url = str(kdocs_doc_url or "").strip() + if kdocs_doc_url and not is_safe_outbound_url(kdocs_doc_url): + return jsonify({"error": "文档链接格式不正确"}), 400 + + if kdocs_default_unit is not None: + kdocs_default_unit = str(kdocs_default_unit or "").strip() + if len(kdocs_default_unit) > 50: + return jsonify({"error": "默认县区长度不能超过50"}), 400 + + if kdocs_sheet_name is not None: + kdocs_sheet_name = str(kdocs_sheet_name or "").strip() + if len(kdocs_sheet_name) > 50: + return jsonify({"error": "Sheet名称长度不能超过50"}), 400 + + if kdocs_sheet_index is not None: + try: + kdocs_sheet_index = int(kdocs_sheet_index) + except Exception: + return jsonify({"error": "Sheet序号必须是数字"}), 400 + if kdocs_sheet_index < 0: + return jsonify({"error": "Sheet序号不能为负数"}), 400 + + if kdocs_unit_column is not None: + kdocs_unit_column = str(kdocs_unit_column or "").strip().upper() + if not kdocs_unit_column: + return jsonify({"error": "县区列不能为空"}), 400 + import re + + if not re.match(r"^[A-Z]{1,3}$", kdocs_unit_column): + return jsonify({"error": "县区列格式错误"}), 400 + + if kdocs_image_column is not None: + kdocs_image_column = str(kdocs_image_column or "").strip().upper() + if not kdocs_image_column: + return jsonify({"error": "图片列不能为空"}), 400 + import re + + if not re.match(r"^[A-Z]{1,3}$", kdocs_image_column): + return jsonify({"error": "图片列格式错误"}), 400 + + if kdocs_admin_notify_enabled is not None: + if isinstance(kdocs_admin_notify_enabled, bool): + kdocs_admin_notify_enabled = 1 if kdocs_admin_notify_enabled else 0 + if kdocs_admin_notify_enabled not in (0, 1): + return jsonify({"error": "管理员通知开关必须是0或1"}), 400 + + if kdocs_admin_notify_email is not None: + kdocs_admin_notify_email = str(kdocs_admin_notify_email or "").strip() + if kdocs_admin_notify_email: + is_valid, error_msg = validate_email(kdocs_admin_notify_email) + if not is_valid: + return jsonify({"error": error_msg}), 400 + + if kdocs_row_start is not None: + try: + kdocs_row_start = int(kdocs_row_start) + except (ValueError, TypeError): + return jsonify({"error": "起始行必须是数字"}), 400 + if kdocs_row_start < 0: + return jsonify({"error": "起始行不能为负数"}), 400 + + if kdocs_row_end is not None: + try: + kdocs_row_end = int(kdocs_row_end) + except (ValueError, TypeError): + return jsonify({"error": "结束行必须是数字"}), 400 + if kdocs_row_end < 0: + return jsonify({"error": "结束行不能为负数"}), 400 + + old_config = database.get_system_config() or {} + + if not database.update_system_config( + max_concurrent=max_concurrent, + schedule_enabled=schedule_enabled, + schedule_time=schedule_time, + schedule_browse_type=schedule_browse_type, + schedule_weekdays=schedule_weekdays, + max_concurrent_per_account=new_max_concurrent_per_account, + max_screenshot_concurrent=new_max_screenshot_concurrent, + enable_screenshot=enable_screenshot, + auto_approve_enabled=auto_approve_enabled, + auto_approve_hourly_limit=auto_approve_hourly_limit, + auto_approve_vip_days=auto_approve_vip_days, + kdocs_enabled=kdocs_enabled, + kdocs_doc_url=kdocs_doc_url, + kdocs_default_unit=kdocs_default_unit, + kdocs_sheet_name=kdocs_sheet_name, + kdocs_sheet_index=kdocs_sheet_index, + kdocs_unit_column=kdocs_unit_column, + kdocs_image_column=kdocs_image_column, + kdocs_admin_notify_enabled=kdocs_admin_notify_enabled, + kdocs_admin_notify_email=kdocs_admin_notify_email, + kdocs_row_start=kdocs_row_start, + kdocs_row_end=kdocs_row_end, + ): + return jsonify({"error": "更新失败"}), 400 + + try: + new_config = database.get_system_config() or {} + scheduler = get_task_scheduler() + scheduler.update_limits( + max_global=int(new_config.get("max_concurrent_global", old_config.get("max_concurrent_global", 2))), + max_per_user=int(new_config.get("max_concurrent_per_account", old_config.get("max_concurrent_per_account", 1))), + ) + if new_max_screenshot_concurrent is not None: + try: + from browser_pool_worker import resize_browser_worker_pool + + if resize_browser_worker_pool(int(new_config.get("max_screenshot_concurrent", new_max_screenshot_concurrent))): + logger.info(f"截图线程池并发已更新为: {new_config.get('max_screenshot_concurrent')}") + except Exception as pool_error: + logger.warning(f"截图线程池并发更新失败: {pool_error}") + except Exception: + pass + + if max_concurrent is not None and max_concurrent != old_config.get("max_concurrent_global"): + logger.info(f"全局并发数已更新为: {max_concurrent}") + if new_max_concurrent_per_account is not None and new_max_concurrent_per_account != old_config.get("max_concurrent_per_account"): + logger.info(f"单用户并发数已更新为: {new_max_concurrent_per_account}") + if new_max_screenshot_concurrent is not None: + logger.info(f"截图并发数已更新为: {new_max_screenshot_concurrent}") + + return jsonify({"message": "系统配置已更新"}) diff --git a/routes/admin_api/tasks_api.py b/routes/admin_api/tasks_api.py new file mode 100644 index 0000000..3814717 --- /dev/null +++ b/routes/admin_api/tasks_api.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import database +from app_logger import get_logger +from flask import jsonify, request +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required +from services.state import safe_iter_task_status_items +from services.tasks import get_task_scheduler + +logger = get_logger("app") + + +def _parse_page_int(name: str, default: int, *, minimum: int, maximum: int) -> int: + try: + value = int(request.args.get(name, default)) + return max(minimum, min(value, maximum)) + except (ValueError, TypeError): + return default + + +@admin_api_bp.route("/task/stats", methods=["GET"]) +@admin_required +def get_task_stats_api(): + """获取任务统计数据""" + date_filter = request.args.get("date") + stats = database.get_task_stats(date_filter) + return jsonify(stats) + + +@admin_api_bp.route("/task/running", methods=["GET"]) +@admin_required +def get_running_tasks_api(): + """获取当前运行中和排队中的任务""" + import time as time_mod + + current_time = time_mod.time() + running = [] + queuing = [] + user_cache = {} + + for account_id, info in safe_iter_task_status_items(): + elapsed = int(current_time - info.get("start_time", current_time)) + + info_user_id = info.get("user_id") + if info_user_id not in user_cache: + user_cache[info_user_id] = database.get_user_by_id(info_user_id) + user = user_cache.get(info_user_id) + user_username = user["username"] if user else "N/A" + + progress = info.get("progress", {"items": 0, "attachments": 0}) + task_info = { + "account_id": account_id, + "user_id": info.get("user_id"), + "user_username": user_username, + "username": info.get("username"), + "browse_type": info.get("browse_type"), + "source": info.get("source", "manual"), + "detail_status": info.get("detail_status", "未知"), + "progress_items": progress.get("items", 0), + "progress_attachments": progress.get("attachments", 0), + "elapsed_seconds": elapsed, + "elapsed_display": f"{elapsed // 60}分{elapsed % 60}秒" if elapsed >= 60 else f"{elapsed}秒", + } + + if info.get("status") == "运行中": + running.append(task_info) + else: + queuing.append(task_info) + + running.sort(key=lambda x: x["elapsed_seconds"], reverse=True) + queuing.sort(key=lambda x: x["elapsed_seconds"], reverse=True) + + try: + max_concurrent = int(get_task_scheduler().max_global) + except Exception: + max_concurrent = int((database.get_system_config() or {}).get("max_concurrent_global", 2)) + + return jsonify( + { + "running": running, + "queuing": queuing, + "running_count": len(running), + "queuing_count": len(queuing), + "max_concurrent": max_concurrent, + } + ) + + +@admin_api_bp.route("/task/logs", methods=["GET"]) +@admin_required +def get_task_logs_api(): + """获取任务日志列表(支持分页和多种筛选)""" + limit = _parse_page_int("limit", 20, minimum=1, maximum=200) + offset = _parse_page_int("offset", 0, minimum=0, maximum=10**9) + + date_filter = request.args.get("date") + status_filter = request.args.get("status") + source_filter = request.args.get("source") + user_id_filter = request.args.get("user_id") + account_filter = (request.args.get("account") or "").strip() + + if user_id_filter: + try: + user_id_filter = int(user_id_filter) + except (ValueError, TypeError): + user_id_filter = None + + try: + result = database.get_task_logs( + limit=limit, + offset=offset, + date_filter=date_filter, + status_filter=status_filter, + source_filter=source_filter, + user_id_filter=user_id_filter, + account_filter=account_filter if account_filter else None, + ) + return jsonify(result) + except Exception as e: + logger.error(f"获取任务日志失败: {e}") + return jsonify({"logs": [], "total": 0, "error": "查询失败"}) + + +@admin_api_bp.route("/task/logs/clear", methods=["POST"]) +@admin_required +def clear_old_task_logs_api(): + """清理旧的任务日志""" + data = request.json or {} + days = data.get("days", 30) + + if not isinstance(days, int) or days < 1: + return jsonify({"error": "天数必须是大于0的整数"}), 400 + + deleted_count = database.delete_old_task_logs(days) + return jsonify({"message": f"已删除{days}天前的{deleted_count}条日志"}) diff --git a/routes/admin_api/users_api.py b/routes/admin_api/users_api.py new file mode 100644 index 0000000..8357d43 --- /dev/null +++ b/routes/admin_api/users_api.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + + +import database +from flask import jsonify, request +from routes.admin_api import admin_api_bp +from routes.decorators import admin_required +from services.state import safe_clear_user_logs, safe_remove_user_accounts + + +# ==================== 用户管理/统计(管理员) ==================== + + +@admin_api_bp.route("/users", methods=["GET"]) +@admin_required +def get_all_users(): + """获取所有用户""" + users = database.get_all_users() + return jsonify(users) + + +@admin_api_bp.route("/users/pending", methods=["GET"]) +@admin_required +def get_pending_users(): + """获取待审核用户""" + users = database.get_pending_users() + return jsonify(users) + + +@admin_api_bp.route("/users//approve", methods=["POST"]) +@admin_required +def approve_user_route(user_id): + """审核通过用户""" + if database.approve_user(user_id): + return jsonify({"success": True}) + return jsonify({"error": "审核失败"}), 400 + + +@admin_api_bp.route("/users//reject", methods=["POST"]) +@admin_required +def reject_user_route(user_id): + """拒绝用户""" + if database.reject_user(user_id): + return jsonify({"success": True}) + return jsonify({"error": "拒绝失败"}), 400 + + +@admin_api_bp.route("/users/", methods=["DELETE"]) +@admin_required +def delete_user_route(user_id): + """删除用户""" + if database.delete_user(user_id): + safe_remove_user_accounts(user_id) + safe_clear_user_logs(user_id) + return jsonify({"success": True}) + return jsonify({"error": "删除失败"}), 400 + + +# ==================== VIP 管理(管理员) ==================== + + +@admin_api_bp.route("/vip/config", methods=["GET"]) +@admin_required +def get_vip_config_api(): + """获取VIP配置""" + config = database.get_vip_config() + return jsonify(config) + + +@admin_api_bp.route("/vip/config", methods=["POST"]) +@admin_required +def set_vip_config_api(): + """设置默认VIP天数""" + data = request.json or {} + days = data.get("default_vip_days", 0) + + if not isinstance(days, int) or days < 0: + return jsonify({"error": "VIP天数必须是非负整数"}), 400 + + database.set_default_vip_days(days) + return jsonify({"message": "VIP配置已更新", "default_vip_days": days}) + + +@admin_api_bp.route("/users//vip", methods=["POST"]) +@admin_required +def set_user_vip_api(user_id): + """设置用户VIP""" + data = request.json or {} + days = data.get("days", 30) + + valid_days = [7, 30, 365, 999999] + if days not in valid_days: + return jsonify({"error": "VIP天数必须是 7/30/365/999999 之一"}), 400 + + if database.set_user_vip(user_id, days): + vip_type = {7: "一周", 30: "一个月", 365: "一年", 999999: "永久"}[days] + return jsonify({"message": f"VIP设置成功: {vip_type}"}) + return jsonify({"error": "设置失败,用户不存在"}), 400 + + +@admin_api_bp.route("/users//vip", methods=["DELETE"]) +@admin_required +def remove_user_vip_api(user_id): + """移除用户VIP""" + if database.remove_user_vip(user_id): + return jsonify({"message": "VIP已移除"}) + return jsonify({"error": "移除失败"}), 400 + + +@admin_api_bp.route("/users//vip", methods=["GET"]) +@admin_required +def get_user_vip_info_api(user_id): + """获取用户VIP信息(管理员)""" + vip_info = database.get_user_vip_info(user_id) + return jsonify(vip_info) diff --git a/routes/api_accounts.py b/routes/api_accounts.py index 2bb5402..0c733d9 100644 --- a/routes/api_accounts.py +++ b/routes/api_accounts.py @@ -40,6 +40,48 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None: pass +def _emit_account_update(user_id: int, account) -> None: + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + + +def _request_json(default=None): + if default is None: + default = {} + data = request.get_json(silent=True) + return data if isinstance(data, dict) else default + + +def _ensure_accounts_loaded(user_id: int) -> dict: + accounts = safe_get_user_accounts_snapshot(user_id) + if accounts: + return accounts + load_user_accounts(user_id) + return safe_get_user_accounts_snapshot(user_id) + + +def _get_user_account(user_id: int, account_id: str, *, refresh_if_missing: bool = False): + account = safe_get_account(user_id, account_id) + if account or (not refresh_if_missing): + return account + load_user_accounts(user_id) + return safe_get_account(user_id, account_id) + + +def _validate_browse_type_input(raw_browse_type, *, default=BROWSE_TYPE_SHOULD_READ): + browse_type = validate_browse_type(raw_browse_type, default=default) + if not browse_type: + return None, (jsonify({"error": "浏览类型无效"}), 400) + return browse_type, None + + +def _cancel_pending_account_task(user_id: int, account_id: str) -> bool: + try: + scheduler = get_task_scheduler() + return bool(scheduler.cancel_pending_task(user_id=user_id, account_id=account_id)) + except Exception: + return False + + @api_accounts_bp.route("/api/accounts", methods=["GET"]) @login_required def get_accounts(): @@ -49,8 +91,7 @@ def get_accounts(): accounts = safe_get_user_accounts_snapshot(user_id) if refresh or not accounts: - load_user_accounts(user_id) - accounts = safe_get_user_accounts_snapshot(user_id) + accounts = _ensure_accounts_loaded(user_id) return jsonify([acc.to_dict() for acc in accounts.values()]) @@ -63,20 +104,18 @@ def add_account(): current_count = len(database.get_user_accounts(user_id)) is_vip = database.is_user_vip(user_id) - if not is_vip and current_count >= 3: + if (not is_vip) and current_count >= 3: return jsonify({"error": "普通用户最多添加3个账号,升级VIP可无限添加"}), 403 - data = request.json - username = data.get("username", "").strip() - password = data.get("password", "").strip() - remark = data.get("remark", "").strip()[:200] + + data = _request_json() + username = str(data.get("username", "")).strip() + password = str(data.get("password", "")).strip() + remark = str(data.get("remark", "")).strip()[:200] if not username or not password: return jsonify({"error": "用户名和密码不能为空"}), 400 - accounts = safe_get_user_accounts_snapshot(user_id) - if not accounts: - load_user_accounts(user_id) - accounts = safe_get_user_accounts_snapshot(user_id) + accounts = _ensure_accounts_loaded(user_id) for acc in accounts.values(): if acc.username == username: return jsonify({"error": f"账号 '{username}' 已存在"}), 400 @@ -92,7 +131,7 @@ def add_account(): safe_set_account(user_id, account_id, account) log_to_client(f"添加账号: {username}", user_id) - _emit("account_update", account.to_dict(), room=f"user_{user_id}") + _emit_account_update(user_id, account) return jsonify(account.to_dict()) @@ -103,15 +142,15 @@ def update_account(account_id): """更新账号信息(密码等)""" user_id = current_user.id - account = safe_get_account(user_id, account_id) + account = _get_user_account(user_id, account_id) if not account: return jsonify({"error": "账号不存在"}), 404 if account.is_running: return jsonify({"error": "账号正在运行中,请先停止"}), 400 - data = request.json - new_password = data.get("password", "").strip() + data = _request_json() + new_password = str(data.get("password", "")).strip() new_remember = data.get("remember", account.remember) if not new_password: @@ -147,7 +186,7 @@ def delete_account(account_id): """删除账号""" user_id = current_user.id - account = safe_get_account(user_id, account_id) + account = _get_user_account(user_id, account_id) if not account: return jsonify({"error": "账号不存在"}), 404 @@ -159,7 +198,6 @@ def delete_account(account_id): username = account.username database.delete_account(account_id) - safe_remove_account(user_id, account_id) log_to_client(f"删除账号: {username}", user_id) @@ -196,12 +234,12 @@ def update_remark(account_id): """更新备注""" user_id = current_user.id - account = safe_get_account(user_id, account_id) + account = _get_user_account(user_id, account_id) if not account: return jsonify({"error": "账号不存在"}), 404 - data = request.json - remark = data.get("remark", "").strip()[:200] + data = _request_json() + remark = str(data.get("remark", "")).strip()[:200] database.update_account_remark(account_id, remark) @@ -217,17 +255,18 @@ def start_account(account_id): """启动账号任务""" user_id = current_user.id - account = safe_get_account(user_id, account_id) + account = _get_user_account(user_id, account_id) if not account: return jsonify({"error": "账号不存在"}), 404 if account.is_running: return jsonify({"error": "任务已在运行中"}), 400 - data = request.json or {} - browse_type = validate_browse_type(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ) - if not browse_type: - return jsonify({"error": "浏览类型无效"}), 400 + data = _request_json() + browse_type, browse_error = _validate_browse_type_input(data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ) + if browse_error: + return browse_error + enable_screenshot = data.get("enable_screenshot", True) ok, message = submit_account_task( user_id=user_id, @@ -249,7 +288,7 @@ def stop_account(account_id): """停止账号任务""" user_id = current_user.id - account = safe_get_account(user_id, account_id) + account = _get_user_account(user_id, account_id) if not account: return jsonify({"error": "账号不存在"}), 404 @@ -259,20 +298,16 @@ def stop_account(account_id): account.should_stop = True account.status = "正在停止" - try: - scheduler = get_task_scheduler() - if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id): - account.status = "已停止" - account.is_running = False - safe_remove_task_status(account_id) - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - log_to_client(f"任务已取消: {account.username}", user_id) - return jsonify({"success": True, "canceled": True}) - except Exception: - pass + if _cancel_pending_account_task(user_id, account_id): + account.status = "已停止" + account.is_running = False + safe_remove_task_status(account_id) + _emit_account_update(user_id, account) + log_to_client(f"任务已取消: {account.username}", user_id) + return jsonify({"success": True, "canceled": True}) log_to_client(f"停止任务: {account.username}", user_id) - _emit("account_update", account.to_dict(), room=f"user_{user_id}") + _emit_account_update(user_id, account) return jsonify({"success": True}) @@ -283,23 +318,20 @@ def manual_screenshot(account_id): """手动为指定账号截图""" user_id = current_user.id - account = safe_get_account(user_id, account_id) - if not account: - load_user_accounts(user_id) - account = safe_get_account(user_id, account_id) + account = _get_user_account(user_id, account_id, refresh_if_missing=True) if not account: return jsonify({"error": "账号不存在"}), 404 if account.is_running: return jsonify({"error": "任务运行中,无法截图"}), 400 - data = request.json or {} + data = _request_json() requested_browse_type = data.get("browse_type", None) if requested_browse_type is None: browse_type = normalize_browse_type(account.last_browse_type) else: - browse_type = validate_browse_type(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ) - if not browse_type: - return jsonify({"error": "浏览类型无效"}), 400 + browse_type, browse_error = _validate_browse_type_input(requested_browse_type, default=BROWSE_TYPE_SHOULD_READ) + if browse_error: + return browse_error account.last_browse_type = browse_type @@ -317,12 +349,16 @@ def manual_screenshot(account_id): def batch_start_accounts(): """批量启动账号""" user_id = current_user.id - data = request.json or {} + data = _request_json() account_ids = data.get("account_ids", []) - browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ) - if not browse_type: - return jsonify({"error": "浏览类型无效"}), 400 + browse_type, browse_error = _validate_browse_type_input( + data.get("browse_type", BROWSE_TYPE_SHOULD_READ), + default=BROWSE_TYPE_SHOULD_READ, + ) + if browse_error: + return browse_error + enable_screenshot = data.get("enable_screenshot", True) if not account_ids: @@ -331,11 +367,10 @@ def batch_start_accounts(): started = [] failed = [] - if not safe_get_user_accounts_snapshot(user_id): - load_user_accounts(user_id) + _ensure_accounts_loaded(user_id) for account_id in account_ids: - account = safe_get_account(user_id, account_id) + account = _get_user_account(user_id, account_id) if not account: failed.append({"id": account_id, "reason": "账号不存在"}) continue @@ -357,7 +392,13 @@ def batch_start_accounts(): failed.append({"id": account_id, "reason": msg}) return jsonify( - {"success": True, "started_count": len(started), "failed_count": len(failed), "started": started, "failed": failed} + { + "success": True, + "started_count": len(started), + "failed_count": len(failed), + "started": started, + "failed": failed, + } ) @@ -366,39 +407,29 @@ def batch_start_accounts(): def batch_stop_accounts(): """批量停止账号""" user_id = current_user.id - data = request.json + data = _request_json() account_ids = data.get("account_ids", []) - if not account_ids: return jsonify({"error": "请选择要停止的账号"}), 400 stopped = [] - - if not safe_get_user_accounts_snapshot(user_id): - load_user_accounts(user_id) + _ensure_accounts_loaded(user_id) for account_id in account_ids: - account = safe_get_account(user_id, account_id) - if not account: - continue - - if not account.is_running: + account = _get_user_account(user_id, account_id) + if (not account) or (not account.is_running): continue account.should_stop = True account.status = "正在停止" stopped.append(account_id) - try: - scheduler = get_task_scheduler() - if scheduler.cancel_pending_task(user_id=user_id, account_id=account_id): - account.status = "已停止" - account.is_running = False - safe_remove_task_status(account_id) - except Exception: - pass + if _cancel_pending_account_task(user_id, account_id): + account.status = "已停止" + account.is_running = False + safe_remove_task_status(account_id) - _emit("account_update", account.to_dict(), room=f"user_{user_id}") + _emit_account_update(user_id, account) return jsonify({"success": True, "stopped_count": len(stopped), "stopped": stopped}) diff --git a/routes/api_auth.py b/routes/api_auth.py index 3621892..2e65443 100644 --- a/routes/api_auth.py +++ b/routes/api_auth.py @@ -2,16 +2,20 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import base64 import random import secrets +import threading import time +import uuid +from io import BytesIO import database import email_service from app_config import get_config from app_logger import get_logger from app_security import get_rate_limit_ip, require_ip_not_locked, validate_email, validate_password, validate_username -from flask import Blueprint, jsonify, redirect, render_template, request, url_for +from flask import Blueprint, jsonify, request from flask_login import login_required, login_user, logout_user from routes.pages import render_app_spa_or_legacy from services.accounts_service import load_user_accounts @@ -39,12 +43,162 @@ config = get_config() api_auth_bp = Blueprint("api_auth", __name__) +_CAPTCHA_FONT_PATHS = [ + "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf", + "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", + "/usr/share/fonts/truetype/freefont/FreeSansBold.ttf", +] +_CAPTCHA_FONT = None +_CAPTCHA_FONT_LOCK = threading.Lock() + + +def _get_json_payload() -> dict: + data = request.get_json(silent=True) + return data if isinstance(data, dict) else {} + + +def _load_captcha_font(image_font_module): + global _CAPTCHA_FONT + + if _CAPTCHA_FONT is not None: + return _CAPTCHA_FONT + + with _CAPTCHA_FONT_LOCK: + if _CAPTCHA_FONT is not None: + return _CAPTCHA_FONT + + for font_path in _CAPTCHA_FONT_PATHS: + try: + _CAPTCHA_FONT = image_font_module.truetype(font_path, 42) + break + except Exception: + continue + + if _CAPTCHA_FONT is None: + _CAPTCHA_FONT = image_font_module.load_default() + + return _CAPTCHA_FONT + + +def _generate_captcha_image_data_uri(code: str) -> str: + from PIL import Image, ImageDraw, ImageFont + + width, height = 160, 60 + image = Image.new("RGB", (width, height), color=(255, 255, 255)) + draw = ImageDraw.Draw(image) + + for _ in range(6): + x1 = random.randint(0, width) + y1 = random.randint(0, height) + x2 = random.randint(0, width) + y2 = random.randint(0, height) + draw.line( + [(x1, y1), (x2, y2)], + fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)), + width=1, + ) + + for _ in range(80): + x = random.randint(0, width) + y = random.randint(0, height) + draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200))) + + font = _load_captcha_font(ImageFont) + for i, char in enumerate(code): + x = 12 + i * 35 + random.randint(-3, 3) + y = random.randint(5, 12) + color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150)) + draw.text((x, y), char, font=font, fill=color) + + buffer = BytesIO() + image.save(buffer, format="PNG") + img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + return f"data:image/png;base64,{img_base64}" + + +def _with_vip_suffix(message: str, auto_approve_enabled: bool, auto_approve_vip_days: int) -> str: + if auto_approve_enabled and auto_approve_vip_days > 0: + return f"{message},赠送{auto_approve_vip_days}天VIP" + return message + + +def _verify_common_captcha(client_ip: str, captcha_session: str, captcha_code: str): + success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) + if success: + return True, None + + is_locked = record_failed_captcha(client_ip) + if is_locked: + return False, (jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429) + return False, (jsonify({"error": message}), 400) + + +def _verify_login_captcha_if_needed( + *, + captcha_required: bool, + captcha_session: str, + captcha_code: str, + client_ip: str, + username_key: str, +): + if not captcha_required: + return True, None + + if not captcha_session or not captcha_code: + return False, (jsonify({"error": "请填写验证码", "need_captcha": True}), 400) + + success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) + if success: + return True, None + + record_login_failure(client_ip, username_key) + return False, (jsonify({"error": message, "need_captcha": True}), 400) + + +def _send_password_reset_email_if_possible(email: str, username: str, user_id: int) -> None: + result = email_service.send_password_reset_email(email=email, username=username, user_id=user_id) + if not result["success"]: + logger.error(f"密码重置邮件发送失败: {result['error']}") + + +def _send_login_security_alert_if_needed(user: dict, username: str, client_ip: str) -> None: + try: + user_agent = request.headers.get("User-Agent", "") + context = database.record_login_context(user["id"], client_ip, user_agent) + if not context or (not context.get("new_ip") and not context.get("new_device")): + return + + if not config.LOGIN_ALERT_ENABLED: + return + if not should_send_login_alert(user["id"], client_ip): + return + if not email_service.get_email_settings().get("login_alert_enabled", True): + return + + user_info = database.get_user_by_id(user["id"]) or {} + if (not user_info.get("email")) or (not user_info.get("email_verified")): + return + if not database.get_user_email_notify(user["id"]): + return + + email_service.send_security_alert_email( + email=user_info.get("email"), + username=user_info.get("username") or username, + ip_address=client_ip, + user_agent=user_agent, + new_ip=context.get("new_ip", False), + new_device=context.get("new_device", False), + user_id=user["id"], + ) + except Exception: + pass + @api_auth_bp.route("/api/register", methods=["POST"]) @require_ip_not_locked def register(): """用户注册""" - data = request.json or {} + data = _get_json_payload() username = data.get("username", "").strip() password = data.get("password", "").strip() email = data.get("email", "").strip().lower() @@ -67,12 +221,9 @@ def register(): if not allowed: return jsonify({"error": error_msg}), 429 - success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) - if not success: - is_locked = record_failed_captcha(client_ip) - if is_locked: - return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429 - return jsonify({"error": message}), 400 + captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code) + if not captcha_ok: + return captcha_error_response email_settings = email_service.get_email_settings() email_verify_enabled = email_settings.get("register_verify_enabled", False) and email_settings.get("enabled", False) @@ -105,20 +256,22 @@ def register(): if email_verify_enabled and email: result = email_service.send_register_verification_email(email=email, username=username, user_id=user_id) if result["success"]: - message = "注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)" - if auto_approve_enabled and auto_approve_vip_days > 0: - message += f",赠送{auto_approve_vip_days}天VIP" + message = _with_vip_suffix( + "注册成功!验证邮件已发送(可直接登录,建议完成邮箱验证)", + auto_approve_enabled, + auto_approve_vip_days, + ) return jsonify({"success": True, "message": message, "need_verify": True}) logger.error(f"注册验证邮件发送失败: {result['error']}") - message = f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录" - if auto_approve_enabled and auto_approve_vip_days > 0: - message += f",赠送{auto_approve_vip_days}天VIP" + message = _with_vip_suffix( + f"注册成功,但验证邮件发送失败({result['error']})。你仍可直接登录", + auto_approve_enabled, + auto_approve_vip_days, + ) return jsonify({"success": True, "message": message, "need_verify": True}) - message = "注册成功!可直接登录" - if auto_approve_enabled and auto_approve_vip_days > 0: - message += f",赠送{auto_approve_vip_days}天VIP" + message = _with_vip_suffix("注册成功!可直接登录", auto_approve_enabled, auto_approve_vip_days) return jsonify({"success": True, "message": message}) return jsonify({"error": "用户名已存在"}), 400 @@ -175,7 +328,7 @@ def verify_email(token): @require_ip_not_locked def resend_verify_email(): """重发验证邮件""" - data = request.json or {} + data = _get_json_payload() email = data.get("email", "").strip().lower() captcha_session = data.get("captcha_session", "") captcha_code = data.get("captcha", "").strip() @@ -195,12 +348,9 @@ def resend_verify_email(): if not allowed: return jsonify({"error": error_msg}), 429 - success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) - if not success: - is_locked = record_failed_captcha(client_ip) - if is_locked: - return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429 - return jsonify({"error": message}), 400 + captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code) + if not captcha_ok: + return captcha_error_response user = database.get_user_by_email(email) if not user: @@ -235,7 +385,7 @@ def get_email_verify_status(): @require_ip_not_locked def forgot_password(): """发送密码重置邮件""" - data = request.json or {} + data = _get_json_payload() email = data.get("email", "").strip().lower() username = data.get("username", "").strip() captcha_session = data.get("captcha_session", "") @@ -263,12 +413,9 @@ def forgot_password(): if not allowed: return jsonify({"error": error_msg}), 429 - success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) - if not success: - is_locked = record_failed_captcha(client_ip) - if is_locked: - return jsonify({"error": "验证码错误次数过多,IP已被锁定1小时"}), 429 - return jsonify({"error": message}), 400 + captcha_ok, captcha_error_response = _verify_common_captcha(client_ip, captcha_session, captcha_code) + if not captcha_ok: + return captcha_error_response email_settings = email_service.get_email_settings() if not email_settings.get("enabled", False): @@ -293,20 +440,16 @@ def forgot_password(): if not allowed: return jsonify({"error": error_msg}), 429 - result = email_service.send_password_reset_email( + _send_password_reset_email_if_possible( email=bound_email, username=user["username"], user_id=user["id"], ) - if not result["success"]: - logger.error(f"密码重置邮件发送失败: {result['error']}") return jsonify({"success": True, "message": "如果该账号已绑定邮箱,您将收到密码重置邮件"}) user = database.get_user_by_email(email) if user and user.get("status") == "approved": - result = email_service.send_password_reset_email(email=email, username=user["username"], user_id=user["id"]) - if not result["success"]: - logger.error(f"密码重置邮件发送失败: {result['error']}") + _send_password_reset_email_if_possible(email=email, username=user["username"], user_id=user["id"]) return jsonify({"success": True, "message": "如果该邮箱已注册,您将收到密码重置邮件"}) @@ -331,7 +474,7 @@ def reset_password_page(token): @api_auth_bp.route("/api/reset-password-confirm", methods=["POST"]) def reset_password_confirm(): """确认密码重置""" - data = request.json or {} + data = _get_json_payload() token = data.get("token", "").strip() new_password = data.get("new_password", "").strip() @@ -356,67 +499,15 @@ def reset_password_confirm(): @api_auth_bp.route("/api/generate_captcha", methods=["POST"]) def generate_captcha(): """生成4位数字验证码图片""" - import base64 - import uuid - from io import BytesIO - session_id = str(uuid.uuid4()) - - code = "".join([str(secrets.randbelow(10)) for _ in range(4)]) + code = "".join(str(secrets.randbelow(10)) for _ in range(4)) safe_set_captcha(session_id, {"code": code, "expire_time": time.time() + 300, "failed_attempts": 0}) safe_cleanup_expired_captcha() try: - from PIL import Image, ImageDraw, ImageFont - import io - - width, height = 160, 60 - image = Image.new("RGB", (width, height), color=(255, 255, 255)) - draw = ImageDraw.Draw(image) - - for _ in range(6): - x1 = random.randint(0, width) - y1 = random.randint(0, height) - x2 = random.randint(0, width) - y2 = random.randint(0, height) - draw.line( - [(x1, y1), (x2, y2)], - fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)), - width=1, - ) - - for _ in range(80): - x = random.randint(0, width) - y = random.randint(0, height) - draw.point((x, y), fill=(random.randint(0, 200), random.randint(0, 200), random.randint(0, 200))) - - font = None - font_paths = [ - "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf", - "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", - "/usr/share/fonts/truetype/freefont/FreeSansBold.ttf", - ] - for font_path in font_paths: - try: - font = ImageFont.truetype(font_path, 42) - break - except Exception: - continue - if font is None: - font = ImageFont.load_default() - - for i, char in enumerate(code): - x = 12 + i * 35 + random.randint(-3, 3) - y = random.randint(5, 12) - color = (random.randint(0, 150), random.randint(0, 150), random.randint(0, 150)) - draw.text((x, y), char, font=font, fill=color) - - buffer = io.BytesIO() - image.save(buffer, format="PNG") - img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - - return jsonify({"session_id": session_id, "captcha_image": f"data:image/png;base64,{img_base64}"}) + captcha_image = _generate_captcha_image_data_uri(code) + return jsonify({"session_id": session_id, "captcha_image": captcha_image}) except ImportError as e: logger.error(f"PIL库未安装,验证码功能不可用: {e}") safe_delete_captcha(session_id) @@ -427,7 +518,7 @@ def generate_captcha(): @require_ip_not_locked def login(): """用户登录""" - data = request.json or {} + data = _get_json_payload() username = data.get("username", "").strip() password = data.get("password", "").strip() captcha_session = data.get("captcha_session", "") @@ -452,13 +543,15 @@ def login(): return jsonify({"error": error_msg, "need_captcha": True}), 429 captcha_required = check_login_captcha_required(client_ip, username_key) or scan_locked or bool(need_captcha) - if captcha_required: - if not captcha_session or not captcha_code: - return jsonify({"error": "请填写验证码", "need_captcha": True}), 400 - success, message = safe_verify_and_consume_captcha(captcha_session, captcha_code) - if not success: - record_login_failure(client_ip, username_key) - return jsonify({"error": message, "need_captcha": True}), 400 + captcha_ok, captcha_error_response = _verify_login_captcha_if_needed( + captcha_required=captcha_required, + captcha_session=captcha_session, + captcha_code=captcha_code, + client_ip=client_ip, + username_key=username_key, + ) + if not captcha_ok: + return captcha_error_response user = database.verify_user(username, password) if not user: @@ -476,29 +569,7 @@ def login(): login_user(user_obj) load_user_accounts(user["id"]) - try: - user_agent = request.headers.get("User-Agent", "") - context = database.record_login_context(user["id"], client_ip, user_agent) - if context and (context.get("new_ip") or context.get("new_device")): - if ( - config.LOGIN_ALERT_ENABLED - and should_send_login_alert(user["id"], client_ip) - and email_service.get_email_settings().get("login_alert_enabled", True) - ): - user_info = database.get_user_by_id(user["id"]) or {} - if user_info.get("email") and user_info.get("email_verified"): - if database.get_user_email_notify(user["id"]): - email_service.send_security_alert_email( - email=user_info.get("email"), - username=user_info.get("username") or username, - ip_address=client_ip, - user_agent=user_agent, - new_ip=context.get("new_ip", False), - new_device=context.get("new_device", False), - user_id=user["id"], - ) - except Exception: - pass + _send_login_security_alert_if_needed(user=user, username=username, client_ip=client_ip) return jsonify({"success": True}) diff --git a/routes/api_schedules.py b/routes/api_schedules.py index 1d81026..14844e5 100644 --- a/routes/api_schedules.py +++ b/routes/api_schedules.py @@ -2,7 +2,11 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import json import re +import threading +import time as time_mod +import uuid import database from flask import Blueprint, jsonify, request @@ -17,6 +21,13 @@ api_schedules_bp = Blueprint("api_schedules", __name__) _HHMM_RE = re.compile(r"^(\d{1,2}):(\d{2})$") +def _request_json(default=None): + if default is None: + default = {} + data = request.get_json(silent=True) + return data if isinstance(data, dict) else default + + def _normalize_hhmm(value: object) -> str | None: match = _HHMM_RE.match(str(value or "").strip()) if not match: @@ -28,18 +39,53 @@ def _normalize_hhmm(value: object) -> str | None: return f"{hour:02d}:{minute:02d}" +def _normalize_random_delay(value) -> tuple[int | None, str | None]: + try: + normalized = int(value or 0) + except Exception: + return None, "random_delay必须是0或1" + if normalized not in (0, 1): + return None, "random_delay必须是0或1" + return normalized, None + + +def _parse_schedule_account_ids(raw_value) -> list: + try: + parsed = json.loads(raw_value or "[]") + except (json.JSONDecodeError, TypeError): + return [] + return parsed if isinstance(parsed, list) else [] + + +def _get_owned_schedule_or_error(schedule_id: int): + schedule = database.get_schedule_by_id(schedule_id) + if not schedule: + return None, (jsonify({"error": "定时任务不存在"}), 404) + if schedule.get("user_id") != current_user.id: + return None, (jsonify({"error": "无权访问"}), 403) + return schedule, None + + +def _ensure_user_accounts_loaded(user_id: int) -> None: + if safe_get_user_accounts_snapshot(user_id): + return + load_user_accounts(user_id) + + +def _parse_browse_type_or_error(raw_value, *, default=BROWSE_TYPE_SHOULD_READ): + browse_type = validate_browse_type(raw_value, default=default) + if not browse_type: + return None, (jsonify({"error": "浏览类型无效"}), 400) + return browse_type, None + + @api_schedules_bp.route("/api/schedules", methods=["GET"]) @login_required def get_user_schedules_api(): """获取当前用户的所有定时任务""" schedules = database.get_user_schedules(current_user.id) - import json - - for s in schedules: - try: - s["account_ids"] = json.loads(s.get("account_ids", "[]") or "[]") - except (json.JSONDecodeError, TypeError): - s["account_ids"] = [] + for schedule in schedules: + schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids")) return jsonify(schedules) @@ -47,23 +93,26 @@ def get_user_schedules_api(): @login_required def create_user_schedule_api(): """创建用户定时任务""" - data = request.json or {} + data = _request_json() name = data.get("name", "我的定时任务") schedule_time = data.get("schedule_time", "08:00") weekdays = data.get("weekdays", "1,2,3,4,5") - browse_type = validate_browse_type(data.get("browse_type", BROWSE_TYPE_SHOULD_READ), default=BROWSE_TYPE_SHOULD_READ) - if not browse_type: - return jsonify({"error": "浏览类型无效"}), 400 + + browse_type, browse_error = _parse_browse_type_or_error(data.get("browse_type", BROWSE_TYPE_SHOULD_READ)) + if browse_error: + return browse_error + enable_screenshot = data.get("enable_screenshot", 1) - random_delay = int(data.get("random_delay", 0) or 0) + random_delay, delay_error = _normalize_random_delay(data.get("random_delay", 0)) + if delay_error: + return jsonify({"error": delay_error}), 400 + account_ids = data.get("account_ids", []) normalized_time = _normalize_hhmm(schedule_time) if not normalized_time: return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400 - if random_delay not in (0, 1): - return jsonify({"error": "random_delay必须是0或1"}), 400 schedule_id = database.create_user_schedule( user_id=current_user.id, @@ -85,18 +134,11 @@ def create_user_schedule_api(): @login_required def get_schedule_detail_api(schedule_id): """获取定时任务详情""" - schedule = database.get_schedule_by_id(schedule_id) - if not schedule: - return jsonify({"error": "定时任务不存在"}), 404 - if schedule["user_id"] != current_user.id: - return jsonify({"error": "无权访问"}), 403 + schedule, error_response = _get_owned_schedule_or_error(schedule_id) + if error_response: + return error_response - import json - - try: - schedule["account_ids"] = json.loads(schedule.get("account_ids", "[]") or "[]") - except (json.JSONDecodeError, TypeError): - schedule["account_ids"] = [] + schedule["account_ids"] = _parse_schedule_account_ids(schedule.get("account_ids")) return jsonify(schedule) @@ -104,14 +146,12 @@ def get_schedule_detail_api(schedule_id): @login_required def update_schedule_api(schedule_id): """更新定时任务""" - schedule = database.get_schedule_by_id(schedule_id) - if not schedule: - return jsonify({"error": "定时任务不存在"}), 404 - if schedule["user_id"] != current_user.id: - return jsonify({"error": "无权访问"}), 403 + _, error_response = _get_owned_schedule_or_error(schedule_id) + if error_response: + return error_response - data = request.json or {} - allowed_fields = [ + data = _request_json() + allowed_fields = { "name", "schedule_time", "weekdays", @@ -120,27 +160,26 @@ def update_schedule_api(schedule_id): "random_delay", "account_ids", "enabled", - ] - - update_data = {k: v for k, v in data.items() if k in allowed_fields} + } + update_data = {key: value for key, value in data.items() if key in allowed_fields} if "schedule_time" in update_data: normalized_time = _normalize_hhmm(update_data["schedule_time"]) if not normalized_time: return jsonify({"error": "时间格式不正确,应为 HH:MM"}), 400 update_data["schedule_time"] = normalized_time + if "random_delay" in update_data: - try: - update_data["random_delay"] = int(update_data.get("random_delay") or 0) - except Exception: - return jsonify({"error": "random_delay必须是0或1"}), 400 - if update_data["random_delay"] not in (0, 1): - return jsonify({"error": "random_delay必须是0或1"}), 400 + random_delay, delay_error = _normalize_random_delay(update_data.get("random_delay")) + if delay_error: + return jsonify({"error": delay_error}), 400 + update_data["random_delay"] = random_delay + if "browse_type" in update_data: - normalized = validate_browse_type(update_data.get("browse_type"), default=BROWSE_TYPE_SHOULD_READ) - if not normalized: - return jsonify({"error": "浏览类型无效"}), 400 - update_data["browse_type"] = normalized + normalized_browse_type, browse_error = _parse_browse_type_or_error(update_data.get("browse_type")) + if browse_error: + return browse_error + update_data["browse_type"] = normalized_browse_type success = database.update_user_schedule(schedule_id, **update_data) if success: @@ -152,11 +191,9 @@ def update_schedule_api(schedule_id): @login_required def delete_schedule_api(schedule_id): """删除定时任务""" - schedule = database.get_schedule_by_id(schedule_id) - if not schedule: - return jsonify({"error": "定时任务不存在"}), 404 - if schedule["user_id"] != current_user.id: - return jsonify({"error": "无权访问"}), 403 + _, error_response = _get_owned_schedule_or_error(schedule_id) + if error_response: + return error_response success = database.delete_user_schedule(schedule_id) if success: @@ -168,13 +205,11 @@ def delete_schedule_api(schedule_id): @login_required def toggle_schedule_api(schedule_id): """启用/禁用定时任务""" - schedule = database.get_schedule_by_id(schedule_id) - if not schedule: - return jsonify({"error": "定时任务不存在"}), 404 - if schedule["user_id"] != current_user.id: - return jsonify({"error": "无权访问"}), 403 + schedule, error_response = _get_owned_schedule_or_error(schedule_id) + if error_response: + return error_response - data = request.json + data = _request_json() enabled = data.get("enabled", not schedule["enabled"]) success = database.toggle_user_schedule(schedule_id, enabled) @@ -187,22 +222,11 @@ def toggle_schedule_api(schedule_id): @login_required def run_schedule_now_api(schedule_id): """立即执行定时任务""" - import json - import threading - import time as time_mod - import uuid - - schedule = database.get_schedule_by_id(schedule_id) - if not schedule: - return jsonify({"error": "定时任务不存在"}), 404 - if schedule["user_id"] != current_user.id: - return jsonify({"error": "无权访问"}), 403 - - try: - account_ids = json.loads(schedule.get("account_ids", "[]") or "[]") - except (json.JSONDecodeError, TypeError): - account_ids = [] + schedule, error_response = _get_owned_schedule_or_error(schedule_id) + if error_response: + return error_response + account_ids = _parse_schedule_account_ids(schedule.get("account_ids")) if not account_ids: return jsonify({"error": "没有配置账号"}), 400 @@ -210,8 +234,7 @@ def run_schedule_now_api(schedule_id): browse_type = normalize_browse_type(schedule.get("browse_type", BROWSE_TYPE_SHOULD_READ)) enable_screenshot = schedule["enable_screenshot"] - if not safe_get_user_accounts_snapshot(user_id): - load_user_accounts(user_id) + _ensure_user_accounts_loaded(user_id) from services.state import safe_create_batch, safe_finalize_batch_after_dispatch from services.task_batches import _send_batch_task_email_if_configured @@ -250,6 +273,7 @@ def run_schedule_now_api(schedule_id): if remaining["done"] or remaining["count"] > 0: return remaining["done"] = True + execution_duration = int(time_mod.time() - execution_start_time) database.update_schedule_execution_log( log_id, @@ -260,19 +284,17 @@ def run_schedule_now_api(schedule_id): status="completed", ) + task_source = f"user_scheduled:{batch_id}" for account_id in account_ids: account = safe_get_account(user_id, account_id) - if not account: - skipped_count += 1 - continue - if account.is_running: + if (not account) or account.is_running: skipped_count += 1 continue - task_source = f"user_scheduled:{batch_id}" with completion_lock: remaining["count"] += 1 - ok, msg = submit_account_task( + + ok, _ = submit_account_task( user_id=user_id, account_id=account_id, browse_type=browse_type, diff --git a/routes/api_screenshots.py b/routes/api_screenshots.py index 0df1ad3..008e67b 100644 --- a/routes/api_screenshots.py +++ b/routes/api_screenshots.py @@ -4,6 +4,7 @@ from __future__ import annotations import os from datetime import datetime +from typing import Iterator import database from app_config import get_config @@ -15,41 +16,67 @@ from services.time_utils import BEIJING_TZ config = get_config() SCREENSHOTS_DIR = config.SCREENSHOTS_DIR +_IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg") api_screenshots_bp = Blueprint("api_screenshots", __name__) +def _get_user_prefix(user_id: int) -> str: + user_info = database.get_user_by_id(user_id) + return user_info["username"] if user_info else f"user{user_id}" + + +def _is_user_screenshot(filename: str, username_prefix: str) -> bool: + return filename.startswith(username_prefix + "_") and filename.lower().endswith(_IMAGE_EXTENSIONS) + + +def _iter_user_screenshot_entries(username_prefix: str) -> Iterator[os.DirEntry]: + if not os.path.exists(SCREENSHOTS_DIR): + return + + with os.scandir(SCREENSHOTS_DIR) as entries: + for entry in entries: + if (not entry.is_file()) or (not _is_user_screenshot(entry.name, username_prefix)): + continue + yield entry + + +def _build_display_name(filename: str) -> str: + base_name, ext = filename.rsplit(".", 1) + parts = base_name.split("_", 1) + if len(parts) > 1: + return f"{parts[1]}.{ext}" + return filename + + @api_screenshots_bp.route("/api/screenshots", methods=["GET"]) @login_required def get_screenshots(): """获取当前用户的截图列表""" user_id = current_user.id - user_info = database.get_user_by_id(user_id) - username_prefix = user_info["username"] if user_info else f"user{user_id}" + username_prefix = _get_user_prefix(user_id) try: screenshots = [] - if os.path.exists(SCREENSHOTS_DIR): - for filename in os.listdir(SCREENSHOTS_DIR): - if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"): - filepath = os.path.join(SCREENSHOTS_DIR, filename) - stat = os.stat(filepath) - created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ) - parts = filename.rsplit(".", 1)[0].split("_", 1) - if len(parts) > 1: - display_name = parts[1] + "." + filename.rsplit(".", 1)[1] - else: - display_name = filename + for entry in _iter_user_screenshot_entries(username_prefix): + filename = entry.name + stat = entry.stat() + created_time = datetime.fromtimestamp(stat.st_mtime, tz=BEIJING_TZ) + + screenshots.append( + { + "filename": filename, + "display_name": _build_display_name(filename), + "size": stat.st_size, + "created": created_time.strftime("%Y-%m-%d %H:%M:%S"), + "_created_ts": stat.st_mtime, + } + ) + + screenshots.sort(key=lambda item: item.get("_created_ts", 0), reverse=True) + for item in screenshots: + item.pop("_created_ts", None) - screenshots.append( - { - "filename": filename, - "display_name": display_name, - "size": stat.st_size, - "created": created_time.strftime("%Y-%m-%d %H:%M:%S"), - } - ) - screenshots.sort(key=lambda x: x["created"], reverse=True) return jsonify(screenshots) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -60,10 +87,9 @@ def get_screenshots(): def serve_screenshot(filename): """提供截图文件访问""" user_id = current_user.id - user_info = database.get_user_by_id(user_id) - username_prefix = user_info["username"] if user_info else f"user{user_id}" + username_prefix = _get_user_prefix(user_id) - if not filename.startswith(username_prefix + "_"): + if not _is_user_screenshot(filename, username_prefix): return jsonify({"error": "无权访问"}), 403 if not is_safe_path(SCREENSHOTS_DIR, filename): @@ -77,12 +103,14 @@ def serve_screenshot(filename): def delete_screenshot(filename): """删除指定截图""" user_id = current_user.id - user_info = database.get_user_by_id(user_id) - username_prefix = user_info["username"] if user_info else f"user{user_id}" + username_prefix = _get_user_prefix(user_id) - if not filename.startswith(username_prefix + "_"): + if not _is_user_screenshot(filename, username_prefix): return jsonify({"error": "无权删除"}), 403 + if not is_safe_path(SCREENSHOTS_DIR, filename): + return jsonify({"error": "非法路径"}), 403 + try: filepath = os.path.join(SCREENSHOTS_DIR, filename) if os.path.exists(filepath): @@ -99,19 +127,15 @@ def delete_screenshot(filename): def clear_all_screenshots(): """清空当前用户的所有截图""" user_id = current_user.id - user_info = database.get_user_by_id(user_id) - username_prefix = user_info["username"] if user_info else f"user{user_id}" + username_prefix = _get_user_prefix(user_id) try: deleted_count = 0 - if os.path.exists(SCREENSHOTS_DIR): - for filename in os.listdir(SCREENSHOTS_DIR): - if filename.lower().endswith((".png", ".jpg", ".jpeg")) and filename.startswith(username_prefix + "_"): - filepath = os.path.join(SCREENSHOTS_DIR, filename) - os.remove(filepath) - deleted_count += 1 + for entry in _iter_user_screenshot_entries(username_prefix): + os.remove(entry.path) + deleted_count += 1 + log_to_client(f"清理了 {deleted_count} 个截图文件", user_id) return jsonify({"success": True, "deleted": deleted_count}) except Exception as e: return jsonify({"error": str(e)}), 500 - diff --git a/routes/api_user.py b/routes/api_user.py index cb40502..f5bcb2e 100644 --- a/routes/api_user.py +++ b/routes/api_user.py @@ -10,12 +10,96 @@ from flask import Blueprint, jsonify, request from flask_login import current_user, login_required from routes.pages import render_app_spa_or_legacy from services.state import check_email_rate_limit, check_ip_request_rate, safe_iter_task_status_items +from services.tasks import get_task_scheduler logger = get_logger("app") api_user_bp = Blueprint("api_user", __name__) +def _get_current_user_record(): + return database.get_user_by_id(current_user.id) + + +def _get_current_user_or_404(): + user = _get_current_user_record() + if user: + return user, None + return None, (jsonify({"error": "用户不存在"}), 404) + + +def _get_current_username(*, fallback: str) -> str: + user = _get_current_user_record() + username = (user or {}).get("username", "") + return username if username else fallback + + +def _coerce_binary_flag(value, *, field_label: str): + if isinstance(value, bool): + value = 1 if value else 0 + try: + value = int(value) + except Exception: + return None, f"{field_label}必须是0或1" + if value not in (0, 1): + return None, f"{field_label}必须是0或1" + return value, None + + +def _check_bind_email_rate_limits(email: str): + client_ip = get_rate_limit_ip() + allowed, error_msg = check_ip_request_rate(client_ip, "email") + if not allowed: + return False, error_msg, 429 + allowed, error_msg = check_email_rate_limit(email, "bind_email") + if not allowed: + return False, error_msg, 429 + return True, "", 200 + + +def _render_verify_bind_failed(*, title: str, error_message: str): + spa_initial_state = { + "page": "verify_result", + "success": False, + "title": title, + "error_message": error_message, + "primary_label": "返回登录", + "primary_url": "/login", + } + return render_app_spa_or_legacy( + "verify_failed.html", + legacy_context={"error_message": error_message}, + spa_initial_state=spa_initial_state, + ) + + +def _render_verify_bind_success(email: str): + spa_initial_state = { + "page": "verify_result", + "success": True, + "title": "邮箱绑定成功", + "message": f"邮箱 {email} 已成功绑定到您的账号!", + "primary_label": "返回登录", + "primary_url": "/login", + "redirect_url": "/login", + "redirect_seconds": 5, + } + return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state) + + +def _get_current_running_count(user_id: int) -> int: + try: + queue_snapshot = get_task_scheduler().get_queue_state_snapshot() or {} + running_by_user = queue_snapshot.get("running_by_user") or {} + return int(running_by_user.get(int(user_id), running_by_user.get(str(user_id), 0)) or 0) + except Exception: + current_running = 0 + for _, info in safe_iter_task_status_items(): + if info.get("user_id") == user_id and info.get("status") == "运行中": + current_running += 1 + return current_running + + @api_user_bp.route("/api/announcements/active", methods=["GET"]) @login_required def get_active_announcement(): @@ -77,8 +161,7 @@ def submit_feedback(): if len(description) > 2000: return jsonify({"error": "描述不能超过2000个字符"}), 400 - user_info = database.get_user_by_id(current_user.id) - username = user_info["username"] if user_info else f"用户{current_user.id}" + username = _get_current_username(fallback=f"用户{current_user.id}") feedback_id = database.create_bug_feedback( user_id=current_user.id, @@ -104,8 +187,7 @@ def get_my_feedbacks(): def get_current_user_vip(): """获取当前用户VIP信息""" vip_info = database.get_user_vip_info(current_user.id) - user_info = database.get_user_by_id(current_user.id) - vip_info["username"] = user_info["username"] if user_info else "Unknown" + vip_info["username"] = _get_current_username(fallback="Unknown") return jsonify(vip_info) @@ -124,9 +206,9 @@ def change_user_password(): if not is_valid: return jsonify({"error": error_msg}), 400 - user = database.get_user_by_id(current_user.id) - if not user: - return jsonify({"error": "用户不存在"}), 404 + user, error_response = _get_current_user_or_404() + if error_response: + return error_response username = user.get("username", "") if not username or not database.verify_user(username, current_password): @@ -141,9 +223,9 @@ def change_user_password(): @login_required def get_user_email(): """获取当前用户的邮箱信息""" - user = database.get_user_by_id(current_user.id) - if not user: - return jsonify({"error": "用户不存在"}), 404 + user, error_response = _get_current_user_or_404() + if error_response: + return error_response return jsonify({"email": user.get("email", ""), "email_verified": user.get("email_verified", False)}) @@ -172,14 +254,9 @@ def update_user_kdocs_settings(): return jsonify({"error": "县区长度不能超过50"}), 400 if kdocs_auto_upload is not None: - if isinstance(kdocs_auto_upload, bool): - kdocs_auto_upload = 1 if kdocs_auto_upload else 0 - try: - kdocs_auto_upload = int(kdocs_auto_upload) - except Exception: - return jsonify({"error": "自动上传开关必须是0或1"}), 400 - if kdocs_auto_upload not in (0, 1): - return jsonify({"error": "自动上传开关必须是0或1"}), 400 + kdocs_auto_upload, parse_error = _coerce_binary_flag(kdocs_auto_upload, field_label="自动上传开关") + if parse_error: + return jsonify({"error": parse_error}), 400 if not database.update_user_kdocs_settings( current_user.id, @@ -207,13 +284,9 @@ def bind_user_email(): if not is_valid: return jsonify({"error": error_msg}), 400 - client_ip = get_rate_limit_ip() - allowed, error_msg = check_ip_request_rate(client_ip, "email") + allowed, error_msg, status_code = _check_bind_email_rate_limits(email) if not allowed: - return jsonify({"error": error_msg}), 429 - allowed, error_msg = check_email_rate_limit(email, "bind_email") - if not allowed: - return jsonify({"error": error_msg}), 429 + return jsonify({"error": error_msg}), status_code settings = email_service.get_email_settings() if not settings.get("enabled", False): @@ -223,9 +296,9 @@ def bind_user_email(): if existing_user and existing_user["id"] != current_user.id: return jsonify({"error": "该邮箱已被其他用户绑定"}), 400 - user = database.get_user_by_id(current_user.id) - if not user: - return jsonify({"error": "用户不存在"}), 404 + user, error_response = _get_current_user_or_404() + if error_response: + return error_response if user.get("email") == email and user.get("email_verified"): return jsonify({"error": "该邮箱已绑定并验证"}), 400 @@ -247,56 +320,20 @@ def verify_bind_email(token): email = result["email"] if database.update_user_email(user_id, email, verified=True): - spa_initial_state = { - "page": "verify_result", - "success": True, - "title": "邮箱绑定成功", - "message": f"邮箱 {email} 已成功绑定到您的账号!", - "primary_label": "返回登录", - "primary_url": "/login", - "redirect_url": "/login", - "redirect_seconds": 5, - } - return render_app_spa_or_legacy("verify_success.html", spa_initial_state=spa_initial_state) + return _render_verify_bind_success(email) - error_message = "邮箱绑定失败,请重试" - spa_initial_state = { - "page": "verify_result", - "success": False, - "title": "绑定失败", - "error_message": error_message, - "primary_label": "返回登录", - "primary_url": "/login", - } - return render_app_spa_or_legacy( - "verify_failed.html", - legacy_context={"error_message": error_message}, - spa_initial_state=spa_initial_state, - ) + return _render_verify_bind_failed(title="绑定失败", error_message="邮箱绑定失败,请重试") - error_message = "验证链接已过期或无效,请重新发送验证邮件" - spa_initial_state = { - "page": "verify_result", - "success": False, - "title": "链接无效", - "error_message": error_message, - "primary_label": "返回登录", - "primary_url": "/login", - } - return render_app_spa_or_legacy( - "verify_failed.html", - legacy_context={"error_message": error_message}, - spa_initial_state=spa_initial_state, - ) + return _render_verify_bind_failed(title="链接无效", error_message="验证链接已过期或无效,请重新发送验证邮件") @api_user_bp.route("/api/user/unbind-email", methods=["POST"]) @login_required def unbind_user_email(): """解绑用户邮箱""" - user = database.get_user_by_id(current_user.id) - if not user: - return jsonify({"error": "用户不存在"}), 404 + user, error_response = _get_current_user_or_404() + if error_response: + return error_response if not user.get("email"): return jsonify({"error": "当前未绑定邮箱"}), 400 @@ -334,10 +371,7 @@ def get_run_stats(): stats = database.get_user_run_stats(user_id) - current_running = 0 - for _, info in safe_iter_task_status_items(): - if info.get("user_id") == user_id and info.get("status") == "运行中": - current_running += 1 + current_running = _get_current_running_count(user_id) return jsonify( { diff --git a/routes/decorators.py b/routes/decorators.py index f99f9c3..5267db4 100644 --- a/routes/decorators.py +++ b/routes/decorators.py @@ -31,7 +31,7 @@ def admin_required(f): if is_api: return jsonify({"error": "需要管理员权限"}), 403 return redirect(url_for("pages.admin_login_page")) - logger.info(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}") + logger.debug(f"[admin_required] 管理员 {session.get('admin_username')} 访问 {request.path}") return f(*args, **kwargs) return decorated_function diff --git a/routes/health.py b/routes/health.py index 57abdc4..c1dc202 100644 --- a/routes/health.py +++ b/routes/health.py @@ -2,12 +2,62 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import os +import time + from flask import Blueprint, jsonify import database +import db_pool from services.time_utils import get_beijing_now health_bp = Blueprint("health", __name__) +_PROCESS_START_TS = time.time() + + +def _build_runtime_metrics() -> dict: + metrics = { + "uptime_seconds": max(0, int(time.time() - _PROCESS_START_TS)), + } + + try: + pool_stats = db_pool.get_pool_stats() or {} + metrics["db_pool"] = { + "pool_size": int(pool_stats.get("pool_size", 0) or 0), + "available": int(pool_stats.get("available", 0) or 0), + "in_use": int(pool_stats.get("in_use", 0) or 0), + } + except Exception: + pass + + try: + import psutil + + proc = psutil.Process(os.getpid()) + with proc.oneshot(): + mem_info = proc.memory_info() + metrics["process"] = { + "rss_mb": round(float(mem_info.rss) / 1024 / 1024, 2), + "cpu_percent": round(float(proc.cpu_percent(interval=None)), 2), + "threads": int(proc.num_threads()), + } + except Exception: + pass + + try: + from services import tasks as tasks_module + + scheduler = getattr(tasks_module, "_task_scheduler", None) + if scheduler is not None: + queue_snapshot = scheduler.get_queue_state_snapshot() or {} + metrics["task_queue"] = { + "pending_total": int(queue_snapshot.get("pending_total", 0) or 0), + "running_total": int(queue_snapshot.get("running_total", 0) or 0), + } + except Exception: + pass + + return metrics @health_bp.route("/health", methods=["GET"]) @@ -26,6 +76,6 @@ def health_check(): "time": get_beijing_now().strftime("%Y-%m-%d %H:%M:%S"), "db_ok": db_ok, "db_error": db_error, + "metrics": _build_runtime_metrics(), } return jsonify(payload), (200 if db_ok else 500) - diff --git a/scripts/HEALTH_MONITOR_README.md b/scripts/HEALTH_MONITOR_README.md new file mode 100644 index 0000000..e40fc26 --- /dev/null +++ b/scripts/HEALTH_MONITOR_README.md @@ -0,0 +1,60 @@ +# 健康监控(邮件版) + +本目录提供 `health_email_monitor.py`,通过调用 `/health` 接口并使用**容器内已有邮件配置**发告警邮件。 + +## 1) 快速试跑 + +```bash +cd /root/zsglpt +python3 scripts/health_email_monitor.py \ + --to 你的告警邮箱@example.com \ + --container knowledge-automation-multiuser \ + --url http://127.0.0.1:51232/health \ + --dry-run +``` + +去掉 `--dry-run` 即会实际发邮件。 + +## 2) 建议 cron(每分钟) + +```bash +* * * * * cd /root/zsglpt && /usr/bin/python3 scripts/health_email_monitor.py \ + --to 你的告警邮箱@example.com \ + --container knowledge-automation-multiuser \ + --url http://127.0.0.1:51232/health \ + >> /root/zsglpt/logs/health_monitor.log 2>&1 +``` + +## 3) 支持的规则 + +- `service_down`:健康接口请求失败(立即告警) +- `health_fail`:返回 `ok/db_ok` 异常或 HTTP 5xx(立即告警) +- `db_pool_exhausted`:连接池耗尽(默认连续 3 次才告警) +- `queue_backlog_high`:任务堆积过高(默认 `pending_total >= 50` 且连续 5 次) + +脚本支持恢复通知(规则恢复正常会发“恢复”邮件)。 + +## 4) 常用参数 + +- `--to`:收件人(必填) +- `--container`:Docker 容器名(默认 `knowledge-automation-multiuser`) +- `--url`:健康地址(默认 `http://127.0.0.1:51232/health`) +- `--state-file`:状态文件路径(默认 `/tmp/zsglpt_health_monitor_state.json`) +- `--remind-seconds`:重复告警间隔(默认 3600 秒) +- `--queue-threshold`:队列告警阈值(默认 50) +- `--queue-streak`:队列连续次数阈值(默认 5) +- `--db-pool-streak`:连接池连续次数阈值(默认 3) + +## 5) 环境变量方式(可选) + +也可不用命令行参数,改用环境变量: + +- `MONITOR_EMAIL_TO` +- `MONITOR_DOCKER_CONTAINER` +- `HEALTH_URL` +- `MONITOR_STATE_FILE` +- `MONITOR_REMIND_SECONDS` +- `MONITOR_QUEUE_THRESHOLD` +- `MONITOR_QUEUE_STREAK` +- `MONITOR_DB_POOL_STREAK` + diff --git a/scripts/health_email_monitor.py b/scripts/health_email_monitor.py new file mode 100644 index 0000000..0abfb37 --- /dev/null +++ b/scripts/health_email_monitor.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import time +from datetime import datetime +from typing import Any, Dict, Tuple +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +DEFAULT_STATE_FILE = "/tmp/zsglpt_health_monitor_state.json" + + +def _now_text() -> str: + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def _safe_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except Exception: + return int(default) + + +def _safe_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except Exception: + return float(default) + + +def _load_state(path: str) -> Dict[str, Any]: + if not path or not os.path.exists(path): + return { + "version": 1, + "rules": {}, + "counters": {}, + } + try: + with open(path, "r", encoding="utf-8") as f: + raw = json.load(f) + if not isinstance(raw, dict): + raise ValueError("state is not dict") + raw.setdefault("version", 1) + raw.setdefault("rules", {}) + raw.setdefault("counters", {}) + return raw + except Exception: + return { + "version": 1, + "rules": {}, + "counters": {}, + } + + +def _save_state(path: str, state: Dict[str, Any]) -> None: + if not path: + return + state_dir = os.path.dirname(path) + if state_dir: + os.makedirs(state_dir, exist_ok=True) + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(state, f, ensure_ascii=False, indent=2) + os.replace(tmp_path, path) + + +def _fetch_health(url: str, timeout: int) -> Tuple[int | None, Dict[str, Any], str | None]: + req = Request( + url, + headers={ + "User-Agent": "zsglpt-health-email-monitor/1.0", + "Accept": "application/json", + }, + method="GET", + ) + try: + with urlopen(req, timeout=max(1, int(timeout))) as resp: + status = int(resp.getcode()) + body = resp.read().decode("utf-8", errors="ignore") + except HTTPError as e: + status = int(getattr(e, "code", 0) or 0) + body = "" + try: + body = e.read().decode("utf-8", errors="ignore") + except Exception: + pass + data = {} + if body: + try: + data = json.loads(body) + if not isinstance(data, dict): + data = {} + except Exception: + data = {} + return status, data, f"HTTPError: {e}" + except URLError as e: + return None, {}, f"URLError: {e}" + except Exception as e: + return None, {}, f"RequestError: {e}" + + data: Dict[str, Any] = {} + if body: + try: + loaded = json.loads(body) + if isinstance(loaded, dict): + data = loaded + except Exception: + data = {} + + return status, data, None + + +def _inc_streak(state: Dict[str, Any], key: str, bad: bool) -> int: + counters = state.setdefault("counters", {}) + current = _safe_int(counters.get(key), 0) + current = (current + 1) if bad else 0 + counters[key] = current + return current + + +def _rule_transition( + state: Dict[str, Any], + *, + rule_name: str, + bad: bool, + streak: int, + threshold: int, + remind_seconds: int, + now_ts: float, +) -> str | None: + rules = state.setdefault("rules", {}) + rule_state = rules.setdefault(rule_name, {"active": False, "last_sent": 0}) + + is_active = bool(rule_state.get("active", False)) + last_sent = _safe_float(rule_state.get("last_sent", 0), 0.0) + threshold = max(1, int(threshold)) + remind_seconds = max(60, int(remind_seconds)) + + if bad and streak >= threshold: + if not is_active: + rule_state["active"] = True + rule_state["last_sent"] = now_ts + return "alert" + if (now_ts - last_sent) >= remind_seconds: + rule_state["last_sent"] = now_ts + return "alert" + return None + + if is_active and (not bad): + rule_state["active"] = False + rule_state["last_sent"] = now_ts + return "recover" + + return None + + +def _send_email_via_container( + *, + container_name: str, + to_email: str, + subject: str, + body: str, + timeout_seconds: int = 45, +) -> Tuple[bool, str]: + code = ( + "import sys,email_service;" + "res=email_service.send_email(to_email=sys.argv[1],subject=sys.argv[2],body=sys.argv[3],email_type='health_monitor');" + "ok=bool(res.get('success'));" + "print('ok' if ok else ('error:'+str(res.get('error'))));" + "raise SystemExit(0 if ok else 2)" + ) + cmd = [ + "docker", + "exec", + container_name, + "python", + "-c", + code, + to_email, + subject, + body, + ] + try: + proc = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=max(5, int(timeout_seconds)), + check=False, + ) + except Exception as e: + return False, str(e) + + output = (proc.stdout or "") + (proc.stderr or "") + output = output.strip() + return proc.returncode == 0, output + + +def _build_common_lines(status: int | None, data: Dict[str, Any], fetch_error: str | None) -> list[str]: + metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {} + db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {} + process = metrics.get("process") if isinstance(metrics.get("process"), dict) else {} + task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {} + + lines = [ + f"时间: {_now_text()}", + f"健康地址: {data.get('_monitor_url', '')}", + f"HTTP状态: {status if status is not None else '请求失败'}", + f"ok/db_ok: {data.get('ok')} / {data.get('db_ok')}", + ] + if fetch_error: + lines.append(f"请求错误: {fetch_error}") + lines.extend( + [ + f"队列: pending={task_queue.get('pending_total', 'N/A')}, running={task_queue.get('running_total', 'N/A')}", + f"连接池: size={db_pool.get('pool_size', 'N/A')}, available={db_pool.get('available', 'N/A')}, in_use={db_pool.get('in_use', 'N/A')}", + f"进程: rss_mb={process.get('rss_mb', 'N/A')}, cpu%={process.get('cpu_percent', 'N/A')}, threads={process.get('threads', 'N/A')}", + f"运行时长: {metrics.get('uptime_seconds', 'N/A')} 秒", + ] + ) + return lines + + +def main() -> int: + parser = argparse.ArgumentParser(description="zsglpt 邮件健康监控(基于 /health)") + parser.add_argument("--url", default=os.environ.get("HEALTH_URL", "http://127.0.0.1:51232/health")) + parser.add_argument("--to", default=os.environ.get("MONITOR_EMAIL_TO", "")) + parser.add_argument( + "--container", + default=os.environ.get("MONITOR_DOCKER_CONTAINER", "knowledge-automation-multiuser"), + ) + parser.add_argument("--state-file", default=os.environ.get("MONITOR_STATE_FILE", DEFAULT_STATE_FILE)) + parser.add_argument("--timeout", type=int, default=_safe_int(os.environ.get("MONITOR_TIMEOUT", 8), 8)) + parser.add_argument( + "--remind-seconds", + type=int, + default=_safe_int(os.environ.get("MONITOR_REMIND_SECONDS", 3600), 3600), + ) + parser.add_argument( + "--queue-threshold", + type=int, + default=_safe_int(os.environ.get("MONITOR_QUEUE_THRESHOLD", 50), 50), + ) + parser.add_argument( + "--queue-streak", + type=int, + default=_safe_int(os.environ.get("MONITOR_QUEUE_STREAK", 5), 5), + ) + parser.add_argument( + "--db-pool-streak", + type=int, + default=_safe_int(os.environ.get("MONITOR_DB_POOL_STREAK", 3), 3), + ) + parser.add_argument("--dry-run", action="store_true", help="仅打印,不实际发邮件") + args = parser.parse_args() + + if not args.to: + print("[monitor] 缺少收件人,请设置 --to 或 MONITOR_EMAIL_TO", flush=True) + return 2 + + state = _load_state(args.state_file) + now_ts = time.time() + + status, data, fetch_error = _fetch_health(args.url, args.timeout) + if not isinstance(data, dict): + data = {} + data["_monitor_url"] = args.url + + metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else {} + db_pool = metrics.get("db_pool") if isinstance(metrics.get("db_pool"), dict) else {} + task_queue = metrics.get("task_queue") if isinstance(metrics.get("task_queue"), dict) else {} + + service_down = status is None + health_fail = bool(status is not None and (status >= 500 or (not data.get("ok", False)) or (not data.get("db_ok", False)))) + db_pool_exhausted = ( + _safe_int(db_pool.get("pool_size"), 0) > 0 + and _safe_int(db_pool.get("available"), 0) <= 0 + and _safe_int(db_pool.get("in_use"), 0) >= _safe_int(db_pool.get("pool_size"), 0) + ) + queue_backlog_high = _safe_int(task_queue.get("pending_total"), 0) >= max(1, int(args.queue_threshold)) + + rule_defs = [ + ("service_down", service_down, 1), + ("health_fail", health_fail, 1), + ("db_pool_exhausted", db_pool_exhausted, max(1, int(args.db_pool_streak))), + ("queue_backlog_high", queue_backlog_high, max(1, int(args.queue_streak))), + ] + + pending_notifications: list[tuple[str, str]] = [] + for rule_name, bad, threshold in rule_defs: + streak = _inc_streak(state, rule_name, bad) + action = _rule_transition( + state, + rule_name=rule_name, + bad=bad, + streak=streak, + threshold=threshold, + remind_seconds=args.remind_seconds, + now_ts=now_ts, + ) + if action: + pending_notifications.append((rule_name, action)) + + _save_state(args.state_file, state) + + if not pending_notifications: + print(f"[monitor] {_now_text()} 正常,无需发送邮件") + return 0 + + common_lines = _build_common_lines(status, data, fetch_error) + + for rule_name, action in pending_notifications: + level = "告警" if action == "alert" else "恢复" + subject = f"[zsglpt健康监控][{level}] {rule_name}" + body_lines = [ + f"规则: {rule_name}", + f"状态: {level}", + "", + *common_lines, + ] + body = "\n".join(body_lines) + + if args.dry_run: + print(f"[monitor][dry-run] subject={subject}\n{body}\n") + continue + + ok, msg = _send_email_via_container( + container_name=args.container, + to_email=args.to, + subject=subject, + body=body, + ) + if ok: + print(f"[monitor] 邮件已发送: {subject}") + else: + print(f"[monitor] 邮件发送失败: {subject} | {msg}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/services/kdocs_uploader.py b/services/kdocs_uploader.py index 88d7f87..4d25fc5 100644 --- a/services/kdocs_uploader.py +++ b/services/kdocs_uploader.py @@ -243,6 +243,35 @@ class KDocsUploader: except queue.Empty: return {"success": False, "error": "操作超时"} + def _put_task_response(self, task: Dict[str, Any], result: Dict[str, Any]) -> None: + response_queue = task.get("response") + if not response_queue: + return + try: + response_queue.put(result) + except Exception: + return + + def _process_task(self, task: Dict[str, Any]) -> bool: + action = task.get("action") + payload = task.get("payload") or {} + + if action == "shutdown": + return False + if action == "upload": + self._handle_upload(payload) + return True + if action == "qr": + self._put_task_response(task, self._handle_qr(payload)) + return True + if action == "clear_login": + self._put_task_response(task, self._handle_clear_login()) + return True + if action == "status": + self._put_task_response(task, self._handle_status_check()) + return True + return True + def _run(self) -> None: thread_id = self._thread_id logger.info(f"[KDocs] 上传线程启动 (ID={thread_id})") @@ -261,34 +290,17 @@ class KDocsUploader: # 更新最后活动时间 self._last_activity = time.time() - action = task.get("action") - if action == "shutdown": - break - try: - if action == "upload": - self._handle_upload(task.get("payload") or {}) - elif action == "qr": - result = self._handle_qr(task.get("payload") or {}) - task.get("response").put(result) - elif action == "clear_login": - result = self._handle_clear_login() - task.get("response").put(result) - elif action == "status": - result = self._handle_status_check() - task.get("response").put(result) + should_continue = self._process_task(task) + if not should_continue: + break # 任务处理完成后更新活动时间 self._last_activity = time.time() except Exception as e: logger.warning(f"[KDocs] 处理任务失败: {e}") - # 如果有响应队列,返回错误 - if "response" in task and task.get("response"): - try: - task["response"].put({"success": False, "error": str(e)}) - except Exception: - pass + self._put_task_response(task, {"success": False, "error": str(e)}) except Exception as e: logger.warning(f"[KDocs] 线程主循环异常: {e}") @@ -830,18 +842,180 @@ class KDocsUploader: except Exception as e: logger.warning(f"[KDocs] 保存登录态失败: {e}") + def _resolve_doc_url(self, cfg: Dict[str, Any]) -> str: + return (cfg.get("kdocs_doc_url") or "").strip() + + def _ensure_doc_access( + self, + doc_url: str, + *, + fast: bool = False, + use_storage_state: bool = True, + ) -> Optional[str]: + if not self._ensure_playwright(use_storage_state=use_storage_state): + return self._last_error or "浏览器不可用" + if not self._open_document(doc_url, fast=fast): + return self._last_error or "打开文档失败" + return None + + def _trigger_fast_login_dialog(self, timeout_ms: int) -> None: + self._ensure_login_dialog( + timeout_ms=timeout_ms, + frame_timeout_ms=timeout_ms, + quick=True, + ) + + def _capture_qr_with_retry(self, fast_login_timeout: int) -> Tuple[Optional[bytes], Optional[bytes]]: + qr_image = None + invalid_qr = None + for attempt in range(10): + if attempt in (3, 7): + self._trigger_fast_login_dialog(fast_login_timeout) + candidate = self._capture_qr_image() + if candidate and self._is_valid_qr_image(candidate): + qr_image = candidate + break + if candidate: + invalid_qr = candidate + time.sleep(0.8) # 优化: 1 -> 0.8 + return qr_image, invalid_qr + + def _save_qr_debug_artifacts(self, invalid_qr: Optional[bytes]) -> None: + try: + pages = self._iter_pages() + page_urls = [getattr(p, "url", "") for p in pages] + logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}") + + ts = int(time.time()) + saved = [] + for idx, page in enumerate(pages[:3]): + try: + path = f"data/kdocs_debug_{ts}_{idx}.png" + page.screenshot(path=path, full_page=True) + saved.append(path) + except Exception: + continue + + if saved: + logger.warning(f"[KDocs] 已保存调试截图: {saved}") + + if invalid_qr: + try: + path = f"data/kdocs_invalid_qr_{ts}.png" + with open(path, "wb") as handle: + handle.write(invalid_qr) + logger.warning(f"[KDocs] 已保存无效二维码截图: {path}") + except Exception: + pass + except Exception: + pass + + def _log_upload_failure(self, message: str, user_id: Any, account_id: Any) -> None: + try: + log_to_client(f"表格上传失败: {message}", user_id, account_id) + except Exception: + pass + + def _mark_upload_tracking(self, user_id: Any, account_id: Any) -> Tuple[Any, Optional[str], bool]: + account = None + prev_status = None + status_tracked = False + try: + account = safe_get_account(user_id, account_id) + if account and self._should_mark_upload(account): + prev_status = getattr(account, "status", None) + account.status = "上传截图" + self._emit_account_update(user_id, account) + status_tracked = True + except Exception: + prev_status = None + return account, prev_status, status_tracked + + def _parse_upload_payload(self, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]: + unit = (payload.get("unit") or "").strip() + name = (payload.get("name") or "").strip() + image_path = payload.get("image_path") + + if not unit or not name: + return None + if not image_path or not os.path.exists(image_path): + return None + + return { + "unit": unit, + "name": name, + "image_path": image_path, + "user_id": payload.get("user_id"), + "account_id": payload.get("account_id"), + } + + def _resolve_upload_sheet_config(self, cfg: Dict[str, Any]) -> Dict[str, Any]: + return { + "sheet_name": (cfg.get("kdocs_sheet_name") or "").strip(), + "sheet_index": int(cfg.get("kdocs_sheet_index") or 0), + "unit_col": (cfg.get("kdocs_unit_column") or "A").strip().upper(), + "image_col": (cfg.get("kdocs_image_column") or "D").strip().upper(), + "row_start": int(cfg.get("kdocs_row_start") or 0), + "row_end": int(cfg.get("kdocs_row_end") or 0), + } + + def _try_upload_to_sheet(self, cfg: Dict[str, Any], unit: str, name: str, image_path: str) -> Tuple[bool, str]: + sheet_cfg = self._resolve_upload_sheet_config(cfg) + success = False + error_msg = "" + + for _ in range(2): + try: + if sheet_cfg["sheet_name"] or sheet_cfg["sheet_index"]: + self._select_sheet(sheet_cfg["sheet_name"], sheet_cfg["sheet_index"]) + + row_num = self._find_person_with_unit( + unit, + name, + sheet_cfg["unit_col"], + row_start=sheet_cfg["row_start"], + row_end=sheet_cfg["row_end"], + ) + if row_num < 0: + error_msg = f"未找到人员: {unit}-{name}" + break + + success = self._upload_image_to_cell(row_num, image_path, sheet_cfg["image_col"]) + if success: + break + except Exception as e: + error_msg = str(e) + + return success, error_msg + + def _handle_upload_login_invalid( + self, + *, + unit: str, + name: str, + image_path: str, + user_id: Any, + account_id: Any, + ) -> None: + error_msg = "登录已失效,请管理员重新扫码登录" + self._login_required = True + self._last_login_ok = False + self._notify_admin(unit, name, image_path, error_msg) + self._log_upload_failure(error_msg, user_id, account_id) + def _handle_qr(self, payload: Dict[str, Any]) -> Dict[str, Any]: cfg = self._load_system_config() - doc_url = (cfg.get("kdocs_doc_url") or "").strip() + doc_url = self._resolve_doc_url(cfg) if not doc_url: return {"success": False, "error": "未配置金山文档链接"} + force = bool(payload.get("force")) if force: self._handle_clear_login() - if not self._ensure_playwright(use_storage_state=not force): - return {"success": False, "error": self._last_error or "浏览器不可用"} - if not self._open_document(doc_url, fast=True): - return {"success": False, "error": self._last_error or "打开文档失败"} + + doc_error = self._ensure_doc_access(doc_url, fast=True, use_storage_state=not force) + if doc_error: + return {"success": False, "error": doc_error} if not force and self._has_saved_login_state() and self._is_logged_in(): self._login_required = False @@ -850,54 +1024,12 @@ class KDocsUploader: return {"success": True, "logged_in": True, "qr_image": ""} fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300")) - self._ensure_login_dialog( - timeout_ms=fast_login_timeout, - frame_timeout_ms=fast_login_timeout, - quick=True, - ) - qr_image = None - invalid_qr = None - for attempt in range(10): - if attempt in (3, 7): - self._ensure_login_dialog( - timeout_ms=fast_login_timeout, - frame_timeout_ms=fast_login_timeout, - quick=True, - ) - candidate = self._capture_qr_image() - if candidate and self._is_valid_qr_image(candidate): - qr_image = candidate - break - if candidate: - invalid_qr = candidate - time.sleep(0.8) # 优化: 1 -> 0.8 + self._trigger_fast_login_dialog(fast_login_timeout) + qr_image, invalid_qr = self._capture_qr_with_retry(fast_login_timeout) + if not qr_image: self._last_error = "二维码识别异常" if invalid_qr else "二维码获取失败" - try: - pages = self._iter_pages() - page_urls = [getattr(p, "url", "") for p in pages] - logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}") - ts = int(time.time()) - saved = [] - for idx, page in enumerate(pages[:3]): - try: - path = f"data/kdocs_debug_{ts}_{idx}.png" - page.screenshot(path=path, full_page=True) - saved.append(path) - except Exception: - continue - if saved: - logger.warning(f"[KDocs] 已保存调试截图: {saved}") - if invalid_qr: - try: - path = f"data/kdocs_invalid_qr_{ts}.png" - with open(path, "wb") as handle: - handle.write(invalid_qr) - logger.warning(f"[KDocs] 已保存无效二维码截图: {path}") - except Exception: - pass - except Exception: - pass + self._save_qr_debug_artifacts(invalid_qr) return {"success": False, "error": self._last_error} try: @@ -933,24 +1065,22 @@ class KDocsUploader: def _handle_status_check(self) -> Dict[str, Any]: cfg = self._load_system_config() - doc_url = (cfg.get("kdocs_doc_url") or "").strip() + doc_url = self._resolve_doc_url(cfg) if not doc_url: return {"success": True, "logged_in": False, "error": "未配置文档链接"} - if not self._ensure_playwright(): - return {"success": False, "logged_in": False, "error": self._last_error or "浏览器不可用"} - if not self._open_document(doc_url, fast=True): - return {"success": False, "logged_in": False, "error": self._last_error or "打开文档失败"} + + doc_error = self._ensure_doc_access(doc_url, fast=True) + if doc_error: + return {"success": False, "logged_in": False, "error": doc_error} + fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300")) - self._ensure_login_dialog( - timeout_ms=fast_login_timeout, - frame_timeout_ms=fast_login_timeout, - quick=True, - ) + self._trigger_fast_login_dialog(fast_login_timeout) self._try_confirm_login( timeout_ms=fast_login_timeout, frame_timeout_ms=fast_login_timeout, quick=True, ) + logged_in = self._is_logged_in() self._last_login_ok = logged_in self._login_required = not logged_in @@ -962,79 +1092,43 @@ class KDocsUploader: cfg = self._load_system_config() if int(cfg.get("kdocs_enabled", 0) or 0) != 1: return - doc_url = (cfg.get("kdocs_doc_url") or "").strip() + + doc_url = self._resolve_doc_url(cfg) if not doc_url: return - unit = (payload.get("unit") or "").strip() - name = (payload.get("name") or "").strip() - image_path = payload.get("image_path") - user_id = payload.get("user_id") - account_id = payload.get("account_id") - - if not unit or not name: - return - if not image_path or not os.path.exists(image_path): + upload_data = self._parse_upload_payload(payload) + if not upload_data: return - account = None - prev_status = None - status_tracked = False + unit = upload_data["unit"] + name = upload_data["name"] + image_path = upload_data["image_path"] + user_id = upload_data["user_id"] + account_id = upload_data["account_id"] + + account, prev_status, status_tracked = self._mark_upload_tracking(user_id, account_id) try: - try: - account = safe_get_account(user_id, account_id) - if account and self._should_mark_upload(account): - prev_status = getattr(account, "status", None) - account.status = "上传截图" - self._emit_account_update(user_id, account) - status_tracked = True - except Exception: - prev_status = None - - if not self._ensure_playwright(): - self._notify_admin(unit, name, image_path, self._last_error or "浏览器不可用") - return - - if not self._open_document(doc_url): - self._notify_admin(unit, name, image_path, self._last_error or "打开文档失败") + doc_error = self._ensure_doc_access(doc_url) + if doc_error: + self._notify_admin(unit, name, image_path, doc_error) return if not self._is_logged_in(): - self._login_required = True - self._last_login_ok = False - self._notify_admin(unit, name, image_path, "登录已失效,请管理员重新扫码登录") - try: - log_to_client("表格上传失败: 登录已失效,请管理员重新扫码登录", user_id, account_id) - except Exception: - pass + self._handle_upload_login_invalid( + unit=unit, + name=name, + image_path=image_path, + user_id=user_id, + account_id=account_id, + ) return + self._login_required = False self._last_login_ok = True - sheet_name = (cfg.get("kdocs_sheet_name") or "").strip() - sheet_index = int(cfg.get("kdocs_sheet_index") or 0) - unit_col = (cfg.get("kdocs_unit_column") or "A").strip().upper() - image_col = (cfg.get("kdocs_image_column") or "D").strip().upper() - row_start = int(cfg.get("kdocs_row_start") or 0) - row_end = int(cfg.get("kdocs_row_end") or 0) - - success = False - error_msg = "" - for attempt in range(2): - try: - if sheet_name or sheet_index: - self._select_sheet(sheet_name, sheet_index) - row_num = self._find_person_with_unit(unit, name, unit_col, row_start=row_start, row_end=row_end) - if row_num < 0: - error_msg = f"未找到人员: {unit}-{name}" - break - success = self._upload_image_to_cell(row_num, image_path, image_col) - if success: - break - except Exception as e: - error_msg = str(e) - + success, error_msg = self._try_upload_to_sheet(cfg, unit, name, image_path) if success: self._last_success_at = time.time() self._last_error = None @@ -1048,10 +1142,7 @@ class KDocsUploader: error_msg = "上传失败" self._last_error = error_msg self._notify_admin(unit, name, image_path, error_msg) - try: - log_to_client(f"表格上传失败: {error_msg}", user_id, account_id) - except Exception: - pass + self._log_upload_failure(error_msg, user_id, account_id) finally: if status_tracked: self._restore_account_status(user_id, account, prev_status) diff --git a/services/maintenance.py b/services/maintenance.py index 7dbfb90..fda72f1 100644 --- a/services/maintenance.py +++ b/services/maintenance.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import os import threading import time from datetime import datetime @@ -10,6 +11,8 @@ from app_config import get_config from app_logger import get_logger from services.state import ( cleanup_expired_ip_rate_limits, + cleanup_expired_ip_request_rates, + cleanup_expired_login_security_state, safe_cleanup_expired_batches, safe_cleanup_expired_captcha, safe_cleanup_expired_pending_random, @@ -31,6 +34,69 @@ PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECON _kdocs_offline_notified: bool = False +def _to_int(value, default: int = 0) -> int: + try: + return int(value) + except Exception: + return int(default) + + +def _collect_active_user_ids() -> set[int]: + active_user_ids: set[int] = set() + for _, info in safe_iter_task_status_items(): + user_id = info.get("user_id") if isinstance(info, dict) else None + if user_id is None: + continue + try: + active_user_ids.add(int(user_id)) + except Exception: + continue + return active_user_ids + + +def _find_expired_user_cache_ids(current_time: float, active_user_ids: set[int]) -> list[int]: + expired_users = [] + for user_id, last_access in (safe_get_user_accounts_last_access_items() or []): + try: + user_id_int = int(user_id) + last_access_ts = float(last_access) + except Exception: + continue + if (current_time - last_access_ts) <= USER_ACCOUNTS_EXPIRE_SECONDS: + continue + if user_id_int in active_user_ids: + continue + if safe_has_user(user_id_int): + expired_users.append(user_id_int) + return expired_users + + +def _find_completed_task_status_ids(current_time: float) -> list[str]: + completed_task_ids = [] + for account_id, status_data in safe_iter_task_status_items(): + status = status_data.get("status") if isinstance(status_data, dict) else None + if status not in ["已完成", "失败", "已停止"]: + continue + + start_time = float(status_data.get("start_time", 0) or 0) + if (current_time - start_time) > 600: # 10分钟 + completed_task_ids.append(account_id) + return completed_task_ids + + +def _reap_zombie_processes() -> None: + while True: + try: + pid, _ = os.waitpid(-1, os.WNOHANG) + if pid == 0: + break + logger.debug(f"已回收僵尸进程: PID={pid}") + except ChildProcessError: + break + except Exception: + break + + def cleanup_expired_data() -> None: """定期清理过期数据,防止内存泄漏(逻辑保持不变)。""" current_time = time.time() @@ -43,48 +109,36 @@ def cleanup_expired_data() -> None: if deleted_ips: logger.debug(f"已清理 {deleted_ips} 个过期IP限流记录") - expired_users = [] - last_access_items = safe_get_user_accounts_last_access_items() - if last_access_items: - task_items = safe_iter_task_status_items() - active_user_ids = {int(info.get("user_id")) for _, info in task_items if info.get("user_id")} - for user_id, last_access in last_access_items: - if (current_time - float(last_access)) <= USER_ACCOUNTS_EXPIRE_SECONDS: - continue - if int(user_id) in active_user_ids: - continue - if safe_has_user(user_id): - expired_users.append(int(user_id)) + deleted_ip_requests = cleanup_expired_ip_request_rates(current_time) + if deleted_ip_requests: + logger.debug(f"已清理 {deleted_ip_requests} 个过期IP请求频率记录") + login_cleanup_stats = cleanup_expired_login_security_state(current_time) + login_cleanup_total = sum(int(v or 0) for v in login_cleanup_stats.values()) + if login_cleanup_total: + logger.debug( + "已清理登录风控缓存: " + f"失败计数={login_cleanup_stats.get('failures', 0)}, " + f"限流桶={login_cleanup_stats.get('rate_limits', 0)}, " + f"扫描状态={login_cleanup_stats.get('scan_states', 0)}, " + f"短时锁={login_cleanup_stats.get('ip_user_locks', 0)}, " + f"告警状态={login_cleanup_stats.get('alerts', 0)}" + ) + + active_user_ids = _collect_active_user_ids() + expired_users = _find_expired_user_cache_ids(current_time, active_user_ids) for user_id in expired_users: safe_remove_user_accounts(user_id) if expired_users: logger.debug(f"已清理 {len(expired_users)} 个过期用户账号缓存") - completed_tasks = [] - for account_id, status_data in safe_iter_task_status_items(): - if status_data.get("status") in ["已完成", "失败", "已停止"]: - start_time = float(status_data.get("start_time", 0) or 0) - if (current_time - start_time) > 600: # 10分钟 - completed_tasks.append(account_id) - for account_id in completed_tasks: + completed_task_ids = _find_completed_task_status_ids(current_time) + for account_id in completed_task_ids: safe_remove_task_status(account_id) - if completed_tasks: - logger.debug(f"已清理 {len(completed_tasks)} 个已完成任务状态") + if completed_task_ids: + logger.debug(f"已清理 {len(completed_task_ids)} 个已完成任务状态") - try: - import os - - while True: - try: - pid, status = os.waitpid(-1, os.WNOHANG) - if pid == 0: - break - logger.debug(f"已回收僵尸进程: PID={pid}") - except ChildProcessError: - break - except Exception: - pass + _reap_zombie_processes() deleted_batches = safe_cleanup_expired_batches(BATCH_TASK_EXPIRE_SECONDS, current_time) if deleted_batches: @@ -95,52 +149,39 @@ def cleanup_expired_data() -> None: logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务") -def check_kdocs_online_status() -> None: - """检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)""" - global _kdocs_offline_notified +def _load_kdocs_monitor_config(): + import database + cfg = database.get_system_config() + if not cfg: + return None + + kdocs_enabled = _to_int(cfg.get("kdocs_enabled"), 0) + if not kdocs_enabled: + return None + + admin_notify_enabled = _to_int(cfg.get("kdocs_admin_notify_enabled"), 0) + admin_notify_email = str(cfg.get("kdocs_admin_notify_email") or "").strip() + if (not admin_notify_enabled) or (not admin_notify_email): + return None + + return admin_notify_email + + +def _is_kdocs_offline(status: dict) -> tuple[bool, bool, bool | None]: + login_required = bool(status.get("login_required", False)) + last_login_ok = status.get("last_login_ok") + is_offline = login_required or (last_login_ok is False) + return is_offline, login_required, last_login_ok + + +def _send_kdocs_offline_alert(admin_notify_email: str, *, login_required: bool, last_login_ok) -> bool: try: - import database - from services.kdocs_uploader import get_kdocs_uploader + import email_service - # 获取系统配置 - cfg = database.get_system_config() - if not cfg: - return - - # 检查是否启用了金山文档功能 - kdocs_enabled = int(cfg.get("kdocs_enabled") or 0) - if not kdocs_enabled: - return - - # 检查是否启用了管理员通知 - admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0) - admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip() - if not admin_notify_enabled or not admin_notify_email: - return - - # 获取金山文档状态 - kdocs = get_kdocs_uploader() - status = kdocs.get_status() - login_required = status.get("login_required", False) - last_login_ok = status.get("last_login_ok") - - # 如果需要登录或最后登录状态不是成功 - is_offline = login_required or (last_login_ok is False) - - if is_offline: - # 已经通知过了,不再重复通知 - if _kdocs_offline_notified: - logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知") - return - - # 发送邮件通知 - try: - import email_service - - now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - subject = "【金山文档离线告警】需要重新登录" - body = f""" + now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + subject = "【金山文档离线告警】需要重新登录" + body = f""" 您好, 系统检测到金山文档上传功能已离线,需要重新扫码登录。 @@ -155,58 +196,92 @@ def check_kdocs_online_status() -> None: --- 此邮件由系统自动发送,请勿直接回复。 """ - email_service.send_email_async( - to_email=admin_notify_email, - subject=subject, - body=body, - email_type="kdocs_offline_alert", - ) - _kdocs_offline_notified = True # 标记为已通知 - logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}") - except Exception as e: - logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}") - else: - # 恢复在线,重置通知状态 + email_service.send_email_async( + to_email=admin_notify_email, + subject=subject, + body=body, + email_type="kdocs_offline_alert", + ) + logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}") + return True + except Exception as e: + logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}") + return False + + +def check_kdocs_online_status() -> None: + """检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)""" + global _kdocs_offline_notified + + try: + admin_notify_email = _load_kdocs_monitor_config() + if not admin_notify_email: + return + + from services.kdocs_uploader import get_kdocs_uploader + + kdocs = get_kdocs_uploader() + status = kdocs.get_status() or {} + is_offline, login_required, last_login_ok = _is_kdocs_offline(status) + + if is_offline: if _kdocs_offline_notified: - logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态") - _kdocs_offline_notified = False - logger.debug("[KDocs监控] 金山文档状态正常") + logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知") + return + + if _send_kdocs_offline_alert( + admin_notify_email, + login_required=login_required, + last_login_ok=last_login_ok, + ): + _kdocs_offline_notified = True + return + + if _kdocs_offline_notified: + logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态") + _kdocs_offline_notified = False + logger.debug("[KDocs监控] 金山文档状态正常") except Exception as e: logger.error(f"[KDocs监控] 检测失败: {e}") -def start_cleanup_scheduler() -> None: - """启动定期清理调度器""" - - def cleanup_loop(): +def _start_daemon_loop(name: str, *, startup_delay: float, interval_seconds: float, job, error_tag: str): + def loop(): + if startup_delay > 0: + time.sleep(startup_delay) while True: try: - time.sleep(300) # 每5分钟执行一次清理 - cleanup_expired_data() + job() + time.sleep(interval_seconds) except Exception as e: - logger.error(f"清理任务执行失败: {e}") + logger.error(f"{error_tag}: {e}") + time.sleep(min(60.0, max(1.0, interval_seconds / 5.0))) - cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True, name="cleanup-scheduler") - cleanup_thread.start() + thread = threading.Thread(target=loop, daemon=True, name=name) + thread.start() + return thread + + +def start_cleanup_scheduler() -> None: + """启动定期清理调度器""" + _start_daemon_loop( + "cleanup-scheduler", + startup_delay=300, + interval_seconds=300, + job=cleanup_expired_data, + error_tag="清理任务执行失败", + ) logger.info("内存清理调度器已启动") def start_kdocs_monitor() -> None: """启动金山文档状态监控""" - - def monitor_loop(): - # 启动后等待 60 秒再开始检测(给系统初始化的时间) - time.sleep(60) - while True: - try: - check_kdocs_online_status() - time.sleep(300) # 每5分钟检测一次 - except Exception as e: - logger.error(f"[KDocs监控] 监控任务执行失败: {e}") - time.sleep(60) - - monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor") - monitor_thread.start() + _start_daemon_loop( + "kdocs-monitor", + startup_delay=60, + interval_seconds=300, + job=check_kdocs_online_status, + error_tag="[KDocs监控] 监控任务执行失败", + ) logger.info("[KDocs监控] 金山文档状态监控已启动(每5分钟检测一次)") - diff --git a/services/scheduler.py b/services/scheduler.py index c562bb1..892b705 100644 --- a/services/scheduler.py +++ b/services/scheduler.py @@ -27,6 +27,12 @@ from services.time_utils import get_beijing_now logger = get_logger("app") config = get_config() +try: + _SCHEDULE_SUBMIT_DELAY_SECONDS = float(os.environ.get("SCHEDULE_SUBMIT_DELAY_SECONDS", "0.2")) +except Exception: + _SCHEDULE_SUBMIT_DELAY_SECONDS = 0.2 +_SCHEDULE_SUBMIT_DELAY_SECONDS = max(0.0, _SCHEDULE_SUBMIT_DELAY_SECONDS) + SCREENSHOTS_DIR = config.SCREENSHOTS_DIR os.makedirs(SCREENSHOTS_DIR, exist_ok=True) @@ -55,6 +61,150 @@ def _normalize_hhmm(value: object, *, default: str) -> str: return f"{hour:02d}:{minute:02d}" +def _safe_recompute_schedule_next_run(schedule_id: int) -> None: + try: + database.recompute_schedule_next_run(schedule_id) + except Exception: + pass + + +def _load_accounts_for_users(approved_users: list[dict]) -> tuple[dict[int, dict], list[str]]: + """批量加载用户账号快照。""" + user_accounts: dict[int, dict] = {} + account_ids: list[str] = [] + for user in approved_users: + user_id = user["id"] + accounts = safe_get_user_accounts_snapshot(user_id) + if not accounts: + load_user_accounts(user_id) + accounts = safe_get_user_accounts_snapshot(user_id) + if accounts: + user_accounts[user_id] = accounts + account_ids.extend(list(accounts.keys())) + return user_accounts, account_ids + + +def _should_skip_suspended_account(account_status_info, account, username: str) -> bool: + """判断是否应跳过暂停账号,并输出日志。""" + if not account_status_info: + return False + + status = account_status_info["status"] if "status" in account_status_info.keys() else "active" + if status != "suspended": + return False + + fail_count = account_status_info["login_fail_count"] if "login_fail_count" in account_status_info.keys() else 0 + logger.info( + f"[定时任务] 跳过暂停账号: {account.username} (用户:{username}) - 连续{fail_count}次密码错误,需修改密码" + ) + return True + + +def _parse_schedule_account_ids(schedule_config: dict, schedule_id: int): + import json + + try: + account_ids_raw = schedule_config.get("account_ids", "[]") or "[]" + account_ids = json.loads(account_ids_raw) + except Exception as e: + logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}") + return [] + if isinstance(account_ids, list): + return account_ids + return [] + + +def _create_user_schedule_batch(*, batch_id: str, user_id: int, browse_type: str, schedule_name: str, now_ts: float) -> None: + safe_create_batch( + batch_id, + { + "user_id": user_id, + "browse_type": browse_type, + "schedule_name": schedule_name, + "screenshots": [], + "total_accounts": 0, + "completed": 0, + "created_at": now_ts, + "updated_at": now_ts, + }, + ) + + +def _build_user_schedule_done_callback( + *, + completion_lock: threading.Lock, + remaining: dict, + counters: dict, + execution_start_time: float, + log_id: int, + schedule_id: int, + total_accounts: int, +): + def on_browse_done(): + with completion_lock: + remaining["count"] -= 1 + if remaining["done"] or remaining["count"] > 0: + return + remaining["done"] = True + + execution_duration = int(time.time() - execution_start_time) + started_count = int(counters.get("started", 0) or 0) + database.update_schedule_execution_log( + log_id, + total_accounts=total_accounts, + success_accounts=started_count, + failed_accounts=total_accounts - started_count, + duration_seconds=execution_duration, + status="completed", + ) + logger.info(f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件") + + return on_browse_done + + +def _submit_user_schedule_accounts( + *, + user_id: int, + account_ids: list, + browse_type: str, + enable_screenshot, + task_source: str, + done_callback, + completion_lock: threading.Lock, + remaining: dict, + counters: dict, +) -> tuple[int, int]: + started_count = 0 + skipped_count = 0 + + for account_id in account_ids: + account = safe_get_account(user_id, account_id) + if (not account) or account.is_running: + skipped_count += 1 + continue + + with completion_lock: + remaining["count"] += 1 + ok, msg = submit_account_task( + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + enable_screenshot=enable_screenshot, + source=task_source, + done_callback=done_callback, + ) + if ok: + started_count += 1 + counters["started"] = started_count + else: + with completion_lock: + remaining["count"] -= 1 + skipped_count += 1 + logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}") + + return started_count, skipped_count + + def run_scheduled_task(skip_weekday_check: bool = False) -> None: """执行所有账号的浏览任务(可被手动调用,过滤重复账号)""" try: @@ -87,17 +237,7 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None: cfg = database.get_system_config() enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1 - user_accounts = {} - account_ids = [] - for user in approved_users: - user_id = user["id"] - accounts = safe_get_user_accounts_snapshot(user_id) - if not accounts: - load_user_accounts(user_id) - accounts = safe_get_user_accounts_snapshot(user_id) - if accounts: - user_accounts[user_id] = accounts - account_ids.extend(list(accounts.keys())) + user_accounts, account_ids = _load_accounts_for_users(approved_users) account_statuses = database.get_account_status_batch(account_ids) @@ -113,18 +253,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None: continue account_status_info = account_statuses.get(str(account_id)) - if account_status_info: - status = account_status_info["status"] if "status" in account_status_info.keys() else "active" - if status == "suspended": - fail_count = ( - account_status_info["login_fail_count"] - if "login_fail_count" in account_status_info.keys() - else 0 - ) - logger.info( - f"[定时任务] 跳过暂停账号: {account.username} (用户:{user['username']}) - 连续{fail_count}次密码错误,需修改密码" - ) - continue + if _should_skip_suspended_account(account_status_info, account, user["username"]): + continue if account.username in executed_usernames: skipped_duplicates += 1 @@ -149,7 +279,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None: else: logger.warning(f"[定时任务] 启动失败({account.username}): {msg}") - time.sleep(2) + if _SCHEDULE_SUBMIT_DELAY_SECONDS > 0: + time.sleep(_SCHEDULE_SUBMIT_DELAY_SECONDS) logger.info( f"[定时任务] 执行完成 - 总账号数:{total_accounts}, 已执行:{executed_accounts}, 跳过重复:{skipped_duplicates}" @@ -198,15 +329,16 @@ def scheduled_task_worker() -> None: deleted_screenshots = 0 if os.path.exists(SCREENSHOTS_DIR): cutoff_time = time.time() - (7 * 24 * 60 * 60) - for filename in os.listdir(SCREENSHOTS_DIR): - if filename.lower().endswith((".png", ".jpg", ".jpeg")): - filepath = os.path.join(SCREENSHOTS_DIR, filename) + with os.scandir(SCREENSHOTS_DIR) as entries: + for entry in entries: + if (not entry.is_file()) or (not entry.name.lower().endswith((".png", ".jpg", ".jpeg"))): + continue try: - if os.path.getmtime(filepath) < cutoff_time: - os.remove(filepath) + if entry.stat().st_mtime < cutoff_time: + os.remove(entry.path) deleted_screenshots += 1 except Exception as e: - logger.warning(f"[定时清理] 删除截图失败 {filename}: {str(e)}") + logger.warning(f"[定时清理] 删除截图失败 {entry.name}: {str(e)}") logger.info(f"[定时清理] 已删除 {deleted_screenshots} 个截图文件") logger.info("[定时清理] 清理完成!") @@ -214,10 +346,97 @@ def scheduled_task_worker() -> None: except Exception as e: logger.exception(f"[定时清理] 清理任务出错: {str(e)}") + def _parse_due_schedule_weekdays(schedule_config: dict, schedule_id: int): + weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5") + try: + return [int(d) for d in weekdays_str.split(",") if d.strip()] + except Exception as e: + logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}") + _safe_recompute_schedule_next_run(schedule_id) + return None + + def _execute_due_user_schedule(schedule_config: dict) -> None: + schedule_name = schedule_config.get("name", "未命名任务") + schedule_id = schedule_config["id"] + user_id = schedule_config["user_id"] + browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ)) + enable_screenshot = schedule_config.get("enable_screenshot", 1) + + account_ids = _parse_schedule_account_ids(schedule_config, schedule_id) + if not account_ids: + _safe_recompute_schedule_next_run(schedule_id) + return + + if not safe_get_user_accounts_snapshot(user_id): + load_user_accounts(user_id) + + import uuid + + execution_start_time = time.time() + log_id = database.create_schedule_execution_log( + schedule_id=schedule_id, + user_id=user_id, + schedule_name=schedule_name, + ) + + batch_id = f"batch_{uuid.uuid4().hex[:12]}" + now_ts = time.time() + _create_user_schedule_batch( + batch_id=batch_id, + user_id=user_id, + browse_type=browse_type, + schedule_name=schedule_name, + now_ts=now_ts, + ) + + completion_lock = threading.Lock() + remaining = {"count": 0, "done": False} + counters = {"started": 0} + + on_browse_done = _build_user_schedule_done_callback( + completion_lock=completion_lock, + remaining=remaining, + counters=counters, + execution_start_time=execution_start_time, + log_id=log_id, + schedule_id=schedule_id, + total_accounts=len(account_ids), + ) + + task_source = f"user_scheduled:{batch_id}" + started_count, skipped_count = _submit_user_schedule_accounts( + user_id=user_id, + account_ids=account_ids, + browse_type=browse_type, + enable_screenshot=enable_screenshot, + task_source=task_source, + done_callback=on_browse_done, + completion_lock=completion_lock, + remaining=remaining, + counters=counters, + ) + + batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time.time()) + if batch_info: + _send_batch_task_email_if_configured(batch_info) + + database.update_schedule_last_run(schedule_id) + + logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号,批次ID: {batch_id}") + if started_count <= 0: + database.update_schedule_execution_log( + log_id, + total_accounts=len(account_ids), + success_accounts=0, + failed_accounts=len(account_ids), + duration_seconds=0, + status="completed", + ) + if started_count == 0 and len(account_ids) > 0: + logger.warning("[用户定时任务] ⚠️ 警告:所有账号都被跳过了!请检查user_accounts状态") + def check_user_schedules(): """检查并执行用户定时任务(O-08:next_run_at 索引驱动)。""" - import json - try: now = get_beijing_now() now_str = now.strftime("%Y-%m-%d %H:%M:%S") @@ -226,145 +445,22 @@ def scheduled_task_worker() -> None: due_schedules = database.get_due_user_schedules(now_str, limit=50) or [] for schedule_config in due_schedules: - schedule_name = schedule_config.get("name", "未命名任务") schedule_id = schedule_config["id"] + schedule_name = schedule_config.get("name", "未命名任务") - weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5") - try: - allowed_weekdays = [int(d) for d in weekdays_str.split(",") if d.strip()] - except Exception as e: - logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}") - try: - database.recompute_schedule_next_run(schedule_id) - except Exception: - pass + allowed_weekdays = _parse_due_schedule_weekdays(schedule_config, schedule_id) + if allowed_weekdays is None: continue if current_weekday not in allowed_weekdays: - try: - database.recompute_schedule_next_run(schedule_id) - except Exception: - pass + _safe_recompute_schedule_next_run(schedule_id) continue - logger.info(f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 (next_run_at={schedule_config.get('next_run_at')})") - - user_id = schedule_config["user_id"] - schedule_id = schedule_config["id"] - browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ)) - enable_screenshot = schedule_config.get("enable_screenshot", 1) - - try: - account_ids_raw = schedule_config.get("account_ids", "[]") or "[]" - account_ids = json.loads(account_ids_raw) - except Exception as e: - logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}") - account_ids = [] - - if not account_ids: - try: - database.recompute_schedule_next_run(schedule_id) - except Exception: - pass - continue - - if not safe_get_user_accounts_snapshot(user_id): - load_user_accounts(user_id) - - import time as time_mod - import uuid - - execution_start_time = time_mod.time() - log_id = database.create_schedule_execution_log( - schedule_id=schedule_id, user_id=user_id, schedule_name=schedule_config.get("name", "未命名任务") + logger.info( + f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 " + f"(next_run_at={schedule_config.get('next_run_at')})" ) - - batch_id = f"batch_{uuid.uuid4().hex[:12]}" - now_ts = time_mod.time() - safe_create_batch( - batch_id, - { - "user_id": user_id, - "browse_type": browse_type, - "schedule_name": schedule_config.get("name", "未命名任务"), - "screenshots": [], - "total_accounts": 0, - "completed": 0, - "created_at": now_ts, - "updated_at": now_ts, - }, - ) - - started_count = 0 - skipped_count = 0 - completion_lock = threading.Lock() - remaining = {"count": 0, "done": False} - - def on_browse_done(): - with completion_lock: - remaining["count"] -= 1 - if remaining["done"] or remaining["count"] > 0: - return - remaining["done"] = True - execution_duration = int(time_mod.time() - execution_start_time) - database.update_schedule_execution_log( - log_id, - total_accounts=len(account_ids), - success_accounts=started_count, - failed_accounts=len(account_ids) - started_count, - duration_seconds=execution_duration, - status="completed", - ) - logger.info( - f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件" - ) - - for account_id in account_ids: - account = safe_get_account(user_id, account_id) - if not account: - skipped_count += 1 - continue - if account.is_running: - skipped_count += 1 - continue - - task_source = f"user_scheduled:{batch_id}" - with completion_lock: - remaining["count"] += 1 - ok, msg = submit_account_task( - user_id=user_id, - account_id=account_id, - browse_type=browse_type, - enable_screenshot=enable_screenshot, - source=task_source, - done_callback=on_browse_done, - ) - if ok: - started_count += 1 - else: - with completion_lock: - remaining["count"] -= 1 - skipped_count += 1 - logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}") - - batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time_mod.time()) - if batch_info: - _send_batch_task_email_if_configured(batch_info) - - database.update_schedule_last_run(schedule_id) - - logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号,批次ID: {batch_id}") - if started_count <= 0: - database.update_schedule_execution_log( - log_id, - total_accounts=len(account_ids), - success_accounts=0, - failed_accounts=len(account_ids), - duration_seconds=0, - status="completed", - ) - if started_count == 0 and len(account_ids) > 0: - logger.warning("[用户定时任务] ⚠️ 警告:所有账号都被跳过了!请检查user_accounts状态") + _execute_due_user_schedule(schedule_config) except Exception as e: logger.exception(f"[用户定时任务] 检查出错: {str(e)}") diff --git a/services/screenshots.py b/services/screenshots.py index 2b47b59..c870f4a 100644 --- a/services/screenshots.py +++ b/services/screenshots.py @@ -6,12 +6,14 @@ import os import shutil import subprocess import time +from urllib.parse import urlsplit import database import email_service from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh from app_config import get_config from app_logger import get_logger +from app_security import sanitize_filename from browser_pool_worker import get_browser_worker_pool from services.client_log import log_to_client from services.runtime import get_socketio @@ -194,6 +196,293 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None: pass +def _set_screenshot_running_status(user_id: int, account_id: str) -> None: + """更新账号状态为截图中。""" + acc = safe_get_account(user_id, account_id) + if not acc: + return + acc.status = "截图中" + safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"}) + _emit("account_update", acc.to_dict(), room=f"user_{user_id}") + + +def _get_worker_display_info(browser_instance) -> tuple[str, int]: + """获取截图 worker 的展示信息。""" + if isinstance(browser_instance, dict): + return str(browser_instance.get("worker_id", "?")), int(browser_instance.get("use_count", 0) or 0) + return "?", 0 + + +def _get_proxy_context(account) -> tuple[dict | None, str | None]: + """提取截图阶段代理配置。""" + proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None + proxy_server = proxy_config.get("server") if proxy_config else None + return proxy_config, proxy_server + + +def _build_screenshot_targets(browse_type: str) -> tuple[str, str, str]: + """构建截图目标 URL 与页面脚本。""" + parsed = urlsplit(config.ZSGL_LOGIN_URL) + base = f"{parsed.scheme}://{parsed.netloc}" + if "注册前" in str(browse_type): + bz = 0 + else: + bz = 0 + + target_url = f"{base}/admin/center.aspx?bz={bz}" + index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx" + run_script = ( + "(function(){" + "function done(){window.status='ready';}" + "function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}" + "function expandMenu(){" + "try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}" + "try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}" + "try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}" + "try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}" + "try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}" + "}" + "function navReady(){" + "try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}" + "}" + "function frameReady(){" + "try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}" + "}" + "function check(){" + "if(navReady() && frameReady()){done();return;}" + "setTimeout(check,300);" + "}" + "var f=document.getElementById('mainframe');" + "ensureNav();" + "expandMenu();" + "if(!f){done();return;}" + f"f.src='{target_url}';" + "f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};" + "setTimeout(check,5000);" + "})();" + ) + return index_url, target_url, run_script + + +def _build_screenshot_output_path(username_prefix: str, account, browse_type: str) -> tuple[str, str]: + """构建截图输出文件名与路径。""" + timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S") + login_account = account.remark if account.remark else account.username + raw_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg" + screenshot_filename = sanitize_filename(raw_filename) + return screenshot_filename, os.path.join(SCREENSHOTS_DIR, screenshot_filename) + + +def _ensure_screenshot_login_state( + *, + account, + proxy_config, + cookie_path: str, + attempt: int, + max_retries: int, + user_id: int, + account_id: str, + custom_log, +) -> str: + """确保截图前登录态有效。返回: ok/retry/fail。""" + should_refresh_login = not is_cookie_jar_fresh(cookie_path) + if not should_refresh_login: + return "ok" + + log_to_client("正在刷新登录态...", user_id, account_id) + if _ensure_login_cookies(account, proxy_config, custom_log): + return "ok" + + if attempt > 1: + log_to_client("截图登录失败", user_id, account_id) + if attempt < max_retries: + log_to_client("将重试...", user_id, account_id) + time.sleep(2) + return "retry" + + log_to_client("❌ 截图失败: 登录失败", user_id, account_id) + return "fail" + + +def _take_screenshot_once( + *, + index_url: str, + target_url: str, + screenshot_path: str, + cookie_path: str, + proxy_server: str | None, + run_script: str, + log_callback, +) -> str: + """执行一次截图尝试并验证输出文件。返回: success/invalid/failed。""" + cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None + + attempts = [ + { + "url": index_url, + "run_script": run_script, + "window_status": "ready", + }, + { + "url": target_url, + "run_script": None, + "window_status": None, + }, + ] + + ok = False + for shot in attempts: + ok = take_screenshot_wkhtmltoimage( + shot["url"], + screenshot_path, + cookies_path=cookies_for_shot, + proxy_server=proxy_server, + run_script=shot["run_script"], + window_status=shot["window_status"], + log_callback=log_callback, + ) + if ok: + break + + if not ok: + return "failed" + + if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000: + return "success" + + if os.path.exists(screenshot_path): + os.remove(screenshot_path) + return "invalid" + + +def _get_result_screenshot_path(result) -> str | None: + """从截图结果中提取截图文件绝对路径。""" + if result and result.get("success") and result.get("filename"): + return os.path.join(SCREENSHOTS_DIR, result["filename"]) + return None + + +def _enqueue_kdocs_upload_if_needed(user_id: int, account_id: str, account, screenshot_path: str | None) -> None: + """按配置提交金山文档上传任务。""" + if not screenshot_path: + return + + cfg = database.get_system_config() or {} + if int(cfg.get("kdocs_enabled", 0) or 0) != 1: + return + + doc_url = (cfg.get("kdocs_doc_url") or "").strip() + if not doc_url: + return + + user_cfg = database.get_user_kdocs_settings(user_id) or {} + if int(user_cfg.get("kdocs_auto_upload", 0) or 0) != 1: + return + + unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip() + name = (account.remark or "").strip() + if not unit: + log_to_client("表格上传跳过: 未配置县区", user_id, account_id) + return + if not name: + log_to_client("表格上传跳过: 账号备注为空", user_id, account_id) + return + + from services.kdocs_uploader import get_kdocs_uploader + + ok = get_kdocs_uploader().enqueue_upload( + user_id=user_id, + account_id=account_id, + unit=unit, + name=name, + image_path=screenshot_path, + ) + if not ok: + log_to_client("表格上传排队失败: 队列已满", user_id, account_id) + + +def _dispatch_screenshot_result( + *, + user_id: int, + account_id: str, + source: str, + browse_type: str, + browse_result: dict, + result, + account, + user_info, +) -> None: + """将截图结果发送到批次统计/邮件通知链路。""" + batch_id = _get_batch_id_from_source(source) + screenshot_path = _get_result_screenshot_path(result) + account_name = account.remark if account.remark else account.username + + try: + if result and result.get("success") and screenshot_path: + _enqueue_kdocs_upload_if_needed(user_id, account_id, account, screenshot_path) + except Exception as kdocs_error: + logger.warning(f"表格上传任务提交失败: {kdocs_error}") + + if batch_id: + _batch_task_record_result( + batch_id=batch_id, + account_name=account_name, + screenshot_path=screenshot_path, + total_items=browse_result.get("total_items", 0), + total_attachments=browse_result.get("total_attachments", 0), + ) + return + + if source and source.startswith("user_scheduled"): + if user_info and user_info.get("email") and database.get_user_email_notify(user_id): + email_service.send_task_complete_email_async( + user_id=user_id, + email=user_info["email"], + username=user_info["username"], + account_name=account_name, + browse_type=browse_type, + total_items=browse_result.get("total_items", 0), + total_attachments=browse_result.get("total_attachments", 0), + screenshot_path=screenshot_path, + log_callback=lambda msg: log_to_client(msg, user_id, account_id), + ) + + +def _finalize_screenshot_callback_state(user_id: int, account_id: str, account) -> None: + """截图回调的通用收尾状态变更。""" + account.is_running = False + account.status = "未开始" + safe_remove_task_status(account_id) + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + + +def _persist_browse_log_after_screenshot( + *, + user_id: int, + account_id: str, + account, + browse_type: str, + source: str, + task_start_time, + browse_result, +) -> None: + """截图完成后写入任务日志(浏览完成日志)。""" + import time as time_module + + total_elapsed = int(time_module.time() - task_start_time) + database.create_task_log( + user_id=user_id, + account_id=account_id, + username=account.username, + browse_type=browse_type, + status="success", + total_items=browse_result.get("total_items", 0), + total_attachments=browse_result.get("total_attachments", 0), + duration=total_elapsed, + source=source, + ) + + def take_screenshot_for_account( user_id, account_id, @@ -213,21 +502,21 @@ def take_screenshot_for_account( # 标记账号正在截图(防止重复提交截图任务) account.is_running = True + + user_info = database.get_user_by_id(user_id) + username_prefix = user_info["username"] if user_info else f"user{user_id}" + def screenshot_task( browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result ): """在worker线程中执行的截图任务""" # ✅ 获得worker后,立即更新状态为"截图中" - acc = safe_get_account(user_id, account_id) - if acc: - acc.status = "截图中" - safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"}) - _emit("account_update", acc.to_dict(), room=f"user_{user_id}") + _set_screenshot_running_status(user_id, account_id) max_retries = 3 - proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None - proxy_server = proxy_config.get("server") if proxy_config else None + proxy_config, proxy_server = _get_proxy_context(account) cookie_path = get_cookie_jar_path(account.username) + index_url, target_url, run_script = _build_screenshot_targets(browse_type) for attempt in range(1, max_retries + 1): try: @@ -239,8 +528,7 @@ def take_screenshot_for_account( if attempt > 1: log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id) - worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?" - use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0 + worker_id, use_count = _get_worker_display_info(browser_instance) log_to_client( f"使用Worker-{worker_id}执行截图(已执行{use_count}次)", user_id, @@ -250,99 +538,39 @@ def take_screenshot_for_account( def custom_log(message: str): log_to_client(message, user_id, account_id) - # 智能登录状态检查:只在必要时才刷新登录 - should_refresh_login = not is_cookie_jar_fresh(cookie_path) - if should_refresh_login and attempt > 1: - # 重试时刷新登录(attempt > 1 表示第2次及以后的尝试) - log_to_client("正在刷新登录态...", user_id, account_id) - if not _ensure_login_cookies(account, proxy_config, custom_log): - log_to_client("截图登录失败", user_id, account_id) - if attempt < max_retries: - log_to_client("将重试...", user_id, account_id) - time.sleep(2) - continue - log_to_client("❌ 截图失败: 登录失败", user_id, account_id) - return {"success": False, "error": "登录失败"} - elif should_refresh_login: - # 首次尝试时快速检查登录状态 - log_to_client("正在刷新登录态...", user_id, account_id) - if not _ensure_login_cookies(account, proxy_config, custom_log): - log_to_client("❌ 截图失败: 登录失败", user_id, account_id) - return {"success": False, "error": "登录失败"} + login_state = _ensure_screenshot_login_state( + account=account, + proxy_config=proxy_config, + cookie_path=cookie_path, + attempt=attempt, + max_retries=max_retries, + user_id=user_id, + account_id=account_id, + custom_log=custom_log, + ) + if login_state == "retry": + continue + if login_state == "fail": + return {"success": False, "error": "登录失败"} log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id) - from urllib.parse import urlsplit - - parsed = urlsplit(config.ZSGL_LOGIN_URL) - base = f"{parsed.scheme}://{parsed.netloc}" - if "注册前" in str(browse_type): - bz = 0 - else: - bz = 0 # 应读(网站更新后 bz=0 为应读) - target_url = f"{base}/admin/center.aspx?bz={bz}" - index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx" - run_script = ( - "(function(){" - "function done(){window.status='ready';}" - "function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}" - "function expandMenu(){" - "try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}" - "try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}" - "try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}" - "try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}" - "try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}" - "}" - "function navReady(){" - "try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}" - "}" - "function frameReady(){" - "try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}" - "}" - "function check(){" - "if(navReady() && frameReady()){done();return;}" - "setTimeout(check,300);" - "}" - "var f=document.getElementById('mainframe');" - "ensureNav();" - "expandMenu();" - "if(!f){done();return;}" - f"f.src='{target_url}';" - "f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};" - "setTimeout(check,5000);" - "})();" - ) - - timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S") - - user_info = database.get_user_by_id(user_id) - username_prefix = user_info["username"] if user_info else f"user{user_id}" - login_account = account.remark if account.remark else account.username - screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg" - screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename) - - cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None - if take_screenshot_wkhtmltoimage( - index_url, - screenshot_path, - cookies_path=cookies_for_shot, + screenshot_filename, screenshot_path = _build_screenshot_output_path(username_prefix, account, browse_type) + shot_state = _take_screenshot_once( + index_url=index_url, + target_url=target_url, + screenshot_path=screenshot_path, + cookie_path=cookie_path, proxy_server=proxy_server, run_script=run_script, - window_status="ready", log_callback=custom_log, - ) or take_screenshot_wkhtmltoimage( - target_url, - screenshot_path, - cookies_path=cookies_for_shot, - proxy_server=proxy_server, - log_callback=custom_log, - ): - if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000: - log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id) - return {"success": True, "filename": screenshot_filename} + ) + if shot_state == "success": + log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id) + return {"success": True, "filename": screenshot_filename} + + if shot_state == "invalid": log_to_client("截图文件异常,将重试", user_id, account_id) - if os.path.exists(screenshot_path): - os.remove(screenshot_path) else: log_to_client("截图保存失败", user_id, account_id) @@ -361,12 +589,7 @@ def take_screenshot_for_account( def screenshot_callback(result, error): """截图完成回调""" try: - account.is_running = False - account.status = "未开始" - - safe_remove_task_status(account_id) - - _emit("account_update", account.to_dict(), room=f"user_{user_id}") + _finalize_screenshot_callback_state(user_id, account_id, account) if error: log_to_client(f"❌ 截图失败: {error}", user_id, account_id) @@ -375,84 +598,27 @@ def take_screenshot_for_account( log_to_client(f"❌ 截图失败: {error_msg}", user_id, account_id) if task_start_time and browse_result: - import time as time_module - - total_elapsed = int(time_module.time() - task_start_time) - database.create_task_log( + _persist_browse_log_after_screenshot( user_id=user_id, account_id=account_id, - username=account.username, + account=account, browse_type=browse_type, - status="success", - total_items=browse_result.get("total_items", 0), - total_attachments=browse_result.get("total_attachments", 0), - duration=total_elapsed, source=source, + task_start_time=task_start_time, + browse_result=browse_result, ) try: - batch_id = _get_batch_id_from_source(source) - - screenshot_path = None - if result and result.get("success") and result.get("filename"): - screenshot_path = os.path.join(SCREENSHOTS_DIR, result["filename"]) - - account_name = account.remark if account.remark else account.username - - try: - if screenshot_path and result and result.get("success"): - cfg = database.get_system_config() or {} - if int(cfg.get("kdocs_enabled", 0) or 0) == 1: - doc_url = (cfg.get("kdocs_doc_url") or "").strip() - if doc_url: - user_cfg = database.get_user_kdocs_settings(user_id) or {} - if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1: - unit = ( - user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "" - ).strip() - name = (account.remark or "").strip() - if unit and name: - from services.kdocs_uploader import get_kdocs_uploader - - ok = get_kdocs_uploader().enqueue_upload( - user_id=user_id, - account_id=account_id, - unit=unit, - name=name, - image_path=screenshot_path, - ) - if not ok: - log_to_client("表格上传排队失败: 队列已满", user_id, account_id) - else: - if not unit: - log_to_client("表格上传跳过: 未配置县区", user_id, account_id) - if not name: - log_to_client("表格上传跳过: 账号备注为空", user_id, account_id) - except Exception as kdocs_error: - logger.warning(f"表格上传任务提交失败: {kdocs_error}") - - if batch_id: - _batch_task_record_result( - batch_id=batch_id, - account_name=account_name, - screenshot_path=screenshot_path, - total_items=browse_result.get("total_items", 0), - total_attachments=browse_result.get("total_attachments", 0), - ) - elif source and source.startswith("user_scheduled"): - user_info = database.get_user_by_id(user_id) - if user_info and user_info.get("email") and database.get_user_email_notify(user_id): - email_service.send_task_complete_email_async( - user_id=user_id, - email=user_info["email"], - username=user_info["username"], - account_name=account_name, - browse_type=browse_type, - total_items=browse_result.get("total_items", 0), - total_attachments=browse_result.get("total_attachments", 0), - screenshot_path=screenshot_path, - log_callback=lambda msg: log_to_client(msg, user_id, account_id), - ) + _dispatch_screenshot_result( + user_id=user_id, + account_id=account_id, + source=source, + browse_type=browse_type, + browse_result=browse_result, + result=result, + account=account, + user_info=user_info, + ) except Exception as email_error: logger.warning(f"发送任务完成邮件失败: {email_error}") except Exception as e: diff --git a/services/state.py b/services/state.py index 3725026..c3968c2 100644 --- a/services/state.py +++ b/services/state.py @@ -13,7 +13,7 @@ from __future__ import annotations import threading import time import random -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from app_config import get_config @@ -161,6 +161,36 @@ _log_cache_lock = threading.RLock() _log_cache_total_count = 0 +def _pop_oldest_log_for_user(uid: int) -> bool: + global _log_cache_total_count + + logs = _log_cache.get(uid) + if not logs: + _log_cache.pop(uid, None) + return False + + logs.pop(0) + _log_cache_total_count = max(0, _log_cache_total_count - 1) + if not logs: + _log_cache.pop(uid, None) + return True + + +def _pop_oldest_log_from_largest_user() -> bool: + largest_uid = None + largest_size = 0 + + for uid, logs in _log_cache.items(): + size = len(logs) + if size > largest_size: + largest_uid = uid + largest_size = size + + if largest_uid is None or largest_size <= 0: + return False + return _pop_oldest_log_for_user(int(largest_uid)) + + def safe_add_log( user_id: int, log_entry: Dict[str, Any], @@ -175,24 +205,17 @@ def safe_add_log( max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS) with _log_cache_lock: - if uid not in _log_cache: - _log_cache[uid] = [] + logs = _log_cache.setdefault(uid, []) - if len(_log_cache[uid]) >= max_logs_per_user: - _log_cache[uid].pop(0) - _log_cache_total_count = max(0, _log_cache_total_count - 1) + if len(logs) >= max_logs_per_user: + _pop_oldest_log_for_user(uid) + logs = _log_cache.setdefault(uid, []) - _log_cache[uid].append(dict(log_entry or {})) + logs.append(dict(log_entry or {})) _log_cache_total_count += 1 while _log_cache_total_count > max_total_logs: - if not _log_cache: - break - max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u])) - if _log_cache.get(max_user): - _log_cache[max_user].pop(0) - _log_cache_total_count -= 1 - else: + if not _pop_oldest_log_from_largest_user(): break @@ -378,6 +401,34 @@ def _get_action_rate_limit(action: str) -> Tuple[int, int]: return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS) +def _format_wait_hint(remaining_seconds: int) -> str: + remaining = max(1, int(remaining_seconds or 0)) + if remaining >= 60: + return f"{remaining // 60 + 1}分钟" + return f"{remaining}秒" + + +def _check_and_increment_rate_bucket( + *, + buckets: Dict[str, Dict[str, Any]], + key: str, + now_ts: float, + max_requests: int, + window_seconds: int, +) -> Tuple[bool, Optional[str]]: + if not key or int(max_requests) <= 0: + return True, None + + data = _get_or_reset_bucket(buckets.get(key), now_ts, window_seconds) + if int(data.get("count", 0) or 0) >= int(max_requests): + remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0)))) + return False, f"请求过于频繁,请{_format_wait_hint(remaining)}后再试" + + data["count"] = int(data.get("count", 0) or 0) + 1 + buckets[key] = data + return True, None + + def check_ip_request_rate( ip_address: str, action: str, @@ -392,21 +443,13 @@ def check_ip_request_rate( key = f"{action}:{ip_address}" with _ip_request_rate_lock: - data = _ip_request_rate.get(key) - if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds: - data = {"window_start": now_ts, "count": 0} - _ip_request_rate[key] = data - - if int(data.get("count", 0) or 0) >= max_requests: - remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0)))) - if remaining >= 60: - wait_hint = f"{remaining // 60 + 1}分钟" - else: - wait_hint = f"{remaining}秒" - return False, f"请求过于频繁,请{wait_hint}后再试" - - data["count"] = int(data.get("count", 0) or 0) + 1 - return True, None + return _check_and_increment_rate_bucket( + buckets=_ip_request_rate, + key=key, + now_ts=now_ts, + max_requests=max_requests, + window_seconds=window_seconds, + ) def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int: @@ -417,8 +460,7 @@ def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int: data = _ip_request_rate.get(key) or {} action = key.split(":", 1)[0] _, window_seconds = _get_action_rate_limit(action) - window_start = float(data.get("window_start", 0) or 0) - if now_ts - window_start >= window_seconds: + if _is_bucket_expired(data, now_ts, window_seconds): _ip_request_rate.pop(key, None) removed += 1 return removed @@ -487,6 +529,30 @@ def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_s return data +def _is_bucket_expired( + data: Optional[Dict[str, Any]], + now_ts: float, + window_seconds: int, + *, + time_field: str = "window_start", +) -> bool: + start_ts = float((data or {}).get(time_field, 0) or 0) + return (now_ts - start_ts) >= max(1, int(window_seconds)) + + +def _cleanup_map_entries( + store: Dict[Any, Dict[str, Any]], + should_remove: Callable[[Dict[str, Any]], bool], +) -> int: + removed = 0 + for key, value in list(store.items()): + item = value if isinstance(value, dict) else {} + if should_remove(item): + store.pop(key, None) + removed += 1 + return removed + + def record_login_username_attempt(ip_address: str, username: str) -> bool: now_ts = time.time() threshold, window_seconds, cooldown_seconds = _get_login_scan_config() @@ -527,26 +593,32 @@ def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optio user_key = _normalize_login_key("user", "", username) ip_user_key = _normalize_login_key("ipuser", ip_address, username) - def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]: - if not key or max_requests <= 0: - return True, None - data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds) - if int(data.get("count", 0) or 0) >= max_requests: - remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0)))) - wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}秒" - return False, f"请求过于频繁,请{wait_hint}后再试" - data["count"] = int(data.get("count", 0) or 0) + 1 - _login_rate_limits[key] = data - return True, None - with _login_rate_limits_lock: - allowed, msg = _check(ip_key, ip_max) + allowed, msg = _check_and_increment_rate_bucket( + buckets=_login_rate_limits, + key=ip_key, + now_ts=now_ts, + max_requests=ip_max, + window_seconds=window_seconds, + ) if not allowed: return False, msg - allowed, msg = _check(ip_user_key, ip_user_max) + allowed, msg = _check_and_increment_rate_bucket( + buckets=_login_rate_limits, + key=ip_user_key, + now_ts=now_ts, + max_requests=ip_user_max, + window_seconds=window_seconds, + ) if not allowed: return False, msg - allowed, msg = _check(user_key, user_max) + allowed, msg = _check_and_increment_rate_bucket( + buckets=_login_rate_limits, + key=user_key, + now_ts=now_ts, + max_requests=user_max, + window_seconds=window_seconds, + ) if not allowed: return False, msg @@ -622,15 +694,18 @@ def check_login_captcha_required(ip_address: str, username: Optional[str] = None ip_key = _normalize_login_key("ip", ip_address) ip_user_key = _normalize_login_key("ipuser", ip_address, username or "") + def _is_over_threshold(data: Optional[Dict[str, Any]]) -> bool: + if not data: + return False + if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds: + return False + return int(data.get("count", 0) or 0) >= max_failures + with _login_failures_lock: - ip_data = _login_failures.get(ip_key) - if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds: - if int(ip_data.get("count", 0) or 0) >= max_failures: - return True - ip_user_data = _login_failures.get(ip_user_key) - if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds: - if int(ip_user_data.get("count", 0) or 0) >= max_failures: - return True + if _is_over_threshold(_login_failures.get(ip_key)): + return True + if _is_over_threshold(_login_failures.get(ip_user_key)): + return True if is_login_scan_locked(ip_address): return True @@ -685,6 +760,56 @@ def should_send_login_alert(user_id: int, ip_address: str) -> bool: return False +def cleanup_expired_login_security_state(now_ts: Optional[float] = None) -> Dict[str, int]: + now_ts = float(now_ts if now_ts is not None else time.time()) + _, captcha_window = _get_login_captcha_config() + _, _, _, rate_window = _get_login_rate_limit_config() + _, lock_window, _ = _get_login_lock_config() + _, scan_window, _ = _get_login_scan_config() + alert_expire_seconds = max(3600, int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS) * 3) + + with _login_failures_lock: + failures_removed = _cleanup_map_entries( + _login_failures, + lambda data: (now_ts - float(data.get("first_failed", 0) or 0)) > max(captcha_window, lock_window), + ) + + with _login_rate_limits_lock: + rate_removed = _cleanup_map_entries( + _login_rate_limits, + lambda data: _is_bucket_expired(data, now_ts, rate_window), + ) + + with _login_scan_lock: + scan_removed = _cleanup_map_entries( + _login_scan_state, + lambda data: ( + (now_ts - float(data.get("first_seen", 0) or 0)) > scan_window + and now_ts >= float(data.get("scan_until", 0) or 0) + ), + ) + + with _login_ip_user_lock: + ip_user_locks_removed = _cleanup_map_entries( + _login_ip_user_locks, + lambda data: now_ts >= float(data.get("lock_until", 0) or 0), + ) + + with _login_alert_lock: + alerts_removed = _cleanup_map_entries( + _login_alert_state, + lambda data: (now_ts - float(data.get("last_sent", 0) or 0)) > alert_expire_seconds, + ) + + return { + "failures": failures_removed, + "rate_limits": rate_removed, + "scan_states": scan_removed, + "ip_user_locks": ip_user_locks_removed, + "alerts": alerts_removed, + } + + # ==================== 邮箱维度限流 ==================== _email_rate_limit: Dict[str, Dict[str, Any]] = {} @@ -701,14 +826,13 @@ def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str] key = f"{action}:{email_key}" with _email_rate_limit_lock: - data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds) - if int(data.get("count", 0) or 0) >= max_requests: - remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0)))) - wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}秒" - return False, f"请求过于频繁,请{wait_hint}后再试" - data["count"] = int(data.get("count", 0) or 0) + 1 - _email_rate_limit[key] = data - return True, None + return _check_and_increment_rate_bucket( + buckets=_email_rate_limit, + key=key, + now_ts=now_ts, + max_requests=max_requests, + window_seconds=window_seconds, + ) # ==================== Batch screenshots(批次任务截图收集) ==================== diff --git a/services/task_scheduler.py b/services/task_scheduler.py new file mode 100644 index 0000000..cee84bd --- /dev/null +++ b/services/task_scheduler.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import heapq +import threading +import time +from concurrent.futures import ThreadPoolExecutor, wait +from dataclasses import dataclass + +import database +from app_logger import get_logger +from services.state import safe_get_account, safe_get_task, safe_remove_task, safe_set_task +from services.task_batches import _batch_task_record_result, _get_batch_id_from_source + +logger = get_logger("app") + +# VIP优先级队列(仅用于可视化/调试) +vip_task_queue = [] # VIP用户任务队列 +normal_task_queue = [] # 普通用户任务队列 +task_queue_lock = threading.Lock() + +@dataclass +class _TaskRequest: + user_id: int + account_id: str + browse_type: str + enable_screenshot: bool + source: str + retry_count: int + submitted_at: float + is_vip: bool + seq: int + canceled: bool = False + done_callback: object = None + + +class TaskScheduler: + """全局任务调度器:队列排队,不为每个任务单独创建线程。""" + + def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000, run_task_fn=None): + self.max_global = max(1, int(max_global)) + self.max_per_user = max(1, int(max_per_user)) + self.max_queue_size = max(1, int(max_queue_size)) + + self._cond = threading.Condition() + self._pending = [] # heap: (priority, submitted_at, seq, task) + self._pending_by_account = {} # {account_id: task} + self._seq = 0 + self._known_account_ids = set() + + self._running_global = 0 + self._running_by_user = {} # {user_id: running_count} + + self._executor_max_workers = self.max_global + self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker") + + self._futures_lock = threading.Lock() + self._active_futures = set() + + self._running = True + self._run_task_fn = run_task_fn + self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher") + self._dispatcher_thread.start() + + def _track_future(self, future) -> None: + with self._futures_lock: + self._active_futures.add(future) + try: + future.add_done_callback(self._untrack_future) + except Exception: + pass + + def _untrack_future(self, future) -> None: + with self._futures_lock: + self._active_futures.discard(future) + + def shutdown(self, timeout: float = 5.0): + """停止调度器(用于进程退出清理)""" + with self._cond: + self._running = False + self._cond.notify_all() + + try: + self._dispatcher_thread.join(timeout=timeout) + except Exception: + pass + + # 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试 + try: + deadline = time.time() + max(0.0, float(timeout or 0)) + while True: + with self._futures_lock: + pending = [f for f in self._active_futures if not f.done()] + if not pending: + break + remaining = deadline - time.time() + if remaining <= 0: + break + wait(pending, timeout=remaining) + except Exception: + pass + + try: + self._executor.shutdown(wait=False) + except Exception: + pass + + # 最后兜底:清理本调度器提交过的 active_task,避免测试/重启时被“任务已在运行中”误拦截 + try: + with self._cond: + known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys()) + self._pending.clear() + self._pending_by_account.clear() + self._cond.notify_all() + for account_id in known_ids: + safe_remove_task(account_id) + except Exception: + pass + + def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None): + """动态更新并发/队列上限(不影响已在运行的任务)""" + with self._cond: + if max_per_user is not None: + self.max_per_user = max(1, int(max_per_user)) + if max_queue_size is not None: + self.max_queue_size = max(1, int(max_queue_size)) + + if max_global is not None: + new_max_global = max(1, int(max_global)) + self.max_global = new_max_global + if new_max_global > self._executor_max_workers: + # 立即关闭旧线程池,防止资源泄漏 + old_executor = self._executor + self._executor_max_workers = new_max_global + self._executor = ThreadPoolExecutor( + max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker" + ) + # 立即关闭旧线程池 + try: + old_executor.shutdown(wait=False) + logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}") + except Exception as e: + logger.warning(f"关闭旧线程池失败: {e}") + + self._cond.notify_all() + + def get_queue_state_snapshot(self) -> dict: + """获取调度器队列/运行状态快照(用于前端展示/监控)。""" + with self._cond: + pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled] + pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq)) + + positions = {} + for idx, t in enumerate(pending_tasks): + positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)} + + return { + "pending_total": len(pending_tasks), + "running_total": int(self._running_global), + "running_by_user": dict(self._running_by_user), + "positions": positions, + } + + def submit_task( + self, + user_id: int, + account_id: str, + browse_type: str, + enable_screenshot: bool = True, + source: str = "manual", + retry_count: int = 0, + is_vip: bool = None, + done_callback=None, + ): + """提交任务进入队列(返回: (ok, message))""" + if not user_id or not account_id: + return False, "参数错误" + + submitted_at = time.time() + if is_vip is None: + try: + is_vip = bool(database.is_user_vip(user_id)) + except Exception: + is_vip = False + else: + is_vip = bool(is_vip) + + with self._cond: + if not self._running: + return False, "调度器未运行" + if len(self._pending_by_account) >= self.max_queue_size: + return False, "任务队列已满,请稍后再试" + if account_id in self._pending_by_account: + return False, "任务已在队列中" + if safe_get_task(account_id) is not None: + return False, "任务已在运行中" + + self._seq += 1 + task = _TaskRequest( + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + enable_screenshot=bool(enable_screenshot), + source=source, + retry_count=int(retry_count or 0), + submitted_at=submitted_at, + is_vip=is_vip, + seq=self._seq, + done_callback=done_callback, + ) + self._pending_by_account[account_id] = task + self._known_account_ids.add(account_id) + priority = 0 if is_vip else 1 + heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task)) + self._cond.notify_all() + + # 用于可视化/调试:记录队列 + with task_queue_lock: + if is_vip: + vip_task_queue.append(account_id) + else: + normal_task_queue.append(account_id) + + return True, "已加入队列" + + def cancel_pending_task(self, user_id: int, account_id: str) -> bool: + """取消尚未开始的排队任务(已运行的任务由 should_stop 控制)""" + canceled_task = None + with self._cond: + task = self._pending_by_account.pop(account_id, None) + if not task: + return False + task.canceled = True + canceled_task = task + self._cond.notify_all() + + # 从可视化队列移除 + with task_queue_lock: + if account_id in vip_task_queue: + vip_task_queue.remove(account_id) + if account_id in normal_task_queue: + normal_task_queue.remove(account_id) + + # 批次任务:取消也要推进完成计数,避免批次缓存常驻 + try: + batch_id = _get_batch_id_from_source(canceled_task.source) + if batch_id: + acc = safe_get_account(user_id, account_id) + if acc: + account_name = acc.remark if acc.remark else acc.username + else: + account_name = account_id + _batch_task_record_result( + batch_id=batch_id, + account_name=account_name, + screenshot_path=None, + total_items=0, + total_attachments=0, + ) + except Exception: + pass + + return True + + def _dispatch_loop(self): + while True: + task = None + with self._cond: + if not self._running: + return + + if not self._pending or self._running_global >= self.max_global: + self._cond.wait(timeout=0.5) + continue + + task = self._pop_next_runnable_locked() + if task is None: + self._cond.wait(timeout=0.5) + continue + + self._running_global += 1 + self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1 + + # 从队列移除(可视化) + with task_queue_lock: + if task.account_id in vip_task_queue: + vip_task_queue.remove(task.account_id) + if task.account_id in normal_task_queue: + normal_task_queue.remove(task.account_id) + + try: + future = self._executor.submit(self._run_task_wrapper, task) + self._track_future(future) + safe_set_task(task.account_id, future) + except Exception: + with self._cond: + self._running_global = max(0, self._running_global - 1) + # 使用默认值 0 与增加时保持一致 + self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1) + if self._running_by_user.get(task.user_id) == 0: + self._running_by_user.pop(task.user_id, None) + self._cond.notify_all() + + def _pop_next_runnable_locked(self): + """在锁内从优先队列取出“可运行”的任务,避免VIP任务占位阻塞普通任务。""" + if not self._pending: + return None + + skipped = [] + selected = None + + while self._pending: + _, _, _, task = heapq.heappop(self._pending) + + if task.canceled: + continue + if self._pending_by_account.get(task.account_id) is not task: + continue + + running_for_user = self._running_by_user.get(task.user_id, 0) + if running_for_user >= self.max_per_user: + skipped.append(task) + continue + + selected = task + break + + for t in skipped: + priority = 0 if t.is_vip else 1 + heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t)) + + if selected is None: + return None + + self._pending_by_account.pop(selected.account_id, None) + return selected + + def _run_task_wrapper(self, task: _TaskRequest): + try: + if callable(self._run_task_fn): + self._run_task_fn( + user_id=task.user_id, + account_id=task.account_id, + browse_type=task.browse_type, + enable_screenshot=task.enable_screenshot, + source=task.source, + retry_count=task.retry_count, + ) + finally: + try: + if callable(task.done_callback): + task.done_callback() + except Exception: + pass + safe_remove_task(task.account_id) + with self._cond: + self._running_global = max(0, self._running_global - 1) + # 使用默认值 0 与增加时保持一致 + self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1) + if self._running_by_user.get(task.user_id) == 0: + self._running_by_user.pop(task.user_id, None) + self._cond.notify_all() + + diff --git a/services/tasks.py b/services/tasks.py index bca2d0a..87255dc 100644 --- a/services/tasks.py +++ b/services/tasks.py @@ -2,12 +2,9 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import heapq import os import threading import time -from concurrent.futures import ThreadPoolExecutor, wait -from dataclasses import dataclass import database import email_service @@ -21,27 +18,22 @@ from services.runtime import get_socketio from services.screenshots import take_screenshot_for_account from services.state import ( safe_get_account, - safe_get_task, safe_remove_task, safe_remove_task_status, - safe_set_task, safe_set_task_status, safe_update_task_status, ) from services.task_batches import _batch_task_record_result, _get_batch_id_from_source +from services.task_scheduler import TaskScheduler from task_checkpoint import TaskStage logger = get_logger("app") config = get_config() -# VIP优先级队列(仅用于可视化/调试) -vip_task_queue = [] # VIP用户任务队列 -normal_task_queue = [] # 普通用户任务队列 -task_queue_lock = threading.Lock() - # 并发默认值(启动后会由系统配置覆盖并调用 update_limits) max_concurrent_per_account = config.MAX_CONCURRENT_PER_ACCOUNT max_concurrent_global = config.MAX_CONCURRENT_GLOBAL +_SOURCE_UNSET = object() def _emit(event: str, data: object, *, room: str | None = None) -> None: @@ -52,347 +44,6 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None: pass -@dataclass -class _TaskRequest: - user_id: int - account_id: str - browse_type: str - enable_screenshot: bool - source: str - retry_count: int - submitted_at: float - is_vip: bool - seq: int - canceled: bool = False - done_callback: object = None - - -class TaskScheduler: - """全局任务调度器:队列排队,不为每个任务单独创建线程。""" - - def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000): - self.max_global = max(1, int(max_global)) - self.max_per_user = max(1, int(max_per_user)) - self.max_queue_size = max(1, int(max_queue_size)) - - self._cond = threading.Condition() - self._pending = [] # heap: (priority, submitted_at, seq, task) - self._pending_by_account = {} # {account_id: task} - self._seq = 0 - self._known_account_ids = set() - - self._running_global = 0 - self._running_by_user = {} # {user_id: running_count} - - self._executor_max_workers = self.max_global - self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker") - - self._futures_lock = threading.Lock() - self._active_futures = set() - - self._running = True - self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher") - self._dispatcher_thread.start() - - def _track_future(self, future) -> None: - with self._futures_lock: - self._active_futures.add(future) - try: - future.add_done_callback(self._untrack_future) - except Exception: - pass - - def _untrack_future(self, future) -> None: - with self._futures_lock: - self._active_futures.discard(future) - - def shutdown(self, timeout: float = 5.0): - """停止调度器(用于进程退出清理)""" - with self._cond: - self._running = False - self._cond.notify_all() - - try: - self._dispatcher_thread.join(timeout=timeout) - except Exception: - pass - - # 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试 - try: - deadline = time.time() + max(0.0, float(timeout or 0)) - while True: - with self._futures_lock: - pending = [f for f in self._active_futures if not f.done()] - if not pending: - break - remaining = deadline - time.time() - if remaining <= 0: - break - wait(pending, timeout=remaining) - except Exception: - pass - - try: - self._executor.shutdown(wait=False) - except Exception: - pass - - # 最后兜底:清理本调度器提交过的 active_task,避免测试/重启时被“任务已在运行中”误拦截 - try: - with self._cond: - known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys()) - self._pending.clear() - self._pending_by_account.clear() - self._cond.notify_all() - for account_id in known_ids: - safe_remove_task(account_id) - except Exception: - pass - - def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None): - """动态更新并发/队列上限(不影响已在运行的任务)""" - with self._cond: - if max_per_user is not None: - self.max_per_user = max(1, int(max_per_user)) - if max_queue_size is not None: - self.max_queue_size = max(1, int(max_queue_size)) - - if max_global is not None: - new_max_global = max(1, int(max_global)) - self.max_global = new_max_global - if new_max_global > self._executor_max_workers: - # 立即关闭旧线程池,防止资源泄漏 - old_executor = self._executor - self._executor_max_workers = new_max_global - self._executor = ThreadPoolExecutor( - max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker" - ) - # 立即关闭旧线程池 - try: - old_executor.shutdown(wait=False) - logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}") - except Exception as e: - logger.warning(f"关闭旧线程池失败: {e}") - - self._cond.notify_all() - - def get_queue_state_snapshot(self) -> dict: - """获取调度器队列/运行状态快照(用于前端展示/监控)。""" - with self._cond: - pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled] - pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq)) - - positions = {} - for idx, t in enumerate(pending_tasks): - positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)} - - return { - "pending_total": len(pending_tasks), - "running_total": int(self._running_global), - "running_by_user": dict(self._running_by_user), - "positions": positions, - } - - def submit_task( - self, - user_id: int, - account_id: str, - browse_type: str, - enable_screenshot: bool = True, - source: str = "manual", - retry_count: int = 0, - is_vip: bool = None, - done_callback=None, - ): - """提交任务进入队列(返回: (ok, message))""" - if not user_id or not account_id: - return False, "参数错误" - - submitted_at = time.time() - if is_vip is None: - try: - is_vip = bool(database.is_user_vip(user_id)) - except Exception: - is_vip = False - else: - is_vip = bool(is_vip) - - with self._cond: - if not self._running: - return False, "调度器未运行" - if len(self._pending_by_account) >= self.max_queue_size: - return False, "任务队列已满,请稍后再试" - if account_id in self._pending_by_account: - return False, "任务已在队列中" - if safe_get_task(account_id) is not None: - return False, "任务已在运行中" - - self._seq += 1 - task = _TaskRequest( - user_id=user_id, - account_id=account_id, - browse_type=browse_type, - enable_screenshot=bool(enable_screenshot), - source=source, - retry_count=int(retry_count or 0), - submitted_at=submitted_at, - is_vip=is_vip, - seq=self._seq, - done_callback=done_callback, - ) - self._pending_by_account[account_id] = task - self._known_account_ids.add(account_id) - priority = 0 if is_vip else 1 - heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task)) - self._cond.notify_all() - - # 用于可视化/调试:记录队列 - with task_queue_lock: - if is_vip: - vip_task_queue.append(account_id) - else: - normal_task_queue.append(account_id) - - return True, "已加入队列" - - def cancel_pending_task(self, user_id: int, account_id: str) -> bool: - """取消尚未开始的排队任务(已运行的任务由 should_stop 控制)""" - canceled_task = None - with self._cond: - task = self._pending_by_account.pop(account_id, None) - if not task: - return False - task.canceled = True - canceled_task = task - self._cond.notify_all() - - # 从可视化队列移除 - with task_queue_lock: - if account_id in vip_task_queue: - vip_task_queue.remove(account_id) - if account_id in normal_task_queue: - normal_task_queue.remove(account_id) - - # 批次任务:取消也要推进完成计数,避免批次缓存常驻 - try: - batch_id = _get_batch_id_from_source(canceled_task.source) - if batch_id: - acc = safe_get_account(user_id, account_id) - if acc: - account_name = acc.remark if acc.remark else acc.username - else: - account_name = account_id - _batch_task_record_result( - batch_id=batch_id, - account_name=account_name, - screenshot_path=None, - total_items=0, - total_attachments=0, - ) - except Exception: - pass - - return True - - def _dispatch_loop(self): - while True: - task = None - with self._cond: - if not self._running: - return - - if not self._pending or self._running_global >= self.max_global: - self._cond.wait(timeout=0.5) - continue - - task = self._pop_next_runnable_locked() - if task is None: - self._cond.wait(timeout=0.5) - continue - - self._running_global += 1 - self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1 - - # 从队列移除(可视化) - with task_queue_lock: - if task.account_id in vip_task_queue: - vip_task_queue.remove(task.account_id) - if task.account_id in normal_task_queue: - normal_task_queue.remove(task.account_id) - - try: - future = self._executor.submit(self._run_task_wrapper, task) - self._track_future(future) - safe_set_task(task.account_id, future) - except Exception: - with self._cond: - self._running_global = max(0, self._running_global - 1) - # 使用默认值 0 与增加时保持一致 - self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1) - if self._running_by_user.get(task.user_id) == 0: - self._running_by_user.pop(task.user_id, None) - self._cond.notify_all() - - def _pop_next_runnable_locked(self): - """在锁内从优先队列取出“可运行”的任务,避免VIP任务占位阻塞普通任务。""" - if not self._pending: - return None - - skipped = [] - selected = None - - while self._pending: - _, _, _, task = heapq.heappop(self._pending) - - if task.canceled: - continue - if self._pending_by_account.get(task.account_id) is not task: - continue - - running_for_user = self._running_by_user.get(task.user_id, 0) - if running_for_user >= self.max_per_user: - skipped.append(task) - continue - - selected = task - break - - for t in skipped: - priority = 0 if t.is_vip else 1 - heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t)) - - if selected is None: - return None - - self._pending_by_account.pop(selected.account_id, None) - return selected - - def _run_task_wrapper(self, task: _TaskRequest): - try: - run_task( - user_id=task.user_id, - account_id=task.account_id, - browse_type=task.browse_type, - enable_screenshot=task.enable_screenshot, - source=task.source, - retry_count=task.retry_count, - ) - finally: - try: - if callable(task.done_callback): - task.done_callback() - except Exception: - pass - safe_remove_task(task.account_id) - with self._cond: - self._running_global = max(0, self._running_global - 1) - # 使用默认值 0 与增加时保持一致 - self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1) - if self._running_by_user.get(task.user_id) == 0: - self._running_by_user.pop(task.user_id, None) - self._cond.notify_all() - - _task_scheduler = None _task_scheduler_lock = threading.Lock() @@ -410,6 +61,7 @@ def get_task_scheduler() -> TaskScheduler: max_global=max_concurrent_global, max_per_user=max_concurrent_per_account, max_queue_size=max_queue_size, + run_task_fn=run_task, ) return _task_scheduler @@ -477,6 +129,584 @@ def submit_account_task( return True, message +def _account_display_name(account) -> str: + return account.remark if account.remark else account.username + + +def _record_batch_result(batch_id, account, total_items: int, total_attachments: int) -> None: + if not batch_id: + return + _batch_task_record_result( + batch_id=batch_id, + account_name=_account_display_name(account), + screenshot_path=None, + total_items=total_items, + total_attachments=total_attachments, + ) + + +def _close_account_automation(account, *, on_error=None) -> bool: + automation = getattr(account, "automation", None) + if not automation: + return False + + closed = False + try: + automation.close() + closed = True + except Exception as e: + if on_error: + try: + on_error(e) + except Exception: + pass + finally: + account.automation = None + + return closed + + +def _create_task_log( + *, + user_id: int, + account_id: str, + account, + browse_type: str, + status: str, + total_items: int, + total_attachments: int, + error_message: str, + duration: int, + source=_SOURCE_UNSET, +) -> None: + payload = { + "user_id": user_id, + "account_id": account_id, + "username": account.username, + "browse_type": browse_type, + "status": status, + "total_items": total_items, + "total_attachments": total_attachments, + "error_message": error_message, + "duration": duration, + } + if source is not _SOURCE_UNSET: + payload["source"] = source + database.create_task_log(**payload) + + +def _handle_stop_requested( + *, + account, + user_id: int, + account_id: str, + batch_id, + remove_task_status: bool = False, + record_batch: bool = False, +) -> bool: + if not account.should_stop: + return False + + log_to_client("任务已取消", user_id, account_id) + account.status = "已停止" + account.is_running = False + if remove_task_status: + safe_remove_task_status(account_id) + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + + if record_batch and batch_id: + _record_batch_result(batch_id=batch_id, account=account, total_items=0, total_attachments=0) + return True + + +def _resolve_proxy_config(user_id: int, account_id: str, account): + proxy_config = None + system_config = database.get_system_config() + if system_config.get("proxy_enabled") != 1: + return proxy_config + + proxy_api_url = system_config.get("proxy_api_url", "").strip() + if not proxy_api_url: + log_to_client("⚠ 代理已启用但未配置API地址", user_id, account_id) + return proxy_config + + log_to_client("正在获取代理IP...", user_id, account_id) + proxy_server = get_proxy_from_api(proxy_api_url, max_retries=3) + if proxy_server: + proxy_config = {"server": proxy_server} + log_to_client(f"[OK] 将使用代理: {proxy_server}", user_id, account_id) + account.proxy_config = proxy_config + else: + log_to_client("✗ 代理获取失败,将不使用代理继续", user_id, account_id) + return proxy_config + + +def _refresh_account_remark(api_browser, account, user_id: int, account_id: str) -> None: + if account.remark: + return + try: + real_name = api_browser.get_real_name() + if not real_name: + return + account.remark = real_name + database.update_account_remark(account_id, real_name) + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + logger.info(f"[自动备注] 账号 {account.username} 自动设置备注为: {real_name}") + except Exception as e: + logger.warning(f"[自动备注] 获取姓名失败: {e}") + + +def _handle_login_failed( + *, + account, + user_id: int, + account_id: str, + browse_type: str, + source: str, + task_start_time: float, + task_id: str, + checkpoint_mgr, + time_module, +) -> None: + error_message = "登录失败" + log_to_client(f"❌ {error_message}", user_id, account_id) + + is_suspended = database.increment_account_login_fail(account_id, error_message) + if is_suspended: + log_to_client("⚠ 该账号连续3次密码错误,已自动暂停", user_id, account_id) + log_to_client("请在前台修改密码后才能继续使用", user_id, account_id) + + retry_action = checkpoint_mgr.record_error(task_id, error_message) + if retry_action == "paused": + logger.warning(f"[断点] 任务 {task_id} 已暂停(登录失败)") + + account.status = "登录失败" + account.is_running = False + _create_task_log( + user_id=user_id, + account_id=account_id, + account=account, + browse_type=browse_type, + status="failed", + total_items=0, + total_attachments=0, + error_message=error_message, + duration=int(time_module.time() - task_start_time), + source=source, + ) + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + + +def _execute_single_attempt( + *, + account, + user_id: int, + account_id: str, + browse_type: str, + source: str, + checkpoint_mgr, + task_id: str, + task_start_time: float, + time_module, +): + proxy_config = _resolve_proxy_config(user_id=user_id, account_id=account_id, account=account) + + checkpoint_mgr.update_stage(task_id, TaskStage.STARTING, progress_percent=10) + + def custom_log(message: str): + log_to_client(message, user_id, account_id) + + log_to_client("开始登录...", user_id, account_id) + safe_update_task_status(account_id, {"detail_status": "正在登录"}) + checkpoint_mgr.update_stage(task_id, TaskStage.LOGGING_IN, progress_percent=25) + + with APIBrowser(log_callback=custom_log, proxy_config=proxy_config) as api_browser: + if not api_browser.login(account.username, account.password): + _handle_login_failed( + account=account, + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + source=source, + task_start_time=task_start_time, + task_id=task_id, + checkpoint_mgr=checkpoint_mgr, + time_module=time_module, + ) + return "login_failed", None + + log_to_client("[OK] 首次登录成功,刷新登录时间...", user_id, account_id) + + if api_browser.login(account.username, account.password): + log_to_client("[OK] 二次登录成功!", user_id, account_id) + else: + log_to_client("⚠ 二次登录失败,继续使用首次登录状态", user_id, account_id) + + api_browser.save_cookies_for_screenshot(account.username) + database.reset_account_login_status(account_id) + _refresh_account_remark(api_browser, account, user_id, account_id) + + safe_update_task_status(account_id, {"detail_status": "正在浏览"}) + log_to_client(f"开始浏览 '{browse_type}' 内容...", user_id, account_id) + account.total_items = 0 + safe_update_task_status(account_id, {"progress": {"items": 0, "attachments": 0}}) + + def should_stop(): + return account.should_stop + + def on_browse_progress(progress: dict): + try: + total_items = int(progress.get("total_items") or 0) + browsed_items = int(progress.get("browsed_items") or 0) + if total_items > 0: + account.total_items = total_items + safe_update_task_status(account_id, {"progress": {"items": browsed_items, "attachments": 0}}) + except Exception: + pass + + checkpoint_mgr.update_stage(task_id, TaskStage.BROWSING, progress_percent=50) + result = api_browser.browse_content( + browse_type=browse_type, + should_stop_callback=should_stop, + progress_callback=on_browse_progress, + ) + return "ok", result + + +def _record_success_without_screenshot( + *, + account, + user_id: int, + account_id: str, + browse_type: str, + source: str, + task_start_time: float, + result, + batch_id, + time_module, +) -> bool: + _create_task_log( + user_id=user_id, + account_id=account_id, + account=account, + browse_type=browse_type, + status="success", + total_items=result.total_items, + total_attachments=result.total_attachments, + error_message="", + duration=int(time_module.time() - task_start_time), + source=source, + ) + if batch_id: + _record_batch_result( + batch_id=batch_id, + account=account, + total_items=result.total_items, + total_attachments=result.total_attachments, + ) + return True + + if source and source.startswith("user_scheduled"): + try: + user_info = database.get_user_by_id(user_id) + if user_info and user_info.get("email") and database.get_user_email_notify(user_id): + email_service.send_task_complete_email_async( + user_id=user_id, + email=user_info["email"], + username=user_info["username"], + account_name=_account_display_name(account), + browse_type=browse_type, + total_items=result.total_items, + total_attachments=result.total_attachments, + screenshot_path=None, + log_callback=lambda msg: log_to_client(msg, user_id, account_id), + ) + except Exception as email_error: + logger.warning(f"发送任务完成邮件失败: {email_error}") + + return False + + +def _is_timeout_error(error_msg: str) -> bool: + return "Timeout" in error_msg or "timeout" in error_msg + + +def _record_failed_task_log( + *, + account, + user_id: int, + account_id: str, + browse_type: str, + source: str, + total_items: int, + total_attachments: int, + error_message: str, + task_start_time: float, + time_module, +) -> None: + account.status = "出错" + _create_task_log( + user_id=user_id, + account_id=account_id, + account=account, + browse_type=browse_type, + status="failed", + total_items=total_items, + total_attachments=total_attachments, + error_message=error_message, + duration=int(time_module.time() - task_start_time), + source=source, + ) + + +def _handle_failed_browse_result( + *, + account, + result, + error_msg: str, + user_id: int, + account_id: str, + browse_type: str, + source: str, + attempt: int, + max_attempts: int, + task_start_time: float, + time_module, +) -> str: + if _is_timeout_error(error_msg): + log_to_client(f"⚠ 检测到超时错误: {error_msg}", user_id, account_id) + + close_ok = _close_account_automation( + account, + on_error=lambda e: logger.debug(f"关闭超时浏览器实例失败: {e}"), + ) + if close_ok: + log_to_client("已关闭超时的浏览器实例", user_id, account_id) + + if attempt < max_attempts: + log_to_client(f"⚠ 代理可能速度过慢,将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id) + time_module.sleep(2) + return "continue" + + log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id) + _record_failed_task_log( + account=account, + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + source=source, + total_items=result.total_items, + total_attachments=result.total_attachments, + error_message=f"重试{max_attempts}次后仍失败: {error_msg}", + task_start_time=task_start_time, + time_module=time_module, + ) + return "break" + + log_to_client(f"浏览出错: {error_msg}", user_id, account_id) + _record_failed_task_log( + account=account, + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + source=source, + total_items=result.total_items, + total_attachments=result.total_attachments, + error_message=error_msg, + task_start_time=task_start_time, + time_module=time_module, + ) + return "break" + + +def _handle_attempt_exception( + *, + account, + error_msg: str, + user_id: int, + account_id: str, + browse_type: str, + source: str, + attempt: int, + max_attempts: int, + task_start_time: float, + time_module, +) -> str: + _close_account_automation( + account, + on_error=lambda e: logger.debug(f"关闭浏览器实例失败: {e}"), + ) + + if _is_timeout_error(error_msg): + log_to_client(f"⚠ 执行超时: {error_msg}", user_id, account_id) + if attempt < max_attempts: + log_to_client(f"⚠ 将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id) + time_module.sleep(2) + return "continue" + + log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id) + _record_failed_task_log( + account=account, + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + source=source, + total_items=account.total_items, + total_attachments=account.total_attachments, + error_message=f"重试{max_attempts}次后仍失败: {error_msg}", + task_start_time=task_start_time, + time_module=time_module, + ) + return "break" + + log_to_client(f"任务执行异常: {error_msg}", user_id, account_id) + _record_failed_task_log( + account=account, + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + source=source, + total_items=account.total_items, + total_attachments=account.total_attachments, + error_message=error_msg, + task_start_time=task_start_time, + time_module=time_module, + ) + return "break" + + +def _schedule_auto_retry( + *, + user_id: int, + account_id: str, + browse_type: str, + enable_screenshot: bool, + source: str, + retry_count: int, +) -> None: + def delayed_retry_submit(): + fresh_account = safe_get_account(user_id, account_id) + if not fresh_account: + log_to_client("自动重试取消: 账户不存在", user_id, account_id) + return + if fresh_account.should_stop: + log_to_client("自动重试取消: 任务已被停止", user_id, account_id) + return + + log_to_client(f"🔄 开始第 {retry_count + 1} 次自动重试...", user_id, account_id) + ok, msg = submit_account_task( + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + enable_screenshot=enable_screenshot, + source=source, + retry_count=retry_count + 1, + ) + if not ok: + log_to_client(f"自动重试提交失败: {msg}", user_id, account_id) + + try: + threading.Timer(5, delayed_retry_submit).start() + except Exception: + delayed_retry_submit() + + +def _finalize_task_run( + *, + account, + user_id: int, + account_id: str, + browse_type: str, + source: str, + enable_screenshot: bool, + retry_count: int, + max_auto_retry: int, + task_start_time: float, + result, + batch_id, + batch_recorded: bool, +) -> None: + final_status = str(account.status or "") + account.is_running = False + screenshot_submitted = False + + _close_account_automation( + account, + on_error=lambda e: log_to_client(f"关闭主任务浏览器时出错: {str(e)}", user_id, account_id), + ) + + safe_remove_task(account_id) + safe_remove_task_status(account_id) + + if final_status == "已完成" and not account.should_stop: + account.status = "已完成" + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + + if enable_screenshot: + log_to_client("等待2秒后开始截图...", user_id, account_id) + account.status = "等待截图" + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + + safe_set_task_status( + account_id, + { + "user_id": user_id, + "username": account.username, + "status": "排队中", + "detail_status": "等待截图资源", + "browse_type": browse_type, + "start_time": time.time(), + "source": source, + "progress": { + "items": result.total_items if result else 0, + "attachments": result.total_attachments if result else 0, + }, + }, + ) + + browse_result_dict = { + "total_items": result.total_items if result else 0, + "total_attachments": result.total_attachments if result else 0, + } + screenshot_submitted = True + threading.Thread( + target=take_screenshot_for_account, + args=(user_id, account_id, browse_type, source, task_start_time, browse_result_dict), + daemon=True, + ).start() + else: + account.status = "未开始" + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + log_to_client("截图功能已禁用,跳过截图", user_id, account_id) + else: + if final_status == "出错" and retry_count < max_auto_retry: + log_to_client(f"⚠ 任务执行失败,5秒后自动重试 ({retry_count + 1}/{max_auto_retry})...", user_id, account_id) + account.status = "等待重试" + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + _schedule_auto_retry( + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + enable_screenshot=enable_screenshot, + source=source, + retry_count=retry_count, + ) + elif final_status in ["登录失败", "出错"]: + account.status = final_status + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + else: + account.status = "未开始" + _emit("account_update", account.to_dict(), room=f"user_{user_id}") + + if batch_id and (not screenshot_submitted) and (not batch_recorded) and account.status != "等待重试": + _record_batch_result( + batch_id=batch_id, + account=account, + total_items=getattr(account, "total_items", 0) or 0, + total_attachments=getattr(account, "total_attachments", 0) or 0, + ) + + def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="manual", retry_count=0): """运行自动化任务 @@ -498,30 +728,27 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m import time as time_module + task_start_time = time_module.time() + result = None + try: - if account.should_stop: - log_to_client("任务已取消", user_id, account_id) - account.status = "已停止" - account.is_running = False - safe_remove_task_status(account_id) - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - if batch_id: - account_name = account.remark if account.remark else account.username - _batch_task_record_result( - batch_id=batch_id, - account_name=account_name, - screenshot_path=None, - total_items=0, - total_attachments=0, - ) + if _handle_stop_requested( + account=account, + user_id=user_id, + account_id=account_id, + batch_id=batch_id, + remove_task_status=True, + record_batch=True, + ): return try: - if account.should_stop: - log_to_client("任务已取消", user_id, account_id) - account.status = "已停止" - account.is_running = False - _emit("account_update", account.to_dict(), room=f"user_{user_id}") + if _handle_stop_requested( + account=account, + user_id=user_id, + account_id=account_id, + batch_id=batch_id, + ): return task_id = checkpoint_mgr.create_checkpoint( @@ -546,111 +773,19 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m if attempt > 1: log_to_client(f"🔄 第 {attempt} 次尝试(共{max_attempts}次)...", user_id, account_id) - proxy_config = None - config = database.get_system_config() - if config.get("proxy_enabled") == 1: - proxy_api_url = config.get("proxy_api_url", "").strip() - if proxy_api_url: - log_to_client("正在获取代理IP...", user_id, account_id) - proxy_server = get_proxy_from_api(proxy_api_url, max_retries=3) - if proxy_server: - proxy_config = {"server": proxy_server} - log_to_client(f"[OK] 将使用代理: {proxy_server}", user_id, account_id) - account.proxy_config = proxy_config # 保存代理配置供截图使用 - else: - log_to_client("✗ 代理获取失败,将不使用代理继续", user_id, account_id) - else: - log_to_client("⚠ 代理已启用但未配置API地址", user_id, account_id) - - checkpoint_mgr.update_stage(task_id, TaskStage.STARTING, progress_percent=10) - - def custom_log(message: str): - log_to_client(message, user_id, account_id) - - log_to_client("开始登录...", user_id, account_id) - safe_update_task_status(account_id, {"detail_status": "正在登录"}) - checkpoint_mgr.update_stage(task_id, TaskStage.LOGGING_IN, progress_percent=25) - - with APIBrowser(log_callback=custom_log, proxy_config=proxy_config) as api_browser: - if api_browser.login(account.username, account.password): - log_to_client("[OK] 首次登录成功,刷新登录时间...", user_id, account_id) - - # 二次登录:让"上次登录时间"变成刚才首次登录的时间 - # 这样截图时显示的"上次登录时间"就是几秒前而不是昨天 - if api_browser.login(account.username, account.password): - log_to_client("[OK] 二次登录成功!", user_id, account_id) - else: - log_to_client("⚠ 二次登录失败,继续使用首次登录状态", user_id, account_id) - - api_browser.save_cookies_for_screenshot(account.username) - database.reset_account_login_status(account_id) - - if not account.remark: - try: - real_name = api_browser.get_real_name() - if real_name: - account.remark = real_name - database.update_account_remark(account_id, real_name) - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - logger.info(f"[自动备注] 账号 {account.username} 自动设置备注为: {real_name}") - except Exception as e: - logger.warning(f"[自动备注] 获取姓名失败: {e}") - - safe_update_task_status(account_id, {"detail_status": "正在浏览"}) - log_to_client(f"开始浏览 '{browse_type}' 内容...", user_id, account_id) - account.total_items = 0 - safe_update_task_status(account_id, {"progress": {"items": 0, "attachments": 0}}) - - def should_stop(): - return account.should_stop - - def on_browse_progress(progress: dict): - try: - total_items = int(progress.get("total_items") or 0) - browsed_items = int(progress.get("browsed_items") or 0) - if total_items > 0: - account.total_items = total_items - safe_update_task_status( - account_id, {"progress": {"items": browsed_items, "attachments": 0}} - ) - except Exception: - pass - - checkpoint_mgr.update_stage(task_id, TaskStage.BROWSING, progress_percent=50) - result = api_browser.browse_content( - browse_type=browse_type, - should_stop_callback=should_stop, - progress_callback=on_browse_progress, - ) - else: - error_message = "登录失败" - log_to_client(f"❌ {error_message}", user_id, account_id) - - is_suspended = database.increment_account_login_fail(account_id, error_message) - if is_suspended: - log_to_client("⚠ 该账号连续3次密码错误,已自动暂停", user_id, account_id) - log_to_client("请在前台修改密码后才能继续使用", user_id, account_id) - - retry_action = checkpoint_mgr.record_error(task_id, error_message) - if retry_action == "paused": - logger.warning(f"[断点] 任务 {task_id} 已暂停(登录失败)") - - account.status = "登录失败" - account.is_running = False - database.create_task_log( - user_id=user_id, - account_id=account_id, - username=account.username, - browse_type=browse_type, - status="failed", - total_items=0, - total_attachments=0, - error_message=error_message, - duration=int(time_module.time() - task_start_time), - source=source, - ) - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - return + attempt_status, result = _execute_single_attempt( + account=account, + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + source=source, + checkpoint_mgr=checkpoint_mgr, + task_id=task_id, + task_start_time=task_start_time, + time_module=time_module, + ) + if attempt_status == "login_failed": + return account.total_items = result.total_items account.total_attachments = result.total_attachments @@ -674,152 +809,63 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m logger.info(f"[断点] 任务 {task_id} 已完成") if not enable_screenshot: - database.create_task_log( + batch_recorded = _record_success_without_screenshot( + account=account, user_id=user_id, account_id=account_id, - username=account.username, browse_type=browse_type, - status="success", - total_items=result.total_items, - total_attachments=result.total_attachments, - error_message="", - duration=int(time_module.time() - task_start_time), source=source, + task_start_time=task_start_time, + result=result, + batch_id=batch_id, + time_module=time_module, ) - if batch_id: - account_name = account.remark if account.remark else account.username - _batch_task_record_result( - batch_id=batch_id, - account_name=account_name, - screenshot_path=None, - total_items=result.total_items, - total_attachments=result.total_attachments, - ) - batch_recorded = True - elif source and source.startswith("user_scheduled"): - try: - user_info = database.get_user_by_id(user_id) - if user_info and user_info.get("email") and database.get_user_email_notify(user_id): - account_name = account.remark if account.remark else account.username - email_service.send_task_complete_email_async( - user_id=user_id, - email=user_info["email"], - username=user_info["username"], - account_name=account_name, - browse_type=browse_type, - total_items=result.total_items, - total_attachments=result.total_attachments, - screenshot_path=None, - log_callback=lambda msg: log_to_client(msg, user_id, account_id), - ) - except Exception as email_error: - logger.warning(f"发送任务完成邮件失败: {email_error}") break error_msg = result.error_message - if "Timeout" in error_msg or "timeout" in error_msg: - log_to_client(f"⚠ 检测到超时错误: {error_msg}", user_id, account_id) - - if account.automation: - try: - account.automation.close() - log_to_client("已关闭超时的浏览器实例", user_id, account_id) - except Exception as e: - logger.debug(f"关闭超时浏览器实例失败: {e}") - account.automation = None - - if attempt < max_attempts: - log_to_client( - f"⚠ 代理可能速度过慢,将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id - ) - time_module.sleep(2) - continue - log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id) - account.status = "出错" - database.create_task_log( - user_id=user_id, - account_id=account_id, - username=account.username, - browse_type=browse_type, - status="failed", - total_items=result.total_items, - total_attachments=result.total_attachments, - error_message=f"重试{max_attempts}次后仍失败: {error_msg}", - duration=int(time_module.time() - task_start_time), - ) - break - - log_to_client(f"浏览出错: {error_msg}", user_id, account_id) - account.status = "出错" - database.create_task_log( + action = _handle_failed_browse_result( + account=account, + result=result, + error_msg=error_msg, user_id=user_id, account_id=account_id, - username=account.username, browse_type=browse_type, - status="failed", - total_items=result.total_items, - total_attachments=result.total_attachments, - error_message=error_msg, - duration=int(time_module.time() - task_start_time), source=source, + attempt=attempt, + max_attempts=max_attempts, + task_start_time=task_start_time, + time_module=time_module, ) + if action == "continue": + continue break except Exception as retry_error: error_msg = str(retry_error) - if account.automation: - try: - account.automation.close() - except Exception as e: - logger.debug(f"关闭浏览器实例失败: {e}") - account.automation = None - - if "Timeout" in error_msg or "timeout" in error_msg: - log_to_client(f"⚠ 执行超时: {error_msg}", user_id, account_id) - if attempt < max_attempts: - log_to_client(f"⚠ 将换新IP重试 ({attempt}/{max_attempts})", user_id, account_id) - time_module.sleep(2) - continue - log_to_client(f"❌ 已达到最大重试次数({max_attempts}),任务失败", user_id, account_id) - account.status = "出错" - database.create_task_log( - user_id=user_id, - account_id=account_id, - username=account.username, - browse_type=browse_type, - status="failed", - total_items=account.total_items, - total_attachments=account.total_attachments, - error_message=f"重试{max_attempts}次后仍失败: {error_msg}", - duration=int(time_module.time() - task_start_time), - source=source, - ) - break - - log_to_client(f"任务执行异常: {error_msg}", user_id, account_id) - account.status = "出错" - database.create_task_log( + action = _handle_attempt_exception( + account=account, + error_msg=error_msg, user_id=user_id, account_id=account_id, - username=account.username, browse_type=browse_type, - status="failed", - total_items=account.total_items, - total_attachments=account.total_attachments, - error_message=error_msg, - duration=int(time_module.time() - task_start_time), source=source, + attempt=attempt, + max_attempts=max_attempts, + task_start_time=task_start_time, + time_module=time_module, ) + if action == "continue": + continue break except Exception as e: error_msg = str(e) log_to_client(f"任务执行出错: {error_msg}", user_id, account_id) account.status = "出错" - database.create_task_log( + _create_task_log( user_id=user_id, account_id=account_id, - username=account.username, + account=account, browse_type=browse_type, status="failed", total_items=account.total_items, @@ -830,108 +876,20 @@ def run_task(user_id, account_id, browse_type, enable_screenshot=True, source="m ) finally: - account.is_running = False - screenshot_submitted = False - if account.status not in ["已完成"]: - account.status = "未开始" - - if account.automation: - try: - account.automation.close() - except Exception as e: - log_to_client(f"关闭主任务浏览器时出错: {str(e)}", user_id, account_id) - finally: - account.automation = None - - safe_remove_task(account_id) - safe_remove_task_status(account_id) - - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - - if account.status == "已完成" and not account.should_stop: - if enable_screenshot: - log_to_client("等待2秒后开始截图...", user_id, account_id) - account.status = "等待截图" - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - import time as time_mod - - safe_set_task_status( - account_id, - { - "user_id": user_id, - "username": account.username, - "status": "排队中", - "detail_status": "等待截图资源", - "browse_type": browse_type, - "start_time": time_mod.time(), - "source": source, - "progress": { - "items": result.total_items if result else 0, - "attachments": result.total_attachments if result else 0, - }, - }, - ) - browse_result_dict = { - "total_items": result.total_items, - "total_attachments": result.total_attachments, - } - screenshot_submitted = True - threading.Thread( - target=take_screenshot_for_account, - args=(user_id, account_id, browse_type, source, task_start_time, browse_result_dict), - daemon=True, - ).start() - else: - account.status = "未开始" - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - log_to_client("截图功能已禁用,跳过截图", user_id, account_id) - else: - if account.status not in ["登录失败", "出错"]: - account.status = "未开始" - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - elif account.status == "出错" and retry_count < MAX_AUTO_RETRY: - log_to_client( - f"⚠ 任务执行失败,5秒后自动重试 ({retry_count + 1}/{MAX_AUTO_RETRY})...", user_id, account_id - ) - account.status = "等待重试" - _emit("account_update", account.to_dict(), room=f"user_{user_id}") - - def delayed_retry_submit(): - # 重新获取最新的账户对象,避免使用闭包中的旧对象 - fresh_account = safe_get_account(user_id, account_id) - if not fresh_account: - log_to_client("自动重试取消: 账户不存在", user_id, account_id) - return - if fresh_account.should_stop: - log_to_client("自动重试取消: 任务已被停止", user_id, account_id) - return - log_to_client(f"🔄 开始第 {retry_count + 1} 次自动重试...", user_id, account_id) - ok, msg = submit_account_task( - user_id=user_id, - account_id=account_id, - browse_type=browse_type, - enable_screenshot=enable_screenshot, - source=source, - retry_count=retry_count + 1, - ) - if not ok: - log_to_client(f"自动重试提交失败: {msg}", user_id, account_id) - - try: - threading.Timer(5, delayed_retry_submit).start() - except Exception: - delayed_retry_submit() - - if batch_id and (not screenshot_submitted) and (not batch_recorded) and account.status != "等待重试": - account_name = account.remark if account.remark else account.username - _batch_task_record_result( - batch_id=batch_id, - account_name=account_name, - screenshot_path=None, - total_items=getattr(account, "total_items", 0) or 0, - total_attachments=getattr(account, "total_attachments", 0) or 0, - ) - batch_recorded = True + _finalize_task_run( + account=account, + user_id=user_id, + account_id=account_id, + browse_type=browse_type, + source=source, + enable_screenshot=enable_screenshot, + retry_count=retry_count, + max_auto_retry=MAX_AUTO_RETRY, + task_start_time=task_start_time, + result=result, + batch_id=batch_id, + batch_recorded=batch_recorded, + ) finally: pass