p1_helpers.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from __future__ import annotations
  2. import copy
  3. import json
  4. from pathlib import Path
  5. from typing import Any
  6. from content_agent.interfaces import QueryVariantResult
  7. REAL_SOURCE_FIXTURE = Path("tests/fixtures/real_case_source/source_context.json")
  8. def real_source_payload(demand_content_id: int = 1) -> dict[str, Any]:
  9. payload = json.loads(REAL_SOURCE_FIXTURE.read_text(encoding="utf-8"))
  10. payload["id"] = demand_content_id
  11. payload["demand_content_id"] = str(demand_content_id)
  12. payload["raw_demand_content"] = copy.deepcopy(payload)
  13. return payload
  14. class FakeDemandSource:
  15. def __init__(self, payload: dict[str, Any] | None = None) -> None:
  16. self.payload = payload or real_source_payload()
  17. self.calls: list[str] = []
  18. def get_default_pg_pattern_source(self) -> dict[str, Any]:
  19. self.calls.append("get_default_pg_pattern_source")
  20. return copy.deepcopy(self.payload)
  21. def get_by_id(self, demand_content_id: int) -> dict[str, Any]:
  22. self.calls.append(f"get_by_id:{demand_content_id}")
  23. payload = copy.deepcopy(self.payload)
  24. payload["id"] = demand_content_id
  25. payload["demand_content_id"] = str(demand_content_id)
  26. return payload
  27. def get_by_run_label(self, run_label: str) -> dict[str, Any]:
  28. self.calls.append(f"get_by_run_label:{run_label}")
  29. payload = copy.deepcopy(self.payload)
  30. payload["run_label"] = run_label
  31. return payload
  32. class FakeQueryVariantClient:
  33. def __init__(
  34. self,
  35. variants: dict[str, str] | None = None,
  36. error: Exception | None = None,
  37. ) -> None:
  38. self.variants = variants or {}
  39. self.error = error
  40. self.calls: list[dict[str, Any]] = []
  41. def generate_variant(
  42. self,
  43. *,
  44. seed_term: str,
  45. evidence_context: dict[str, Any],
  46. ) -> QueryVariantResult:
  47. self.calls.append(
  48. {
  49. "seed_term": seed_term,
  50. "evidence_context": copy.deepcopy(evidence_context),
  51. }
  52. )
  53. if self.error:
  54. raise self.error
  55. return QueryVariantResult(
  56. query=self.variants.get(seed_term, f"{seed_term} 拓展素材"),
  57. model="fake-query-model",
  58. prompt_version="fake-query-prompt-v1",
  59. input_evidence=copy.deepcopy(evidence_context),
  60. )