#!/usr/bin/env python3 # -*- coding: utf-8 -*- from __future__ import annotations import heapq import threading import time from concurrent.futures import ThreadPoolExecutor, wait from dataclasses import dataclass import database from app_logger import get_logger from services.state import safe_get_account, safe_get_task, safe_remove_task, safe_set_task from services.task_batches import _batch_task_record_result, _get_batch_id_from_source logger = get_logger("app") # VIP优先级队列(仅用于可视化/调试) vip_task_queue = [] # VIP用户任务队列 normal_task_queue = [] # 普通用户任务队列 task_queue_lock = threading.Lock() @dataclass class _TaskRequest: user_id: int account_id: str browse_type: str enable_screenshot: bool source: str retry_count: int submitted_at: float is_vip: bool seq: int canceled: bool = False done_callback: object = None class TaskScheduler: """全局任务调度器:队列排队,不为每个任务单独创建线程。""" def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000, run_task_fn=None): self.max_global = max(1, int(max_global)) self.max_per_user = max(1, int(max_per_user)) self.max_queue_size = max(1, int(max_queue_size)) self._cond = threading.Condition() self._pending = [] # heap: (priority, submitted_at, seq, task) self._pending_by_account = {} # {account_id: task} self._seq = 0 self._known_account_ids = set() self._running_global = 0 self._running_by_user = {} # {user_id: running_count} self._executor_max_workers = self.max_global self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker") self._futures_lock = threading.Lock() self._active_futures = set() self._running = True self._run_task_fn = run_task_fn self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher") self._dispatcher_thread.start() def _track_future(self, future) -> None: with self._futures_lock: self._active_futures.add(future) try: future.add_done_callback(self._untrack_future) except Exception: pass def _untrack_future(self, future) -> None: with self._futures_lock: self._active_futures.discard(future) def shutdown(self, timeout: float = 5.0): """停止调度器(用于进程退出清理)""" with self._cond: self._running = False self._cond.notify_all() try: self._dispatcher_thread.join(timeout=timeout) except Exception: pass # 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试 try: deadline = time.time() + max(0.0, float(timeout or 0)) while True: with self._futures_lock: pending = [f for f in self._active_futures if not f.done()] if not pending: break remaining = deadline - time.time() if remaining <= 0: break wait(pending, timeout=remaining) except Exception: pass try: self._executor.shutdown(wait=False) except Exception: pass # 最后兜底:清理本调度器提交过的 active_task,避免测试/重启时被“任务已在运行中”误拦截 try: with self._cond: known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys()) self._pending.clear() self._pending_by_account.clear() self._cond.notify_all() for account_id in known_ids: safe_remove_task(account_id) except Exception: pass def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None): """动态更新并发/队列上限(不影响已在运行的任务)""" with self._cond: if max_per_user is not None: self.max_per_user = max(1, int(max_per_user)) if max_queue_size is not None: self.max_queue_size = max(1, int(max_queue_size)) if max_global is not None: new_max_global = max(1, int(max_global)) self.max_global = new_max_global if new_max_global > self._executor_max_workers: # 立即关闭旧线程池,防止资源泄漏 old_executor = self._executor self._executor_max_workers = new_max_global self._executor = ThreadPoolExecutor( max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker" ) # 立即关闭旧线程池 try: old_executor.shutdown(wait=False) logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}") except Exception as e: logger.warning(f"关闭旧线程池失败: {e}") self._cond.notify_all() def get_queue_state_snapshot(self) -> dict: """获取调度器队列/运行状态快照(用于前端展示/监控)。""" with self._cond: pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled] pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq)) positions = {} for idx, t in enumerate(pending_tasks): positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)} return { "pending_total": len(pending_tasks), "running_total": int(self._running_global), "running_by_user": dict(self._running_by_user), "positions": positions, } def submit_task( self, user_id: int, account_id: str, browse_type: str, enable_screenshot: bool = True, source: str = "manual", retry_count: int = 0, is_vip: bool = None, done_callback=None, ): """提交任务进入队列(返回: (ok, message))""" if not user_id or not account_id: return False, "参数错误" submitted_at = time.time() if is_vip is None: try: is_vip = bool(database.is_user_vip(user_id)) except Exception: is_vip = False else: is_vip = bool(is_vip) with self._cond: if not self._running: return False, "调度器未运行" if len(self._pending_by_account) >= self.max_queue_size: return False, "任务队列已满,请稍后再试" if account_id in self._pending_by_account: return False, "任务已在队列中" if safe_get_task(account_id) is not None: return False, "任务已在运行中" self._seq += 1 task = _TaskRequest( user_id=user_id, account_id=account_id, browse_type=browse_type, enable_screenshot=bool(enable_screenshot), source=source, retry_count=int(retry_count or 0), submitted_at=submitted_at, is_vip=is_vip, seq=self._seq, done_callback=done_callback, ) self._pending_by_account[account_id] = task self._known_account_ids.add(account_id) priority = 0 if is_vip else 1 heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task)) self._cond.notify_all() # 用于可视化/调试:记录队列 with task_queue_lock: if is_vip: vip_task_queue.append(account_id) else: normal_task_queue.append(account_id) return True, "已加入队列" def cancel_pending_task(self, user_id: int, account_id: str) -> bool: """取消尚未开始的排队任务(已运行的任务由 should_stop 控制)""" canceled_task = None with self._cond: task = self._pending_by_account.pop(account_id, None) if not task: return False task.canceled = True canceled_task = task self._cond.notify_all() # 从可视化队列移除 with task_queue_lock: if account_id in vip_task_queue: vip_task_queue.remove(account_id) if account_id in normal_task_queue: normal_task_queue.remove(account_id) # 批次任务:取消也要推进完成计数,避免批次缓存常驻 try: batch_id = _get_batch_id_from_source(canceled_task.source) if batch_id: acc = safe_get_account(user_id, account_id) if acc: account_name = acc.remark if acc.remark else acc.username else: account_name = account_id _batch_task_record_result( batch_id=batch_id, account_name=account_name, screenshot_path=None, total_items=0, total_attachments=0, ) except Exception: pass return True def _dispatch_loop(self): while True: task = None with self._cond: if not self._running: return if not self._pending or self._running_global >= self.max_global: self._cond.wait(timeout=0.5) continue task = self._pop_next_runnable_locked() if task is None: self._cond.wait(timeout=0.5) continue self._running_global += 1 self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1 # 从队列移除(可视化) with task_queue_lock: if task.account_id in vip_task_queue: vip_task_queue.remove(task.account_id) if task.account_id in normal_task_queue: normal_task_queue.remove(task.account_id) try: future = self._executor.submit(self._run_task_wrapper, task) self._track_future(future) safe_set_task(task.account_id, future) except Exception: with self._cond: self._running_global = max(0, self._running_global - 1) # 使用默认值 0 与增加时保持一致 self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1) if self._running_by_user.get(task.user_id) == 0: self._running_by_user.pop(task.user_id, None) self._cond.notify_all() def _pop_next_runnable_locked(self): """在锁内从优先队列取出“可运行”的任务,避免VIP任务占位阻塞普通任务。""" if not self._pending: return None skipped = [] selected = None while self._pending: _, _, _, task = heapq.heappop(self._pending) if task.canceled: continue if self._pending_by_account.get(task.account_id) is not task: continue running_for_user = self._running_by_user.get(task.user_id, 0) if running_for_user >= self.max_per_user: skipped.append(task) continue selected = task break for t in skipped: priority = 0 if t.is_vip else 1 heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t)) if selected is None: return None self._pending_by_account.pop(selected.account_id, None) return selected def _run_task_wrapper(self, task: _TaskRequest): try: if callable(self._run_task_fn): self._run_task_fn( user_id=task.user_id, account_id=task.account_id, browse_type=task.browse_type, enable_screenshot=task.enable_screenshot, source=task.source, retry_count=task.retry_count, ) finally: try: if callable(task.done_callback): task.done_callback() except Exception: pass safe_remove_task(task.account_id) with self._cond: self._running_global = max(0, self._running_global - 1) # 使用默认值 0 与增加时保持一致 self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1) if self._running_by_user.get(task.user_id) == 0: self._running_by_user.pop(task.user_id, None) self._cond.notify_all()