Files
xb/app/tests/test_services.py

164 lines
7.1 KiB
Python

from __future__ import annotations
import threading
import unittest
from pathlib import Path
from services.post_ops import (
run_history_clear,
run_mark_issue,
)
from services.workflows import (
run_correction_apply_api,
run_generate_api,
run_parse_api,
)
class ServiceTestBase(unittest.TestCase):
def setUp(self) -> None:
self.progress_calls: list[tuple[str, str, int, str, str]] = []
def base_ctx(self) -> dict:
lock = threading.Lock()
def set_progress(token: str, *, status: str, stage: str, percent: int, detail: str = "", error: str = ""):
self.progress_calls.append((token, status, int(percent), str(stage), str(detail or error)))
return {
"normalize_insurance_year": lambda v: str(v) if v in {"3", "5"} else None,
"normalize_insurance_year_choices": lambda v: v if isinstance(v, dict) else {},
"load_config": lambda: {
"relay_handling": {
"dedup": {"key_fields": ["branch", "amount", "type"]},
"parse_rules": {"line_pattern": r"^\\d+、\\s*"},
}
},
"resolve_history_path": lambda config: Path("/tmp/xibao_test_history.json"),
"resolve_template_path": lambda config, override=None: Path("/tmp/template.pptx"),
"resolve_output_dir": lambda config, override=None: Path("/tmp/xibao_output"),
"load_history": lambda history_path: [],
"save_history": lambda history_path, records: None,
"parse_records": lambda raw_text, config, history, insurance_year_choice, insurance_year_choices: {
"has_trigger": True,
"records": [],
"new_records": [],
"skipped": [],
},
"generate_records": lambda new_records, config, template_path, output_dir, progress_cb=None: {
"generated_count": len(new_records),
"generated": list(new_records),
"download_images": [],
},
"set_generation_progress": set_progress,
"append_review_log": lambda event, payload=None: "/tmp/review.jsonl",
"log_parse_skipped": lambda skipped, source: 0,
"append_new_history": lambda history_path, history, records, key_fields: {
"added": len(records),
"total": len(history) + len(records),
},
"upsert_issue_mark": lambda **kwargs: ({"id": "issue_1", **kwargs}, True),
"suppress_skip_item": lambda line, reason="": ({"id": "skip_1", "line": line, "reason": reason}, True),
"normalize_line": lambda line, pattern: str(line or "").strip(),
"normalize_branch_value": lambda value, config: str(value or "").strip(),
"normalize_amount_text": lambda value: str(value or "").strip(),
"normalize_status_value": lambda value, config: str(value or "").strip(),
"infer_page_from_type": lambda type_keyword, config: "page_2",
"apply_record_overrides": lambda record, overrides, config: {**record, **(overrides or {})},
"render_output_filename": lambda config, record, index: f"喜报_{index}.png",
"validate_record_for_generation": lambda record, config: None,
"upsert_history_records": lambda history_path, history, records, key_fields: {
"added": len(records),
"updated": 0,
},
"infer_correction_rule_keyword": lambda **kwargs: "keyword",
"save_or_update_manual_rule": lambda **kwargs: {"keyword": kwargs.get("keyword", "keyword")},
"resolve_issue_marks_by_source_line": lambda source_line, reason="": {"count": 0, "ids": []},
"update_issue_mark": lambda **kwargs: {"id": kwargs.get("issue_id", "")},
"delete_issue_mark": lambda issue_id: True,
"cleanup_output_artifacts": lambda output_dir: {"removed_dirs": 0, "removed_files": 0},
"clear_skip_suppressions": lambda: 0,
"_HISTORY_LOCK": lock,
}
class TestWorkflows(ServiceTestBase):
def test_run_parse_api_success(self) -> None:
ctx = self.base_ctx()
status, body = run_parse_api({"raw_text": "#接龙\n1、测试"}, ctx)
self.assertEqual(int(status), 200)
self.assertTrue(body.get("ok"))
self.assertIn("result", body)
def test_run_generate_api_requires_insurance(self) -> None:
ctx = self.base_ctx()
def parse_records(*args, **kwargs):
return {
"has_trigger": True,
"records": [],
"new_records": [{"branch": "营江路", "amount": "10万", "type": "一年期定期"}],
"skipped": [],
"needs_insurance_choice": True,
}
ctx["parse_records"] = parse_records
status, body = run_generate_api({"raw_text": "#接龙\n1、测试"}, ctx)
self.assertEqual(int(status), 400)
self.assertEqual(body.get("error_code"), "insurance_year_required")
self.assertTrue(any(item[1] == "need_input" for item in self.progress_calls))
def test_run_generate_api_success(self) -> None:
ctx = self.base_ctx()
def parse_records(*args, **kwargs):
return {
"has_trigger": True,
"records": [],
"new_records": [
{
"source_line": "1、营江路揽收现金10万存一年",
"raw_text": "营江路揽收现金10万存一年",
"branch": "营江路",
"amount": "10万",
"type": "一年期定期",
"page": "page_2",
"status": "揽收现金",
}
],
"skipped": [],
"needs_insurance_choice": False,
"dedup_key_fields": ["branch", "amount", "type"],
}
ctx["parse_records"] = parse_records
status, body = run_generate_api({"raw_text": "#接龙\n1、测试", "save_history": True}, ctx)
self.assertEqual(int(status), 200)
self.assertTrue(body.get("ok"))
self.assertEqual(int(body.get("generated_count", 0)), 1)
self.assertTrue(any(item[1] == "done" for item in self.progress_calls))
def test_run_correction_apply_api_validation(self) -> None:
ctx = self.base_ctx()
status, body = run_correction_apply_api({}, ctx)
self.assertEqual(int(status), 400)
self.assertFalse(body.get("ok", True))
class TestPostOps(ServiceTestBase):
def test_run_mark_issue_invalid_type(self) -> None:
ctx = self.base_ctx()
status, body = run_mark_issue({"mark_type": "bad", "source_line": "x"}, ctx)
self.assertEqual(int(status), 400)
self.assertIn("mark_type", body.get("error", ""))
def test_run_history_clear_success(self) -> None:
ctx = self.base_ctx()
status, body = run_history_clear(ctx)
self.assertEqual(int(status), 200)
self.assertTrue(body.get("ok"))
if __name__ == "__main__":
unittest.main()