feat: initial import (exclude templates and runtime temp files)

This commit is contained in:
237899745
2026-02-27 15:21:15 +08:00
commit 0951732c7a
33 changed files with 11698 additions and 0 deletions

0
app/tests/__init__.py Normal file
View File

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
import unittest
from server import (
resolve_build_workers,
resolve_extract_workers,
resolve_generation_strategy,
resolve_image_delivery_options,
resolve_memory_reclaim_options,
resolve_single_generation_mode,
)
class TestGenerationStrategy(unittest.TestCase):
def test_default_without_performance_uses_page_template_cache(self) -> None:
self.assertEqual(resolve_generation_strategy({}, total_records=1), "legacy")
self.assertEqual(resolve_generation_strategy({}, total_records=5), "page_template_cache")
def test_page_template_cache_requires_min_records(self) -> None:
cfg = {
"performance": {
"generation_strategy": "page_template_cache",
"template_cache_min_records": 2,
"single_slide_output": True,
}
}
self.assertEqual(resolve_generation_strategy(cfg, total_records=1), "legacy")
self.assertEqual(resolve_generation_strategy(cfg, total_records=2), "page_template_cache")
def test_page_template_cache_respects_single_slide_output(self) -> None:
cfg = {
"performance": {
"generation_strategy": "page_template_cache",
"template_cache_min_records": 1,
"single_slide_output": False,
}
}
self.assertEqual(resolve_generation_strategy(cfg, total_records=3), "legacy")
def test_legacy_strategy_always_returns_legacy(self) -> None:
cfg = {
"performance": {
"generation_strategy": "legacy",
"template_cache_min_records": 1,
"single_slide_output": True,
}
}
self.assertEqual(resolve_generation_strategy(cfg, total_records=20), "legacy")
def test_image_delivery_defaults(self) -> None:
opts = resolve_image_delivery_options({})
self.assertEqual(int(opts.get("max_kbps", 0)), 300)
self.assertEqual(int(opts.get("chunk_size", 0)), 16 * 1024)
def test_image_delivery_disable_limit(self) -> None:
cfg = {"performance": {"image_delivery": {"enabled": False, "max_kbps": 999}}}
opts = resolve_image_delivery_options(cfg)
self.assertEqual(int(opts.get("max_kbps", -1)), 0)
def test_single_generation_mode_defaults_enabled(self) -> None:
self.assertTrue(resolve_single_generation_mode({}))
self.assertFalse(resolve_single_generation_mode({"performance": {"single_generation_mode": False}}))
def test_build_workers_default_single(self) -> None:
self.assertEqual(resolve_build_workers({}, total_records=10), 1)
cfg = {"performance": {"max_build_workers": 3}}
self.assertEqual(resolve_build_workers(cfg, total_records=2), 2)
def test_extract_workers_default_single(self) -> None:
self.assertEqual(resolve_extract_workers({}), 1)
cfg = {"performance": {"max_extract_workers": 4}}
self.assertGreaterEqual(resolve_extract_workers(cfg), 1)
def test_memory_reclaim_defaults_enabled(self) -> None:
opts = resolve_memory_reclaim_options({})
self.assertTrue(opts["enabled"])
self.assertTrue(opts["gc_collect"])
self.assertTrue(opts["malloc_trim"])
def test_memory_reclaim_can_disable(self) -> None:
opts = resolve_memory_reclaim_options({"performance": {"memory_reclaim": {"enabled": False}}})
self.assertFalse(opts["enabled"])
self.assertFalse(opts["gc_collect"])
self.assertFalse(opts["malloc_trim"])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import unittest
from server import load_config, parse_records
class TestParseAmountSplit(unittest.TestCase):
def test_amount_before_comma_and_term_after_comma(self) -> None:
config = load_config()
raw_text = "#接龙\n11、 四马桥一潘纪君 拜访门窗商户微信提现15万存定期一年"
result = parse_records(raw_text, config, history=[])
records = result.get("new_records", [])
self.assertTrue(
any(
str(item.get("branch", "")) == "四马桥"
and str(item.get("type", "")) == "一年期定期"
and str(item.get("amount", "")) == "15万"
for item in records
)
)
skipped_reasons = [str(x.get("reason", "")) for x in result.get("skipped", [])]
self.assertNotIn("amount_not_found", skipped_reasons)
if __name__ == "__main__":
unittest.main()

163
app/tests/test_services.py Normal file
View File

@@ -0,0 +1,163 @@
from __future__ import annotations
import threading
import unittest
from pathlib import Path
from services.post_ops import (
run_history_clear,
run_mark_issue,
)
from services.workflows import (
run_correction_apply_api,
run_generate_api,
run_parse_api,
)
class ServiceTestBase(unittest.TestCase):
def setUp(self) -> None:
self.progress_calls: list[tuple[str, str, int, str, str]] = []
def base_ctx(self) -> dict:
lock = threading.Lock()
def set_progress(token: str, *, status: str, stage: str, percent: int, detail: str = "", error: str = ""):
self.progress_calls.append((token, status, int(percent), str(stage), str(detail or error)))
return {
"normalize_insurance_year": lambda v: str(v) if v in {"3", "5"} else None,
"normalize_insurance_year_choices": lambda v: v if isinstance(v, dict) else {},
"load_config": lambda: {
"relay_handling": {
"dedup": {"key_fields": ["branch", "amount", "type"]},
"parse_rules": {"line_pattern": r"^\\d+、\\s*"},
}
},
"resolve_history_path": lambda config: Path("/tmp/xibao_test_history.json"),
"resolve_template_path": lambda config, override=None: Path("/tmp/template.pptx"),
"resolve_output_dir": lambda config, override=None: Path("/tmp/xibao_output"),
"load_history": lambda history_path: [],
"save_history": lambda history_path, records: None,
"parse_records": lambda raw_text, config, history, insurance_year_choice, insurance_year_choices: {
"has_trigger": True,
"records": [],
"new_records": [],
"skipped": [],
},
"generate_records": lambda new_records, config, template_path, output_dir, progress_cb=None: {
"generated_count": len(new_records),
"generated": list(new_records),
"download_images": [],
},
"set_generation_progress": set_progress,
"append_review_log": lambda event, payload=None: "/tmp/review.jsonl",
"log_parse_skipped": lambda skipped, source: 0,
"append_new_history": lambda history_path, history, records, key_fields: {
"added": len(records),
"total": len(history) + len(records),
},
"upsert_issue_mark": lambda **kwargs: ({"id": "issue_1", **kwargs}, True),
"suppress_skip_item": lambda line, reason="": ({"id": "skip_1", "line": line, "reason": reason}, True),
"normalize_line": lambda line, pattern: str(line or "").strip(),
"normalize_branch_value": lambda value, config: str(value or "").strip(),
"normalize_amount_text": lambda value: str(value or "").strip(),
"normalize_status_value": lambda value, config: str(value or "").strip(),
"infer_page_from_type": lambda type_keyword, config: "page_2",
"apply_record_overrides": lambda record, overrides, config: {**record, **(overrides or {})},
"render_output_filename": lambda config, record, index: f"喜报_{index}.png",
"validate_record_for_generation": lambda record, config: None,
"upsert_history_records": lambda history_path, history, records, key_fields: {
"added": len(records),
"updated": 0,
},
"infer_correction_rule_keyword": lambda **kwargs: "keyword",
"save_or_update_manual_rule": lambda **kwargs: {"keyword": kwargs.get("keyword", "keyword")},
"resolve_issue_marks_by_source_line": lambda source_line, reason="": {"count": 0, "ids": []},
"update_issue_mark": lambda **kwargs: {"id": kwargs.get("issue_id", "")},
"delete_issue_mark": lambda issue_id: True,
"cleanup_output_artifacts": lambda output_dir: {"removed_dirs": 0, "removed_files": 0},
"clear_skip_suppressions": lambda: 0,
"_HISTORY_LOCK": lock,
}
class TestWorkflows(ServiceTestBase):
def test_run_parse_api_success(self) -> None:
ctx = self.base_ctx()
status, body = run_parse_api({"raw_text": "#接龙\n1、测试"}, ctx)
self.assertEqual(int(status), 200)
self.assertTrue(body.get("ok"))
self.assertIn("result", body)
def test_run_generate_api_requires_insurance(self) -> None:
ctx = self.base_ctx()
def parse_records(*args, **kwargs):
return {
"has_trigger": True,
"records": [],
"new_records": [{"branch": "营江路", "amount": "10万", "type": "一年期定期"}],
"skipped": [],
"needs_insurance_choice": True,
}
ctx["parse_records"] = parse_records
status, body = run_generate_api({"raw_text": "#接龙\n1、测试"}, ctx)
self.assertEqual(int(status), 400)
self.assertEqual(body.get("error_code"), "insurance_year_required")
self.assertTrue(any(item[1] == "need_input" for item in self.progress_calls))
def test_run_generate_api_success(self) -> None:
ctx = self.base_ctx()
def parse_records(*args, **kwargs):
return {
"has_trigger": True,
"records": [],
"new_records": [
{
"source_line": "1、营江路揽收现金10万存一年",
"raw_text": "营江路揽收现金10万存一年",
"branch": "营江路",
"amount": "10万",
"type": "一年期定期",
"page": "page_2",
"status": "揽收现金",
}
],
"skipped": [],
"needs_insurance_choice": False,
"dedup_key_fields": ["branch", "amount", "type"],
}
ctx["parse_records"] = parse_records
status, body = run_generate_api({"raw_text": "#接龙\n1、测试", "save_history": True}, ctx)
self.assertEqual(int(status), 200)
self.assertTrue(body.get("ok"))
self.assertEqual(int(body.get("generated_count", 0)), 1)
self.assertTrue(any(item[1] == "done" for item in self.progress_calls))
def test_run_correction_apply_api_validation(self) -> None:
ctx = self.base_ctx()
status, body = run_correction_apply_api({}, ctx)
self.assertEqual(int(status), 400)
self.assertFalse(body.get("ok", True))
class TestPostOps(ServiceTestBase):
def test_run_mark_issue_invalid_type(self) -> None:
ctx = self.base_ctx()
status, body = run_mark_issue({"mark_type": "bad", "source_line": "x"}, ctx)
self.assertEqual(int(status), 400)
self.assertIn("mark_type", body.get("error", ""))
def test_run_history_clear_success(self) -> None:
ctx = self.base_ctx()
status, body = run_history_clear(ctx)
self.assertEqual(int(status), 200)
self.assertTrue(body.get("ok"))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
import unittest
from server import load_config, parse_records
class TestStatusAlias(unittest.TestCase):
def test_collect_external_bank_maps_to_collect_other_bank(self) -> None:
config = load_config()
raw_text = "#接龙\n21、 濂溪揽收外行5.3万存一年"
result = parse_records(raw_text, config, history=[])
records = result.get("new_records", [])
self.assertTrue(
any(
str(item.get("branch", "")) == "濂溪"
and str(item.get("type", "")) == "一年期定期"
and str(item.get("status", "")) == "揽收他行"
for item in records
)
)
def test_transfer_external_bank_maps_to_transfer_other_bank(self) -> None:
config = load_config()
raw_text = "#接龙\n22、 潇水南路挖转外行20万存半年"
result = parse_records(raw_text, config, history=[])
records = result.get("new_records", [])
self.assertTrue(
any(
str(item.get("branch", "")) == "潇水南路"
and str(item.get("type", "")) == "六个月定期"
and str(item.get("status", "")) == "他行挖转"
for item in records
)
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
from __future__ import annotations
import unittest
from argparse import Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
from wechat_bot_bridge import WechatXibaoBridge
class WechatBotBridgeSkipTests(unittest.TestCase):
def _new_bridge(self, tmp_dir: str) -> WechatXibaoBridge:
base = Path(tmp_dir)
args = Namespace(
wechat_base_url="http://127.0.0.1:18238",
xibao_base_url="http://127.0.0.1:8787",
wechat_auth_key="",
wechat_session_file=str(base / "session.json"),
sync_count=30,
poll_interval=2.0,
max_images=3,
once=True,
dry_run=True,
allow_from="",
state_file=str(base / "state.json"),
meta_file=str(base / "meta.json"),
daily_cleanup_time="00:10",
)
return WechatXibaoBridge(args)
def test_parse_skip_command(self) -> None:
with TemporaryDirectory() as tmp_dir:
bridge = self._new_bridge(tmp_dir)
cmd = bridge.parse_command("跳过3")
self.assertIsNotNone(cmd)
self.assertEqual(cmd.get("action"), "set_skip")
self.assertEqual(cmd.get("count"), "3")
cmd2 = bridge.parse_command("跳过12")
self.assertIsNotNone(cmd2)
self.assertEqual(cmd2.get("action"), "set_skip")
self.assertEqual(cmd2.get("count"), "12")
def test_apply_daily_skip_to_numbered_lines(self) -> None:
with TemporaryDirectory() as tmp_dir:
bridge = self._new_bridge(tmp_dir)
bridge.set_daily_skip("wxid_u1", 2)
raw = "#接龙\n1、第一条\n2、第二条\n3、第三条\n说明文本"
text, removed, requested = bridge.apply_daily_skip_to_raw_text("wxid_u1", raw)
self.assertEqual(requested, 2)
self.assertEqual(removed, 2)
self.assertIn("3、第三条", text)
self.assertNotIn("1、第一条", text)
self.assertNotIn("2、第二条", text)
def test_skip_zero_clears_setting(self) -> None:
with TemporaryDirectory() as tmp_dir:
bridge = self._new_bridge(tmp_dir)
bridge.set_daily_skip("wxid_u2", 5)
self.assertEqual(bridge.get_daily_skip("wxid_u2"), 5)
bridge.set_daily_skip("wxid_u2", 0)
self.assertEqual(bridge.get_daily_skip("wxid_u2"), 0)
def test_stale_daily_skip_is_invalidated(self) -> None:
with TemporaryDirectory() as tmp_dir:
bridge = self._new_bridge(tmp_dir)
bridge.meta["daily_skip"] = {
"wxid_old": {
"date": "2000-01-01",
"count": 3,
"updated_at": "2000-01-01T00:00:00",
}
}
value = bridge.get_daily_skip("wxid_old")
self.assertEqual(value, 0)
self.assertNotIn("wxid_old", bridge.meta.get("daily_skip", {}))
if __name__ == "__main__":
unittest.main()