test_query_variant.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import copy
  2. from content_agent.integrations import query_variant
  3. from content_agent.integrations.query_prompt_config import DEFAULT_PROFILE
  4. from content_agent.integrations.query_variant import (
  5. MissingQueryVariantClient,
  6. OpenRouterQueryVariantClient,
  7. _messages,
  8. _render_messages,
  9. query_variant_client_from_env,
  10. )
  11. def test_query_variant_client_empty_env_does_not_read_process_env(monkeypatch):
  12. monkeypatch.setenv("OPENROUTER_API_KEY", "process-secret")
  13. monkeypatch.setenv("CONTENT_AGENT_QUERY_LLM_MODEL", "process-model")
  14. client = query_variant_client_from_env({})
  15. assert isinstance(client, MissingQueryVariantClient)
  16. def test_query_variant_client_uses_model_fallback_without_exposing_key():
  17. client = query_variant_client_from_env(
  18. {
  19. "OPEN_ROUTER_API_KEY": "test-secret",
  20. "MODEL": "test-model",
  21. "OPENROUTER_BASE_URL": "https://example.invalid/api/v1",
  22. "CONTENT_AGENT_QUERY_LLM_TIMEOUT_SECONDS": "12",
  23. "CONTENT_AGENT_QUERY_LLM_PROMPT_VERSION": "prompt-test",
  24. }
  25. )
  26. assert isinstance(client, OpenRouterQueryVariantClient)
  27. assert client.model == "test-model"
  28. assert client.timeout_seconds == 12
  29. assert client.prompt_version == "query_variant.v1"
  30. def test_default_render_messages_matches_legacy_messages():
  31. evidence = {"seed_term": "中医养生", "support": 0.2}
  32. rendered = _render_messages(DEFAULT_PROFILE, "中医养生", evidence)
  33. assert rendered == _messages("中医养生", evidence)
  34. assert str(evidence) in rendered[1]["content"]
  35. assert '"seed_term"' not in rendered[1]["content"]
  36. def test_openrouter_client_uses_custom_profile(monkeypatch):
  37. profile = copy.deepcopy(DEFAULT_PROFILE)
  38. profile.update(
  39. {
  40. "prompt_version": "custom-query-v2",
  41. "system": "custom system",
  42. "user": "Seed={seed_term}; Evidence={evidence_context}",
  43. "temperature": 0.9,
  44. "max_tokens": 23,
  45. }
  46. )
  47. captured = {}
  48. class FakeResponse:
  49. def raise_for_status(self):
  50. return None
  51. def json(self):
  52. return {"choices": [{"message": {"content": " 气血食疗 "}}]}
  53. def fake_post(url, *, headers, json, timeout):
  54. captured.update({"url": url, "headers": headers, "json": json, "timeout": timeout})
  55. return FakeResponse()
  56. monkeypatch.setattr(query_variant.httpx, "post", fake_post)
  57. client = OpenRouterQueryVariantClient(
  58. api_key="secret",
  59. model="model-x",
  60. base_url="https://example.invalid/api/v1",
  61. timeout_seconds=7,
  62. prompt_version="ignored-env-version",
  63. profile=profile,
  64. )
  65. result = client.generate_variant(seed_term="中医养生", evidence_context={"support": 0.2})
  66. assert result.query == "气血食疗"
  67. assert result.prompt_version == "custom-query-v2"
  68. assert captured["timeout"] == 7
  69. assert captured["json"]["temperature"] == 0.9
  70. assert captured["json"]["max_tokens"] == 23
  71. assert captured["json"]["messages"] == [
  72. {"role": "system", "content": "custom system"},
  73. {"role": "user", "content": "Seed=中医养生; Evidence={'support': 0.2}"},
  74. ]