Files
zsglpt/tools/update_agent.py

607 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ZSGLPT Update-Agent宿主机运行
职责:
- 定期检查 Git 远端是否有新版本(写入 data/update/status.json
- 接收后台写入的 data/update/request.json 请求check/update
- 执行 git reset --hard origin/<branch> + docker compose build/up
- 更新前备份数据库 data/app_data.db
- 写入 data/update/result.json 与 data/update/jobs/<job_id>.log
仅使用标准库,便于在宿主机直接运行。
"""
from __future__ import annotations
import argparse
import fnmatch
import json
import os
import shutil
import subprocess
import sys
import time
import urllib.error
import urllib.request
import uuid
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple
def ts_str() -> str:
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def json_load(path: Path) -> Tuple[dict, Optional[str]]:
try:
with open(path, "r", encoding="utf-8") as f:
return dict(json.load(f) or {}), None
except FileNotFoundError:
return {}, None
except Exception as e:
return {}, f"{type(e).__name__}: {e}"
def json_dump_atomic(path: Path, data: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(f"{path.suffix}.tmp.{os.getpid()}.{int(time.time() * 1000)}")
with open(tmp, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2, sort_keys=True)
f.flush()
os.fsync(f.fileno())
os.replace(tmp, path)
def sanitize_job_id(value: object) -> str:
import re
text = str(value or "").strip()
if not text:
return f"job_{uuid.uuid4().hex[:8]}"
if not re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9_.-]{0,63}", text):
return f"job_{uuid.uuid4().hex[:8]}"
return text
def _as_bool(value: object) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, int):
return value != 0
text = str(value or "").strip().lower()
return text in ("1", "true", "yes", "y", "on")
def _run(cmd: list[str], *, cwd: Path, log_fp, env: Optional[dict] = None, check: bool = True) -> subprocess.CompletedProcess:
log_fp.write(f"[{ts_str()}] $ {' '.join(cmd)}\n")
log_fp.flush()
merged_env = os.environ.copy()
if env:
merged_env.update(env)
return subprocess.run(
cmd,
cwd=str(cwd),
env=merged_env,
stdout=log_fp,
stderr=log_fp,
text=True,
check=check,
)
def _git_rev_parse(ref: str, *, cwd: Path) -> str:
out = subprocess.check_output(["git", "rev-parse", ref], cwd=str(cwd), text=True).strip()
return out
def _git_has_tracked_changes(*, cwd: Path) -> bool:
"""是否存在 tracked 的未提交修改(含暂存区)。"""
for cmd in (["git", "diff", "--quiet"], ["git", "diff", "--cached", "--quiet"]):
proc = subprocess.run(cmd, cwd=str(cwd))
if proc.returncode == 1:
return True
if proc.returncode != 0:
raise RuntimeError(f"{' '.join(cmd)} failed with code {proc.returncode}")
return False
def _normalize_prefixes(prefixes: Tuple[str, ...]) -> Tuple[str, ...]:
normalized = []
for p in prefixes:
text = str(p or "").strip()
if not text:
continue
if not text.endswith("/"):
text += "/"
normalized.append(text)
return tuple(normalized)
def _git_has_untracked_changes(*, cwd: Path, ignore_prefixes: Tuple[str, ...]) -> Tuple[bool, int, list[str]]:
"""检查 untracked 文件(尊重 .gitignore并忽略指定前缀目录。"""
return _git_has_untracked_changes_v2(cwd=cwd, ignore_prefixes=ignore_prefixes, ignore_globs=())
def _normalize_globs(globs: Tuple[str, ...]) -> Tuple[str, ...]:
normalized = []
for g in globs:
text = str(g or "").strip()
if not text:
continue
normalized.append(text)
return tuple(normalized)
def _git_has_untracked_changes_v2(
*, cwd: Path, ignore_prefixes: Tuple[str, ...], ignore_globs: Tuple[str, ...]
) -> Tuple[bool, int, list[str]]:
"""检查 untracked 文件(尊重 .gitignore并忽略指定前缀目录/通配符。"""
ignore_prefixes = _normalize_prefixes(ignore_prefixes)
ignore_globs = _normalize_globs(ignore_globs)
out = subprocess.check_output(["git", "ls-files", "--others", "--exclude-standard"], cwd=str(cwd), text=True)
paths = [line.strip() for line in out.splitlines() if line.strip()]
filtered = []
for p in paths:
if ignore_prefixes and any(p.startswith(prefix) for prefix in ignore_prefixes):
continue
if ignore_globs and any(fnmatch.fnmatch(p, pattern) for pattern in ignore_globs):
continue
filtered.append(p)
samples = filtered[:20]
return (len(filtered) > 0), len(filtered), samples
def _git_is_dirty(
*,
cwd: Path,
ignore_untracked_prefixes: Tuple[str, ...] = ("data/",),
ignore_untracked_globs: Tuple[str, ...] = ("*.bak.*", "*.tmp.*", "*.backup.*"),
) -> dict:
"""
判断工作区是否“脏”:
- tracked 变更(含暂存区)一律算脏
- untracked 文件默认忽略 data/(运行时数据目录,避免后台长期提示)
"""
tracked_dirty = False
untracked_dirty = False
untracked_count = 0
untracked_samples: list[str] = []
try:
tracked_dirty = _git_has_tracked_changes(cwd=cwd)
except Exception:
# 若 diff 检测异常,回退到保守策略:认为脏
tracked_dirty = True
try:
untracked_dirty, untracked_count, untracked_samples = _git_has_untracked_changes_v2(
cwd=cwd, ignore_prefixes=ignore_untracked_prefixes, ignore_globs=ignore_untracked_globs
)
except Exception:
# 若 untracked 检测异常,回退到不影响更新:不计入 dirty
untracked_dirty = False
untracked_count = 0
untracked_samples = []
return {
"dirty": bool(tracked_dirty or untracked_dirty),
"dirty_tracked": bool(tracked_dirty),
"dirty_untracked": bool(untracked_dirty),
"dirty_ignore_untracked_prefixes": list(_normalize_prefixes(ignore_untracked_prefixes)),
"dirty_ignore_untracked_globs": list(_normalize_globs(ignore_untracked_globs)),
"untracked_count": int(untracked_count),
"untracked_samples": list(untracked_samples),
}
def _compose_cmd() -> list[str]:
# 优先使用 docker composev2
try:
subprocess.check_output(["docker", "compose", "version"], stderr=subprocess.STDOUT, text=True)
return ["docker", "compose"]
except Exception:
return ["docker-compose"]
def _http_healthcheck(url: str, *, timeout: float = 5.0) -> Tuple[bool, str]:
try:
req = urllib.request.Request(url, headers={"User-Agent": "zsglpt-update-agent/1.0"})
with urllib.request.urlopen(req, timeout=timeout) as resp:
code = int(getattr(resp, "status", 200) or 200)
if 200 <= code < 400:
return True, f"HTTP {code}"
return False, f"HTTP {code}"
except urllib.error.HTTPError as e:
return False, f"HTTPError {e.code}"
except Exception as e:
return False, f"{type(e).__name__}: {e}"
@dataclass
class Paths:
repo_dir: Path
data_dir: Path
update_dir: Path
status_path: Path
request_path: Path
result_path: Path
jobs_dir: Path
def build_paths(repo_dir: Path, data_dir: Optional[Path] = None) -> Paths:
repo_dir = repo_dir.resolve()
data_dir = (data_dir or (repo_dir / "data")).resolve()
update_dir = data_dir / "update"
return Paths(
repo_dir=repo_dir,
data_dir=data_dir,
update_dir=update_dir,
status_path=update_dir / "status.json",
request_path=update_dir / "request.json",
result_path=update_dir / "result.json",
jobs_dir=update_dir / "jobs",
)
def ensure_dirs(paths: Paths) -> None:
paths.jobs_dir.mkdir(parents=True, exist_ok=True)
def check_updates(*, paths: Paths, branch: str, log_fp=None) -> dict:
env = {"GIT_TERMINAL_PROMPT": "0"}
err = ""
local = ""
remote = ""
dirty_info: dict = {}
try:
if log_fp:
_run(["git", "fetch", "origin", branch], cwd=paths.repo_dir, log_fp=log_fp, env=env)
else:
subprocess.run(["git", "fetch", "origin", branch], cwd=str(paths.repo_dir), env={**os.environ, **env}, check=True)
local = _git_rev_parse("HEAD", cwd=paths.repo_dir)
remote = _git_rev_parse(f"origin/{branch}", cwd=paths.repo_dir)
dirty_info = _git_is_dirty(cwd=paths.repo_dir, ignore_untracked_prefixes=("data/",))
except Exception as e:
err = f"{type(e).__name__}: {e}"
update_available = bool(local and remote and local != remote) if not err else False
return {
"branch": branch,
"checked_at": ts_str(),
"local_commit": local,
"remote_commit": remote,
"update_available": update_available,
**(dirty_info or {"dirty": False}),
"error": err,
}
def backup_db(*, paths: Paths, log_fp, keep: int = 20) -> str:
db_path = paths.data_dir / "app_data.db"
backups_dir = paths.data_dir / "backups"
backups_dir.mkdir(parents=True, exist_ok=True)
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = backups_dir / f"app_data.db.{stamp}.bak"
if db_path.exists():
log_fp.write(f"[{ts_str()}] backup db: {db_path} -> {backup_path}\n")
log_fp.flush()
shutil.copy2(db_path, backup_path)
else:
log_fp.write(f"[{ts_str()}] backup skipped: db not found: {db_path}\n")
log_fp.flush()
# 简单保留策略:按文件名排序,保留最近 keep 个
try:
items = sorted([p for p in backups_dir.glob("app_data.db.*.bak") if p.is_file()], key=lambda p: p.name)
if len(items) > keep:
for p in items[: len(items) - keep]:
try:
p.unlink()
except Exception:
pass
except Exception:
pass
return str(backup_path)
def write_result(paths: Paths, data: dict) -> None:
json_dump_atomic(paths.result_path, data)
def consume_request(paths: Paths) -> Tuple[dict, Optional[str]]:
data, err = json_load(paths.request_path)
if err:
# 避免解析失败导致死循环:将坏文件移走
try:
bad_name = f"request.bad.{datetime.now().strftime('%Y%m%d_%H%M%S')}.{uuid.uuid4().hex[:6]}.json"
bad_path = paths.update_dir / bad_name
paths.request_path.rename(bad_path)
except Exception:
try:
paths.request_path.unlink(missing_ok=True) # type: ignore[arg-type]
except Exception:
pass
return {}, err
if not data:
return {}, None
try:
paths.request_path.unlink(missing_ok=True) # type: ignore[arg-type]
except Exception:
try:
os.remove(paths.request_path)
except Exception:
pass
return data, None
def handle_update_job(
*,
paths: Paths,
branch: str,
health_url: str,
job_id: str,
requested_by: str,
build_no_cache: bool = False,
build_pull: bool = False,
) -> None:
ensure_dirs(paths)
log_path = paths.jobs_dir / f"{job_id}.log"
with open(log_path, "a", encoding="utf-8") as log_fp:
log_fp.write(f"[{ts_str()}] job start: {job_id}, branch={branch}, by={requested_by}\n")
log_fp.flush()
result: Dict[str, object] = {
"job_id": job_id,
"action": "update",
"status": "running",
"stage": "start",
"message": "",
"started_at": ts_str(),
"finished_at": None,
"duration_seconds": None,
"requested_by": requested_by,
"branch": branch,
"build_no_cache": bool(build_no_cache),
"build_pull": bool(build_pull),
"from_commit": None,
"to_commit": None,
"backup_db": None,
"health_url": health_url,
"health_ok": None,
"health_message": None,
"error": "",
}
write_result(paths, result)
start_ts = time.time()
try:
result["stage"] = "backup"
result["message"] = "备份数据库"
write_result(paths, result)
result["backup_db"] = backup_db(paths=paths, log_fp=log_fp)
result["stage"] = "git_fetch"
result["message"] = "拉取远端代码"
write_result(paths, result)
_run(["git", "fetch", "origin", branch], cwd=paths.repo_dir, log_fp=log_fp, env={"GIT_TERMINAL_PROMPT": "0"})
from_commit = _git_rev_parse("HEAD", cwd=paths.repo_dir)
result["from_commit"] = from_commit
result["stage"] = "git_reset"
result["message"] = f"切换到 origin/{branch}"
write_result(paths, result)
_run(["git", "reset", "--hard", f"origin/{branch}"], cwd=paths.repo_dir, log_fp=log_fp, env={"GIT_TERMINAL_PROMPT": "0"})
to_commit = _git_rev_parse("HEAD", cwd=paths.repo_dir)
result["to_commit"] = to_commit
compose = _compose_cmd()
result["stage"] = "docker_build"
result["message"] = "构建容器镜像"
write_result(paths, result)
build_no_cache = bool(result.get("build_no_cache") is True)
build_pull = bool(result.get("build_pull") is True)
build_cmd = [*compose, "build"]
if build_pull:
build_cmd.append("--pull")
if build_no_cache:
build_cmd.append("--no-cache")
try:
_run(build_cmd, cwd=paths.repo_dir, log_fp=log_fp)
except subprocess.CalledProcessError as e:
if (not build_no_cache) and (e.returncode != 0):
log_fp.write(f"[{ts_str()}] build failed, retry with --no-cache\n")
log_fp.flush()
build_no_cache = True
result["build_no_cache"] = True
write_result(paths, result)
retry_cmd = [*compose, "build"]
if build_pull:
retry_cmd.append("--pull")
retry_cmd.append("--no-cache")
_run(retry_cmd, cwd=paths.repo_dir, log_fp=log_fp)
else:
raise
result["stage"] = "docker_up"
result["message"] = "重建并启动服务"
write_result(paths, result)
_run([*compose, "up", "-d", "--force-recreate"], cwd=paths.repo_dir, log_fp=log_fp)
result["stage"] = "health_check"
result["message"] = "健康检查"
write_result(paths, result)
ok = False
health_msg = ""
deadline = time.time() + 180
while time.time() < deadline:
ok, health_msg = _http_healthcheck(health_url, timeout=5.0)
if ok:
break
time.sleep(3)
result["health_ok"] = ok
result["health_message"] = health_msg
if not ok:
raise RuntimeError(f"healthcheck failed: {health_msg}")
result["status"] = "success"
result["stage"] = "done"
result["message"] = "更新完成"
except Exception as e:
result["status"] = "failed"
result["error"] = f"{type(e).__name__}: {e}"
result["stage"] = result.get("stage") or "failed"
result["message"] = "更新失败"
log_fp.write(f"[{ts_str()}] ERROR: {result['error']}\n")
log_fp.flush()
finally:
result["finished_at"] = ts_str()
result["duration_seconds"] = int(time.time() - start_ts)
write_result(paths, result)
# 更新 status成功/失败都尽量写一份最新状态)
try:
status = check_updates(paths=paths, branch=branch, log_fp=log_fp)
json_dump_atomic(paths.status_path, status)
except Exception:
pass
log_fp.write(f"[{ts_str()}] job end: {job_id}\n")
log_fp.flush()
def handle_check_job(*, paths: Paths, branch: str, job_id: str, requested_by: str) -> None:
ensure_dirs(paths)
log_path = paths.jobs_dir / f"{job_id}.log"
with open(log_path, "a", encoding="utf-8") as log_fp:
log_fp.write(f"[{ts_str()}] job start: {job_id}, action=check, branch={branch}, by={requested_by}\n")
log_fp.flush()
status = check_updates(paths=paths, branch=branch, log_fp=log_fp)
json_dump_atomic(paths.status_path, status)
log_fp.write(f"[{ts_str()}] job end: {job_id}\n")
log_fp.flush()
def main(argv: list[str]) -> int:
parser = argparse.ArgumentParser(description="ZSGLPT Update-Agent (host)")
parser.add_argument("--repo-dir", default=".", help="部署仓库目录(包含 docker-compose.yml")
parser.add_argument("--data-dir", default="", help="数据目录(默认 <repo>/data")
parser.add_argument("--branch", default="master", help="允许更新的分支名(默认 master")
parser.add_argument("--health-url", default="http://127.0.0.1:51232/", help="更新后健康检查URL")
parser.add_argument("--check-interval-seconds", type=int, default=300, help="自动检查更新间隔(秒)")
parser.add_argument("--poll-seconds", type=int, default=5, help="轮询 request.json 的间隔(秒)")
args = parser.parse_args(argv)
repo_dir = Path(args.repo_dir).resolve()
if not (repo_dir / "docker-compose.yml").exists():
print(f"[fatal] docker-compose.yml not found in {repo_dir}", file=sys.stderr)
return 2
if not (repo_dir / ".git").exists():
print(f"[fatal] .git not found in {repo_dir} (need git repo)", file=sys.stderr)
return 2
data_dir = Path(args.data_dir).resolve() if args.data_dir else None
paths = build_paths(repo_dir, data_dir=data_dir)
ensure_dirs(paths)
last_check_ts = 0.0
check_interval = max(30, int(args.check_interval_seconds))
poll_seconds = max(2, int(args.poll_seconds))
branch = str(args.branch or "master").strip()
health_url = str(args.health_url or "").strip()
# 启动时先写一次状态,便于后台立即看到
try:
status = check_updates(paths=paths, branch=branch)
json_dump_atomic(paths.status_path, status)
last_check_ts = time.time()
except Exception:
pass
while True:
try:
# 1) 优先处理 request
req, err = consume_request(paths)
if err:
# request 文件损坏:写入 result 便于后台看到
write_result(
paths,
{
"job_id": f"badreq_{uuid.uuid4().hex[:8]}",
"action": "unknown",
"status": "failed",
"stage": "parse_request",
"message": "request.json 解析失败",
"error": err,
"started_at": ts_str(),
"finished_at": ts_str(),
},
)
elif req:
action = str(req.get("action") or "").strip().lower()
job_id = sanitize_job_id(req.get("job_id"))
requested_by = str(req.get("requested_by") or "")
# 只允许固定分支,避免被注入/误操作
if action not in ("check", "update"):
write_result(
paths,
{
"job_id": job_id,
"action": action,
"status": "failed",
"stage": "validate",
"message": "不支持的 action",
"error": f"unsupported action: {action}",
"started_at": ts_str(),
"finished_at": ts_str(),
},
)
elif action == "check":
handle_check_job(paths=paths, branch=branch, job_id=job_id, requested_by=requested_by)
else:
build_no_cache = _as_bool(req.get("build_no_cache") or req.get("no_cache") or False)
build_pull = _as_bool(req.get("build_pull") or req.get("pull") or False)
handle_update_job(
paths=paths,
branch=branch,
health_url=health_url,
job_id=job_id,
requested_by=requested_by,
build_no_cache=build_no_cache,
build_pull=build_pull,
)
last_check_ts = time.time()
# 2) 周期性 check
now = time.time()
if now - last_check_ts >= check_interval:
try:
status = check_updates(paths=paths, branch=branch)
json_dump_atomic(paths.status_path, status)
except Exception:
pass
last_check_ts = now
time.sleep(poll_seconds)
except KeyboardInterrupt:
return 0
except Exception:
time.sleep(2)
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))