test_search_intent.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import copy
  2. import pytest
  3. from content_agent.business_modules import search_intent
  4. from content_agent.errors import ContentAgentError
  5. from content_agent.integrations.query_prompt_config import DEFAULT_PROFILE
  6. from content_agent.run_service import RunService
  7. from content_agent.schemas import RunStartRequest
  8. from tests.p1_helpers import FakeQueryVariantClient, REAL_SOURCE_FIXTURE
  9. FORBIDDEN_FIXED_BUSINESS_TERMS = [
  10. "\u8d2a\u8150",
  11. "\u57fa\u5c42\u516c\u804c\u4eba\u5458",
  12. "\u6848\u4f8b",
  13. "\u89e3\u8bfb",
  14. "\u8b66\u793a",
  15. ]
  16. class _Runtime:
  17. def __init__(self):
  18. self.rows = {}
  19. def append_jsonl(self, _run_id, filename, rows):
  20. self.rows[filename] = rows
  21. def _seed_pack():
  22. return {
  23. "seed_terms": ["中医养生"],
  24. "itemset_items": ["补气血"],
  25. "category_bindings": [{"category_id": "c1"}],
  26. "element_bindings": [{"element_id": "e1"}],
  27. "pattern_source_system": "pg_pattern_v2",
  28. "pattern_execution_id": 1987,
  29. "mining_config_id": 58,
  30. "source_post_id": "60219550",
  31. "matched_post_ids": ["60219550"],
  32. "itemset_ids": [1607977],
  33. "support": 0.2,
  34. "absolute_support": 31,
  35. "confidence": 0.8,
  36. }
  37. def test_search_seed_and_queries_do_not_inject_fixed_business_terms(tmp_path):
  38. service = RunService(
  39. runtime_root=tmp_path / "runtime" / "v1",
  40. query_variant_client=FakeQueryVariantClient(
  41. {
  42. "爱国情感": "家国叙事素材",
  43. "人物故事": "榜样人物素材",
  44. }
  45. ),
  46. )
  47. state = service.start_run(
  48. RunStartRequest(platform_mode="mock", source=str(REAL_SOURCE_FIXTURE))
  49. )
  50. run_id = state["run_id"]
  51. pattern_seed_pack = service.read_json(run_id, "pattern_seed_pack.json")
  52. queries = service.read_jsonl(run_id, "search_queries.jsonl")
  53. p2_queries = [
  54. row
  55. for row in queries
  56. if row["search_query_generation_method"] in {"item_single", "llm_variant"}
  57. ]
  58. assert pattern_seed_pack["seed_terms"] == ["爱国情感", "人物故事"]
  59. assert [row["search_query_id"] for row in p2_queries] == ["q_001", "q_002", "q_003", "q_004"]
  60. assert [row["search_query"] for row in p2_queries] == [
  61. "爱国情感",
  62. "家国叙事素材",
  63. "人物故事",
  64. "榜样人物素材",
  65. ]
  66. assert [row["search_query_generation_method"] for row in p2_queries] == [
  67. "item_single",
  68. "llm_variant",
  69. "item_single",
  70. "llm_variant",
  71. ]
  72. assert p2_queries[1]["llm_variant_of"] == "q_001"
  73. assert p2_queries[3]["llm_variant_of"] == "q_003"
  74. for value in [
  75. *pattern_seed_pack["seed_terms"],
  76. *(row["search_query"] for row in p2_queries),
  77. ]:
  78. assert not any(term in value for term in FORBIDDEN_FIXED_BUSINESS_TERMS)
  79. def test_search_queries_preserve_source_terms_for_replay(tmp_path):
  80. service = RunService(
  81. runtime_root=tmp_path / "runtime" / "v1",
  82. query_variant_client=FakeQueryVariantClient(
  83. {
  84. "爱国情感": "家国叙事素材",
  85. "人物故事": "榜样人物素材",
  86. }
  87. ),
  88. )
  89. state = service.start_run(
  90. RunStartRequest(platform_mode="mock", source=str(REAL_SOURCE_FIXTURE))
  91. )
  92. queries = service.read_jsonl(state["run_id"], "search_queries.jsonl")
  93. p2_queries = [
  94. query
  95. for query in queries
  96. if query["search_query_generation_method"] in {"item_single", "llm_variant"}
  97. ]
  98. expected_source_terms = [["爱国情感"], ["爱国情感"], ["人物故事"], ["人物故事"]]
  99. for query, source_terms in zip(p2_queries, expected_source_terms, strict=True):
  100. assert query["query_source_terms"] == source_terms
  101. assert query["query_source_fields"] == ["seed_terms"]
  102. assert query["raw_payload"]["query_source_terms"] == source_terms
  103. assert query["pattern_seed_ref"]["source_field"] == "seed_terms"
  104. assert query["pattern_seed_ref"]["seed_term"] == source_terms[0]
  105. assert query["raw_payload"]["pattern_seed_ref"]["seed_term"] == source_terms[0]
  106. llm_queries = [
  107. query
  108. for query in p2_queries
  109. if query["search_query_generation_method"] == "llm_variant"
  110. ]
  111. assert len(llm_queries) == 2
  112. for query in llm_queries:
  113. assert query["raw_payload"]["llm_prompt_version"] == "fake-query-prompt-v1"
  114. assert query["raw_payload"]["llm_generation_model"] == "fake-query-model"
  115. assert query["raw_payload"]["llm_input_evidence"]["source_field"] == "seed_terms"
  116. assert query["raw_payload"]["llm_input_evidence"]["itemset_items"]
  117. def test_search_intent_custom_evidence_fields_whitelist():
  118. client = FakeQueryVariantClient({"中医养生": "气血食疗"})
  119. client.profile = copy.deepcopy(DEFAULT_PROFILE)
  120. client.profile["evidence_fields"] = ["seed_term", "support"]
  121. runtime = _Runtime()
  122. queries = search_intent.run("run_1", "policy_1", _seed_pack(), runtime, client)
  123. llm_query = [row for row in queries if row["search_query_generation_method"] == "llm_variant"][0]
  124. assert list(llm_query["llm_input_evidence"].keys()) == ["seed_term", "support"]
  125. assert list(llm_query["raw_payload"]["llm_input_evidence"].keys()) == ["seed_term", "support"]
  126. assert llm_query["query_source_fields"] == ["seed_terms"]
  127. def test_search_intent_custom_generic_filter_blocks_query():
  128. client = FakeQueryVariantClient({"中医养生": "禁用泛词"})
  129. client.profile = copy.deepcopy(DEFAULT_PROFILE)
  130. client.profile["generic_filter"] = {"queries": ["禁用泛词"], "tokens": []}
  131. with pytest.raises(ContentAgentError) as exc:
  132. search_intent.run("run_1", "policy_1", _seed_pack(), _Runtime(), client)
  133. assert exc.value.error_code == "QUERY_GENERATION_FAILED"
  134. assert exc.value.detail["reason"] == "llm_variant_generic"
  135. def test_search_intent_rejects_unsupported_variants_per_seed():
  136. client = FakeQueryVariantClient({"中医养生": "气血食疗"})
  137. client.profile = copy.deepcopy(DEFAULT_PROFILE)
  138. client.profile["variants_per_seed"] = 2
  139. with pytest.raises(ContentAgentError) as exc:
  140. search_intent.run("run_1", "policy_1", _seed_pack(), _Runtime(), client)
  141. assert exc.value.error_code == "QUERY_GENERATION_FAILED"
  142. assert exc.value.detail == {"reason": "variants_per_seed_unsupported", "variants_per_seed": 2}