Files
xb/app/services/workflows.py

449 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import uuid
from http import HTTPStatus
from typing import Any
from repositories.history_repository import (
append_generated_history,
load_history_for_config,
upsert_generated_history,
)
def _normalize_key_fields(value: Any) -> list[str]:
key_fields = value if isinstance(value, list) else []
if not key_fields:
key_fields = ["branch", "amount", "type"]
return [str(x) for x in key_fields]
def run_parse_api(payload: dict[str, Any], ctx: dict[str, Any]) -> tuple[HTTPStatus, dict[str, Any]]:
normalize_insurance_year = ctx["normalize_insurance_year"]
normalize_insurance_year_choices = ctx["normalize_insurance_year_choices"]
load_config = ctx["load_config"]
parse_records = ctx["parse_records"]
log_parse_skipped = ctx["log_parse_skipped"]
append_review_log = ctx["append_review_log"]
raw_text = str(payload.get("raw_text", ""))
try:
insurance_year_choice = normalize_insurance_year(payload.get("insurance_year"))
insurance_year_choices = normalize_insurance_year_choices(payload.get("insurance_year_choices"))
config = load_config()
_, history = load_history_for_config(config, ctx)
result = parse_records(raw_text, config, history, insurance_year_choice, insurance_year_choices)
log_parse_skipped(result.get("skipped", []), source="api_parse")
return HTTPStatus.OK, {"ok": True, "result": result}
except ValueError as exc:
return HTTPStatus.BAD_REQUEST, {"ok": False, "error": str(exc)}
except Exception as exc:
append_review_log(
"parse_api_error",
{
"error": str(exc),
"raw_text": raw_text,
},
)
return HTTPStatus.INTERNAL_SERVER_ERROR, {"ok": False, "error": str(exc)}
def run_generate_api(payload: dict[str, Any], ctx: dict[str, Any]) -> tuple[HTTPStatus, dict[str, Any]]:
normalize_insurance_year = ctx["normalize_insurance_year"]
normalize_insurance_year_choices = ctx["normalize_insurance_year_choices"]
load_config = ctx["load_config"]
resolve_template_path = ctx["resolve_template_path"]
resolve_output_dir = ctx["resolve_output_dir"]
parse_records = ctx["parse_records"]
generate_records = ctx["generate_records"]
set_generation_progress = ctx["set_generation_progress"]
append_review_log = ctx["append_review_log"]
log_parse_skipped = ctx["log_parse_skipped"]
resolve_single_generation_mode = ctx.get("resolve_single_generation_mode")
acquire_generation_slot = ctx.get("acquire_generation_slot")
release_generation_slot = ctx.get("release_generation_slot")
raw_text = str(payload.get("raw_text", ""))
progress_token = str(payload.get("progress_token", "")).strip() or uuid.uuid4().hex
slot_acquired = False
try:
template_override = str(payload.get("template_file", "")).strip() or None
output_override = str(payload.get("output_dir", "")).strip() or None
insurance_year_choice = normalize_insurance_year(payload.get("insurance_year"))
insurance_year_choices = normalize_insurance_year_choices(payload.get("insurance_year_choices"))
save_history_flag = bool(payload.get("save_history", True))
set_generation_progress(
progress_token,
status="running",
stage="接收请求",
percent=1,
detail="已收到生成请求",
)
config = load_config()
single_mode = bool(resolve_single_generation_mode(config)) if callable(resolve_single_generation_mode) else True
if single_mode and callable(acquire_generation_slot):
slot_acquired = bool(acquire_generation_slot(progress_token))
if not slot_acquired:
set_generation_progress(
progress_token,
status="busy",
stage="系统繁忙",
percent=0,
detail="已有任务在生成,请稍后重试",
)
return (
HTTPStatus.TOO_MANY_REQUESTS,
{
"ok": False,
"error": "generate_busy",
"error_code": "generate_busy",
"message": "已有任务在生成,请稍后再试。",
"progress_token": progress_token,
},
)
history_path, history = load_history_for_config(config, ctx)
set_generation_progress(
progress_token,
status="running",
stage="解析文本",
percent=8,
detail="正在解析接龙内容",
)
parse_result = parse_records(
raw_text,
config,
history,
insurance_year_choice,
insurance_year_choices,
)
log_parse_skipped(parse_result.get("skipped", []), source="api_generate")
if parse_result.get("needs_insurance_choice") and insurance_year_choice is None:
set_generation_progress(
progress_token,
status="need_input",
stage="等待选择",
percent=15,
detail="检测到保险记录等待选择3年交/5年交",
)
return (
HTTPStatus.BAD_REQUEST,
{
"ok": False,
"error": "insurance_year_required",
"error_code": "insurance_year_required",
"result": parse_result,
"options": ["3", "5"],
"message": "检测到保险记录但未指定年限请逐条选择3年交或5年交。",
"progress_token": progress_token,
},
)
new_records = parse_result.get("new_records", [])
if not isinstance(new_records, list):
new_records = []
if not new_records:
set_generation_progress(
progress_token,
status="done",
stage="完成",
percent=100,
detail="没有可生成的新记录",
)
return (
HTTPStatus.OK,
{
"ok": True,
"message": "没有可生成的新记录",
"result": parse_result,
"generated_count": 0,
"progress_token": progress_token,
},
)
set_generation_progress(
progress_token,
status="running",
stage="准备模板",
percent=12,
detail=f"待生成 {len(new_records)}",
)
template_path = resolve_template_path(config, template_override)
output_dir = resolve_output_dir(config, output_override)
def on_progress(percent: int, stage: str, detail: str) -> None:
set_generation_progress(
progress_token,
status="running",
stage=stage,
percent=percent,
detail=detail,
)
gen_result = generate_records(
new_records,
config,
template_path,
output_dir,
progress_cb=on_progress,
)
history_stat = None
if save_history_flag:
set_generation_progress(
progress_token,
status="running",
stage="更新历史",
percent=96,
detail="写入历史记录",
)
key_fields = _normalize_key_fields(parse_result.get("dedup_key_fields", ["branch", "amount", "type"]))
history_stat = append_generated_history(
history_path=history_path,
generated_items=gen_result.get("generated", []),
key_fields=key_fields,
ctx=ctx,
)
set_generation_progress(
progress_token,
status="done",
stage="完成",
percent=100,
detail=f"已生成 {gen_result.get('generated_count', 0)}",
)
return (
HTTPStatus.OK,
{
"ok": True,
"message": "生成完成",
"result": parse_result,
"generated_count": gen_result.get("generated_count", 0),
"generated": gen_result.get("generated", []),
"download_images": gen_result.get("download_images", []),
"generation_strategy": gen_result.get("generation_strategy", "legacy"),
"history": history_stat,
"progress_token": progress_token,
},
)
except ValueError as exc:
set_generation_progress(
progress_token,
status="error",
stage="失败",
percent=100,
detail="请求参数错误",
error=str(exc),
)
return HTTPStatus.BAD_REQUEST, {"ok": False, "error": str(exc)}
except Exception as exc:
append_review_log(
"generate_api_error",
{
"error": str(exc),
"raw_text": raw_text,
},
)
set_generation_progress(
progress_token,
status="error",
stage="失败",
percent=100,
detail="生成过程异常",
error=str(exc),
)
return HTTPStatus.INTERNAL_SERVER_ERROR, {"ok": False, "error": str(exc)}
finally:
if slot_acquired and callable(release_generation_slot):
release_generation_slot(progress_token)
def run_correction_apply_api(payload: dict[str, Any], ctx: dict[str, Any]) -> tuple[HTTPStatus, dict[str, Any]]:
load_config = ctx["load_config"]
resolve_template_path = ctx["resolve_template_path"]
resolve_output_dir = ctx["resolve_output_dir"]
resolve_history_path = ctx["resolve_history_path"]
normalize_line = ctx["normalize_line"]
normalize_branch_value = ctx["normalize_branch_value"]
normalize_amount_text = ctx["normalize_amount_text"]
normalize_status_value = ctx["normalize_status_value"]
infer_page_from_type = ctx["infer_page_from_type"]
apply_record_overrides = ctx["apply_record_overrides"]
render_output_filename = ctx["render_output_filename"]
validate_record_for_generation = ctx["validate_record_for_generation"]
generate_records = ctx["generate_records"]
infer_correction_rule_keyword = ctx["infer_correction_rule_keyword"]
save_or_update_manual_rule = ctx["save_or_update_manual_rule"]
append_review_log = ctx["append_review_log"]
resolve_issue_marks_by_source_line = ctx["resolve_issue_marks_by_source_line"]
resolve_single_generation_mode = ctx.get("resolve_single_generation_mode")
acquire_generation_slot = ctx.get("acquire_generation_slot")
release_generation_slot = ctx.get("release_generation_slot")
issue_resolve_stat: dict[str, Any] = {"count": 0, "ids": []}
slot_acquired = False
try:
record = payload.get("record")
if not isinstance(record, dict):
raise ValueError("record is required")
overrides = payload.get("overrides", {})
if overrides is None:
overrides = {}
if not isinstance(overrides, dict):
raise ValueError("overrides must be an object")
config = load_config()
single_mode = bool(resolve_single_generation_mode(config)) if callable(resolve_single_generation_mode) else True
if single_mode and callable(acquire_generation_slot):
slot_acquired = bool(acquire_generation_slot("correction_apply"))
if not slot_acquired:
return (
HTTPStatus.TOO_MANY_REQUESTS,
{
"ok": False,
"error": "generate_busy",
"error_code": "generate_busy",
"message": "已有任务在生成,请稍后再试。",
},
)
template_override = str(payload.get("template_file", "")).strip() or None
output_override = str(payload.get("output_dir", "")).strip() or None
template_path = resolve_template_path(config, template_override)
output_dir = resolve_output_dir(config, output_override)
history_path = resolve_history_path(config)
relay_cfg = config.get("relay_handling", {})
parse_rules = relay_cfg.get("parse_rules", {}) if isinstance(relay_cfg, dict) else {}
line_pattern = str(parse_rules.get("line_pattern", r"^\d+、\s*"))
source_line = str(record.get("source_line", "")).strip()
raw_text = str(record.get("raw_text", "")).strip()
normalized_line = normalize_line(source_line or raw_text, line_pattern)
base_record = {
"source_line": source_line or raw_text,
"raw_text": normalized_line or raw_text,
"branch": normalize_branch_value(record.get("branch", ""), config),
"amount": normalize_amount_text(record.get("amount", "")),
"type": str(record.get("type", "")).strip(),
"page": str(record.get("page", "")).strip(),
"status": normalize_status_value(str(record.get("status", "")).strip(), config),
"output_file": str(record.get("output_file", "")).strip(),
}
if not base_record["page"] and base_record["type"]:
base_record["page"] = infer_page_from_type(base_record["type"], config)
corrected = apply_record_overrides(base_record, overrides, config)
corrected["source_line"] = str(corrected.get("source_line") or source_line or raw_text)
corrected["raw_text"] = normalize_line(str(corrected.get("raw_text") or normalized_line), line_pattern)
if not corrected.get("output_file"):
corrected["output_file"] = render_output_filename(config, corrected, 1)
validate_record_for_generation(corrected, config)
gen_result = generate_records(
[corrected],
config,
template_path,
output_dir,
progress_cb=None,
)
relay_cfg = config.get("relay_handling", {})
dedup_cfg = relay_cfg.get("dedup", {}) if isinstance(relay_cfg, dict) else {}
key_fields = _normalize_key_fields(dedup_cfg.get("key_fields", ["branch", "amount", "type"]))
history_stat = upsert_generated_history(
history_path=history_path,
generated_items=gen_result.get("generated", []),
key_fields=key_fields,
ctx=ctx,
)
remember_rule = bool(payload.get("remember_rule", False))
remember_amount = bool(payload.get("remember_amount", False))
rule_keyword = str(payload.get("rule_keyword", "")).strip()
note = str(payload.get("note", "")).strip()
applied_rule = None
if remember_rule:
rule_updates: dict[str, Any] = {}
for field in ("branch", "type", "page", "status", "amount"):
if field not in overrides:
continue
val = str(corrected.get(field, "")).strip()
if val:
rule_updates[field] = val
if not remember_amount:
rule_updates.pop("amount", None)
if rule_updates:
keyword = rule_keyword or infer_correction_rule_keyword(
source_line=str(corrected.get("source_line", "")),
normalized_line=str(corrected.get("raw_text", "")),
corrected_record=corrected,
)
applied_rule = save_or_update_manual_rule(
keyword=keyword,
updates=rule_updates,
note=note,
match_mode=str(payload.get("rule_mode", "normalized")),
)
append_review_log(
"manual_correction_apply",
{
"source_line": str(corrected.get("source_line", "")),
"record_before": base_record,
"record_after": corrected,
"overrides": overrides,
"remember_rule": remember_rule,
"rule": applied_rule,
"note": note,
},
)
issue_resolve_stat = resolve_issue_marks_by_source_line(
str(corrected.get("source_line", "")),
reason="manual_correction_apply",
)
if int(issue_resolve_stat.get("count", 0)) > 0:
append_review_log(
"issue_auto_resolve",
{
"source_line": str(corrected.get("source_line", "")),
"resolved_issue_count": int(issue_resolve_stat.get("count", 0)),
"resolved_issue_ids": issue_resolve_stat.get("ids", []),
},
)
return (
HTTPStatus.OK,
{
"ok": True,
"message": "修正已生成",
"generated_count": gen_result.get("generated_count", 0),
"generated": gen_result.get("generated", []),
"download_images": gen_result.get("download_images", []),
"generation_strategy": gen_result.get("generation_strategy", "legacy"),
"history": history_stat,
"rule": applied_rule,
"resolved_issue_count": int(issue_resolve_stat.get("count", 0)),
"resolved_issue_ids": issue_resolve_stat.get("ids", []),
},
)
except ValueError as exc:
return HTTPStatus.BAD_REQUEST, {"ok": False, "error": str(exc)}
except Exception as exc:
append_review_log(
"manual_correction_error",
{
"error": str(exc),
"record": payload.get("record") if isinstance(payload, dict) else {},
},
)
return HTTPStatus.INTERNAL_SERVER_ERROR, {"ok": False, "error": str(exc)}
finally:
if slot_acquired and callable(release_generation_slot):
release_generation_slot("correction_apply")