| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- from __future__ import annotations
- import copy
- import json
- from pathlib import Path
- from typing import Any
- from content_agent.interfaces import QueryVariantResult
- REAL_SOURCE_FIXTURE = Path("tests/fixtures/real_case_source/source_context.json")
- def real_source_payload(demand_content_id: int = 1) -> dict[str, Any]:
- payload = json.loads(REAL_SOURCE_FIXTURE.read_text(encoding="utf-8"))
- payload["id"] = demand_content_id
- payload["demand_content_id"] = str(demand_content_id)
- payload["raw_demand_content"] = copy.deepcopy(payload)
- return payload
- class FakeDemandSource:
- def __init__(self, payload: dict[str, Any] | None = None) -> None:
- self.payload = payload or real_source_payload()
- self.calls: list[str] = []
- def get_default_pg_pattern_source(self) -> dict[str, Any]:
- self.calls.append("get_default_pg_pattern_source")
- return copy.deepcopy(self.payload)
- def get_by_id(self, demand_content_id: int) -> dict[str, Any]:
- self.calls.append(f"get_by_id:{demand_content_id}")
- payload = copy.deepcopy(self.payload)
- payload["id"] = demand_content_id
- payload["demand_content_id"] = str(demand_content_id)
- return payload
- def get_by_run_label(self, run_label: str) -> dict[str, Any]:
- self.calls.append(f"get_by_run_label:{run_label}")
- payload = copy.deepcopy(self.payload)
- payload["run_label"] = run_label
- return payload
- class FakeQueryVariantClient:
- def __init__(
- self,
- variants: dict[str, str] | None = None,
- error: Exception | None = None,
- ) -> None:
- self.variants = variants or {}
- self.error = error
- self.calls: list[dict[str, Any]] = []
- def generate_variant(
- self,
- *,
- seed_term: str,
- evidence_context: dict[str, Any],
- ) -> QueryVariantResult:
- self.calls.append(
- {
- "seed_term": seed_term,
- "evidence_context": copy.deepcopy(evidence_context),
- }
- )
- if self.error:
- raise self.error
- return QueryVariantResult(
- query=self.variants.get(seed_term, f"{seed_term} 拓展素材"),
- model="fake-query-model",
- prompt_version="fake-query-prompt-v1",
- input_evidence=copy.deepcopy(evidence_context),
- )
|