Some checks are pending
Docker Image CI / build-and-push-image (push) Waiting to run
387 lines
15 KiB
Python
387 lines
15 KiB
Python
"""
|
||
任务管理器
|
||
负责管理后台任务、日志队列和 WebSocket 推送
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import threading
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import Dict, Optional, List, Callable, Any
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 全局线程池(支持最多 50 个并发注册任务)
|
||
_executor = ThreadPoolExecutor(max_workers=50, thread_name_prefix="reg_worker")
|
||
|
||
# 全局元锁:保护所有 defaultdict 的首次 key 创建(避免多线程竞态)
|
||
_meta_lock = threading.Lock()
|
||
|
||
# 任务日志队列 (task_uuid -> list of logs)
|
||
_log_queues: Dict[str, List[str]] = defaultdict(list)
|
||
_log_locks: Dict[str, threading.Lock] = {}
|
||
|
||
# WebSocket 连接管理 (task_uuid -> list of websockets)
|
||
_ws_connections: Dict[str, List] = defaultdict(list)
|
||
_ws_lock = threading.Lock()
|
||
|
||
# WebSocket 已发送日志索引 (task_uuid -> {websocket: sent_count})
|
||
_ws_sent_index: Dict[str, Dict] = defaultdict(dict)
|
||
|
||
# 任务状态
|
||
_task_status: Dict[str, dict] = {}
|
||
|
||
# 任务取消标志
|
||
_task_cancelled: Dict[str, bool] = {}
|
||
|
||
# 批量任务状态 (batch_id -> dict)
|
||
_batch_status: Dict[str, dict] = {}
|
||
_batch_logs: Dict[str, List[str]] = defaultdict(list)
|
||
_batch_locks: Dict[str, threading.Lock] = {}
|
||
|
||
|
||
def _get_log_lock(task_uuid: str) -> threading.Lock:
|
||
"""线程安全地获取或创建任务日志锁"""
|
||
if task_uuid not in _log_locks:
|
||
with _meta_lock:
|
||
if task_uuid not in _log_locks:
|
||
_log_locks[task_uuid] = threading.Lock()
|
||
return _log_locks[task_uuid]
|
||
|
||
|
||
def _get_batch_lock(batch_id: str) -> threading.Lock:
|
||
"""线程安全地获取或创建批量任务日志锁"""
|
||
if batch_id not in _batch_locks:
|
||
with _meta_lock:
|
||
if batch_id not in _batch_locks:
|
||
_batch_locks[batch_id] = threading.Lock()
|
||
return _batch_locks[batch_id]
|
||
|
||
|
||
class TaskManager:
|
||
"""任务管理器"""
|
||
|
||
def __init__(self):
|
||
self.executor = _executor
|
||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||
|
||
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
||
"""设置事件循环(在 FastAPI 启动时调用)"""
|
||
self._loop = loop
|
||
|
||
def get_loop(self) -> Optional[asyncio.AbstractEventLoop]:
|
||
"""获取事件循环"""
|
||
return self._loop
|
||
|
||
def is_cancelled(self, task_uuid: str) -> bool:
|
||
"""检查任务是否已取消"""
|
||
return _task_cancelled.get(task_uuid, False)
|
||
|
||
def cancel_task(self, task_uuid: str):
|
||
"""取消任务"""
|
||
_task_cancelled[task_uuid] = True
|
||
logger.info(f"任务 {task_uuid} 已标记为取消")
|
||
|
||
def add_log(self, task_uuid: str, log_message: str):
|
||
"""添加日志并推送到 WebSocket(线程安全)"""
|
||
# 先广播到 WebSocket,确保实时推送
|
||
# 然后再添加到队列,这样 get_unsent_logs 不会获取到这条日志
|
||
if self._loop and self._loop.is_running():
|
||
try:
|
||
asyncio.run_coroutine_threadsafe(
|
||
self._broadcast_log(task_uuid, log_message),
|
||
self._loop
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"推送日志到 WebSocket 失败: {e}")
|
||
|
||
# 广播后再添加到队列
|
||
with _get_log_lock(task_uuid):
|
||
_log_queues[task_uuid].append(log_message)
|
||
|
||
async def _broadcast_log(self, task_uuid: str, log_message: str):
|
||
"""广播日志到所有 WebSocket 连接"""
|
||
with _ws_lock:
|
||
connections = _ws_connections.get(task_uuid, []).copy()
|
||
# 注意:不在这里更新 sent_index,因为日志已经通过 add_log 添加到队列
|
||
# sent_index 应该只在 get_unsent_logs 或发送历史日志时更新
|
||
# 这样可以避免竞态条件
|
||
|
||
for ws in connections:
|
||
try:
|
||
await ws.send_json({
|
||
"type": "log",
|
||
"task_uuid": task_uuid,
|
||
"message": log_message,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
})
|
||
# 发送成功后更新 sent_index
|
||
with _ws_lock:
|
||
ws_id = id(ws)
|
||
if task_uuid in _ws_sent_index and ws_id in _ws_sent_index[task_uuid]:
|
||
_ws_sent_index[task_uuid][ws_id] += 1
|
||
except Exception as e:
|
||
logger.warning(f"WebSocket 发送失败: {e}")
|
||
|
||
async def broadcast_status(self, task_uuid: str, status: str, **kwargs):
|
||
"""广播任务状态更新"""
|
||
with _ws_lock:
|
||
connections = _ws_connections.get(task_uuid, []).copy()
|
||
|
||
message = {
|
||
"type": "status",
|
||
"task_uuid": task_uuid,
|
||
"status": status,
|
||
"timestamp": datetime.utcnow().isoformat(),
|
||
**kwargs
|
||
}
|
||
|
||
for ws in connections:
|
||
try:
|
||
await ws.send_json(message)
|
||
except Exception as e:
|
||
logger.warning(f"WebSocket 发送状态失败: {e}")
|
||
|
||
def register_websocket(self, task_uuid: str, websocket):
|
||
"""注册 WebSocket 连接"""
|
||
with _ws_lock:
|
||
if task_uuid not in _ws_connections:
|
||
_ws_connections[task_uuid] = []
|
||
# 避免重复注册同一个连接
|
||
if websocket not in _ws_connections[task_uuid]:
|
||
_ws_connections[task_uuid].append(websocket)
|
||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||
with _get_log_lock(task_uuid):
|
||
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
|
||
logger.info(f"WebSocket 连接已注册: {task_uuid}")
|
||
else:
|
||
logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}")
|
||
|
||
def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]:
|
||
"""获取未发送给该 WebSocket 的日志"""
|
||
with _ws_lock:
|
||
ws_id = id(websocket)
|
||
sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0)
|
||
|
||
with _get_log_lock(task_uuid):
|
||
all_logs = _log_queues.get(task_uuid, [])
|
||
unsent_logs = all_logs[sent_count:]
|
||
# 更新已发送索引
|
||
_ws_sent_index[task_uuid][ws_id] = len(all_logs)
|
||
return unsent_logs
|
||
|
||
def unregister_websocket(self, task_uuid: str, websocket):
|
||
"""注销 WebSocket 连接"""
|
||
with _ws_lock:
|
||
if task_uuid in _ws_connections:
|
||
try:
|
||
_ws_connections[task_uuid].remove(websocket)
|
||
except ValueError:
|
||
pass
|
||
# 清理已发送索引
|
||
if task_uuid in _ws_sent_index:
|
||
_ws_sent_index[task_uuid].pop(id(websocket), None)
|
||
logger.info(f"WebSocket 连接已注销: {task_uuid}")
|
||
|
||
def get_logs(self, task_uuid: str) -> List[str]:
|
||
"""获取任务的所有日志"""
|
||
with _get_log_lock(task_uuid):
|
||
return _log_queues.get(task_uuid, []).copy()
|
||
|
||
def update_status(self, task_uuid: str, status: str, **kwargs):
|
||
"""更新任务状态"""
|
||
if task_uuid not in _task_status:
|
||
_task_status[task_uuid] = {}
|
||
|
||
_task_status[task_uuid]["status"] = status
|
||
_task_status[task_uuid].update(kwargs)
|
||
|
||
def get_status(self, task_uuid: str) -> Optional[dict]:
|
||
"""获取任务状态"""
|
||
return _task_status.get(task_uuid)
|
||
|
||
def cleanup_task(self, task_uuid: str):
|
||
"""清理任务数据"""
|
||
# 保留日志队列一段时间,以便后续查询
|
||
# 只清理取消标志
|
||
if task_uuid in _task_cancelled:
|
||
del _task_cancelled[task_uuid]
|
||
|
||
# ============== 批量任务管理 ==============
|
||
|
||
def init_batch(self, batch_id: str, total: int):
|
||
"""初始化批量任务"""
|
||
_batch_status[batch_id] = {
|
||
"status": "running",
|
||
"total": total,
|
||
"completed": 0,
|
||
"success": 0,
|
||
"failed": 0,
|
||
"skipped": 0,
|
||
"current_index": 0,
|
||
"finished": False
|
||
}
|
||
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
|
||
|
||
def add_batch_log(self, batch_id: str, log_message: str):
|
||
"""添加批量任务日志并推送"""
|
||
# 先广播到 WebSocket,确保实时推送
|
||
if self._loop and self._loop.is_running():
|
||
try:
|
||
asyncio.run_coroutine_threadsafe(
|
||
self._broadcast_batch_log(batch_id, log_message),
|
||
self._loop
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"推送批量日志到 WebSocket 失败: {e}")
|
||
|
||
# 广播后再添加到队列
|
||
with _get_batch_lock(batch_id):
|
||
_batch_logs[batch_id].append(log_message)
|
||
|
||
async def _broadcast_batch_log(self, batch_id: str, log_message: str):
|
||
"""广播批量任务日志"""
|
||
key = f"batch_{batch_id}"
|
||
with _ws_lock:
|
||
connections = _ws_connections.get(key, []).copy()
|
||
# 注意:不在这里更新 sent_index,避免竞态条件
|
||
|
||
for ws in connections:
|
||
try:
|
||
await ws.send_json({
|
||
"type": "log",
|
||
"batch_id": batch_id,
|
||
"message": log_message,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
})
|
||
# 发送成功后更新 sent_index
|
||
with _ws_lock:
|
||
ws_id = id(ws)
|
||
if key in _ws_sent_index and ws_id in _ws_sent_index[key]:
|
||
_ws_sent_index[key][ws_id] += 1
|
||
except Exception as e:
|
||
logger.warning(f"WebSocket 发送批量日志失败: {e}")
|
||
|
||
def update_batch_status(self, batch_id: str, **kwargs):
|
||
"""更新批量任务状态"""
|
||
if batch_id not in _batch_status:
|
||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||
return
|
||
|
||
_batch_status[batch_id].update(kwargs)
|
||
|
||
# 异步广播状态更新
|
||
if self._loop and self._loop.is_running():
|
||
try:
|
||
asyncio.run_coroutine_threadsafe(
|
||
self._broadcast_batch_status(batch_id),
|
||
self._loop
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"广播批量状态失败: {e}")
|
||
|
||
async def _broadcast_batch_status(self, batch_id: str):
|
||
"""广播批量任务状态"""
|
||
with _ws_lock:
|
||
connections = _ws_connections.get(f"batch_{batch_id}", []).copy()
|
||
|
||
status = _batch_status.get(batch_id, {})
|
||
|
||
for ws in connections:
|
||
try:
|
||
await ws.send_json({
|
||
"type": "status",
|
||
"batch_id": batch_id,
|
||
"timestamp": datetime.utcnow().isoformat(),
|
||
**status
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"WebSocket 发送批量状态失败: {e}")
|
||
|
||
def get_batch_status(self, batch_id: str) -> Optional[dict]:
|
||
"""获取批量任务状态"""
|
||
return _batch_status.get(batch_id)
|
||
|
||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||
"""获取批量任务日志"""
|
||
with _get_batch_lock(batch_id):
|
||
return _batch_logs.get(batch_id, []).copy()
|
||
|
||
def is_batch_cancelled(self, batch_id: str) -> bool:
|
||
"""检查批量任务是否已取消"""
|
||
status = _batch_status.get(batch_id, {})
|
||
return status.get("cancelled", False)
|
||
|
||
def cancel_batch(self, batch_id: str):
|
||
"""取消批量任务"""
|
||
if batch_id in _batch_status:
|
||
_batch_status[batch_id]["cancelled"] = True
|
||
_batch_status[batch_id]["status"] = "cancelling"
|
||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||
|
||
def register_batch_websocket(self, batch_id: str, websocket):
|
||
"""注册批量任务 WebSocket 连接"""
|
||
key = f"batch_{batch_id}"
|
||
with _ws_lock:
|
||
if key not in _ws_connections:
|
||
_ws_connections[key] = []
|
||
# 避免重复注册同一个连接
|
||
if websocket not in _ws_connections[key]:
|
||
_ws_connections[key].append(websocket)
|
||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||
with _get_batch_lock(batch_id):
|
||
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
|
||
logger.info(f"批量任务 WebSocket 连接已注册: {batch_id}")
|
||
else:
|
||
logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}")
|
||
|
||
def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]:
|
||
"""获取未发送给该 WebSocket 的批量任务日志"""
|
||
key = f"batch_{batch_id}"
|
||
with _ws_lock:
|
||
ws_id = id(websocket)
|
||
sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0)
|
||
|
||
with _get_batch_lock(batch_id):
|
||
all_logs = _batch_logs.get(batch_id, [])
|
||
unsent_logs = all_logs[sent_count:]
|
||
# 更新已发送索引
|
||
_ws_sent_index[key][ws_id] = len(all_logs)
|
||
return unsent_logs
|
||
|
||
def unregister_batch_websocket(self, batch_id: str, websocket):
|
||
"""注销批量任务 WebSocket 连接"""
|
||
key = f"batch_{batch_id}"
|
||
with _ws_lock:
|
||
if key in _ws_connections:
|
||
try:
|
||
_ws_connections[key].remove(websocket)
|
||
except ValueError:
|
||
pass
|
||
# 清理已发送索引
|
||
if key in _ws_sent_index:
|
||
_ws_sent_index[key].pop(id(websocket), None)
|
||
logger.info(f"批量任务 WebSocket 连接已注销: {batch_id}")
|
||
|
||
def create_log_callback(self, task_uuid: str, prefix: str = "", batch_id: str = "") -> Callable[[str], None]:
|
||
"""创建日志回调函数,可附加任务编号前缀,并同时推送到批量任务频道"""
|
||
def callback(msg: str):
|
||
full_msg = f"{prefix} {msg}" if prefix else msg
|
||
self.add_log(task_uuid, full_msg)
|
||
# 如果属于批量任务,同步推送到 batch 频道,前端可在混合日志中看到详细步骤
|
||
if batch_id:
|
||
self.add_batch_log(batch_id, full_msg)
|
||
return callback
|
||
|
||
def create_check_cancelled_callback(self, task_uuid: str) -> Callable[[], bool]:
|
||
"""创建检查取消的回调函数"""
|
||
def callback() -> bool:
|
||
return self.is_cancelled(task_uuid)
|
||
return callback
|
||
|
||
|
||
# 全局实例
|
||
task_manager = TaskManager()
|