from __future__ import annotations import threading import unittest from pathlib import Path from services.post_ops import ( run_history_clear, run_mark_issue, ) from services.workflows import ( run_correction_apply_api, run_generate_api, run_parse_api, ) class ServiceTestBase(unittest.TestCase): def setUp(self) -> None: self.progress_calls: list[tuple[str, str, int, str, str]] = [] def base_ctx(self) -> dict: lock = threading.Lock() def set_progress(token: str, *, status: str, stage: str, percent: int, detail: str = "", error: str = ""): self.progress_calls.append((token, status, int(percent), str(stage), str(detail or error))) return { "normalize_insurance_year": lambda v: str(v) if v in {"3", "5"} else None, "normalize_insurance_year_choices": lambda v: v if isinstance(v, dict) else {}, "load_config": lambda: { "relay_handling": { "dedup": {"key_fields": ["branch", "amount", "type"]}, "parse_rules": {"line_pattern": r"^\\d+、\\s*"}, } }, "resolve_history_path": lambda config: Path("/tmp/xibao_test_history.json"), "resolve_template_path": lambda config, override=None: Path("/tmp/template.pptx"), "resolve_output_dir": lambda config, override=None: Path("/tmp/xibao_output"), "load_history": lambda history_path: [], "save_history": lambda history_path, records: None, "parse_records": lambda raw_text, config, history, insurance_year_choice, insurance_year_choices: { "has_trigger": True, "records": [], "new_records": [], "skipped": [], }, "generate_records": lambda new_records, config, template_path, output_dir, progress_cb=None: { "generated_count": len(new_records), "generated": list(new_records), "download_images": [], }, "set_generation_progress": set_progress, "append_review_log": lambda event, payload=None: "/tmp/review.jsonl", "log_parse_skipped": lambda skipped, source: 0, "append_new_history": lambda history_path, history, records, key_fields: { "added": len(records), "total": len(history) + len(records), }, "upsert_issue_mark": lambda **kwargs: ({"id": "issue_1", **kwargs}, True), "suppress_skip_item": lambda line, reason="": ({"id": "skip_1", "line": line, "reason": reason}, True), "normalize_line": lambda line, pattern: str(line or "").strip(), "normalize_branch_value": lambda value, config: str(value or "").strip(), "normalize_amount_text": lambda value: str(value or "").strip(), "normalize_status_value": lambda value, config: str(value or "").strip(), "infer_page_from_type": lambda type_keyword, config: "page_2", "apply_record_overrides": lambda record, overrides, config: {**record, **(overrides or {})}, "render_output_filename": lambda config, record, index: f"喜报_{index}.png", "validate_record_for_generation": lambda record, config: None, "upsert_history_records": lambda history_path, history, records, key_fields: { "added": len(records), "updated": 0, }, "infer_correction_rule_keyword": lambda **kwargs: "keyword", "save_or_update_manual_rule": lambda **kwargs: {"keyword": kwargs.get("keyword", "keyword")}, "resolve_issue_marks_by_source_line": lambda source_line, reason="": {"count": 0, "ids": []}, "update_issue_mark": lambda **kwargs: {"id": kwargs.get("issue_id", "")}, "delete_issue_mark": lambda issue_id: True, "cleanup_output_artifacts": lambda output_dir: {"removed_dirs": 0, "removed_files": 0}, "clear_skip_suppressions": lambda: 0, "_HISTORY_LOCK": lock, } class TestWorkflows(ServiceTestBase): def test_run_parse_api_success(self) -> None: ctx = self.base_ctx() status, body = run_parse_api({"raw_text": "#接龙\n1、测试"}, ctx) self.assertEqual(int(status), 200) self.assertTrue(body.get("ok")) self.assertIn("result", body) def test_run_generate_api_requires_insurance(self) -> None: ctx = self.base_ctx() def parse_records(*args, **kwargs): return { "has_trigger": True, "records": [], "new_records": [{"branch": "营江路", "amount": "10万", "type": "一年期定期"}], "skipped": [], "needs_insurance_choice": True, } ctx["parse_records"] = parse_records status, body = run_generate_api({"raw_text": "#接龙\n1、测试"}, ctx) self.assertEqual(int(status), 400) self.assertEqual(body.get("error_code"), "insurance_year_required") self.assertTrue(any(item[1] == "need_input" for item in self.progress_calls)) def test_run_generate_api_success(self) -> None: ctx = self.base_ctx() def parse_records(*args, **kwargs): return { "has_trigger": True, "records": [], "new_records": [ { "source_line": "1、营江路揽收现金10万存一年", "raw_text": "营江路揽收现金10万存一年", "branch": "营江路", "amount": "10万", "type": "一年期定期", "page": "page_2", "status": "揽收现金", } ], "skipped": [], "needs_insurance_choice": False, "dedup_key_fields": ["branch", "amount", "type"], } ctx["parse_records"] = parse_records status, body = run_generate_api({"raw_text": "#接龙\n1、测试", "save_history": True}, ctx) self.assertEqual(int(status), 200) self.assertTrue(body.get("ok")) self.assertEqual(int(body.get("generated_count", 0)), 1) self.assertTrue(any(item[1] == "done" for item in self.progress_calls)) def test_run_correction_apply_api_validation(self) -> None: ctx = self.base_ctx() status, body = run_correction_apply_api({}, ctx) self.assertEqual(int(status), 400) self.assertFalse(body.get("ok", True)) class TestPostOps(ServiceTestBase): def test_run_mark_issue_invalid_type(self) -> None: ctx = self.base_ctx() status, body = run_mark_issue({"mark_type": "bad", "source_line": "x"}, ctx) self.assertEqual(int(status), 400) self.assertIn("mark_type", body.get("error", "")) def test_run_history_clear_success(self) -> None: ctx = self.base_ctx() status, body = run_history_clear(ctx) self.assertEqual(int(status), 200) self.assertTrue(body.get("ok")) if __name__ == "__main__": unittest.main()