refactor: optimize structure, stability and runtime performance
This commit is contained in:
365
services/task_scheduler.py
Normal file
365
services/task_scheduler.py
Normal file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from dataclasses import dataclass
|
||||
|
||||
import database
|
||||
from app_logger import get_logger
|
||||
from services.state import safe_get_account, safe_get_task, safe_remove_task, safe_set_task
|
||||
from services.task_batches import _batch_task_record_result, _get_batch_id_from_source
|
||||
|
||||
logger = get_logger("app")
|
||||
|
||||
# VIP优先级队列(仅用于可视化/调试)
|
||||
vip_task_queue = [] # VIP用户任务队列
|
||||
normal_task_queue = [] # 普通用户任务队列
|
||||
task_queue_lock = threading.Lock()
|
||||
|
||||
@dataclass
|
||||
class _TaskRequest:
|
||||
user_id: int
|
||||
account_id: str
|
||||
browse_type: str
|
||||
enable_screenshot: bool
|
||||
source: str
|
||||
retry_count: int
|
||||
submitted_at: float
|
||||
is_vip: bool
|
||||
seq: int
|
||||
canceled: bool = False
|
||||
done_callback: object = None
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""全局任务调度器:队列排队,不为每个任务单独创建线程。"""
|
||||
|
||||
def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000, run_task_fn=None):
|
||||
self.max_global = max(1, int(max_global))
|
||||
self.max_per_user = max(1, int(max_per_user))
|
||||
self.max_queue_size = max(1, int(max_queue_size))
|
||||
|
||||
self._cond = threading.Condition()
|
||||
self._pending = [] # heap: (priority, submitted_at, seq, task)
|
||||
self._pending_by_account = {} # {account_id: task}
|
||||
self._seq = 0
|
||||
self._known_account_ids = set()
|
||||
|
||||
self._running_global = 0
|
||||
self._running_by_user = {} # {user_id: running_count}
|
||||
|
||||
self._executor_max_workers = self.max_global
|
||||
self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker")
|
||||
|
||||
self._futures_lock = threading.Lock()
|
||||
self._active_futures = set()
|
||||
|
||||
self._running = True
|
||||
self._run_task_fn = run_task_fn
|
||||
self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher")
|
||||
self._dispatcher_thread.start()
|
||||
|
||||
def _track_future(self, future) -> None:
|
||||
with self._futures_lock:
|
||||
self._active_futures.add(future)
|
||||
try:
|
||||
future.add_done_callback(self._untrack_future)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _untrack_future(self, future) -> None:
|
||||
with self._futures_lock:
|
||||
self._active_futures.discard(future)
|
||||
|
||||
def shutdown(self, timeout: float = 5.0):
|
||||
"""停止调度器(用于进程退出清理)"""
|
||||
with self._cond:
|
||||
self._running = False
|
||||
self._cond.notify_all()
|
||||
|
||||
try:
|
||||
self._dispatcher_thread.join(timeout=timeout)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试
|
||||
try:
|
||||
deadline = time.time() + max(0.0, float(timeout or 0))
|
||||
while True:
|
||||
with self._futures_lock:
|
||||
pending = [f for f in self._active_futures if not f.done()]
|
||||
if not pending:
|
||||
break
|
||||
remaining = deadline - time.time()
|
||||
if remaining <= 0:
|
||||
break
|
||||
wait(pending, timeout=remaining)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self._executor.shutdown(wait=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 最后兜底:清理本调度器提交过的 active_task,避免测试/重启时被“任务已在运行中”误拦截
|
||||
try:
|
||||
with self._cond:
|
||||
known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys())
|
||||
self._pending.clear()
|
||||
self._pending_by_account.clear()
|
||||
self._cond.notify_all()
|
||||
for account_id in known_ids:
|
||||
safe_remove_task(account_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None):
|
||||
"""动态更新并发/队列上限(不影响已在运行的任务)"""
|
||||
with self._cond:
|
||||
if max_per_user is not None:
|
||||
self.max_per_user = max(1, int(max_per_user))
|
||||
if max_queue_size is not None:
|
||||
self.max_queue_size = max(1, int(max_queue_size))
|
||||
|
||||
if max_global is not None:
|
||||
new_max_global = max(1, int(max_global))
|
||||
self.max_global = new_max_global
|
||||
if new_max_global > self._executor_max_workers:
|
||||
# 立即关闭旧线程池,防止资源泄漏
|
||||
old_executor = self._executor
|
||||
self._executor_max_workers = new_max_global
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker"
|
||||
)
|
||||
# 立即关闭旧线程池
|
||||
try:
|
||||
old_executor.shutdown(wait=False)
|
||||
logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭旧线程池失败: {e}")
|
||||
|
||||
self._cond.notify_all()
|
||||
|
||||
def get_queue_state_snapshot(self) -> dict:
|
||||
"""获取调度器队列/运行状态快照(用于前端展示/监控)。"""
|
||||
with self._cond:
|
||||
pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled]
|
||||
pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq))
|
||||
|
||||
positions = {}
|
||||
for idx, t in enumerate(pending_tasks):
|
||||
positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)}
|
||||
|
||||
return {
|
||||
"pending_total": len(pending_tasks),
|
||||
"running_total": int(self._running_global),
|
||||
"running_by_user": dict(self._running_by_user),
|
||||
"positions": positions,
|
||||
}
|
||||
|
||||
def submit_task(
|
||||
self,
|
||||
user_id: int,
|
||||
account_id: str,
|
||||
browse_type: str,
|
||||
enable_screenshot: bool = True,
|
||||
source: str = "manual",
|
||||
retry_count: int = 0,
|
||||
is_vip: bool = None,
|
||||
done_callback=None,
|
||||
):
|
||||
"""提交任务进入队列(返回: (ok, message))"""
|
||||
if not user_id or not account_id:
|
||||
return False, "参数错误"
|
||||
|
||||
submitted_at = time.time()
|
||||
if is_vip is None:
|
||||
try:
|
||||
is_vip = bool(database.is_user_vip(user_id))
|
||||
except Exception:
|
||||
is_vip = False
|
||||
else:
|
||||
is_vip = bool(is_vip)
|
||||
|
||||
with self._cond:
|
||||
if not self._running:
|
||||
return False, "调度器未运行"
|
||||
if len(self._pending_by_account) >= self.max_queue_size:
|
||||
return False, "任务队列已满,请稍后再试"
|
||||
if account_id in self._pending_by_account:
|
||||
return False, "任务已在队列中"
|
||||
if safe_get_task(account_id) is not None:
|
||||
return False, "任务已在运行中"
|
||||
|
||||
self._seq += 1
|
||||
task = _TaskRequest(
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
browse_type=browse_type,
|
||||
enable_screenshot=bool(enable_screenshot),
|
||||
source=source,
|
||||
retry_count=int(retry_count or 0),
|
||||
submitted_at=submitted_at,
|
||||
is_vip=is_vip,
|
||||
seq=self._seq,
|
||||
done_callback=done_callback,
|
||||
)
|
||||
self._pending_by_account[account_id] = task
|
||||
self._known_account_ids.add(account_id)
|
||||
priority = 0 if is_vip else 1
|
||||
heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task))
|
||||
self._cond.notify_all()
|
||||
|
||||
# 用于可视化/调试:记录队列
|
||||
with task_queue_lock:
|
||||
if is_vip:
|
||||
vip_task_queue.append(account_id)
|
||||
else:
|
||||
normal_task_queue.append(account_id)
|
||||
|
||||
return True, "已加入队列"
|
||||
|
||||
def cancel_pending_task(self, user_id: int, account_id: str) -> bool:
|
||||
"""取消尚未开始的排队任务(已运行的任务由 should_stop 控制)"""
|
||||
canceled_task = None
|
||||
with self._cond:
|
||||
task = self._pending_by_account.pop(account_id, None)
|
||||
if not task:
|
||||
return False
|
||||
task.canceled = True
|
||||
canceled_task = task
|
||||
self._cond.notify_all()
|
||||
|
||||
# 从可视化队列移除
|
||||
with task_queue_lock:
|
||||
if account_id in vip_task_queue:
|
||||
vip_task_queue.remove(account_id)
|
||||
if account_id in normal_task_queue:
|
||||
normal_task_queue.remove(account_id)
|
||||
|
||||
# 批次任务:取消也要推进完成计数,避免批次缓存常驻
|
||||
try:
|
||||
batch_id = _get_batch_id_from_source(canceled_task.source)
|
||||
if batch_id:
|
||||
acc = safe_get_account(user_id, account_id)
|
||||
if acc:
|
||||
account_name = acc.remark if acc.remark else acc.username
|
||||
else:
|
||||
account_name = account_id
|
||||
_batch_task_record_result(
|
||||
batch_id=batch_id,
|
||||
account_name=account_name,
|
||||
screenshot_path=None,
|
||||
total_items=0,
|
||||
total_attachments=0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def _dispatch_loop(self):
|
||||
while True:
|
||||
task = None
|
||||
with self._cond:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
if not self._pending or self._running_global >= self.max_global:
|
||||
self._cond.wait(timeout=0.5)
|
||||
continue
|
||||
|
||||
task = self._pop_next_runnable_locked()
|
||||
if task is None:
|
||||
self._cond.wait(timeout=0.5)
|
||||
continue
|
||||
|
||||
self._running_global += 1
|
||||
self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1
|
||||
|
||||
# 从队列移除(可视化)
|
||||
with task_queue_lock:
|
||||
if task.account_id in vip_task_queue:
|
||||
vip_task_queue.remove(task.account_id)
|
||||
if task.account_id in normal_task_queue:
|
||||
normal_task_queue.remove(task.account_id)
|
||||
|
||||
try:
|
||||
future = self._executor.submit(self._run_task_wrapper, task)
|
||||
self._track_future(future)
|
||||
safe_set_task(task.account_id, future)
|
||||
except Exception:
|
||||
with self._cond:
|
||||
self._running_global = max(0, self._running_global - 1)
|
||||
# 使用默认值 0 与增加时保持一致
|
||||
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
|
||||
if self._running_by_user.get(task.user_id) == 0:
|
||||
self._running_by_user.pop(task.user_id, None)
|
||||
self._cond.notify_all()
|
||||
|
||||
def _pop_next_runnable_locked(self):
|
||||
"""在锁内从优先队列取出“可运行”的任务,避免VIP任务占位阻塞普通任务。"""
|
||||
if not self._pending:
|
||||
return None
|
||||
|
||||
skipped = []
|
||||
selected = None
|
||||
|
||||
while self._pending:
|
||||
_, _, _, task = heapq.heappop(self._pending)
|
||||
|
||||
if task.canceled:
|
||||
continue
|
||||
if self._pending_by_account.get(task.account_id) is not task:
|
||||
continue
|
||||
|
||||
running_for_user = self._running_by_user.get(task.user_id, 0)
|
||||
if running_for_user >= self.max_per_user:
|
||||
skipped.append(task)
|
||||
continue
|
||||
|
||||
selected = task
|
||||
break
|
||||
|
||||
for t in skipped:
|
||||
priority = 0 if t.is_vip else 1
|
||||
heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t))
|
||||
|
||||
if selected is None:
|
||||
return None
|
||||
|
||||
self._pending_by_account.pop(selected.account_id, None)
|
||||
return selected
|
||||
|
||||
def _run_task_wrapper(self, task: _TaskRequest):
|
||||
try:
|
||||
if callable(self._run_task_fn):
|
||||
self._run_task_fn(
|
||||
user_id=task.user_id,
|
||||
account_id=task.account_id,
|
||||
browse_type=task.browse_type,
|
||||
enable_screenshot=task.enable_screenshot,
|
||||
source=task.source,
|
||||
retry_count=task.retry_count,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
if callable(task.done_callback):
|
||||
task.done_callback()
|
||||
except Exception:
|
||||
pass
|
||||
safe_remove_task(task.account_id)
|
||||
with self._cond:
|
||||
self._running_global = max(0, self._running_global - 1)
|
||||
# 使用默认值 0 与增加时保持一致
|
||||
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
|
||||
if self._running_by_user.get(task.user_id) == 0:
|
||||
self._running_by_user.pop(task.user_id, None)
|
||||
self._cond.notify_all()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user