326 lines
10 KiB
Python
326 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
知识管理平台自动化工具 - 多用户版本
|
||
|
||
说明(P0–P3 优化后):
|
||
- `app.py` 仅保留启动/装配层
|
||
- 路由拆分到 `routes/`(Blueprint)
|
||
- 业务拆分到 `services/`(tasks/screenshots/scheduler/...)
|
||
- SocketIO 事件拆分到 `realtime/`
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import atexit
|
||
import os
|
||
import signal
|
||
import sys
|
||
import threading
|
||
|
||
from flask import Flask, jsonify, redirect, request, send_from_directory, session, url_for
|
||
from flask_login import LoginManager, current_user
|
||
from flask_socketio import SocketIO
|
||
|
||
import database
|
||
import db_pool
|
||
import email_service
|
||
from app_config import get_config
|
||
from app_logger import get_logger, init_logging
|
||
from app_security import generate_csrf_token, is_safe_path, validate_csrf_token
|
||
from browser_pool_worker import init_browser_worker_pool, shutdown_browser_worker_pool
|
||
from realtime.socketio_handlers import register_socketio_handlers
|
||
from realtime.status_push import status_push_worker
|
||
from routes import register_blueprints
|
||
from security import init_security_middleware
|
||
from services.checkpoints import init_checkpoint_manager
|
||
from services.maintenance import start_cleanup_scheduler
|
||
from services.models import User
|
||
from services.runtime import init_runtime
|
||
from services.scheduler import scheduled_task_worker
|
||
from services.state import safe_iter_user_accounts_items
|
||
from services.tasks import get_task_scheduler
|
||
|
||
|
||
# ==================== 进程级基础设置 ====================
|
||
|
||
# 设置时区为中国标准时间(CST, UTC+8)
|
||
os.environ["TZ"] = "Asia/Shanghai"
|
||
try:
|
||
import time as _time
|
||
|
||
_time.tzset()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def _sigchld_handler(signum, frame):
|
||
"""SIGCHLD 信号处理器 - 自动回收僵尸子进程(Docker PID 1 场景)"""
|
||
while True:
|
||
try:
|
||
pid, _ = os.waitpid(-1, os.WNOHANG)
|
||
if pid == 0:
|
||
break
|
||
except ChildProcessError:
|
||
break
|
||
except Exception:
|
||
break
|
||
|
||
|
||
if os.name != "nt":
|
||
signal.signal(signal.SIGCHLD, _sigchld_handler)
|
||
|
||
|
||
# ==================== Flask / SocketIO 装配 ====================
|
||
|
||
config = get_config()
|
||
|
||
app = Flask(__name__)
|
||
app.config.from_object(config)
|
||
|
||
if not app.config.get("SECRET_KEY"):
|
||
raise RuntimeError("SECRET_KEY未配置,请检查 app_config.py 或环境变量")
|
||
|
||
cors_origins = os.environ.get("CORS_ALLOWED_ORIGINS", "").strip()
|
||
cors_allowed = [o.strip() for o in cors_origins.split(",") if o.strip()] if cors_origins else []
|
||
|
||
socketio = SocketIO(
|
||
app,
|
||
cors_allowed_origins=cors_allowed if cors_allowed else None,
|
||
async_mode="threading",
|
||
ping_timeout=60,
|
||
ping_interval=25,
|
||
logger=False,
|
||
engineio_logger=False,
|
||
)
|
||
|
||
init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE)
|
||
logger = get_logger("app")
|
||
init_runtime(socketio=socketio, logger=logger)
|
||
|
||
# 初始化安全中间件(需在其他中间件/Blueprint 之前注册)
|
||
init_security_middleware(app)
|
||
|
||
# 注册 Blueprint(路由不变)
|
||
register_blueprints(app)
|
||
|
||
# 注册 SocketIO 事件
|
||
register_socketio_handlers(socketio)
|
||
|
||
|
||
# ==================== Flask-Login 配置 ====================
|
||
|
||
login_manager = LoginManager()
|
||
login_manager.init_app(app)
|
||
login_manager.login_view = "pages.login_page"
|
||
|
||
|
||
@login_manager.user_loader
|
||
def load_user(user_id: str):
|
||
user = database.get_user_by_id(int(user_id))
|
||
if user:
|
||
return User(user["id"])
|
||
return None
|
||
|
||
|
||
@login_manager.unauthorized_handler
|
||
def unauthorized():
|
||
"""未授权访问:API 返回 JSON,页面重定向登录页(行为保持不变)。"""
|
||
if request.path.startswith("/api/") or request.path.startswith("/yuyx/api/"):
|
||
return jsonify({"error": "请先登录", "code": "unauthorized"}), 401
|
||
return redirect(url_for("pages.login_page", next=request.url))
|
||
|
||
|
||
@app.before_request
|
||
def enforce_csrf_protection():
|
||
if request.method in {"GET", "HEAD", "OPTIONS"}:
|
||
return
|
||
if request.path.startswith("/static/"):
|
||
return
|
||
if not (current_user.is_authenticated or "admin_id" in session):
|
||
return
|
||
token = request.headers.get("X-CSRF-Token") or request.form.get("csrf_token")
|
||
if not token or not validate_csrf_token(token):
|
||
return jsonify({"error": "CSRF token missing or invalid"}), 403
|
||
|
||
|
||
@app.after_request
|
||
def ensure_csrf_cookie(response):
|
||
if request.path.startswith("/static/"):
|
||
return response
|
||
token = session.get("csrf_token")
|
||
if not token:
|
||
token = generate_csrf_token()
|
||
response.set_cookie(
|
||
"csrf_token",
|
||
token,
|
||
httponly=False,
|
||
secure=bool(config.SESSION_COOKIE_SECURE),
|
||
samesite=config.SESSION_COOKIE_SAMESITE,
|
||
)
|
||
return response
|
||
|
||
|
||
# ==================== 静态文件(保持 endpoint 名称不变) ====================
|
||
|
||
|
||
@app.route("/static/<path:filename>")
|
||
def serve_static(filename):
|
||
if not is_safe_path("static", filename):
|
||
return jsonify({"error": "非法路径"}), 403
|
||
|
||
response = send_from_directory("static", filename)
|
||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
|
||
response.headers["Pragma"] = "no-cache"
|
||
response.headers["Expires"] = "0"
|
||
return response
|
||
|
||
|
||
# ==================== 退出清理 ====================
|
||
|
||
|
||
def cleanup_on_exit():
|
||
logger.info("正在清理资源...")
|
||
|
||
logger.info("- 停止运行中的任务...")
|
||
try:
|
||
for _, accounts in safe_iter_user_accounts_items():
|
||
for acc in accounts.values():
|
||
if getattr(acc, "is_running", False):
|
||
acc.should_stop = True
|
||
except Exception:
|
||
pass
|
||
|
||
logger.info("- 停止任务调度器...")
|
||
try:
|
||
scheduler = get_task_scheduler()
|
||
scheduler.shutdown(timeout=5)
|
||
except Exception:
|
||
pass
|
||
|
||
logger.info("- 关闭截图线程池...")
|
||
try:
|
||
shutdown_browser_worker_pool()
|
||
except Exception:
|
||
pass
|
||
|
||
logger.info("- 关闭邮件队列...")
|
||
try:
|
||
email_service.shutdown_email_queue()
|
||
except Exception:
|
||
pass
|
||
|
||
logger.info("- 关闭数据库连接池...")
|
||
try:
|
||
db_pool._pool.close_all() if db_pool._pool else None
|
||
except Exception:
|
||
pass
|
||
|
||
logger.info("✓ 资源清理完成")
|
||
|
||
|
||
# ==================== 启动入口(保持 python app.py 可用) ====================
|
||
|
||
|
||
def _signal_handler(sig, frame):
|
||
logger.info("收到退出信号,正在关闭...")
|
||
cleanup_on_exit()
|
||
sys.exit(0)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
atexit.register(cleanup_on_exit)
|
||
signal.signal(signal.SIGINT, _signal_handler)
|
||
signal.signal(signal.SIGTERM, _signal_handler)
|
||
|
||
logger.info("=" * 60)
|
||
logger.info("知识管理平台自动化工具 - 多用户版")
|
||
logger.info("=" * 60)
|
||
|
||
database.init_database()
|
||
init_checkpoint_manager()
|
||
logger.info("✓ 任务断点管理器已初始化")
|
||
|
||
# 【新增】容器重启时清理遗留的任务状态
|
||
logger.info("清理遗留任务状态...")
|
||
try:
|
||
from services.state import safe_remove_task, safe_get_active_task_ids, safe_remove_task_status
|
||
# 重置所有账号的运行状态
|
||
for _, accounts in safe_iter_user_accounts_items():
|
||
for acc in accounts.values():
|
||
if getattr(acc, "is_running", False):
|
||
acc.is_running = False
|
||
acc.should_stop = False
|
||
acc.status = "未开始"
|
||
# 清理活跃任务句柄
|
||
for account_id in list(safe_get_active_task_ids()):
|
||
safe_remove_task(account_id)
|
||
safe_remove_task_status(account_id)
|
||
logger.info("✓ 遗留任务状态已清理")
|
||
except Exception as e:
|
||
logger.warning(f"清理遗留任务状态失败: {e}")
|
||
|
||
try:
|
||
email_service.init_email_service()
|
||
logger.info("✓ 邮件服务已初始化")
|
||
except Exception as e:
|
||
logger.warning(f"警告: 邮件服务初始化失败: {e}")
|
||
|
||
start_cleanup_scheduler()
|
||
|
||
try:
|
||
system_config = database.get_system_config() or {}
|
||
max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
|
||
max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
|
||
get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
|
||
logger.info(f"✓ 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
|
||
except Exception as e:
|
||
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
|
||
|
||
logger.info("启动定时任务调度器...")
|
||
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
|
||
logger.info("✓ 定时任务调度器已启动")
|
||
|
||
logger.info("✓ 状态推送线程已启动(默认2秒/次)")
|
||
threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
|
||
|
||
logger.info("服务器启动中...")
|
||
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
|
||
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
|
||
logger.info("默认管理员: admin (首次运行随机密码见日志)")
|
||
logger.info("=" * 60)
|
||
|
||
try:
|
||
pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
|
||
except Exception:
|
||
pool_size = 3
|
||
try:
|
||
logger.info(f"初始化截图线程池({pool_size}个worker,按需启动执行环境,空闲5分钟后自动释放)...")
|
||
init_browser_worker_pool(pool_size=pool_size)
|
||
logger.info("✓ 截图线程池初始化完成")
|
||
except Exception as e:
|
||
logger.warning(f"警告: 截图线程池初始化失败: {e}")
|
||
|
||
# 预热 API 连接(后台进行,不阻塞启动)
|
||
logger.info("预热 API 连接...")
|
||
try:
|
||
from api_browser import warmup_api_connection
|
||
import threading
|
||
|
||
threading.Thread(
|
||
target=warmup_api_connection,
|
||
kwargs={"log_callback": lambda msg: logger.info(msg)},
|
||
daemon=True,
|
||
name="api-warmup",
|
||
).start()
|
||
except Exception as e:
|
||
logger.warning(f"API 预热失败: {e}")
|
||
|
||
socketio.run(
|
||
app,
|
||
host=config.SERVER_HOST,
|
||
port=config.SERVER_PORT,
|
||
debug=config.DEBUG,
|
||
allow_unsafe_werkzeug=True,
|
||
)
|