refactor: optimize structure, stability and runtime performance

This commit is contained in:
2026-02-07 00:35:11 +08:00
parent fae21329d7
commit bf29ac1924
44 changed files with 6894 additions and 4792 deletions

View File

@@ -243,6 +243,35 @@ class KDocsUploader:
except queue.Empty:
return {"success": False, "error": "操作超时"}
def _put_task_response(self, task: Dict[str, Any], result: Dict[str, Any]) -> None:
response_queue = task.get("response")
if not response_queue:
return
try:
response_queue.put(result)
except Exception:
return
def _process_task(self, task: Dict[str, Any]) -> bool:
action = task.get("action")
payload = task.get("payload") or {}
if action == "shutdown":
return False
if action == "upload":
self._handle_upload(payload)
return True
if action == "qr":
self._put_task_response(task, self._handle_qr(payload))
return True
if action == "clear_login":
self._put_task_response(task, self._handle_clear_login())
return True
if action == "status":
self._put_task_response(task, self._handle_status_check())
return True
return True
def _run(self) -> None:
thread_id = self._thread_id
logger.info(f"[KDocs] 上传线程启动 (ID={thread_id})")
@@ -261,34 +290,17 @@ class KDocsUploader:
# 更新最后活动时间
self._last_activity = time.time()
action = task.get("action")
if action == "shutdown":
break
try:
if action == "upload":
self._handle_upload(task.get("payload") or {})
elif action == "qr":
result = self._handle_qr(task.get("payload") or {})
task.get("response").put(result)
elif action == "clear_login":
result = self._handle_clear_login()
task.get("response").put(result)
elif action == "status":
result = self._handle_status_check()
task.get("response").put(result)
should_continue = self._process_task(task)
if not should_continue:
break
# 任务处理完成后更新活动时间
self._last_activity = time.time()
except Exception as e:
logger.warning(f"[KDocs] 处理任务失败: {e}")
# 如果有响应队列,返回错误
if "response" in task and task.get("response"):
try:
task["response"].put({"success": False, "error": str(e)})
except Exception:
pass
self._put_task_response(task, {"success": False, "error": str(e)})
except Exception as e:
logger.warning(f"[KDocs] 线程主循环异常: {e}")
@@ -830,18 +842,180 @@ class KDocsUploader:
except Exception as e:
logger.warning(f"[KDocs] 保存登录态失败: {e}")
def _resolve_doc_url(self, cfg: Dict[str, Any]) -> str:
return (cfg.get("kdocs_doc_url") or "").strip()
def _ensure_doc_access(
self,
doc_url: str,
*,
fast: bool = False,
use_storage_state: bool = True,
) -> Optional[str]:
if not self._ensure_playwright(use_storage_state=use_storage_state):
return self._last_error or "浏览器不可用"
if not self._open_document(doc_url, fast=fast):
return self._last_error or "打开文档失败"
return None
def _trigger_fast_login_dialog(self, timeout_ms: int) -> None:
self._ensure_login_dialog(
timeout_ms=timeout_ms,
frame_timeout_ms=timeout_ms,
quick=True,
)
def _capture_qr_with_retry(self, fast_login_timeout: int) -> Tuple[Optional[bytes], Optional[bytes]]:
qr_image = None
invalid_qr = None
for attempt in range(10):
if attempt in (3, 7):
self._trigger_fast_login_dialog(fast_login_timeout)
candidate = self._capture_qr_image()
if candidate and self._is_valid_qr_image(candidate):
qr_image = candidate
break
if candidate:
invalid_qr = candidate
time.sleep(0.8) # 优化: 1 -> 0.8
return qr_image, invalid_qr
def _save_qr_debug_artifacts(self, invalid_qr: Optional[bytes]) -> None:
try:
pages = self._iter_pages()
page_urls = [getattr(p, "url", "") for p in pages]
logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
ts = int(time.time())
saved = []
for idx, page in enumerate(pages[:3]):
try:
path = f"data/kdocs_debug_{ts}_{idx}.png"
page.screenshot(path=path, full_page=True)
saved.append(path)
except Exception:
continue
if saved:
logger.warning(f"[KDocs] 已保存调试截图: {saved}")
if invalid_qr:
try:
path = f"data/kdocs_invalid_qr_{ts}.png"
with open(path, "wb") as handle:
handle.write(invalid_qr)
logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
except Exception:
pass
except Exception:
pass
def _log_upload_failure(self, message: str, user_id: Any, account_id: Any) -> None:
try:
log_to_client(f"表格上传失败: {message}", user_id, account_id)
except Exception:
pass
def _mark_upload_tracking(self, user_id: Any, account_id: Any) -> Tuple[Any, Optional[str], bool]:
account = None
prev_status = None
status_tracked = False
try:
account = safe_get_account(user_id, account_id)
if account and self._should_mark_upload(account):
prev_status = getattr(account, "status", None)
account.status = "上传截图"
self._emit_account_update(user_id, account)
status_tracked = True
except Exception:
prev_status = None
return account, prev_status, status_tracked
def _parse_upload_payload(self, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unit = (payload.get("unit") or "").strip()
name = (payload.get("name") or "").strip()
image_path = payload.get("image_path")
if not unit or not name:
return None
if not image_path or not os.path.exists(image_path):
return None
return {
"unit": unit,
"name": name,
"image_path": image_path,
"user_id": payload.get("user_id"),
"account_id": payload.get("account_id"),
}
def _resolve_upload_sheet_config(self, cfg: Dict[str, Any]) -> Dict[str, Any]:
return {
"sheet_name": (cfg.get("kdocs_sheet_name") or "").strip(),
"sheet_index": int(cfg.get("kdocs_sheet_index") or 0),
"unit_col": (cfg.get("kdocs_unit_column") or "A").strip().upper(),
"image_col": (cfg.get("kdocs_image_column") or "D").strip().upper(),
"row_start": int(cfg.get("kdocs_row_start") or 0),
"row_end": int(cfg.get("kdocs_row_end") or 0),
}
def _try_upload_to_sheet(self, cfg: Dict[str, Any], unit: str, name: str, image_path: str) -> Tuple[bool, str]:
sheet_cfg = self._resolve_upload_sheet_config(cfg)
success = False
error_msg = ""
for _ in range(2):
try:
if sheet_cfg["sheet_name"] or sheet_cfg["sheet_index"]:
self._select_sheet(sheet_cfg["sheet_name"], sheet_cfg["sheet_index"])
row_num = self._find_person_with_unit(
unit,
name,
sheet_cfg["unit_col"],
row_start=sheet_cfg["row_start"],
row_end=sheet_cfg["row_end"],
)
if row_num < 0:
error_msg = f"未找到人员: {unit}-{name}"
break
success = self._upload_image_to_cell(row_num, image_path, sheet_cfg["image_col"])
if success:
break
except Exception as e:
error_msg = str(e)
return success, error_msg
def _handle_upload_login_invalid(
self,
*,
unit: str,
name: str,
image_path: str,
user_id: Any,
account_id: Any,
) -> None:
error_msg = "登录已失效,请管理员重新扫码登录"
self._login_required = True
self._last_login_ok = False
self._notify_admin(unit, name, image_path, error_msg)
self._log_upload_failure(error_msg, user_id, account_id)
def _handle_qr(self, payload: Dict[str, Any]) -> Dict[str, Any]:
cfg = self._load_system_config()
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return {"success": False, "error": "未配置金山文档链接"}
force = bool(payload.get("force"))
if force:
self._handle_clear_login()
if not self._ensure_playwright(use_storage_state=not force):
return {"success": False, "error": self._last_error or "浏览器不可用"}
if not self._open_document(doc_url, fast=True):
return {"success": False, "error": self._last_error or "打开文档失败"}
doc_error = self._ensure_doc_access(doc_url, fast=True, use_storage_state=not force)
if doc_error:
return {"success": False, "error": doc_error}
if not force and self._has_saved_login_state() and self._is_logged_in():
self._login_required = False
@@ -850,54 +1024,12 @@ class KDocsUploader:
return {"success": True, "logged_in": True, "qr_image": ""}
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
self._ensure_login_dialog(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
qr_image = None
invalid_qr = None
for attempt in range(10):
if attempt in (3, 7):
self._ensure_login_dialog(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
candidate = self._capture_qr_image()
if candidate and self._is_valid_qr_image(candidate):
qr_image = candidate
break
if candidate:
invalid_qr = candidate
time.sleep(0.8) # 优化: 1 -> 0.8
self._trigger_fast_login_dialog(fast_login_timeout)
qr_image, invalid_qr = self._capture_qr_with_retry(fast_login_timeout)
if not qr_image:
self._last_error = "二维码识别异常" if invalid_qr else "二维码获取失败"
try:
pages = self._iter_pages()
page_urls = [getattr(p, "url", "") for p in pages]
logger.warning(f"[KDocs] 二维码未捕获,页面: {page_urls}")
ts = int(time.time())
saved = []
for idx, page in enumerate(pages[:3]):
try:
path = f"data/kdocs_debug_{ts}_{idx}.png"
page.screenshot(path=path, full_page=True)
saved.append(path)
except Exception:
continue
if saved:
logger.warning(f"[KDocs] 已保存调试截图: {saved}")
if invalid_qr:
try:
path = f"data/kdocs_invalid_qr_{ts}.png"
with open(path, "wb") as handle:
handle.write(invalid_qr)
logger.warning(f"[KDocs] 已保存无效二维码截图: {path}")
except Exception:
pass
except Exception:
pass
self._save_qr_debug_artifacts(invalid_qr)
return {"success": False, "error": self._last_error}
try:
@@ -933,24 +1065,22 @@ class KDocsUploader:
def _handle_status_check(self) -> Dict[str, Any]:
cfg = self._load_system_config()
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return {"success": True, "logged_in": False, "error": "未配置文档链接"}
if not self._ensure_playwright():
return {"success": False, "logged_in": False, "error": self._last_error or "浏览器不可用"}
if not self._open_document(doc_url, fast=True):
return {"success": False, "logged_in": False, "error": self._last_error or "打开文档失败"}
doc_error = self._ensure_doc_access(doc_url, fast=True)
if doc_error:
return {"success": False, "logged_in": False, "error": doc_error}
fast_login_timeout = int(os.environ.get("KDOCS_FAST_LOGIN_TIMEOUT_MS", "300"))
self._ensure_login_dialog(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
self._trigger_fast_login_dialog(fast_login_timeout)
self._try_confirm_login(
timeout_ms=fast_login_timeout,
frame_timeout_ms=fast_login_timeout,
quick=True,
)
logged_in = self._is_logged_in()
self._last_login_ok = logged_in
self._login_required = not logged_in
@@ -962,79 +1092,43 @@ class KDocsUploader:
cfg = self._load_system_config()
if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
return
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
doc_url = self._resolve_doc_url(cfg)
if not doc_url:
return
unit = (payload.get("unit") or "").strip()
name = (payload.get("name") or "").strip()
image_path = payload.get("image_path")
user_id = payload.get("user_id")
account_id = payload.get("account_id")
if not unit or not name:
return
if not image_path or not os.path.exists(image_path):
upload_data = self._parse_upload_payload(payload)
if not upload_data:
return
account = None
prev_status = None
status_tracked = False
unit = upload_data["unit"]
name = upload_data["name"]
image_path = upload_data["image_path"]
user_id = upload_data["user_id"]
account_id = upload_data["account_id"]
account, prev_status, status_tracked = self._mark_upload_tracking(user_id, account_id)
try:
try:
account = safe_get_account(user_id, account_id)
if account and self._should_mark_upload(account):
prev_status = getattr(account, "status", None)
account.status = "上传截图"
self._emit_account_update(user_id, account)
status_tracked = True
except Exception:
prev_status = None
if not self._ensure_playwright():
self._notify_admin(unit, name, image_path, self._last_error or "浏览器不可用")
return
if not self._open_document(doc_url):
self._notify_admin(unit, name, image_path, self._last_error or "打开文档失败")
doc_error = self._ensure_doc_access(doc_url)
if doc_error:
self._notify_admin(unit, name, image_path, doc_error)
return
if not self._is_logged_in():
self._login_required = True
self._last_login_ok = False
self._notify_admin(unit, name, image_path, "登录已失效,请管理员重新扫码登录")
try:
log_to_client("表格上传失败: 登录已失效,请管理员重新扫码登录", user_id, account_id)
except Exception:
pass
self._handle_upload_login_invalid(
unit=unit,
name=name,
image_path=image_path,
user_id=user_id,
account_id=account_id,
)
return
self._login_required = False
self._last_login_ok = True
sheet_name = (cfg.get("kdocs_sheet_name") or "").strip()
sheet_index = int(cfg.get("kdocs_sheet_index") or 0)
unit_col = (cfg.get("kdocs_unit_column") or "A").strip().upper()
image_col = (cfg.get("kdocs_image_column") or "D").strip().upper()
row_start = int(cfg.get("kdocs_row_start") or 0)
row_end = int(cfg.get("kdocs_row_end") or 0)
success = False
error_msg = ""
for attempt in range(2):
try:
if sheet_name or sheet_index:
self._select_sheet(sheet_name, sheet_index)
row_num = self._find_person_with_unit(unit, name, unit_col, row_start=row_start, row_end=row_end)
if row_num < 0:
error_msg = f"未找到人员: {unit}-{name}"
break
success = self._upload_image_to_cell(row_num, image_path, image_col)
if success:
break
except Exception as e:
error_msg = str(e)
success, error_msg = self._try_upload_to_sheet(cfg, unit, name, image_path)
if success:
self._last_success_at = time.time()
self._last_error = None
@@ -1048,10 +1142,7 @@ class KDocsUploader:
error_msg = "上传失败"
self._last_error = error_msg
self._notify_admin(unit, name, image_path, error_msg)
try:
log_to_client(f"表格上传失败: {error_msg}", user_id, account_id)
except Exception:
pass
self._log_upload_failure(error_msg, user_id, account_id)
finally:
if status_tracked:
self._restore_account_status(user_id, account, prev_status)

View File

@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import threading
import time
from datetime import datetime
@@ -10,6 +11,8 @@ from app_config import get_config
from app_logger import get_logger
from services.state import (
cleanup_expired_ip_rate_limits,
cleanup_expired_ip_request_rates,
cleanup_expired_login_security_state,
safe_cleanup_expired_batches,
safe_cleanup_expired_captcha,
safe_cleanup_expired_pending_random,
@@ -31,6 +34,69 @@ PENDING_RANDOM_EXPIRE_SECONDS = int(getattr(config, "PENDING_RANDOM_EXPIRE_SECON
_kdocs_offline_notified: bool = False
def _to_int(value, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _collect_active_user_ids() -> set[int]:
active_user_ids: set[int] = set()
for _, info in safe_iter_task_status_items():
user_id = info.get("user_id") if isinstance(info, dict) else None
if user_id is None:
continue
try:
active_user_ids.add(int(user_id))
except Exception:
continue
return active_user_ids
def _find_expired_user_cache_ids(current_time: float, active_user_ids: set[int]) -> list[int]:
expired_users = []
for user_id, last_access in (safe_get_user_accounts_last_access_items() or []):
try:
user_id_int = int(user_id)
last_access_ts = float(last_access)
except Exception:
continue
if (current_time - last_access_ts) <= USER_ACCOUNTS_EXPIRE_SECONDS:
continue
if user_id_int in active_user_ids:
continue
if safe_has_user(user_id_int):
expired_users.append(user_id_int)
return expired_users
def _find_completed_task_status_ids(current_time: float) -> list[str]:
completed_task_ids = []
for account_id, status_data in safe_iter_task_status_items():
status = status_data.get("status") if isinstance(status_data, dict) else None
if status not in ["已完成", "失败", "已停止"]:
continue
start_time = float(status_data.get("start_time", 0) or 0)
if (current_time - start_time) > 600: # 10分钟
completed_task_ids.append(account_id)
return completed_task_ids
def _reap_zombie_processes() -> None:
while True:
try:
pid, _ = os.waitpid(-1, os.WNOHANG)
if pid == 0:
break
logger.debug(f"已回收僵尸进程: PID={pid}")
except ChildProcessError:
break
except Exception:
break
def cleanup_expired_data() -> None:
"""定期清理过期数据,防止内存泄漏(逻辑保持不变)。"""
current_time = time.time()
@@ -43,48 +109,36 @@ def cleanup_expired_data() -> None:
if deleted_ips:
logger.debug(f"已清理 {deleted_ips} 个过期IP限流记录")
expired_users = []
last_access_items = safe_get_user_accounts_last_access_items()
if last_access_items:
task_items = safe_iter_task_status_items()
active_user_ids = {int(info.get("user_id")) for _, info in task_items if info.get("user_id")}
for user_id, last_access in last_access_items:
if (current_time - float(last_access)) <= USER_ACCOUNTS_EXPIRE_SECONDS:
continue
if int(user_id) in active_user_ids:
continue
if safe_has_user(user_id):
expired_users.append(int(user_id))
deleted_ip_requests = cleanup_expired_ip_request_rates(current_time)
if deleted_ip_requests:
logger.debug(f"已清理 {deleted_ip_requests} 个过期IP请求频率记录")
login_cleanup_stats = cleanup_expired_login_security_state(current_time)
login_cleanup_total = sum(int(v or 0) for v in login_cleanup_stats.values())
if login_cleanup_total:
logger.debug(
"已清理登录风控缓存: "
f"失败计数={login_cleanup_stats.get('failures', 0)}, "
f"限流桶={login_cleanup_stats.get('rate_limits', 0)}, "
f"扫描状态={login_cleanup_stats.get('scan_states', 0)}, "
f"短时锁={login_cleanup_stats.get('ip_user_locks', 0)}, "
f"告警状态={login_cleanup_stats.get('alerts', 0)}"
)
active_user_ids = _collect_active_user_ids()
expired_users = _find_expired_user_cache_ids(current_time, active_user_ids)
for user_id in expired_users:
safe_remove_user_accounts(user_id)
if expired_users:
logger.debug(f"已清理 {len(expired_users)} 个过期用户账号缓存")
completed_tasks = []
for account_id, status_data in safe_iter_task_status_items():
if status_data.get("status") in ["已完成", "失败", "已停止"]:
start_time = float(status_data.get("start_time", 0) or 0)
if (current_time - start_time) > 600: # 10分钟
completed_tasks.append(account_id)
for account_id in completed_tasks:
completed_task_ids = _find_completed_task_status_ids(current_time)
for account_id in completed_task_ids:
safe_remove_task_status(account_id)
if completed_tasks:
logger.debug(f"已清理 {len(completed_tasks)} 个已完成任务状态")
if completed_task_ids:
logger.debug(f"已清理 {len(completed_task_ids)} 个已完成任务状态")
try:
import os
while True:
try:
pid, status = os.waitpid(-1, os.WNOHANG)
if pid == 0:
break
logger.debug(f"已回收僵尸进程: PID={pid}")
except ChildProcessError:
break
except Exception:
pass
_reap_zombie_processes()
deleted_batches = safe_cleanup_expired_batches(BATCH_TASK_EXPIRE_SECONDS, current_time)
if deleted_batches:
@@ -95,52 +149,39 @@ def cleanup_expired_data() -> None:
logger.debug(f"已清理 {deleted_random} 个过期随机延迟任务")
def check_kdocs_online_status() -> None:
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
global _kdocs_offline_notified
def _load_kdocs_monitor_config():
import database
cfg = database.get_system_config()
if not cfg:
return None
kdocs_enabled = _to_int(cfg.get("kdocs_enabled"), 0)
if not kdocs_enabled:
return None
admin_notify_enabled = _to_int(cfg.get("kdocs_admin_notify_enabled"), 0)
admin_notify_email = str(cfg.get("kdocs_admin_notify_email") or "").strip()
if (not admin_notify_enabled) or (not admin_notify_email):
return None
return admin_notify_email
def _is_kdocs_offline(status: dict) -> tuple[bool, bool, bool | None]:
login_required = bool(status.get("login_required", False))
last_login_ok = status.get("last_login_ok")
is_offline = login_required or (last_login_ok is False)
return is_offline, login_required, last_login_ok
def _send_kdocs_offline_alert(admin_notify_email: str, *, login_required: bool, last_login_ok) -> bool:
try:
import database
from services.kdocs_uploader import get_kdocs_uploader
import email_service
# 获取系统配置
cfg = database.get_system_config()
if not cfg:
return
# 检查是否启用了金山文档功能
kdocs_enabled = int(cfg.get("kdocs_enabled") or 0)
if not kdocs_enabled:
return
# 检查是否启用了管理员通知
admin_notify_enabled = int(cfg.get("kdocs_admin_notify_enabled") or 0)
admin_notify_email = (cfg.get("kdocs_admin_notify_email") or "").strip()
if not admin_notify_enabled or not admin_notify_email:
return
# 获取金山文档状态
kdocs = get_kdocs_uploader()
status = kdocs.get_status()
login_required = status.get("login_required", False)
last_login_ok = status.get("last_login_ok")
# 如果需要登录或最后登录状态不是成功
is_offline = login_required or (last_login_ok is False)
if is_offline:
# 已经通知过了,不再重复通知
if _kdocs_offline_notified:
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
return
# 发送邮件通知
try:
import email_service
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subject = "【金山文档离线告警】需要重新登录"
body = f"""
now_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subject = "【金山文档离线告警】需要重新登录"
body = f"""
您好,
系统检测到金山文档上传功能已离线,需要重新扫码登录。
@@ -155,58 +196,92 @@ def check_kdocs_online_status() -> None:
---
此邮件由系统自动发送,请勿直接回复。
"""
email_service.send_email_async(
to_email=admin_notify_email,
subject=subject,
body=body,
email_type="kdocs_offline_alert",
)
_kdocs_offline_notified = True # 标记为已通知
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
except Exception as e:
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
else:
# 恢复在线,重置通知状态
email_service.send_email_async(
to_email=admin_notify_email,
subject=subject,
body=body,
email_type="kdocs_offline_alert",
)
logger.warning(f"[KDocs监控] 金山文档离线,已发送通知邮件到 {admin_notify_email}")
return True
except Exception as e:
logger.error(f"[KDocs监控] 发送离线通知邮件失败: {e}")
return False
def check_kdocs_online_status() -> None:
"""检测金山文档登录状态,如果离线则发送邮件通知管理员(每次掉线只通知一次)"""
global _kdocs_offline_notified
try:
admin_notify_email = _load_kdocs_monitor_config()
if not admin_notify_email:
return
from services.kdocs_uploader import get_kdocs_uploader
kdocs = get_kdocs_uploader()
status = kdocs.get_status() or {}
is_offline, login_required, last_login_ok = _is_kdocs_offline(status)
if is_offline:
if _kdocs_offline_notified:
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
_kdocs_offline_notified = False
logger.debug("[KDocs监控] 金山文档状态正常")
logger.debug("[KDocs监控] 金山文档离线,已通知过,跳过重复通知")
return
if _send_kdocs_offline_alert(
admin_notify_email,
login_required=login_required,
last_login_ok=last_login_ok,
):
_kdocs_offline_notified = True
return
if _kdocs_offline_notified:
logger.info("[KDocs监控] 金山文档已恢复在线,重置通知状态")
_kdocs_offline_notified = False
logger.debug("[KDocs监控] 金山文档状态正常")
except Exception as e:
logger.error(f"[KDocs监控] 检测失败: {e}")
def start_cleanup_scheduler() -> None:
"""启动定期清理调度器"""
def cleanup_loop():
def _start_daemon_loop(name: str, *, startup_delay: float, interval_seconds: float, job, error_tag: str):
def loop():
if startup_delay > 0:
time.sleep(startup_delay)
while True:
try:
time.sleep(300) # 每5分钟执行一次清理
cleanup_expired_data()
job()
time.sleep(interval_seconds)
except Exception as e:
logger.error(f"清理任务执行失败: {e}")
logger.error(f"{error_tag}: {e}")
time.sleep(min(60.0, max(1.0, interval_seconds / 5.0)))
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True, name="cleanup-scheduler")
cleanup_thread.start()
thread = threading.Thread(target=loop, daemon=True, name=name)
thread.start()
return thread
def start_cleanup_scheduler() -> None:
"""启动定期清理调度器"""
_start_daemon_loop(
"cleanup-scheduler",
startup_delay=300,
interval_seconds=300,
job=cleanup_expired_data,
error_tag="清理任务执行失败",
)
logger.info("内存清理调度器已启动")
def start_kdocs_monitor() -> None:
"""启动金山文档状态监控"""
def monitor_loop():
# 启动后等待 60 秒再开始检测(给系统初始化的时间)
time.sleep(60)
while True:
try:
check_kdocs_online_status()
time.sleep(300) # 每5分钟检测一次
except Exception as e:
logger.error(f"[KDocs监控] 监控任务执行失败: {e}")
time.sleep(60)
monitor_thread = threading.Thread(target=monitor_loop, daemon=True, name="kdocs-monitor")
monitor_thread.start()
_start_daemon_loop(
"kdocs-monitor",
startup_delay=60,
interval_seconds=300,
job=check_kdocs_online_status,
error_tag="[KDocs监控] 监控任务执行失败",
)
logger.info("[KDocs监控] 金山文档状态监控已启动每5分钟检测一次")

View File

@@ -27,6 +27,12 @@ from services.time_utils import get_beijing_now
logger = get_logger("app")
config = get_config()
try:
_SCHEDULE_SUBMIT_DELAY_SECONDS = float(os.environ.get("SCHEDULE_SUBMIT_DELAY_SECONDS", "0.2"))
except Exception:
_SCHEDULE_SUBMIT_DELAY_SECONDS = 0.2
_SCHEDULE_SUBMIT_DELAY_SECONDS = max(0.0, _SCHEDULE_SUBMIT_DELAY_SECONDS)
SCREENSHOTS_DIR = config.SCREENSHOTS_DIR
os.makedirs(SCREENSHOTS_DIR, exist_ok=True)
@@ -55,6 +61,150 @@ def _normalize_hhmm(value: object, *, default: str) -> str:
return f"{hour:02d}:{minute:02d}"
def _safe_recompute_schedule_next_run(schedule_id: int) -> None:
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
def _load_accounts_for_users(approved_users: list[dict]) -> tuple[dict[int, dict], list[str]]:
"""批量加载用户账号快照。"""
user_accounts: dict[int, dict] = {}
account_ids: list[str] = []
for user in approved_users:
user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
user_accounts[user_id] = accounts
account_ids.extend(list(accounts.keys()))
return user_accounts, account_ids
def _should_skip_suspended_account(account_status_info, account, username: str) -> bool:
"""判断是否应跳过暂停账号,并输出日志。"""
if not account_status_info:
return False
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
if status != "suspended":
return False
fail_count = account_status_info["login_fail_count"] if "login_fail_count" in account_status_info.keys() else 0
logger.info(
f"[定时任务] 跳过暂停账号: {account.username} (用户:{username}) - 连续{fail_count}次密码错误,需修改密码"
)
return True
def _parse_schedule_account_ids(schedule_config: dict, schedule_id: int):
import json
try:
account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
account_ids = json.loads(account_ids_raw)
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
return []
if isinstance(account_ids, list):
return account_ids
return []
def _create_user_schedule_batch(*, batch_id: str, user_id: int, browse_type: str, schedule_name: str, now_ts: float) -> None:
safe_create_batch(
batch_id,
{
"user_id": user_id,
"browse_type": browse_type,
"schedule_name": schedule_name,
"screenshots": [],
"total_accounts": 0,
"completed": 0,
"created_at": now_ts,
"updated_at": now_ts,
},
)
def _build_user_schedule_done_callback(
*,
completion_lock: threading.Lock,
remaining: dict,
counters: dict,
execution_start_time: float,
log_id: int,
schedule_id: int,
total_accounts: int,
):
def on_browse_done():
with completion_lock:
remaining["count"] -= 1
if remaining["done"] or remaining["count"] > 0:
return
remaining["done"] = True
execution_duration = int(time.time() - execution_start_time)
started_count = int(counters.get("started", 0) or 0)
database.update_schedule_execution_log(
log_id,
total_accounts=total_accounts,
success_accounts=started_count,
failed_accounts=total_accounts - started_count,
duration_seconds=execution_duration,
status="completed",
)
logger.info(f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件")
return on_browse_done
def _submit_user_schedule_accounts(
*,
user_id: int,
account_ids: list,
browse_type: str,
enable_screenshot,
task_source: str,
done_callback,
completion_lock: threading.Lock,
remaining: dict,
counters: dict,
) -> tuple[int, int]:
started_count = 0
skipped_count = 0
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
if (not account) or account.is_running:
skipped_count += 1
continue
with completion_lock:
remaining["count"] += 1
ok, msg = submit_account_task(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
source=task_source,
done_callback=done_callback,
)
if ok:
started_count += 1
counters["started"] = started_count
else:
with completion_lock:
remaining["count"] -= 1
skipped_count += 1
logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
return started_count, skipped_count
def run_scheduled_task(skip_weekday_check: bool = False) -> None:
"""执行所有账号的浏览任务(可被手动调用,过滤重复账号)"""
try:
@@ -87,17 +237,7 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
cfg = database.get_system_config()
enable_screenshot_scheduled = cfg.get("enable_screenshot", 0) == 1
user_accounts = {}
account_ids = []
for user in approved_users:
user_id = user["id"]
accounts = safe_get_user_accounts_snapshot(user_id)
if not accounts:
load_user_accounts(user_id)
accounts = safe_get_user_accounts_snapshot(user_id)
if accounts:
user_accounts[user_id] = accounts
account_ids.extend(list(accounts.keys()))
user_accounts, account_ids = _load_accounts_for_users(approved_users)
account_statuses = database.get_account_status_batch(account_ids)
@@ -113,18 +253,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
continue
account_status_info = account_statuses.get(str(account_id))
if account_status_info:
status = account_status_info["status"] if "status" in account_status_info.keys() else "active"
if status == "suspended":
fail_count = (
account_status_info["login_fail_count"]
if "login_fail_count" in account_status_info.keys()
else 0
)
logger.info(
f"[定时任务] 跳过暂停账号: {account.username} (用户:{user['username']}) - 连续{fail_count}次密码错误,需修改密码"
)
continue
if _should_skip_suspended_account(account_status_info, account, user["username"]):
continue
if account.username in executed_usernames:
skipped_duplicates += 1
@@ -149,7 +279,8 @@ def run_scheduled_task(skip_weekday_check: bool = False) -> None:
else:
logger.warning(f"[定时任务] 启动失败({account.username}): {msg}")
time.sleep(2)
if _SCHEDULE_SUBMIT_DELAY_SECONDS > 0:
time.sleep(_SCHEDULE_SUBMIT_DELAY_SECONDS)
logger.info(
f"[定时任务] 执行完成 - 总账号数:{total_accounts}, 已执行:{executed_accounts}, 跳过重复:{skipped_duplicates}"
@@ -198,15 +329,16 @@ def scheduled_task_worker() -> None:
deleted_screenshots = 0
if os.path.exists(SCREENSHOTS_DIR):
cutoff_time = time.time() - (7 * 24 * 60 * 60)
for filename in os.listdir(SCREENSHOTS_DIR):
if filename.lower().endswith((".png", ".jpg", ".jpeg")):
filepath = os.path.join(SCREENSHOTS_DIR, filename)
with os.scandir(SCREENSHOTS_DIR) as entries:
for entry in entries:
if (not entry.is_file()) or (not entry.name.lower().endswith((".png", ".jpg", ".jpeg"))):
continue
try:
if os.path.getmtime(filepath) < cutoff_time:
os.remove(filepath)
if entry.stat().st_mtime < cutoff_time:
os.remove(entry.path)
deleted_screenshots += 1
except Exception as e:
logger.warning(f"[定时清理] 删除截图失败 {filename}: {str(e)}")
logger.warning(f"[定时清理] 删除截图失败 {entry.name}: {str(e)}")
logger.info(f"[定时清理] 已删除 {deleted_screenshots} 个截图文件")
logger.info("[定时清理] 清理完成!")
@@ -214,10 +346,97 @@ def scheduled_task_worker() -> None:
except Exception as e:
logger.exception(f"[定时清理] 清理任务出错: {str(e)}")
def _parse_due_schedule_weekdays(schedule_config: dict, schedule_id: int):
weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
try:
return [int(d) for d in weekdays_str.split(",") if d.strip()]
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
_safe_recompute_schedule_next_run(schedule_id)
return None
def _execute_due_user_schedule(schedule_config: dict) -> None:
schedule_name = schedule_config.get("name", "未命名任务")
schedule_id = schedule_config["id"]
user_id = schedule_config["user_id"]
browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule_config.get("enable_screenshot", 1)
account_ids = _parse_schedule_account_ids(schedule_config, schedule_id)
if not account_ids:
_safe_recompute_schedule_next_run(schedule_id)
return
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
import uuid
execution_start_time = time.time()
log_id = database.create_schedule_execution_log(
schedule_id=schedule_id,
user_id=user_id,
schedule_name=schedule_name,
)
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
now_ts = time.time()
_create_user_schedule_batch(
batch_id=batch_id,
user_id=user_id,
browse_type=browse_type,
schedule_name=schedule_name,
now_ts=now_ts,
)
completion_lock = threading.Lock()
remaining = {"count": 0, "done": False}
counters = {"started": 0}
on_browse_done = _build_user_schedule_done_callback(
completion_lock=completion_lock,
remaining=remaining,
counters=counters,
execution_start_time=execution_start_time,
log_id=log_id,
schedule_id=schedule_id,
total_accounts=len(account_ids),
)
task_source = f"user_scheduled:{batch_id}"
started_count, skipped_count = _submit_user_schedule_accounts(
user_id=user_id,
account_ids=account_ids,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
task_source=task_source,
done_callback=on_browse_done,
completion_lock=completion_lock,
remaining=remaining,
counters=counters,
)
batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time.time())
if batch_info:
_send_batch_task_email_if_configured(batch_info)
database.update_schedule_last_run(schedule_id)
logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号批次ID: {batch_id}")
if started_count <= 0:
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=0,
failed_accounts=len(account_ids),
duration_seconds=0,
status="completed",
)
if started_count == 0 and len(account_ids) > 0:
logger.warning("[用户定时任务] ⚠️ 警告所有账号都被跳过了请检查user_accounts状态")
def check_user_schedules():
"""检查并执行用户定时任务O-08next_run_at 索引驱动)。"""
import json
try:
now = get_beijing_now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
@@ -226,145 +445,22 @@ def scheduled_task_worker() -> None:
due_schedules = database.get_due_user_schedules(now_str, limit=50) or []
for schedule_config in due_schedules:
schedule_name = schedule_config.get("name", "未命名任务")
schedule_id = schedule_config["id"]
schedule_name = schedule_config.get("name", "未命名任务")
weekdays_str = schedule_config.get("weekdays", "1,2,3,4,5")
try:
allowed_weekdays = [int(d) for d in weekdays_str.split(",") if d.strip()]
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析weekdays失败: {e}")
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
allowed_weekdays = _parse_due_schedule_weekdays(schedule_config, schedule_id)
if allowed_weekdays is None:
continue
if current_weekday not in allowed_weekdays:
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
_safe_recompute_schedule_next_run(schedule_id)
continue
logger.info(f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 (next_run_at={schedule_config.get('next_run_at')})")
user_id = schedule_config["user_id"]
schedule_id = schedule_config["id"]
browse_type = normalize_browse_type(schedule_config.get("browse_type", BROWSE_TYPE_SHOULD_READ))
enable_screenshot = schedule_config.get("enable_screenshot", 1)
try:
account_ids_raw = schedule_config.get("account_ids", "[]") or "[]"
account_ids = json.loads(account_ids_raw)
except Exception as e:
logger.warning(f"[定时任务] 任务#{schedule_id} 解析account_ids失败: {e}")
account_ids = []
if not account_ids:
try:
database.recompute_schedule_next_run(schedule_id)
except Exception:
pass
continue
if not safe_get_user_accounts_snapshot(user_id):
load_user_accounts(user_id)
import time as time_mod
import uuid
execution_start_time = time_mod.time()
log_id = database.create_schedule_execution_log(
schedule_id=schedule_id, user_id=user_id, schedule_name=schedule_config.get("name", "未命名任务")
logger.info(
f"[用户定时任务] 任务#{schedule_id} '{schedule_name}' 到期,开始执行 "
f"(next_run_at={schedule_config.get('next_run_at')})"
)
batch_id = f"batch_{uuid.uuid4().hex[:12]}"
now_ts = time_mod.time()
safe_create_batch(
batch_id,
{
"user_id": user_id,
"browse_type": browse_type,
"schedule_name": schedule_config.get("name", "未命名任务"),
"screenshots": [],
"total_accounts": 0,
"completed": 0,
"created_at": now_ts,
"updated_at": now_ts,
},
)
started_count = 0
skipped_count = 0
completion_lock = threading.Lock()
remaining = {"count": 0, "done": False}
def on_browse_done():
with completion_lock:
remaining["count"] -= 1
if remaining["done"] or remaining["count"] > 0:
return
remaining["done"] = True
execution_duration = int(time_mod.time() - execution_start_time)
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=started_count,
failed_accounts=len(account_ids) - started_count,
duration_seconds=execution_duration,
status="completed",
)
logger.info(
f"[用户定时任务] 任务#{schedule_id}浏览阶段完成,耗时{execution_duration}秒,等待截图完成后发送邮件"
)
for account_id in account_ids:
account = safe_get_account(user_id, account_id)
if not account:
skipped_count += 1
continue
if account.is_running:
skipped_count += 1
continue
task_source = f"user_scheduled:{batch_id}"
with completion_lock:
remaining["count"] += 1
ok, msg = submit_account_task(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=enable_screenshot,
source=task_source,
done_callback=on_browse_done,
)
if ok:
started_count += 1
else:
with completion_lock:
remaining["count"] -= 1
skipped_count += 1
logger.warning(f"[用户定时任务] 账号 {account.username} 启动失败: {msg}")
batch_info = safe_finalize_batch_after_dispatch(batch_id, started_count, now_ts=time_mod.time())
if batch_info:
_send_batch_task_email_if_configured(batch_info)
database.update_schedule_last_run(schedule_id)
logger.info(f"[用户定时任务] 已启动 {started_count} 个账号,跳过 {skipped_count} 个账号批次ID: {batch_id}")
if started_count <= 0:
database.update_schedule_execution_log(
log_id,
total_accounts=len(account_ids),
success_accounts=0,
failed_accounts=len(account_ids),
duration_seconds=0,
status="completed",
)
if started_count == 0 and len(account_ids) > 0:
logger.warning("[用户定时任务] ⚠️ 警告所有账号都被跳过了请检查user_accounts状态")
_execute_due_user_schedule(schedule_config)
except Exception as e:
logger.exception(f"[用户定时任务] 检查出错: {str(e)}")

View File

@@ -6,12 +6,14 @@ import os
import shutil
import subprocess
import time
from urllib.parse import urlsplit
import database
import email_service
from api_browser import APIBrowser, get_cookie_jar_path, is_cookie_jar_fresh
from app_config import get_config
from app_logger import get_logger
from app_security import sanitize_filename
from browser_pool_worker import get_browser_worker_pool
from services.client_log import log_to_client
from services.runtime import get_socketio
@@ -194,6 +196,293 @@ def _emit(event: str, data: object, *, room: str | None = None) -> None:
pass
def _set_screenshot_running_status(user_id: int, account_id: str) -> None:
"""更新账号状态为截图中。"""
acc = safe_get_account(user_id, account_id)
if not acc:
return
acc.status = "截图中"
safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
def _get_worker_display_info(browser_instance) -> tuple[str, int]:
"""获取截图 worker 的展示信息。"""
if isinstance(browser_instance, dict):
return str(browser_instance.get("worker_id", "?")), int(browser_instance.get("use_count", 0) or 0)
return "?", 0
def _get_proxy_context(account) -> tuple[dict | None, str | None]:
"""提取截图阶段代理配置。"""
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
proxy_server = proxy_config.get("server") if proxy_config else None
return proxy_config, proxy_server
def _build_screenshot_targets(browse_type: str) -> tuple[str, str, str]:
"""构建截图目标 URL 与页面脚本。"""
parsed = urlsplit(config.ZSGL_LOGIN_URL)
base = f"{parsed.scheme}://{parsed.netloc}"
if "注册前" in str(browse_type):
bz = 0
else:
bz = 0
target_url = f"{base}/admin/center.aspx?bz={bz}"
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
run_script = (
"(function(){"
"function done(){window.status='ready';}"
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
"function expandMenu(){"
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
"}"
"function navReady(){"
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
"}"
"function frameReady(){"
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
"}"
"function check(){"
"if(navReady() && frameReady()){done();return;}"
"setTimeout(check,300);"
"}"
"var f=document.getElementById('mainframe');"
"ensureNav();"
"expandMenu();"
"if(!f){done();return;}"
f"f.src='{target_url}';"
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
"setTimeout(check,5000);"
"})();"
)
return index_url, target_url, run_script
def _build_screenshot_output_path(username_prefix: str, account, browse_type: str) -> tuple[str, str]:
"""构建截图输出文件名与路径。"""
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
login_account = account.remark if account.remark else account.username
raw_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_filename = sanitize_filename(raw_filename)
return screenshot_filename, os.path.join(SCREENSHOTS_DIR, screenshot_filename)
def _ensure_screenshot_login_state(
*,
account,
proxy_config,
cookie_path: str,
attempt: int,
max_retries: int,
user_id: int,
account_id: str,
custom_log,
) -> str:
"""确保截图前登录态有效。返回: ok/retry/fail。"""
should_refresh_login = not is_cookie_jar_fresh(cookie_path)
if not should_refresh_login:
return "ok"
log_to_client("正在刷新登录态...", user_id, account_id)
if _ensure_login_cookies(account, proxy_config, custom_log):
return "ok"
if attempt > 1:
log_to_client("截图登录失败", user_id, account_id)
if attempt < max_retries:
log_to_client("将重试...", user_id, account_id)
time.sleep(2)
return "retry"
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
return "fail"
def _take_screenshot_once(
*,
index_url: str,
target_url: str,
screenshot_path: str,
cookie_path: str,
proxy_server: str | None,
run_script: str,
log_callback,
) -> str:
"""执行一次截图尝试并验证输出文件。返回: success/invalid/failed。"""
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
attempts = [
{
"url": index_url,
"run_script": run_script,
"window_status": "ready",
},
{
"url": target_url,
"run_script": None,
"window_status": None,
},
]
ok = False
for shot in attempts:
ok = take_screenshot_wkhtmltoimage(
shot["url"],
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
run_script=shot["run_script"],
window_status=shot["window_status"],
log_callback=log_callback,
)
if ok:
break
if not ok:
return "failed"
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
return "success"
if os.path.exists(screenshot_path):
os.remove(screenshot_path)
return "invalid"
def _get_result_screenshot_path(result) -> str | None:
"""从截图结果中提取截图文件绝对路径。"""
if result and result.get("success") and result.get("filename"):
return os.path.join(SCREENSHOTS_DIR, result["filename"])
return None
def _enqueue_kdocs_upload_if_needed(user_id: int, account_id: str, account, screenshot_path: str | None) -> None:
"""按配置提交金山文档上传任务。"""
if not screenshot_path:
return
cfg = database.get_system_config() or {}
if int(cfg.get("kdocs_enabled", 0) or 0) != 1:
return
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
if not doc_url:
return
user_cfg = database.get_user_kdocs_settings(user_id) or {}
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) != 1:
return
unit = (user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or "").strip()
name = (account.remark or "").strip()
if not unit:
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
return
if not name:
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
return
from services.kdocs_uploader import get_kdocs_uploader
ok = get_kdocs_uploader().enqueue_upload(
user_id=user_id,
account_id=account_id,
unit=unit,
name=name,
image_path=screenshot_path,
)
if not ok:
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
def _dispatch_screenshot_result(
*,
user_id: int,
account_id: str,
source: str,
browse_type: str,
browse_result: dict,
result,
account,
user_info,
) -> None:
"""将截图结果发送到批次统计/邮件通知链路。"""
batch_id = _get_batch_id_from_source(source)
screenshot_path = _get_result_screenshot_path(result)
account_name = account.remark if account.remark else account.username
try:
if result and result.get("success") and screenshot_path:
_enqueue_kdocs_upload_if_needed(user_id, account_id, account, screenshot_path)
except Exception as kdocs_error:
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
if batch_id:
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=screenshot_path,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
)
return
if source and source.startswith("user_scheduled"):
if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
email_service.send_task_complete_email_async(
user_id=user_id,
email=user_info["email"],
username=user_info["username"],
account_name=account_name,
browse_type=browse_type,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
screenshot_path=screenshot_path,
log_callback=lambda msg: log_to_client(msg, user_id, account_id),
)
def _finalize_screenshot_callback_state(user_id: int, account_id: str, account) -> None:
"""截图回调的通用收尾状态变更。"""
account.is_running = False
account.status = "未开始"
safe_remove_task_status(account_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
def _persist_browse_log_after_screenshot(
*,
user_id: int,
account_id: str,
account,
browse_type: str,
source: str,
task_start_time,
browse_result,
) -> None:
"""截图完成后写入任务日志(浏览完成日志)。"""
import time as time_module
total_elapsed = int(time_module.time() - task_start_time)
database.create_task_log(
user_id=user_id,
account_id=account_id,
username=account.username,
browse_type=browse_type,
status="success",
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
duration=total_elapsed,
source=source,
)
def take_screenshot_for_account(
user_id,
account_id,
@@ -213,21 +502,21 @@ def take_screenshot_for_account(
# 标记账号正在截图(防止重复提交截图任务)
account.is_running = True
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
def screenshot_task(
browser_instance, user_id, account_id, account, browse_type, source, task_start_time, browse_result
):
"""在worker线程中执行的截图任务"""
# ✅ 获得worker后立即更新状态为"截图中"
acc = safe_get_account(user_id, account_id)
if acc:
acc.status = "截图中"
safe_update_task_status(account_id, {"status": "运行中", "detail_status": "正在截图"})
_emit("account_update", acc.to_dict(), room=f"user_{user_id}")
_set_screenshot_running_status(user_id, account_id)
max_retries = 3
proxy_config = account.proxy_config if hasattr(account, "proxy_config") else None
proxy_server = proxy_config.get("server") if proxy_config else None
proxy_config, proxy_server = _get_proxy_context(account)
cookie_path = get_cookie_jar_path(account.username)
index_url, target_url, run_script = _build_screenshot_targets(browse_type)
for attempt in range(1, max_retries + 1):
try:
@@ -239,8 +528,7 @@ def take_screenshot_for_account(
if attempt > 1:
log_to_client(f"🔄 第 {attempt} 次截图尝试...", user_id, account_id)
worker_id = browser_instance.get("worker_id", "?") if isinstance(browser_instance, dict) else "?"
use_count = browser_instance.get("use_count", 0) if isinstance(browser_instance, dict) else 0
worker_id, use_count = _get_worker_display_info(browser_instance)
log_to_client(
f"使用Worker-{worker_id}执行截图(已执行{use_count}次)",
user_id,
@@ -250,99 +538,39 @@ def take_screenshot_for_account(
def custom_log(message: str):
log_to_client(message, user_id, account_id)
# 智能登录状态检查:只在必要时才刷新登录
should_refresh_login = not is_cookie_jar_fresh(cookie_path)
if should_refresh_login and attempt > 1:
# 重试时刷新登录attempt > 1 表示第2次及以后的尝试
log_to_client("正在刷新登录态...", user_id, account_id)
if not _ensure_login_cookies(account, proxy_config, custom_log):
log_to_client("截图登录失败", user_id, account_id)
if attempt < max_retries:
log_to_client("将重试...", user_id, account_id)
time.sleep(2)
continue
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
return {"success": False, "error": "登录失败"}
elif should_refresh_login:
# 首次尝试时快速检查登录状态
log_to_client("正在刷新登录态...", user_id, account_id)
if not _ensure_login_cookies(account, proxy_config, custom_log):
log_to_client("❌ 截图失败: 登录失败", user_id, account_id)
return {"success": False, "error": "登录失败"}
login_state = _ensure_screenshot_login_state(
account=account,
proxy_config=proxy_config,
cookie_path=cookie_path,
attempt=attempt,
max_retries=max_retries,
user_id=user_id,
account_id=account_id,
custom_log=custom_log,
)
if login_state == "retry":
continue
if login_state == "fail":
return {"success": False, "error": "登录失败"}
log_to_client(f"导航到 '{browse_type}' 页面...", user_id, account_id)
from urllib.parse import urlsplit
parsed = urlsplit(config.ZSGL_LOGIN_URL)
base = f"{parsed.scheme}://{parsed.netloc}"
if "注册前" in str(browse_type):
bz = 0
else:
bz = 0 # 应读(网站更新后 bz=0 为应读)
target_url = f"{base}/admin/center.aspx?bz={bz}"
index_url = config.ZSGL_INDEX_URL or f"{base}/admin/index.aspx"
run_script = (
"(function(){"
"function done(){window.status='ready';}"
"function ensureNav(){try{if(typeof loadMenuTree==='function'){loadMenuTree(true);}}catch(e){}}"
"function expandMenu(){"
"try{var body=document.body;if(body&&body.classList.contains('lay-mini')){body.classList.remove('lay-mini');}}catch(e){}"
"try{if(typeof mainPageResize==='function'){mainPageResize();}}catch(e){}"
"try{if(typeof toggleMainMenu==='function' && document.body && document.body.classList.contains('lay-mini')){toggleMainMenu();}}catch(e){}"
"try{var navRight=document.querySelector('.nav-right');if(navRight){navRight.style.display='block';}}catch(e){}"
"try{var mainNav=document.getElementById('main-nav');if(mainNav){mainNav.style.display='block';}}catch(e){}"
"}"
"function navReady(){"
"try{var nav=document.getElementById('sidebar-nav');return nav && nav.querySelectorAll('a').length>0;}catch(e){return false;}"
"}"
"function frameReady(){"
"try{var f=document.getElementById('mainframe');return f && f.contentDocument && f.contentDocument.readyState==='complete';}catch(e){return false;}"
"}"
"function check(){"
"if(navReady() && frameReady()){done();return;}"
"setTimeout(check,300);"
"}"
"var f=document.getElementById('mainframe');"
"ensureNav();"
"expandMenu();"
"if(!f){done();return;}"
f"f.src='{target_url}';"
"f.onload=function(){ensureNav();expandMenu();setTimeout(check,300);};"
"setTimeout(check,5000);"
"})();"
)
timestamp = get_beijing_now().strftime("%Y%m%d_%H%M%S")
user_info = database.get_user_by_id(user_id)
username_prefix = user_info["username"] if user_info else f"user{user_id}"
login_account = account.remark if account.remark else account.username
screenshot_filename = f"{username_prefix}_{login_account}_{browse_type}_{timestamp}.jpg"
screenshot_path = os.path.join(SCREENSHOTS_DIR, screenshot_filename)
cookies_for_shot = cookie_path if is_cookie_jar_fresh(cookie_path) else None
if take_screenshot_wkhtmltoimage(
index_url,
screenshot_path,
cookies_path=cookies_for_shot,
screenshot_filename, screenshot_path = _build_screenshot_output_path(username_prefix, account, browse_type)
shot_state = _take_screenshot_once(
index_url=index_url,
target_url=target_url,
screenshot_path=screenshot_path,
cookie_path=cookie_path,
proxy_server=proxy_server,
run_script=run_script,
window_status="ready",
log_callback=custom_log,
) or take_screenshot_wkhtmltoimage(
target_url,
screenshot_path,
cookies_path=cookies_for_shot,
proxy_server=proxy_server,
log_callback=custom_log,
):
if os.path.exists(screenshot_path) and os.path.getsize(screenshot_path) > 1000:
log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
return {"success": True, "filename": screenshot_filename}
)
if shot_state == "success":
log_to_client(f"[OK] 截图成功: {screenshot_filename}", user_id, account_id)
return {"success": True, "filename": screenshot_filename}
if shot_state == "invalid":
log_to_client("截图文件异常,将重试", user_id, account_id)
if os.path.exists(screenshot_path):
os.remove(screenshot_path)
else:
log_to_client("截图保存失败", user_id, account_id)
@@ -361,12 +589,7 @@ def take_screenshot_for_account(
def screenshot_callback(result, error):
"""截图完成回调"""
try:
account.is_running = False
account.status = "未开始"
safe_remove_task_status(account_id)
_emit("account_update", account.to_dict(), room=f"user_{user_id}")
_finalize_screenshot_callback_state(user_id, account_id, account)
if error:
log_to_client(f"❌ 截图失败: {error}", user_id, account_id)
@@ -375,84 +598,27 @@ def take_screenshot_for_account(
log_to_client(f"❌ 截图失败: {error_msg}", user_id, account_id)
if task_start_time and browse_result:
import time as time_module
total_elapsed = int(time_module.time() - task_start_time)
database.create_task_log(
_persist_browse_log_after_screenshot(
user_id=user_id,
account_id=account_id,
username=account.username,
account=account,
browse_type=browse_type,
status="success",
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
duration=total_elapsed,
source=source,
task_start_time=task_start_time,
browse_result=browse_result,
)
try:
batch_id = _get_batch_id_from_source(source)
screenshot_path = None
if result and result.get("success") and result.get("filename"):
screenshot_path = os.path.join(SCREENSHOTS_DIR, result["filename"])
account_name = account.remark if account.remark else account.username
try:
if screenshot_path and result and result.get("success"):
cfg = database.get_system_config() or {}
if int(cfg.get("kdocs_enabled", 0) or 0) == 1:
doc_url = (cfg.get("kdocs_doc_url") or "").strip()
if doc_url:
user_cfg = database.get_user_kdocs_settings(user_id) or {}
if int(user_cfg.get("kdocs_auto_upload", 0) or 0) == 1:
unit = (
user_cfg.get("kdocs_unit") or cfg.get("kdocs_default_unit") or ""
).strip()
name = (account.remark or "").strip()
if unit and name:
from services.kdocs_uploader import get_kdocs_uploader
ok = get_kdocs_uploader().enqueue_upload(
user_id=user_id,
account_id=account_id,
unit=unit,
name=name,
image_path=screenshot_path,
)
if not ok:
log_to_client("表格上传排队失败: 队列已满", user_id, account_id)
else:
if not unit:
log_to_client("表格上传跳过: 未配置县区", user_id, account_id)
if not name:
log_to_client("表格上传跳过: 账号备注为空", user_id, account_id)
except Exception as kdocs_error:
logger.warning(f"表格上传任务提交失败: {kdocs_error}")
if batch_id:
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=screenshot_path,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
)
elif source and source.startswith("user_scheduled"):
user_info = database.get_user_by_id(user_id)
if user_info and user_info.get("email") and database.get_user_email_notify(user_id):
email_service.send_task_complete_email_async(
user_id=user_id,
email=user_info["email"],
username=user_info["username"],
account_name=account_name,
browse_type=browse_type,
total_items=browse_result.get("total_items", 0),
total_attachments=browse_result.get("total_attachments", 0),
screenshot_path=screenshot_path,
log_callback=lambda msg: log_to_client(msg, user_id, account_id),
)
_dispatch_screenshot_result(
user_id=user_id,
account_id=account_id,
source=source,
browse_type=browse_type,
browse_result=browse_result,
result=result,
account=account,
user_info=user_info,
)
except Exception as email_error:
logger.warning(f"发送任务完成邮件失败: {email_error}")
except Exception as e:

View File

@@ -13,7 +13,7 @@ from __future__ import annotations
import threading
import time
import random
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from app_config import get_config
@@ -161,6 +161,36 @@ _log_cache_lock = threading.RLock()
_log_cache_total_count = 0
def _pop_oldest_log_for_user(uid: int) -> bool:
global _log_cache_total_count
logs = _log_cache.get(uid)
if not logs:
_log_cache.pop(uid, None)
return False
logs.pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
if not logs:
_log_cache.pop(uid, None)
return True
def _pop_oldest_log_from_largest_user() -> bool:
largest_uid = None
largest_size = 0
for uid, logs in _log_cache.items():
size = len(logs)
if size > largest_size:
largest_uid = uid
largest_size = size
if largest_uid is None or largest_size <= 0:
return False
return _pop_oldest_log_for_user(int(largest_uid))
def safe_add_log(
user_id: int,
log_entry: Dict[str, Any],
@@ -175,24 +205,17 @@ def safe_add_log(
max_total_logs = int(max_total_logs or config.MAX_TOTAL_LOGS)
with _log_cache_lock:
if uid not in _log_cache:
_log_cache[uid] = []
logs = _log_cache.setdefault(uid, [])
if len(_log_cache[uid]) >= max_logs_per_user:
_log_cache[uid].pop(0)
_log_cache_total_count = max(0, _log_cache_total_count - 1)
if len(logs) >= max_logs_per_user:
_pop_oldest_log_for_user(uid)
logs = _log_cache.setdefault(uid, [])
_log_cache[uid].append(dict(log_entry or {}))
logs.append(dict(log_entry or {}))
_log_cache_total_count += 1
while _log_cache_total_count > max_total_logs:
if not _log_cache:
break
max_user = max(_log_cache.keys(), key=lambda u: len(_log_cache[u]))
if _log_cache.get(max_user):
_log_cache[max_user].pop(0)
_log_cache_total_count -= 1
else:
if not _pop_oldest_log_from_largest_user():
break
@@ -378,6 +401,34 @@ def _get_action_rate_limit(action: str) -> Tuple[int, int]:
return int(config.IP_RATE_LIMIT_LOGIN_MAX), int(config.IP_RATE_LIMIT_LOGIN_WINDOW_SECONDS)
def _format_wait_hint(remaining_seconds: int) -> str:
remaining = max(1, int(remaining_seconds or 0))
if remaining >= 60:
return f"{remaining // 60 + 1}分钟"
return f"{remaining}"
def _check_and_increment_rate_bucket(
*,
buckets: Dict[str, Dict[str, Any]],
key: str,
now_ts: float,
max_requests: int,
window_seconds: int,
) -> Tuple[bool, Optional[str]]:
if not key or int(max_requests) <= 0:
return True, None
data = _get_or_reset_bucket(buckets.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= int(max_requests):
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
return False, f"请求过于频繁,请{_format_wait_hint(remaining)}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
buckets[key] = data
return True, None
def check_ip_request_rate(
ip_address: str,
action: str,
@@ -392,21 +443,13 @@ def check_ip_request_rate(
key = f"{action}:{ip_address}"
with _ip_request_rate_lock:
data = _ip_request_rate.get(key)
if not data or (now_ts - float(data.get("window_start", 0) or 0)) >= window_seconds:
data = {"window_start": now_ts, "count": 0}
_ip_request_rate[key] = data
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
if remaining >= 60:
wait_hint = f"{remaining // 60 + 1}分钟"
else:
wait_hint = f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
return True, None
return _check_and_increment_rate_bucket(
buckets=_ip_request_rate,
key=key,
now_ts=now_ts,
max_requests=max_requests,
window_seconds=window_seconds,
)
def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
@@ -417,8 +460,7 @@ def cleanup_expired_ip_request_rates(now_ts: Optional[float] = None) -> int:
data = _ip_request_rate.get(key) or {}
action = key.split(":", 1)[0]
_, window_seconds = _get_action_rate_limit(action)
window_start = float(data.get("window_start", 0) or 0)
if now_ts - window_start >= window_seconds:
if _is_bucket_expired(data, now_ts, window_seconds):
_ip_request_rate.pop(key, None)
removed += 1
return removed
@@ -487,6 +529,30 @@ def _get_or_reset_bucket(data: Optional[Dict[str, Any]], now_ts: float, window_s
return data
def _is_bucket_expired(
data: Optional[Dict[str, Any]],
now_ts: float,
window_seconds: int,
*,
time_field: str = "window_start",
) -> bool:
start_ts = float((data or {}).get(time_field, 0) or 0)
return (now_ts - start_ts) >= max(1, int(window_seconds))
def _cleanup_map_entries(
store: Dict[Any, Dict[str, Any]],
should_remove: Callable[[Dict[str, Any]], bool],
) -> int:
removed = 0
for key, value in list(store.items()):
item = value if isinstance(value, dict) else {}
if should_remove(item):
store.pop(key, None)
removed += 1
return removed
def record_login_username_attempt(ip_address: str, username: str) -> bool:
now_ts = time.time()
threshold, window_seconds, cooldown_seconds = _get_login_scan_config()
@@ -527,26 +593,32 @@ def check_login_rate_limits(ip_address: str, username: str) -> Tuple[bool, Optio
user_key = _normalize_login_key("user", "", username)
ip_user_key = _normalize_login_key("ipuser", ip_address, username)
def _check(key: str, max_requests: int) -> Tuple[bool, Optional[str]]:
if not key or max_requests <= 0:
return True, None
data = _get_or_reset_bucket(_login_rate_limits.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_login_rate_limits[key] = data
return True, None
with _login_rate_limits_lock:
allowed, msg = _check(ip_key, ip_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=ip_key,
now_ts=now_ts,
max_requests=ip_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
allowed, msg = _check(ip_user_key, ip_user_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=ip_user_key,
now_ts=now_ts,
max_requests=ip_user_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
allowed, msg = _check(user_key, user_max)
allowed, msg = _check_and_increment_rate_bucket(
buckets=_login_rate_limits,
key=user_key,
now_ts=now_ts,
max_requests=user_max,
window_seconds=window_seconds,
)
if not allowed:
return False, msg
@@ -622,15 +694,18 @@ def check_login_captcha_required(ip_address: str, username: Optional[str] = None
ip_key = _normalize_login_key("ip", ip_address)
ip_user_key = _normalize_login_key("ipuser", ip_address, username or "")
def _is_over_threshold(data: Optional[Dict[str, Any]]) -> bool:
if not data:
return False
if (now_ts - float(data.get("first_failed", 0) or 0)) > window_seconds:
return False
return int(data.get("count", 0) or 0) >= max_failures
with _login_failures_lock:
ip_data = _login_failures.get(ip_key)
if ip_data and (now_ts - float(ip_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_data.get("count", 0) or 0) >= max_failures:
return True
ip_user_data = _login_failures.get(ip_user_key)
if ip_user_data and (now_ts - float(ip_user_data.get("first_failed", 0) or 0)) <= window_seconds:
if int(ip_user_data.get("count", 0) or 0) >= max_failures:
return True
if _is_over_threshold(_login_failures.get(ip_key)):
return True
if _is_over_threshold(_login_failures.get(ip_user_key)):
return True
if is_login_scan_locked(ip_address):
return True
@@ -685,6 +760,56 @@ def should_send_login_alert(user_id: int, ip_address: str) -> bool:
return False
def cleanup_expired_login_security_state(now_ts: Optional[float] = None) -> Dict[str, int]:
now_ts = float(now_ts if now_ts is not None else time.time())
_, captcha_window = _get_login_captcha_config()
_, _, _, rate_window = _get_login_rate_limit_config()
_, lock_window, _ = _get_login_lock_config()
_, scan_window, _ = _get_login_scan_config()
alert_expire_seconds = max(3600, int(config.LOGIN_ALERT_MIN_INTERVAL_SECONDS) * 3)
with _login_failures_lock:
failures_removed = _cleanup_map_entries(
_login_failures,
lambda data: (now_ts - float(data.get("first_failed", 0) or 0)) > max(captcha_window, lock_window),
)
with _login_rate_limits_lock:
rate_removed = _cleanup_map_entries(
_login_rate_limits,
lambda data: _is_bucket_expired(data, now_ts, rate_window),
)
with _login_scan_lock:
scan_removed = _cleanup_map_entries(
_login_scan_state,
lambda data: (
(now_ts - float(data.get("first_seen", 0) or 0)) > scan_window
and now_ts >= float(data.get("scan_until", 0) or 0)
),
)
with _login_ip_user_lock:
ip_user_locks_removed = _cleanup_map_entries(
_login_ip_user_locks,
lambda data: now_ts >= float(data.get("lock_until", 0) or 0),
)
with _login_alert_lock:
alerts_removed = _cleanup_map_entries(
_login_alert_state,
lambda data: (now_ts - float(data.get("last_sent", 0) or 0)) > alert_expire_seconds,
)
return {
"failures": failures_removed,
"rate_limits": rate_removed,
"scan_states": scan_removed,
"ip_user_locks": ip_user_locks_removed,
"alerts": alerts_removed,
}
# ==================== 邮箱维度限流 ====================
_email_rate_limit: Dict[str, Dict[str, Any]] = {}
@@ -701,14 +826,13 @@ def check_email_rate_limit(email: str, action: str) -> Tuple[bool, Optional[str]
key = f"{action}:{email_key}"
with _email_rate_limit_lock:
data = _get_or_reset_bucket(_email_rate_limit.get(key), now_ts, window_seconds)
if int(data.get("count", 0) or 0) >= max_requests:
remaining = max(1, int(window_seconds - (now_ts - float(data.get("window_start", 0) or 0))))
wait_hint = f"{remaining // 60 + 1}分钟" if remaining >= 60 else f"{remaining}"
return False, f"请求过于频繁,请{wait_hint}后再试"
data["count"] = int(data.get("count", 0) or 0) + 1
_email_rate_limit[key] = data
return True, None
return _check_and_increment_rate_bucket(
buckets=_email_rate_limit,
key=key,
now_ts=now_ts,
max_requests=max_requests,
window_seconds=window_seconds,
)
# ==================== Batch screenshots批次任务截图收集 ====================

365
services/task_scheduler.py Normal file
View File

@@ -0,0 +1,365 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import heapq
import threading
import time
from concurrent.futures import ThreadPoolExecutor, wait
from dataclasses import dataclass
import database
from app_logger import get_logger
from services.state import safe_get_account, safe_get_task, safe_remove_task, safe_set_task
from services.task_batches import _batch_task_record_result, _get_batch_id_from_source
logger = get_logger("app")
# VIP优先级队列仅用于可视化/调试)
vip_task_queue = [] # VIP用户任务队列
normal_task_queue = [] # 普通用户任务队列
task_queue_lock = threading.Lock()
@dataclass
class _TaskRequest:
user_id: int
account_id: str
browse_type: str
enable_screenshot: bool
source: str
retry_count: int
submitted_at: float
is_vip: bool
seq: int
canceled: bool = False
done_callback: object = None
class TaskScheduler:
"""全局任务调度器:队列排队,不为每个任务单独创建线程。"""
def __init__(self, max_global: int, max_per_user: int, max_queue_size: int = 1000, run_task_fn=None):
self.max_global = max(1, int(max_global))
self.max_per_user = max(1, int(max_per_user))
self.max_queue_size = max(1, int(max_queue_size))
self._cond = threading.Condition()
self._pending = [] # heap: (priority, submitted_at, seq, task)
self._pending_by_account = {} # {account_id: task}
self._seq = 0
self._known_account_ids = set()
self._running_global = 0
self._running_by_user = {} # {user_id: running_count}
self._executor_max_workers = self.max_global
self._executor = ThreadPoolExecutor(max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker")
self._futures_lock = threading.Lock()
self._active_futures = set()
self._running = True
self._run_task_fn = run_task_fn
self._dispatcher_thread = threading.Thread(target=self._dispatch_loop, daemon=True, name="TaskDispatcher")
self._dispatcher_thread.start()
def _track_future(self, future) -> None:
with self._futures_lock:
self._active_futures.add(future)
try:
future.add_done_callback(self._untrack_future)
except Exception:
pass
def _untrack_future(self, future) -> None:
with self._futures_lock:
self._active_futures.discard(future)
def shutdown(self, timeout: float = 5.0):
"""停止调度器(用于进程退出清理)"""
with self._cond:
self._running = False
self._cond.notify_all()
try:
self._dispatcher_thread.join(timeout=timeout)
except Exception:
pass
# 等待已提交的任务收尾(最多等待 timeout 秒),避免遗留 active_task 干扰后续调度/测试
try:
deadline = time.time() + max(0.0, float(timeout or 0))
while True:
with self._futures_lock:
pending = [f for f in self._active_futures if not f.done()]
if not pending:
break
remaining = deadline - time.time()
if remaining <= 0:
break
wait(pending, timeout=remaining)
except Exception:
pass
try:
self._executor.shutdown(wait=False)
except Exception:
pass
# 最后兜底:清理本调度器提交过的 active_task避免测试/重启时被“任务已在运行中”误拦截
try:
with self._cond:
known_ids = set(self._known_account_ids) | set(self._pending_by_account.keys())
self._pending.clear()
self._pending_by_account.clear()
self._cond.notify_all()
for account_id in known_ids:
safe_remove_task(account_id)
except Exception:
pass
def update_limits(self, max_global: int = None, max_per_user: int = None, max_queue_size: int = None):
"""动态更新并发/队列上限(不影响已在运行的任务)"""
with self._cond:
if max_per_user is not None:
self.max_per_user = max(1, int(max_per_user))
if max_queue_size is not None:
self.max_queue_size = max(1, int(max_queue_size))
if max_global is not None:
new_max_global = max(1, int(max_global))
self.max_global = new_max_global
if new_max_global > self._executor_max_workers:
# 立即关闭旧线程池,防止资源泄漏
old_executor = self._executor
self._executor_max_workers = new_max_global
self._executor = ThreadPoolExecutor(
max_workers=self._executor_max_workers, thread_name_prefix="TaskWorker"
)
# 立即关闭旧线程池
try:
old_executor.shutdown(wait=False)
logger.info(f"线程池已扩容:{old_executor._max_workers} -> {self._executor_max_workers}")
except Exception as e:
logger.warning(f"关闭旧线程池失败: {e}")
self._cond.notify_all()
def get_queue_state_snapshot(self) -> dict:
"""获取调度器队列/运行状态快照(用于前端展示/监控)。"""
with self._cond:
pending_tasks = [t for t in self._pending_by_account.values() if t and not t.canceled]
pending_tasks.sort(key=lambda t: (0 if t.is_vip else 1, t.submitted_at, t.seq))
positions = {}
for idx, t in enumerate(pending_tasks):
positions[t.account_id] = {"queue_position": idx + 1, "queue_ahead": idx, "is_vip": bool(t.is_vip)}
return {
"pending_total": len(pending_tasks),
"running_total": int(self._running_global),
"running_by_user": dict(self._running_by_user),
"positions": positions,
}
def submit_task(
self,
user_id: int,
account_id: str,
browse_type: str,
enable_screenshot: bool = True,
source: str = "manual",
retry_count: int = 0,
is_vip: bool = None,
done_callback=None,
):
"""提交任务进入队列(返回: (ok, message)"""
if not user_id or not account_id:
return False, "参数错误"
submitted_at = time.time()
if is_vip is None:
try:
is_vip = bool(database.is_user_vip(user_id))
except Exception:
is_vip = False
else:
is_vip = bool(is_vip)
with self._cond:
if not self._running:
return False, "调度器未运行"
if len(self._pending_by_account) >= self.max_queue_size:
return False, "任务队列已满,请稍后再试"
if account_id in self._pending_by_account:
return False, "任务已在队列中"
if safe_get_task(account_id) is not None:
return False, "任务已在运行中"
self._seq += 1
task = _TaskRequest(
user_id=user_id,
account_id=account_id,
browse_type=browse_type,
enable_screenshot=bool(enable_screenshot),
source=source,
retry_count=int(retry_count or 0),
submitted_at=submitted_at,
is_vip=is_vip,
seq=self._seq,
done_callback=done_callback,
)
self._pending_by_account[account_id] = task
self._known_account_ids.add(account_id)
priority = 0 if is_vip else 1
heapq.heappush(self._pending, (priority, task.submitted_at, task.seq, task))
self._cond.notify_all()
# 用于可视化/调试:记录队列
with task_queue_lock:
if is_vip:
vip_task_queue.append(account_id)
else:
normal_task_queue.append(account_id)
return True, "已加入队列"
def cancel_pending_task(self, user_id: int, account_id: str) -> bool:
"""取消尚未开始的排队任务(已运行的任务由 should_stop 控制)"""
canceled_task = None
with self._cond:
task = self._pending_by_account.pop(account_id, None)
if not task:
return False
task.canceled = True
canceled_task = task
self._cond.notify_all()
# 从可视化队列移除
with task_queue_lock:
if account_id in vip_task_queue:
vip_task_queue.remove(account_id)
if account_id in normal_task_queue:
normal_task_queue.remove(account_id)
# 批次任务:取消也要推进完成计数,避免批次缓存常驻
try:
batch_id = _get_batch_id_from_source(canceled_task.source)
if batch_id:
acc = safe_get_account(user_id, account_id)
if acc:
account_name = acc.remark if acc.remark else acc.username
else:
account_name = account_id
_batch_task_record_result(
batch_id=batch_id,
account_name=account_name,
screenshot_path=None,
total_items=0,
total_attachments=0,
)
except Exception:
pass
return True
def _dispatch_loop(self):
while True:
task = None
with self._cond:
if not self._running:
return
if not self._pending or self._running_global >= self.max_global:
self._cond.wait(timeout=0.5)
continue
task = self._pop_next_runnable_locked()
if task is None:
self._cond.wait(timeout=0.5)
continue
self._running_global += 1
self._running_by_user[task.user_id] = self._running_by_user.get(task.user_id, 0) + 1
# 从队列移除(可视化)
with task_queue_lock:
if task.account_id in vip_task_queue:
vip_task_queue.remove(task.account_id)
if task.account_id in normal_task_queue:
normal_task_queue.remove(task.account_id)
try:
future = self._executor.submit(self._run_task_wrapper, task)
self._track_future(future)
safe_set_task(task.account_id, future)
except Exception:
with self._cond:
self._running_global = max(0, self._running_global - 1)
# 使用默认值 0 与增加时保持一致
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
if self._running_by_user.get(task.user_id) == 0:
self._running_by_user.pop(task.user_id, None)
self._cond.notify_all()
def _pop_next_runnable_locked(self):
"""在锁内从优先队列取出“可运行”的任务避免VIP任务占位阻塞普通任务。"""
if not self._pending:
return None
skipped = []
selected = None
while self._pending:
_, _, _, task = heapq.heappop(self._pending)
if task.canceled:
continue
if self._pending_by_account.get(task.account_id) is not task:
continue
running_for_user = self._running_by_user.get(task.user_id, 0)
if running_for_user >= self.max_per_user:
skipped.append(task)
continue
selected = task
break
for t in skipped:
priority = 0 if t.is_vip else 1
heapq.heappush(self._pending, (priority, t.submitted_at, t.seq, t))
if selected is None:
return None
self._pending_by_account.pop(selected.account_id, None)
return selected
def _run_task_wrapper(self, task: _TaskRequest):
try:
if callable(self._run_task_fn):
self._run_task_fn(
user_id=task.user_id,
account_id=task.account_id,
browse_type=task.browse_type,
enable_screenshot=task.enable_screenshot,
source=task.source,
retry_count=task.retry_count,
)
finally:
try:
if callable(task.done_callback):
task.done_callback()
except Exception:
pass
safe_remove_task(task.account_id)
with self._cond:
self._running_global = max(0, self._running_global - 1)
# 使用默认值 0 与增加时保持一致
self._running_by_user[task.user_id] = max(0, self._running_by_user.get(task.user_id, 0) - 1)
if self._running_by_user.get(task.user_id) == 0:
self._running_by_user.pop(task.user_id, None)
self._cond.notify_all()

File diff suppressed because it is too large Load Diff