search_intent.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. from __future__ import annotations
  2. from datetime import datetime, timezone
  3. from typing import Any
  4. from content_agent.constants import RUNTIME_RECORD_SCHEMA_VERSION
  5. from content_agent.errors import ContentAgentError, ErrorCode
  6. from content_agent.integrations.query_prompt_config import DEFAULT_PROFILE
  7. from content_agent.interfaces import QueryVariantClient, QueryVariantResult, RuntimeFileStore
  8. from content_agent.record_payload import with_raw_payload
  9. GENERIC_QUERIES = {
  10. "内容",
  11. "视频",
  12. "热门",
  13. "推荐",
  14. "短视频",
  15. "热门视频",
  16. "推荐视频",
  17. "热门内容",
  18. "推荐内容",
  19. "相关视频",
  20. "相关内容",
  21. "热" + "点视频",
  22. "热" + "点内容",
  23. }
  24. GENERIC_QUERY_TOKENS = (
  25. "短视频",
  26. "热门",
  27. "推荐",
  28. "相关",
  29. "热" + "点",
  30. "内容",
  31. "视频",
  32. "素材",
  33. "资料",
  34. "信息",
  35. "话题",
  36. )
  37. def run(
  38. run_id: str,
  39. policy_run_id: str,
  40. pattern_seed_pack: dict[str, Any],
  41. runtime: RuntimeFileStore,
  42. query_variant_client: QueryVariantClient,
  43. ) -> list[dict[str, Any]]:
  44. created_at = datetime.now(timezone.utc).isoformat()
  45. seed_terms = _terms(pattern_seed_pack.get("seed_terms"))
  46. if not seed_terms:
  47. raise _query_generation_error("seed_terms_empty")
  48. profile = getattr(query_variant_client, "profile", DEFAULT_PROFILE)
  49. variants_per_seed = int(profile.get("variants_per_seed", 1))
  50. if variants_per_seed != 1:
  51. raise _query_generation_error(
  52. "variants_per_seed_unsupported",
  53. {"variants_per_seed": variants_per_seed},
  54. )
  55. evidence_fields = profile.get("evidence_fields")
  56. generic_filter = profile.get("generic_filter")
  57. search_queries: list[dict[str, Any]] = []
  58. seen_queries: set[str] = set()
  59. for seed_index, seed_term in enumerate(seed_terms):
  60. item_query_id = f"q_{seed_index * 2 + 1:03d}"
  61. variant_query_id = f"q_{seed_index * 2 + 2:03d}"
  62. pattern_seed_ref = _pattern_seed_ref(pattern_seed_pack, seed_term, seed_index)
  63. item_single = _base_query_record(
  64. run_id=run_id,
  65. policy_run_id=policy_run_id,
  66. search_query_id=item_query_id,
  67. search_query=seed_term,
  68. generation_method="item_single",
  69. seed_term=seed_term,
  70. pattern_seed_ref=pattern_seed_ref,
  71. created_at=created_at,
  72. )
  73. _reserve_query(item_single, seen_queries, seed_term=seed_term, method="item_single")
  74. search_queries.append(item_single)
  75. evidence_context = _llm_input_evidence(
  76. pattern_seed_pack,
  77. seed_terms,
  78. seed_term,
  79. seed_index,
  80. sorted(seen_queries),
  81. evidence_fields=evidence_fields,
  82. )
  83. variant = _generate_variant(query_variant_client, seed_term, evidence_context)
  84. variant_query = _normalize_query(variant.query)
  85. _validate_variant_query(variant_query, seed_term, seen_queries, generic_filter=generic_filter)
  86. llm_variant = _base_query_record(
  87. run_id=run_id,
  88. policy_run_id=policy_run_id,
  89. search_query_id=variant_query_id,
  90. search_query=variant_query,
  91. generation_method="llm_variant",
  92. seed_term=seed_term,
  93. pattern_seed_ref=pattern_seed_ref,
  94. created_at=created_at,
  95. )
  96. llm_variant.update(
  97. {
  98. "llm_variant_of": item_query_id,
  99. "llm_input_evidence": variant.input_evidence,
  100. "llm_prompt_version": variant.prompt_version,
  101. "llm_generation_model": variant.model,
  102. }
  103. )
  104. _reserve_query(llm_variant, seen_queries, seed_term=seed_term, method="llm_variant")
  105. search_queries.append(llm_variant)
  106. expected_count = len(seed_terms) * 2
  107. if len(search_queries) != expected_count:
  108. raise _query_generation_error(
  109. "query_count_mismatch",
  110. {
  111. "seed_terms_count": len(seed_terms),
  112. "expected_query_count": expected_count,
  113. "actual_query_count": len(search_queries),
  114. },
  115. )
  116. search_queries = [with_raw_payload(row) for row in search_queries]
  117. runtime.append_jsonl(run_id, "search_queries.jsonl", search_queries)
  118. return search_queries
  119. def _terms(values: Any) -> list[str]:
  120. if not isinstance(values, list):
  121. return []
  122. unique: list[str] = []
  123. seen: set[str] = set()
  124. for value in values:
  125. if not isinstance(value, str):
  126. continue
  127. term = " ".join(value.split()).strip()
  128. if not term or term in seen:
  129. continue
  130. seen.add(term)
  131. unique.append(term)
  132. return unique
  133. def _base_query_record(
  134. *,
  135. run_id: str,
  136. policy_run_id: str,
  137. search_query_id: str,
  138. search_query: str,
  139. generation_method: str,
  140. seed_term: str,
  141. pattern_seed_ref: dict[str, Any],
  142. created_at: str,
  143. ) -> dict[str, Any]:
  144. return {
  145. "record_schema_version": RUNTIME_RECORD_SCHEMA_VERSION,
  146. "run_id": run_id,
  147. "policy_run_id": policy_run_id,
  148. "search_query_id": search_query_id,
  149. "search_query": search_query,
  150. "search_query_generation_method": generation_method,
  151. "discovery_start_source": "pattern_itemset",
  152. "previous_discovery_step": "pattern_search_query",
  153. "search_query_effect_status": "pending",
  154. "query_source_terms": [seed_term],
  155. "query_source_fields": ["seed_terms"],
  156. "pattern_seed_ref": pattern_seed_ref,
  157. "created_at": created_at,
  158. }
  159. def _generate_variant(
  160. query_variant_client: QueryVariantClient,
  161. seed_term: str,
  162. evidence_context: dict[str, Any],
  163. ) -> QueryVariantResult:
  164. try:
  165. result = query_variant_client.generate_variant(
  166. seed_term=seed_term,
  167. evidence_context=evidence_context,
  168. )
  169. except ContentAgentError:
  170. raise
  171. except Exception as exc:
  172. raise _query_generation_error(
  173. "llm_variant_exception",
  174. {
  175. "seed_term": seed_term,
  176. "exception_type": type(exc).__name__,
  177. },
  178. ) from exc
  179. if not isinstance(result, QueryVariantResult):
  180. raise _query_generation_error(
  181. "llm_variant_result_invalid",
  182. {
  183. "seed_term": seed_term,
  184. "result_type": type(result).__name__,
  185. },
  186. )
  187. return result
  188. def _validate_variant_query(
  189. query: str,
  190. seed_term: str,
  191. seen_queries: set[str],
  192. *,
  193. generic_filter: dict[str, Any] | None = None,
  194. ) -> None:
  195. if not query:
  196. raise _query_generation_error("llm_variant_empty", {"seed_term": seed_term})
  197. if query == seed_term:
  198. raise _query_generation_error("llm_variant_same_as_seed", {"seed_term": seed_term})
  199. if query in seen_queries:
  200. raise _query_generation_error(
  201. "llm_variant_duplicate",
  202. {
  203. "seed_term": seed_term,
  204. "search_query": query,
  205. },
  206. )
  207. if _is_generic_query(query, generic_filter=generic_filter):
  208. raise _query_generation_error(
  209. "llm_variant_generic",
  210. {
  211. "seed_term": seed_term,
  212. "search_query": query,
  213. },
  214. )
  215. def _reserve_query(
  216. row: dict[str, Any],
  217. seen_queries: set[str],
  218. *,
  219. seed_term: str,
  220. method: str,
  221. ) -> None:
  222. query = _normalize_query(row.get("search_query", ""))
  223. if not query:
  224. raise _query_generation_error(
  225. "search_query_empty",
  226. {
  227. "seed_term": seed_term,
  228. "search_query_generation_method": method,
  229. },
  230. )
  231. if query in seen_queries:
  232. raise _query_generation_error(
  233. "search_query_duplicate",
  234. {
  235. "seed_term": seed_term,
  236. "search_query": query,
  237. "search_query_generation_method": method,
  238. },
  239. )
  240. row["search_query"] = query
  241. seen_queries.add(query)
  242. def _pattern_seed_ref(
  243. pattern_seed_pack: dict[str, Any],
  244. seed_term: str,
  245. seed_index: int,
  246. ) -> dict[str, Any]:
  247. return {
  248. "source_field": "seed_terms",
  249. "source_index": seed_index,
  250. "seed_term": seed_term,
  251. "pattern_execution_id": pattern_seed_pack.get("pattern_execution_id"),
  252. "mining_config_id": pattern_seed_pack.get("mining_config_id"),
  253. "source_post_id": pattern_seed_pack.get("source_post_id"),
  254. "matched_post_ids": pattern_seed_pack.get("matched_post_ids") or [],
  255. "itemset_ids": _itemset_ids(pattern_seed_pack),
  256. }
  257. def _llm_input_evidence(
  258. pattern_seed_pack: dict[str, Any],
  259. seed_terms: list[str],
  260. seed_term: str,
  261. seed_index: int,
  262. existing_search_queries: list[str],
  263. evidence_fields: list[str] | None = None,
  264. ) -> dict[str, Any]:
  265. evidence = {
  266. "seed_term": seed_term,
  267. "seed_terms": seed_terms,
  268. "existing_search_queries": existing_search_queries,
  269. "source_field": "seed_terms",
  270. "source_index": seed_index,
  271. "itemset_items": pattern_seed_pack.get("itemset_items")
  272. or pattern_seed_pack.get("itemsets")
  273. or [],
  274. "category_bindings": pattern_seed_pack.get("category_bindings") or [],
  275. "element_bindings": pattern_seed_pack.get("element_bindings") or [],
  276. "pattern_source_system": pattern_seed_pack.get("pattern_source_system"),
  277. "pattern_execution_id": pattern_seed_pack.get("pattern_execution_id"),
  278. "mining_config_id": pattern_seed_pack.get("mining_config_id"),
  279. "source_post_id": pattern_seed_pack.get("source_post_id"),
  280. "matched_post_ids": pattern_seed_pack.get("matched_post_ids") or [],
  281. "itemset_ids": _itemset_ids(pattern_seed_pack),
  282. "support": pattern_seed_pack.get("support"),
  283. "absolute_support": pattern_seed_pack.get("absolute_support"),
  284. "confidence": pattern_seed_pack.get("confidence"),
  285. }
  286. if evidence_fields is None:
  287. return evidence
  288. return {field: evidence[field] for field in evidence_fields if field in evidence}
  289. def _itemset_ids(pattern_seed_pack: dict[str, Any]) -> list[Any]:
  290. direct = pattern_seed_pack.get("itemset_ids")
  291. if isinstance(direct, list):
  292. return direct
  293. itemsets = pattern_seed_pack.get("itemsets")
  294. if not isinstance(itemsets, list):
  295. return []
  296. ids: list[Any] = []
  297. for itemset in itemsets:
  298. if not isinstance(itemset, dict):
  299. continue
  300. itemset_id = itemset.get("itemset_id")
  301. if itemset_id is not None:
  302. ids.append(itemset_id)
  303. return ids
  304. def _normalize_query(value: Any) -> str:
  305. if not isinstance(value, str):
  306. return ""
  307. return " ".join(value.split()).strip()
  308. def _is_generic_query(query: str, generic_filter: dict[str, Any] | None = None) -> bool:
  309. generic_queries = set((generic_filter or {}).get("queries") or GENERIC_QUERIES)
  310. generic_tokens = tuple((generic_filter or {}).get("tokens") or GENERIC_QUERY_TOKENS)
  311. compact = "".join(query.split())
  312. if not compact or len(compact) <= 1:
  313. return True
  314. if not any(char.isalnum() for char in compact):
  315. return True
  316. if compact in generic_queries:
  317. return True
  318. remainder = compact
  319. for token in generic_tokens:
  320. remainder = remainder.replace(token, "")
  321. return not remainder
  322. def _query_generation_error(
  323. reason: str,
  324. detail: dict[str, Any] | None = None,
  325. ) -> ContentAgentError:
  326. return ContentAgentError(
  327. ErrorCode.QUERY_GENERATION_FAILED,
  328. "query generation failed",
  329. {
  330. "reason": reason,
  331. **(detail or {}),
  332. },
  333. )