""" 任务断点续传模块 功能: 1. 记录任务执行进度(每个步骤的状态) 2. 任务异常时自动保存断点 3. 重启后自动恢复未完成任务 4. 智能重试机制 """ import time import json from datetime import datetime from enum import Enum import db_pool class TaskStage(Enum): """任务执行阶段""" QUEUED = 'queued' # 排队中 STARTING = 'starting' # 启动浏览器 LOGGING_IN = 'logging_in' # 登录中 BROWSING = 'browsing' # 浏览中 DOWNLOADING = 'downloading' # 下载中 COMPLETING = 'completing' # 完成中 COMPLETED = 'completed' # 已完成 FAILED = 'failed' # 失败 PAUSED = 'paused' # 暂停(等待恢复) class TaskCheckpoint: """任务断点管理器""" def __init__(self): """初始化(使用全局连接池)""" self._init_table() def _init_table(self): """初始化任务进度表""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS task_checkpoints ( id INTEGER PRIMARY KEY AUTOINCREMENT, task_id TEXT UNIQUE NOT NULL, -- 任务唯一ID (user_id:account_id:timestamp) user_id INTEGER NOT NULL, account_id TEXT NOT NULL, username TEXT NOT NULL, browse_type TEXT NOT NULL, -- 任务状态 stage TEXT NOT NULL, -- 当前阶段 status TEXT NOT NULL, -- running/paused/completed/failed progress_percent INTEGER DEFAULT 0, -- 进度百分比 -- 进度详情 current_page INTEGER DEFAULT 0, -- 当前浏览到第几页 total_pages INTEGER DEFAULT 0, -- 总页数(如果已知) processed_items INTEGER DEFAULT 0, -- 已处理条目数 downloaded_files INTEGER DEFAULT 0, -- 已下载文件数 -- 错误处理 retry_count INTEGER DEFAULT 0, -- 重试次数 max_retries INTEGER DEFAULT 3, -- 最大重试次数 last_error TEXT, -- 最后一次错误信息 error_count INTEGER DEFAULT 0, -- 累计错误次数 -- 断点数据(JSON格式存储上下文) checkpoint_data TEXT, -- 断点上下文数据 -- 时间戳 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, completed_at TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE ) """) # 创建索引加速查询 cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_task_status ON task_checkpoints(status, stage) """) cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_task_user ON task_checkpoints(user_id, account_id) """) conn.commit() def create_checkpoint(self, user_id, account_id, username, browse_type): """创建新的任务断点""" task_id = f"{user_id}:{account_id}:{int(time.time())}" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO task_checkpoints (task_id, user_id, account_id, username, browse_type, stage, status) VALUES (?, ?, ?, ?, ?, ?, ?) """, (task_id, user_id, account_id, username, browse_type, TaskStage.QUEUED.value, 'running')) conn.commit() return task_id def update_stage(self, task_id, stage, progress_percent=None, checkpoint_data=None): """更新任务阶段""" with db_pool.get_db() as conn: cursor = conn.cursor() updates = ['stage = ?', 'updated_at = CURRENT_TIMESTAMP'] params = [stage.value if isinstance(stage, TaskStage) else stage] if progress_percent is not None: updates.append('progress_percent = ?') params.append(progress_percent) if checkpoint_data is not None: updates.append('checkpoint_data = ?') params.append(json.dumps(checkpoint_data, ensure_ascii=False)) params.append(task_id) cursor.execute(f""" UPDATE task_checkpoints SET {', '.join(updates)} WHERE task_id = ? """, params) conn.commit() def update_progress(self, task_id, **kwargs): """更新任务进度 Args: task_id: 任务ID current_page: 当前页码 total_pages: 总页数 processed_items: 已处理条目数 downloaded_files: 已下载文件数 """ with db_pool.get_db() as conn: cursor = conn.cursor() updates = ['updated_at = CURRENT_TIMESTAMP'] params = [] for key in ['current_page', 'total_pages', 'processed_items', 'downloaded_files']: if key in kwargs: updates.append(f'{key} = ?') params.append(kwargs[key]) # 自动计算进度百分比 if 'current_page' in kwargs and 'total_pages' in kwargs and kwargs['total_pages'] > 0: progress = int((kwargs['current_page'] / kwargs['total_pages']) * 100) updates.append('progress_percent = ?') params.append(min(progress, 100)) params.append(task_id) cursor.execute(f""" UPDATE task_checkpoints SET {', '.join(updates)} WHERE task_id = ? """, params) conn.commit() def record_error(self, task_id, error_message, pause=False): """记录错误并决定是否暂停任务""" with db_pool.get_db() as conn: cursor = conn.cursor() # 获取当前重试次数和最大重试次数 cursor.execute(""" SELECT retry_count, max_retries, error_count FROM task_checkpoints WHERE task_id = ? """, (task_id,)) result = cursor.fetchone() if result: retry_count, max_retries, error_count = result retry_count += 1 error_count += 1 # 判断是否超过最大重试次数 if retry_count >= max_retries or pause: # 超过重试次数,暂停任务等待人工处理 cursor.execute(""" UPDATE task_checkpoints SET status = 'paused', stage = ?, retry_count = ?, error_count = ?, last_error = ?, updated_at = CURRENT_TIMESTAMP WHERE task_id = ? """, (TaskStage.PAUSED.value, retry_count, error_count, error_message, task_id)) conn.commit() return 'paused' else: # 还可以重试 cursor.execute(""" UPDATE task_checkpoints SET retry_count = ?, error_count = ?, last_error = ?, updated_at = CURRENT_TIMESTAMP WHERE task_id = ? """, (retry_count, error_count, error_message, task_id)) conn.commit() return 'retry' return 'unknown' def complete_task(self, task_id, success=True): """完成任务""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE task_checkpoints SET status = ?, stage = ?, progress_percent = 100, completed_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE task_id = ? """, ('completed' if success else 'failed', TaskStage.COMPLETED.value if success else TaskStage.FAILED.value, task_id)) conn.commit() def get_checkpoint(self, task_id): """获取任务断点信息""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(""" SELECT task_id, user_id, account_id, username, browse_type, stage, status, progress_percent, current_page, total_pages, processed_items, downloaded_files, retry_count, max_retries, last_error, error_count, checkpoint_data, created_at, updated_at, completed_at FROM task_checkpoints WHERE task_id = ? """, (task_id,)) row = cursor.fetchone() if row: return { 'task_id': row[0], 'user_id': row[1], 'account_id': row[2], 'username': row[3], 'browse_type': row[4], 'stage': row[5], 'status': row[6], 'progress_percent': row[7], 'current_page': row[8], 'total_pages': row[9], 'processed_items': row[10], 'downloaded_files': row[11], 'retry_count': row[12], 'max_retries': row[13], 'last_error': row[14], 'error_count': row[15], 'checkpoint_data': json.loads(row[16]) if row[16] else None, 'created_at': row[17], 'updated_at': row[18], 'completed_at': row[19] } return None def get_paused_tasks(self, user_id=None): """获取所有暂停的任务(可恢复的任务)""" with db_pool.get_db() as conn: cursor = conn.cursor() if user_id: cursor.execute(""" SELECT task_id, user_id, account_id, username, browse_type, stage, progress_percent, last_error, retry_count, updated_at FROM task_checkpoints WHERE status = 'paused' AND user_id = ? ORDER BY updated_at DESC """, (user_id,)) else: cursor.execute(""" SELECT task_id, user_id, account_id, username, browse_type, stage, progress_percent, last_error, retry_count, updated_at FROM task_checkpoints WHERE status = 'paused' ORDER BY updated_at DESC """) tasks = [] for row in cursor.fetchall(): tasks.append({ 'task_id': row[0], 'user_id': row[1], 'account_id': row[2], 'username': row[3], 'browse_type': row[4], 'stage': row[5], 'progress_percent': row[6], 'last_error': row[7], 'retry_count': row[8], 'updated_at': row[9] }) return tasks def resume_task(self, task_id): """恢复暂停的任务""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE task_checkpoints SET status = 'running', retry_count = 0, updated_at = CURRENT_TIMESTAMP WHERE task_id = ? AND status = 'paused' """, (task_id,)) conn.commit() return cursor.rowcount > 0 def abandon_task(self, task_id): """放弃暂停的任务""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(""" UPDATE task_checkpoints SET status = 'failed', stage = ?, completed_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE task_id = ? AND status = 'paused' """, (TaskStage.FAILED.value, task_id)) conn.commit() return cursor.rowcount > 0 def cleanup_old_checkpoints(self, days=7): """清理旧的断点数据(保留最近N天)""" with db_pool.get_db() as conn: cursor = conn.cursor() cursor.execute(""" DELETE FROM task_checkpoints WHERE status IN ('completed', 'failed') AND datetime(completed_at) < datetime('now', '-' || ? || ' days') """, (days,)) deleted = cursor.rowcount conn.commit() return deleted # 全局单例 _checkpoint_manager = None def get_checkpoint_manager(): """获取全局断点管理器实例""" global _checkpoint_manager if _checkpoint_manager is None: _checkpoint_manager = TaskCheckpoint() return _checkpoint_manager