#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 知识管理平台自动化工具 - 多用户版本 说明(P0–P3 优化后): - `app.py` 仅保留启动/装配层 - 路由拆分到 `routes/`(Blueprint) - 业务拆分到 `services/`(tasks/screenshots/scheduler/...) - SocketIO 事件拆分到 `realtime/` """ from __future__ import annotations import atexit import os import signal import sys import threading import time from flask import Flask, g, jsonify, redirect, request, send_from_directory, session, url_for from flask_login import LoginManager, current_user from flask_socketio import SocketIO import database import db_pool import email_service from app_config import get_config from app_logger import get_logger, init_logging from app_security import generate_csrf_token, is_safe_path, validate_csrf_token from browser_pool_worker import init_browser_worker_pool, shutdown_browser_worker_pool from realtime.socketio_handlers import register_socketio_handlers from realtime.status_push import status_push_worker from routes import register_blueprints from security import init_security_middleware from services.checkpoints import init_checkpoint_manager from services.maintenance import start_cleanup_scheduler, start_database_maintenance_scheduler, start_kdocs_monitor from services.request_metrics import record_request_metric from services.models import User from services.runtime import init_runtime from services.scheduler import scheduled_task_worker from services.state import safe_iter_user_accounts_items from services.tasks import get_task_scheduler # ==================== 进程级基础设置 ==================== # 设置时区为中国标准时间(CST, UTC+8) os.environ["TZ"] = "Asia/Shanghai" try: import time as _time _time.tzset() except Exception: pass def _sigchld_handler(signum, frame): """SIGCHLD 信号处理器 - 自动回收僵尸子进程(Docker PID 1 场景)""" while True: try: pid, _ = os.waitpid(-1, os.WNOHANG) if pid == 0: break except ChildProcessError: break except Exception: break if os.name != "nt": signal.signal(signal.SIGCHLD, _sigchld_handler) # ==================== Flask / SocketIO 装配 ==================== config = get_config() app = Flask(__name__) app.config.from_object(config) if not app.config.get("SECRET_KEY"): raise RuntimeError("SECRET_KEY未配置,请检查 app_config.py 或环境变量") cors_origins = os.environ.get("CORS_ALLOWED_ORIGINS", "").strip() cors_allowed = [o.strip() for o in cors_origins.split(",") if o.strip()] if cors_origins else [] socketio = SocketIO( app, cors_allowed_origins=cors_allowed if cors_allowed else None, async_mode="threading", ping_timeout=60, ping_interval=25, logger=False, engineio_logger=False, ) init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE) logger = get_logger("app") init_runtime(socketio=socketio, logger=logger) _API_DIAGNOSTIC_LOG = str(os.environ.get("API_DIAGNOSTIC_LOG", "0")).strip().lower() in { "1", "true", "yes", "on", } _API_DIAGNOSTIC_SLOW_MS = max(0.0, float(os.environ.get("API_DIAGNOSTIC_SLOW_MS", "0") or 0.0)) def _is_api_or_health_path(path: str) -> bool: raw = str(path or "") return raw.startswith("/api/") or raw.startswith("/yuyx/api/") or raw == "/health" # 初始化安全中间件(需在其他中间件/Blueprint 之前注册) init_security_middleware(app) # 注册 Blueprint(路由不变) register_blueprints(app) # 注册 SocketIO 事件 register_socketio_handlers(socketio) # ==================== Flask-Login 配置 ==================== login_manager = LoginManager() login_manager.init_app(app) login_manager.login_view = "pages.login_page" @login_manager.user_loader def load_user(user_id: str): user = database.get_user_by_id(int(user_id)) if user: return User(user["id"]) return None @login_manager.unauthorized_handler def unauthorized(): """未授权访问:API 返回 JSON,页面重定向登录页(行为保持不变)。""" if request.path.startswith("/api/") or request.path.startswith("/yuyx/api/"): return jsonify({"error": "请先登录", "code": "unauthorized"}), 401 return redirect(url_for("pages.login_page", next=request.url)) @app.before_request def track_request_start_time(): g.request_start_perf = time.perf_counter() @app.before_request def enforce_csrf_protection(): if request.method in {"GET", "HEAD", "OPTIONS"}: return if request.path.startswith("/static/"): return # 登录相关路由豁免 CSRF 检查(登录本身就是建立 session 的过程) csrf_exempt_paths = {"/yuyx/api/login", "/api/login", "/api/auth/login"} if request.path in csrf_exempt_paths: return if not (current_user.is_authenticated or "admin_id" in session): return token = request.headers.get("X-CSRF-Token") or request.form.get("csrf_token") if not token or not validate_csrf_token(token): return jsonify({"error": "CSRF token missing or invalid"}), 403 def _record_request_metric_after_response(response) -> None: try: started = float(getattr(g, "request_start_perf", 0.0) or 0.0) if started <= 0: return duration_ms = max(0.0, (time.perf_counter() - started) * 1000.0) path = request.path or "/" method = request.method or "GET" status_code = int(getattr(response, "status_code", 0) or 0) is_api = _is_api_or_health_path(path) record_request_metric( path=path, method=method, status_code=status_code, duration_ms=duration_ms, is_api=is_api, ) if _API_DIAGNOSTIC_LOG and is_api: is_slow = _API_DIAGNOSTIC_SLOW_MS > 0 and duration_ms >= _API_DIAGNOSTIC_SLOW_MS is_server_error = status_code >= 500 if is_slow or is_server_error: logger.warning( f"[API-DIAG] {method} {path} -> {status_code} ({duration_ms:.1f}ms)" ) except Exception: pass @app.after_request def ensure_csrf_cookie(response): if not request.path.startswith("/static/"): token = session.get("csrf_token") if not token: token = generate_csrf_token() response.set_cookie( "csrf_token", token, httponly=False, secure=bool(config.SESSION_COOKIE_SECURE), samesite=config.SESSION_COOKIE_SAMESITE, ) _record_request_metric_after_response(response) return response # ==================== 静态文件(保持 endpoint 名称不变) ==================== @app.route("/static/") def serve_static(filename): if not is_safe_path("static", filename): return jsonify({"error": "非法路径"}), 403 cache_ttl = 3600 lowered = filename.lower() if "/assets/" in lowered or lowered.endswith((".js", ".css", ".woff", ".woff2", ".ttf", ".svg")): cache_ttl = 604800 # 7天 if request.args.get("v"): cache_ttl = max(cache_ttl, 604800) response = send_from_directory("static", filename, max_age=cache_ttl, conditional=True) # 协商缓存:确保存在 ETag,并基于 If-None-Match/If-Modified-Since 返回 304 try: response.add_etag(overwrite=False) except Exception: pass try: response.make_conditional(request) except Exception: pass response.headers.setdefault("Vary", "Accept-Encoding") response.headers["Cache-Control"] = f"public, max-age={cache_ttl}" return response # ==================== 退出清理 ==================== def cleanup_on_exit(): logger.info("正在清理资源...") logger.info("- 停止运行中的任务...") try: for _, accounts in safe_iter_user_accounts_items(): for acc in accounts.values(): if getattr(acc, "is_running", False): acc.should_stop = True except Exception: pass logger.info("- 停止任务调度器...") try: scheduler = get_task_scheduler() scheduler.shutdown(timeout=5) except Exception: pass logger.info("- 关闭截图线程池...") try: shutdown_browser_worker_pool() except Exception: pass logger.info("- 关闭邮件队列...") try: email_service.shutdown_email_queue() except Exception: pass logger.info("- 关闭数据库连接池...") try: db_pool._pool.close_all() if db_pool._pool else None except Exception: pass logger.info("[OK] 资源清理完成") # ==================== 启动入口(保持 python app.py 可用) ==================== def _signal_handler(sig, frame): logger.info("收到退出信号,正在关闭...") cleanup_on_exit() sys.exit(0) def _cleanup_stale_task_state() -> None: logger.info("清理遗留任务状态...") try: from services.state import safe_get_active_task_ids, safe_remove_task, safe_remove_task_status for _, accounts in safe_iter_user_accounts_items(): for acc in accounts.values(): if not getattr(acc, "is_running", False): continue acc.is_running = False acc.should_stop = False acc.status = "未开始" for account_id in list(safe_get_active_task_ids()): safe_remove_task(account_id) safe_remove_task_status(account_id) logger.info("[OK] 遗留任务状态已清理") except Exception as e: logger.warning(f"清理遗留任务状态失败: {e}") def _init_optional_email_service() -> None: try: email_service.init_email_service() logger.info("[OK] 邮件服务已初始化") except Exception as e: logger.warning(f"警告: 邮件服务初始化失败: {e}") def _load_and_apply_scheduler_limits() -> None: try: system_config = database.get_system_config() or {} max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL)) max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT)) get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account) logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}") except Exception as e: logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}") def _start_background_workers() -> None: logger.info("启动定时任务调度器...") threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start() logger.info("[OK] 定时任务调度器已启动") logger.info("[OK] 状态推送线程已启动(默认2秒/次)") threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start() def _init_screenshot_worker_pool() -> None: try: pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3)) except Exception: pool_size = 3 try: logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...") init_browser_worker_pool(pool_size=pool_size) logger.info("[OK] 截图线程池初始化完成") except Exception as e: logger.warning(f"警告: 截图线程池初始化失败: {e}") def _warmup_api_connection() -> None: logger.info("预热 API 连接...") try: from api_browser import warmup_api_connection threading.Thread( target=warmup_api_connection, kwargs={"log_callback": lambda msg: logger.info(msg)}, daemon=True, name="api-warmup", ).start() except Exception as e: logger.warning(f"API 预热失败: {e}") def _log_startup_urls() -> None: logger.info("服务器启动中...") logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}") logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx") logger.info("默认管理员: admin (首次运行随机密码见日志)") logger.info("=" * 60) if __name__ == "__main__": atexit.register(cleanup_on_exit) signal.signal(signal.SIGINT, _signal_handler) signal.signal(signal.SIGTERM, _signal_handler) logger.info("=" * 60) logger.info("知识管理平台自动化工具 - 多用户版") logger.info("=" * 60) database.init_database() init_checkpoint_manager() logger.info("[OK] 任务断点管理器已初始化") _cleanup_stale_task_state() _init_optional_email_service() start_cleanup_scheduler() start_database_maintenance_scheduler() start_kdocs_monitor() _load_and_apply_scheduler_limits() _start_background_workers() _log_startup_urls() _init_screenshot_worker_pool() _warmup_api_connection() socketio.run( app, host=config.SERVER_HOST, port=config.SERVER_PORT, debug=config.DEBUG, allow_unsafe_werkzeug=True, )