Files
zsglpt/app.py

509 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
知识管理平台自动化工具 - 多用户版本
说明P0P3 优化后):
- `app.py` 仅保留启动/装配层
- 路由拆分到 `routes/`Blueprint
- 业务拆分到 `services/`tasks/screenshots/scheduler/...
- SocketIO 事件拆分到 `realtime/`
"""
from __future__ import annotations
import atexit
import os
import re
import signal
import sys
import threading
import time
from flask import Flask, g, jsonify, redirect, request, send_from_directory, session, url_for
from flask_login import LoginManager, current_user
from flask_socketio import SocketIO
import database
import db_pool
import email_service
from app_config import get_config
from app_logger import get_logger, init_logging
from app_security import generate_csrf_token, is_safe_path, validate_csrf_token
from browser_pool_worker import init_browser_worker_pool, shutdown_browser_worker_pool
from realtime.socketio_handlers import register_socketio_handlers
from realtime.status_push import status_push_worker
from routes import register_blueprints
from security import init_security_middleware
from services.checkpoints import init_checkpoint_manager
from services.maintenance import start_cleanup_scheduler, start_database_maintenance_scheduler, start_kdocs_monitor
from services.request_metrics import record_request_metric
from services.models import User
from services.runtime import init_runtime
from services.scheduler import scheduled_task_worker
from services.state import safe_iter_user_accounts_items
from services.tasks import get_task_scheduler
# ==================== 进程级基础设置 ====================
# 设置时区为中国标准时间CST, UTC+8
os.environ["TZ"] = "Asia/Shanghai"
_TZSET_ERROR = None
try:
import time as _time
_time.tzset()
except Exception as e:
_TZSET_ERROR = e
def _sigchld_handler(signum, frame):
"""SIGCHLD 信号处理器 - 自动回收僵尸子进程Docker PID 1 场景)"""
while True:
try:
pid, _ = os.waitpid(-1, os.WNOHANG)
if pid == 0:
break
except ChildProcessError:
break
except Exception:
break
if os.name != "nt":
signal.signal(signal.SIGCHLD, _sigchld_handler)
# ==================== Flask / SocketIO 装配 ====================
config = get_config()
app = Flask(__name__)
app.config.from_object(config)
if not app.config.get("SECRET_KEY"):
raise RuntimeError("SECRET_KEY未配置请检查 app_config.py 或环境变量")
cors_origins = os.environ.get("CORS_ALLOWED_ORIGINS", "").strip()
cors_allowed = [o.strip() for o in cors_origins.split(",") if o.strip()] if cors_origins else []
_socketio_preferred_mode = (os.environ.get("SOCKETIO_ASYNC_MODE", "eventlet") or "").strip().lower()
if _socketio_preferred_mode in {"", "auto"}:
_socketio_preferred_mode = None
_socketio_fallback_reason = None
try:
socketio = SocketIO(
app,
cors_allowed_origins=cors_allowed if cors_allowed else None,
async_mode=_socketio_preferred_mode,
ping_timeout=60,
ping_interval=25,
logger=False,
engineio_logger=False,
)
except Exception as socketio_error:
_socketio_fallback_reason = str(socketio_error)
socketio = SocketIO(
app,
cors_allowed_origins=cors_allowed if cors_allowed else None,
async_mode="threading",
ping_timeout=60,
ping_interval=25,
logger=False,
engineio_logger=False,
)
init_logging(log_level=config.LOG_LEVEL, log_file=config.LOG_FILE)
logger = get_logger("app")
if _TZSET_ERROR is not None:
logger.warning(f"设置时区失败,将继续使用系统默认时区: {_TZSET_ERROR}")
if _socketio_fallback_reason:
logger.warning(f"[SocketIO] 初始化失败,已回退 threading 模式: {_socketio_fallback_reason}")
logger.info(f"[SocketIO] 当前 async_mode: {socketio.async_mode}")
init_runtime(socketio=socketio, logger=logger)
_API_DIAGNOSTIC_LOG = str(os.environ.get("API_DIAGNOSTIC_LOG", "0")).strip().lower() in {
"1",
"true",
"yes",
"on",
}
_API_DIAGNOSTIC_SLOW_MS = max(0.0, float(os.environ.get("API_DIAGNOSTIC_SLOW_MS", "0") or 0.0))
def _is_api_or_health_path(path: str) -> bool:
raw = str(path or "")
return raw.startswith("/api/") or raw.startswith("/yuyx/api/") or raw == "/health"
def _request_uses_https() -> bool:
try:
if bool(request.is_secure):
return True
except Exception as e:
logger.debug(f"检查 request.is_secure 失败: {e}")
try:
forwarded_proto = str(request.headers.get("X-Forwarded-Proto", "") or "").split(",", 1)[0].strip().lower()
if forwarded_proto == "https":
return True
except Exception as e:
logger.debug(f"检查 X-Forwarded-Proto 失败: {e}")
return False
_SECURITY_RESPONSE_HEADERS = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "SAMEORIGIN",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "camera=(), microphone=(), geolocation=(), payment=()",
}
_SECURITY_CSP_HEADER = str(os.environ.get("SECURITY_CONTENT_SECURITY_POLICY", "") or "").strip()
_HASHED_STATIC_ASSET_RE = re.compile(r".*-[a-z0-9_-]{8,}\.(?:js|css|woff2?|ttf|svg|png|jpe?g|webp)$", re.IGNORECASE)
# 初始化安全中间件(需在其他中间件/Blueprint 之前注册)
init_security_middleware(app)
# 注册 Blueprint路由不变
register_blueprints(app)
# 注册 SocketIO 事件
register_socketio_handlers(socketio)
# ==================== Flask-Login 配置 ====================
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = "pages.login_page"
@login_manager.user_loader
def load_user(user_id: str):
user = database.get_user_by_id(int(user_id))
if user:
return User(user["id"])
return None
@login_manager.unauthorized_handler
def unauthorized():
"""未授权访问API 返回 JSON页面重定向登录页行为保持不变"""
if request.path.startswith("/api/") or request.path.startswith("/yuyx/api/"):
return jsonify({"error": "请先登录", "code": "unauthorized"}), 401
return redirect(url_for("pages.login_page", next=request.url))
@app.before_request
def track_request_start_time():
g.request_start_perf = time.perf_counter()
@app.before_request
def enforce_csrf_protection():
if request.method in {"GET", "HEAD", "OPTIONS"}:
return
if request.path.startswith("/static/"):
return
# 登录相关路由豁免 CSRF 检查(登录本身就是建立 session 的过程)
csrf_exempt_paths = {
"/yuyx/api/login",
"/api/login",
"/api/auth/login",
"/yuyx/api/passkeys/login/options",
"/yuyx/api/passkeys/login/verify",
"/api/passkeys/login/options",
"/api/passkeys/login/verify",
}
if request.path in csrf_exempt_paths:
return
if not (current_user.is_authenticated or "admin_id" in session):
return
token = request.headers.get("X-CSRF-Token") or request.form.get("csrf_token")
if not token or not validate_csrf_token(token):
return jsonify({"error": "CSRF token missing or invalid"}), 403
def _record_request_metric_after_response(response) -> None:
try:
started = float(getattr(g, "request_start_perf", 0.0) or 0.0)
if started <= 0:
return
duration_ms = max(0.0, (time.perf_counter() - started) * 1000.0)
path = request.path or "/"
method = request.method or "GET"
status_code = int(getattr(response, "status_code", 0) or 0)
is_api = _is_api_or_health_path(path)
record_request_metric(
path=path,
method=method,
status_code=status_code,
duration_ms=duration_ms,
is_api=is_api,
)
if _API_DIAGNOSTIC_LOG and is_api:
is_slow = _API_DIAGNOSTIC_SLOW_MS > 0 and duration_ms >= _API_DIAGNOSTIC_SLOW_MS
is_server_error = status_code >= 500
if is_slow or is_server_error:
logger.warning(
f"[API-DIAG] {method} {path} -> {status_code} ({duration_ms:.1f}ms)"
)
except Exception as e:
logger.debug(f"记录请求指标失败: {e}")
@app.after_request
def ensure_csrf_cookie(response):
if not request.path.startswith("/static/"):
token = session.get("csrf_token")
if not token:
token = generate_csrf_token()
response.set_cookie(
"csrf_token",
token,
httponly=False,
secure=bool(config.SESSION_COOKIE_SECURE),
samesite=config.SESSION_COOKIE_SAMESITE,
)
for header_name, header_value in _SECURITY_RESPONSE_HEADERS.items():
response.headers.setdefault(header_name, header_value)
if _request_uses_https():
response.headers.setdefault("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
if _SECURITY_CSP_HEADER:
response.headers.setdefault("Content-Security-Policy", _SECURITY_CSP_HEADER)
_record_request_metric_after_response(response)
return response
# ==================== 静态文件(保持 endpoint 名称不变) ====================
@app.route("/static/<path:filename>")
def serve_static(filename):
if not is_safe_path("static", filename):
return jsonify({"error": "非法路径"}), 403
lowered = filename.lower()
is_asset_file = "/assets/" in lowered or lowered.endswith((".js", ".css", ".woff", ".woff2", ".ttf", ".svg"))
is_hashed_asset = bool(_HASHED_STATIC_ASSET_RE.match(lowered))
cache_ttl = 3600
if is_asset_file:
cache_ttl = 604800 # 7天
if is_hashed_asset:
cache_ttl = 31536000 # 365天
if request.args.get("v"):
cache_ttl = max(cache_ttl, 604800)
response = send_from_directory("static", filename, max_age=cache_ttl, conditional=True)
# 协商缓存:确保存在 ETag并基于 If-None-Match/If-Modified-Since 返回 304
try:
response.add_etag(overwrite=False)
except Exception as e:
logger.debug(f"静态资源 ETag 设置失败({filename}): {e}")
try:
response.make_conditional(request)
except Exception as e:
logger.debug(f"静态资源协商缓存处理失败({filename}): {e}")
response.headers.setdefault("Vary", "Accept-Encoding")
if is_hashed_asset:
response.headers["Cache-Control"] = f"public, max-age={cache_ttl}, immutable"
elif is_asset_file:
response.headers["Cache-Control"] = f"public, max-age={cache_ttl}, stale-while-revalidate=60"
else:
response.headers["Cache-Control"] = f"public, max-age={cache_ttl}"
return response
# ==================== 退出清理 ====================
def cleanup_on_exit():
logger.info("正在清理资源...")
logger.info("- 停止运行中的任务...")
try:
for _, accounts in safe_iter_user_accounts_items():
for acc in accounts.values():
if getattr(acc, "is_running", False):
acc.should_stop = True
except Exception as e:
logger.warning(f"停止运行中任务失败: {e}")
logger.info("- 停止任务调度器...")
try:
scheduler = get_task_scheduler()
scheduler.shutdown(timeout=5)
except Exception as e:
logger.warning(f"停止任务调度器失败: {e}")
logger.info("- 关闭截图线程池...")
try:
shutdown_browser_worker_pool()
except Exception as e:
logger.warning(f"关闭截图线程池失败: {e}")
logger.info("- 关闭邮件队列...")
try:
email_service.shutdown_email_queue()
except Exception as e:
logger.warning(f"关闭邮件队列失败: {e}")
logger.info("- 关闭数据库连接池...")
try:
db_pool._pool.close_all() if db_pool._pool else None
except Exception as e:
logger.warning(f"关闭数据库连接池失败: {e}")
logger.info("[OK] 资源清理完成")
# ==================== 启动入口(保持 python app.py 可用) ====================
def _signal_handler(sig, frame):
logger.info("收到退出信号,正在关闭...")
cleanup_on_exit()
sys.exit(0)
def _cleanup_stale_task_state() -> None:
logger.info("清理遗留任务状态...")
try:
from services.state import safe_get_active_task_ids, safe_remove_task, safe_remove_task_status
for _, accounts in safe_iter_user_accounts_items():
for acc in accounts.values():
if not getattr(acc, "is_running", False):
continue
acc.is_running = False
acc.should_stop = False
acc.status = "未开始"
for account_id in list(safe_get_active_task_ids()):
safe_remove_task(account_id)
safe_remove_task_status(account_id)
logger.info("[OK] 遗留任务状态已清理")
except Exception as e:
logger.warning(f"清理遗留任务状态失败: {e}")
def _init_optional_email_service() -> None:
try:
email_service.init_email_service()
logger.info("[OK] 邮件服务已初始化")
except Exception as e:
logger.warning(f"警告: 邮件服务初始化失败: {e}")
def _load_and_apply_scheduler_limits() -> None:
try:
system_config = database.get_system_config() or {}
max_concurrent_global = int(system_config.get("max_concurrent_global", config.MAX_CONCURRENT_GLOBAL))
max_concurrent_per_account = int(system_config.get("max_concurrent_per_account", config.MAX_CONCURRENT_PER_ACCOUNT))
get_task_scheduler().update_limits(max_global=max_concurrent_global, max_per_user=max_concurrent_per_account)
logger.info(f"[OK] 已加载并发配置: 全局={max_concurrent_global}, 单账号={max_concurrent_per_account}")
except Exception as e:
logger.warning(f"警告: 加载并发配置失败,使用默认值: {e}")
def _start_background_workers() -> None:
logger.info("启动定时任务调度器...")
threading.Thread(target=scheduled_task_worker, daemon=True, name="scheduled-task-worker").start()
logger.info("[OK] 定时任务调度器已启动")
logger.info("[OK] 状态推送线程已启动默认2秒/次)")
threading.Thread(target=status_push_worker, daemon=True, name="status-push-worker").start()
def _init_screenshot_worker_pool() -> None:
try:
pool_size = int((database.get_system_config() or {}).get("max_screenshot_concurrent", 3))
except Exception:
pool_size = 3
try:
logger.info(f"初始化截图线程池({pool_size}个worker按需启动执行环境空闲5分钟后自动释放...")
init_browser_worker_pool(pool_size=pool_size)
logger.info("[OK] 截图线程池初始化完成")
except Exception as e:
logger.warning(f"警告: 截图线程池初始化失败: {e}")
def _warmup_api_connection() -> None:
logger.info("预热 API 连接...")
try:
from api_browser import warmup_api_connection
threading.Thread(
target=warmup_api_connection,
kwargs={"log_callback": lambda msg: logger.info(msg)},
daemon=True,
name="api-warmup",
).start()
except Exception as e:
logger.warning(f"API 预热失败: {e}")
def _log_startup_urls() -> None:
logger.info("服务器启动中...")
logger.info(f"用户访问地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}")
logger.info(f"后台管理地址: http://{config.SERVER_HOST}:{config.SERVER_PORT}/yuyx")
logger.info("默认管理员: admin (首次运行密码写入 data/default_admin_credentials.txt)")
logger.info("=" * 60)
if __name__ == "__main__":
atexit.register(cleanup_on_exit)
signal.signal(signal.SIGINT, _signal_handler)
signal.signal(signal.SIGTERM, _signal_handler)
logger.info("=" * 60)
logger.info("知识管理平台自动化工具 - 多用户版")
logger.info("=" * 60)
database.init_database()
init_checkpoint_manager()
logger.info("[OK] 任务断点管理器已初始化")
_cleanup_stale_task_state()
_init_optional_email_service()
start_cleanup_scheduler()
start_database_maintenance_scheduler()
start_kdocs_monitor()
_load_and_apply_scheduler_limits()
_start_background_workers()
_log_startup_urls()
_init_screenshot_worker_pool()
_warmup_api_connection()
run_kwargs = {
"host": config.SERVER_HOST,
"port": config.SERVER_PORT,
"debug": config.DEBUG,
}
if str(socketio.async_mode) == "threading":
run_kwargs["allow_unsafe_werkzeug"] = True
socketio.run(app, **run_kwargs)