#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 知识管理平台自动化工具 - 多用户版本 说明(P0–P3 优化后): - `app.py` 仅保留启动/装配层 - 路由拆分到 `routes/`(Blueprint) - 业务拆分到 `services/`(tasks/screenshots/scheduler/...) - SocketIO 事件拆分到 `realtime/` """ from __future__ import annotations import atexit import os import signal import sys import threading from flask import Flask, jsonify, redirect, request, send_from_directory, session, url_for from flask_login import LoginManager, current_user from flask_socketio import SocketIO import database import db_pool import email_service from app_config import get_config from app_logger import get_logger, init_logging from app_security import generate_csrf_token, is_safe_path, validate_csrf_token from browser_pool_worker import init_browser_worker_pool, shutdown_browser_worker_pool from realtime.socketio_handlers import register_socketio_handlers from realtime.status_push import status_push_worker from routes import register_blueprints from security import init_security_middleware from services.checkpoints import init_checkpoint_manager from services.maintenance import start_cleanup_scheduler, start_kdocs_monitor from services.models import User from services.runtime import init_runtime from services.scheduler import scheduled_task_worker from services.state import safe_iter_user_accounts_items from services.tasks import get_task_scheduler # ==================== 进程级基础设置 ==================== # 设置时区为中国标准时间(CST, UTC+8) os.environ["TZ"] = "Asia/Shanghai" try: import time as _time _time.tzset() except Exception: pass def _sigchld_handler(signum, frame): """SIGCHLD 信号处理器 - 自动回收僵尸子进程(Docker PID 1 场景)""" while True: try: pid, _ = os.waitpid(-1, os.WNOHANG) if pid == 0: break except ChildProcessError: break except Exception: break if os.name != "nt": signal.signal(signal.SIGCHLD, _sigchld_handler) # ==================== Flask / SocketIO 装配 ==================== config = get_config() app = Flask(__name__) app.config.from_object(config) if not app.config.get("SECRET_KEY"): raise RuntimeError("SECRET_KEY未配置,请检查 app_config.py 或环境变量") cors_origins = os.environ.get("CORS_ALLOWED_ORIGINS", "").strip() cors_allowed = [o.strip() for o in cors_origins.split(",") if o.strip()] if cors_origins else [] socketio = SocketIO( app, cors_allowed_origins=cors_allowed if cors_allowed else None, async_mode="threading", ping_timeout=60, ping_interval=25, logger=False, engineio_logger=False, ) init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE) logger = get_logger("app") init_runtime(socketio=socketio, logger=logger) # 初始化安全中间件(需在其他中间件/Blueprint 之前注册) init_security_middleware(app) # 注册 Blueprint(路由不变) register_blueprints(app) # 注册 SocketIO 事件 register_socketio_handlers(socketio) # ==================== Flask-Login 配置 ==================== login_manager = LoginManager() login_manager.init_app(app) login_manager.login_view = "pages.login_page" @login_manager.user_loader def load_user(user_id: str): user = database.get_user_by_id(int(user_id)) if user: return User(user["id"]) return None @login_manager.unauthorized_handler def unauthorized(): """未授权访问:API 返回 JSON,页面重定向登录页(行为保持不变)。""" if request.path.startswith("/api/") or request.path.startswith("/yuyx/api/"): return jsonify({"error": "请先登录", "code": "unauthorized"}), 401 return redirect(url_for("pages.login_page", next=request.url)) @app.before_request def enforce_csrf_protection(): if request.method in {"GET", "HEAD", "OPTIONS"}: return if request.path.startswith("/static/"): return # 登录相关路由豁免 CSRF 检查(登录本身就是建立 session 的过程) csrf_exempt_paths = {"/yuyx/api/login", "/api/login", "/api/auth/login"} if request.path in csrf_exempt_paths: return if not (current_user.is_authenticated or "admin_id" in session): return token = request.headers.get("X-CSRF-Token") or request.form.get("csrf_token") if not token or not validate_csrf_token(token): return jsonify({"error": "CSRF token missing or invalid"}), 403 @app.after_request def ensure_csrf_cookie(response): if request.path.startswith("/static/"): return response token = session.get("csrf_token") if not token: token = generate_csrf_token() response.set_cookie( "csrf_token", token, httponly=False, secure=bool(config.SESSION_COOKIE_SECURE), samesite=config.SESSION_COOKIE_SAMESITE, ) return response # ==================== 静态文件(保持 endpoint 名称不变) ==================== @app.route("/static/") def serve_static(filename): if not is_safe_path("static", filename): return jsonify({"error": "非法路径"}), 403 response = send_from_directory("static", filename) response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" return response # ==================== 退出清理 ==================== def cleanup_on_exit(): logger.info("正在清理资源...") logger.info("- 停止运行中的任务...") try: for _, accounts in safe_iter_user_accounts_items(): for acc in accounts.values(): if getattr(acc, "is_running", False): acc.should_stop = True except Exception: pass logger.info("- 停止任务调度器...") try: scheduler = get_task_scheduler() scheduler.shutdown(timeout=5) except Exception: pass logger.info("- 关闭截图线程池...") try: shutdown_browser_worker_pool() except Exception: pass logger.info("- 关闭邮件队列...") try: email_service.shutdown_email_queue() except Exception: pass logger.info("- 关闭数据库连接池...") try: db_pool._pool.close_all() if db_pool._pool else None except Exception: pass logger.info("[OK] 资源清理完成") # ==================== 启动入口(保持 python app.py 可用) ==================== def _signal_handler(sig, frame): logger.info("收到退出信号,正在关闭...") cleanup_on_exit() sys.exit(0) if __name__ == "__main__": atexit.register(cleanup_on_exit) signal.signal(signal.SIGINT, _signal_handler) signal.signal(signal.SIGTERM, _signal_handler) logger.info("=" * 60) logger.info("知识管理平台自动化工具 - 多用户版") logger.info("=" * 60) database.init_database() init_checkpoint_manager() logger.info("[OK] 任务断点管理器已初始化") # 【新增】容器重启时清理遗留的任务状态 logger.info("清理遗留任务状态...") try: from services.state import safe_remove_task, safe_get_active_task_ids, safe_remove_task_status # 重置所有账号的运行状态 for _, accounts in safe_iter_user_accounts_items(): for acc in accounts.values(): if getattr(acc, "is_running", False): acc.is_running = False acc.should_stop = False acc.status = "未开始" # 清理活跃任务句柄 for account_id in list(safe_get_active_task_ids()): safe_remove_task(account_id) safe_remove_task_status(account_id) logger.info("[OK] 遗留任务状态已清理") except Exception as e: logger.warning(f"清理遗留任务状态失败: {e}") try: email_service.init_email_service() logger.info("[OK] 邮件服务已初始化") except Exception as e: logger.warning(f"警告: 邮件服务初始化失败: {e}") start_cleanup_scheduler() start_kdocs_monitor() try: system_config = database.get_system_config() or {} max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL)) max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT)) get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account) logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}") except Exception as e: logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}") logger.info("启动定时任务调度器...") threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start() logger.info("[OK] 定时任务调度器已启动") logger.info("[OK] 状态推送线程已启动(默认2秒/次)") threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start() logger.info("服务器启动中...") logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}") logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx") logger.info("默认管理员: admin (首次运行随机密码见日志)") logger.info("=" * 60) try: pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3)) except Exception: pool_size = 3 try: logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...") init_browser_worker_pool(pool_size=pool_size) logger.info("[OK] 截图线程池初始化完成") except Exception as e: logger.warning(f"警告: 截图线程池初始化失败: {e}") # 预热 API 连接(后台进行,不阻塞启动) logger.info("预热 API 连接...") try: from api_browser import warmup_api_connection import threading threading.Thread( target=warmup_api_connection, kwargs={"log_callback": lambda msg: logger.info(msg)}, daemon=True, name="api-warmup", ).start() except Exception as e: logger.warning(f"API 预热失败: {e}") socketio.run( app, host=config.SERVER_HOST, port=config.SERVER_PORT, debug=config.DEBUG, allow_unsafe_werkzeug=True, )