| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- from __future__ import annotations
- from datetime import datetime, timezone
- from typing import Any
- from content_agent.constants import RUNTIME_RECORD_SCHEMA_VERSION
- from content_agent.errors import ContentAgentError, ErrorCode
- from content_agent.integrations.query_prompt_config import DEFAULT_PROFILE
- from content_agent.interfaces import QueryVariantClient, QueryVariantResult, RuntimeFileStore
- from content_agent.record_payload import with_raw_payload
- GENERIC_QUERIES = {
- "内容",
- "视频",
- "热门",
- "推荐",
- "短视频",
- "热门视频",
- "推荐视频",
- "热门内容",
- "推荐内容",
- "相关视频",
- "相关内容",
- "热" + "点视频",
- "热" + "点内容",
- }
- GENERIC_QUERY_TOKENS = (
- "短视频",
- "热门",
- "推荐",
- "相关",
- "热" + "点",
- "内容",
- "视频",
- "素材",
- "资料",
- "信息",
- "话题",
- )
- def run(
- run_id: str,
- policy_run_id: str,
- pattern_seed_pack: dict[str, Any],
- runtime: RuntimeFileStore,
- query_variant_client: QueryVariantClient,
- ) -> list[dict[str, Any]]:
- created_at = datetime.now(timezone.utc).isoformat()
- seed_terms = _terms(pattern_seed_pack.get("seed_terms"))
- if not seed_terms:
- raise _query_generation_error("seed_terms_empty")
- profile = getattr(query_variant_client, "profile", DEFAULT_PROFILE)
- variants_per_seed = int(profile.get("variants_per_seed", 1))
- if variants_per_seed != 1:
- raise _query_generation_error(
- "variants_per_seed_unsupported",
- {"variants_per_seed": variants_per_seed},
- )
- evidence_fields = profile.get("evidence_fields")
- generic_filter = profile.get("generic_filter")
- search_queries: list[dict[str, Any]] = []
- seen_queries: set[str] = set()
- for seed_index, seed_term in enumerate(seed_terms):
- item_query_id = f"q_{seed_index * 2 + 1:03d}"
- variant_query_id = f"q_{seed_index * 2 + 2:03d}"
- pattern_seed_ref = _pattern_seed_ref(pattern_seed_pack, seed_term, seed_index)
- item_single = _base_query_record(
- run_id=run_id,
- policy_run_id=policy_run_id,
- search_query_id=item_query_id,
- search_query=seed_term,
- generation_method="item_single",
- seed_term=seed_term,
- pattern_seed_ref=pattern_seed_ref,
- created_at=created_at,
- )
- _reserve_query(item_single, seen_queries, seed_term=seed_term, method="item_single")
- search_queries.append(item_single)
- evidence_context = _llm_input_evidence(
- pattern_seed_pack,
- seed_terms,
- seed_term,
- seed_index,
- sorted(seen_queries),
- evidence_fields=evidence_fields,
- )
- variant = _generate_variant(query_variant_client, seed_term, evidence_context)
- variant_query = _normalize_query(variant.query)
- _validate_variant_query(variant_query, seed_term, seen_queries, generic_filter=generic_filter)
- llm_variant = _base_query_record(
- run_id=run_id,
- policy_run_id=policy_run_id,
- search_query_id=variant_query_id,
- search_query=variant_query,
- generation_method="llm_variant",
- seed_term=seed_term,
- pattern_seed_ref=pattern_seed_ref,
- created_at=created_at,
- )
- llm_variant.update(
- {
- "llm_variant_of": item_query_id,
- "llm_input_evidence": variant.input_evidence,
- "llm_prompt_version": variant.prompt_version,
- "llm_generation_model": variant.model,
- }
- )
- _reserve_query(llm_variant, seen_queries, seed_term=seed_term, method="llm_variant")
- search_queries.append(llm_variant)
- expected_count = len(seed_terms) * 2
- if len(search_queries) != expected_count:
- raise _query_generation_error(
- "query_count_mismatch",
- {
- "seed_terms_count": len(seed_terms),
- "expected_query_count": expected_count,
- "actual_query_count": len(search_queries),
- },
- )
- search_queries = [with_raw_payload(row) for row in search_queries]
- runtime.append_jsonl(run_id, "search_queries.jsonl", search_queries)
- return search_queries
- def _terms(values: Any) -> list[str]:
- if not isinstance(values, list):
- return []
- unique: list[str] = []
- seen: set[str] = set()
- for value in values:
- if not isinstance(value, str):
- continue
- term = " ".join(value.split()).strip()
- if not term or term in seen:
- continue
- seen.add(term)
- unique.append(term)
- return unique
- def _base_query_record(
- *,
- run_id: str,
- policy_run_id: str,
- search_query_id: str,
- search_query: str,
- generation_method: str,
- seed_term: str,
- pattern_seed_ref: dict[str, Any],
- created_at: str,
- ) -> dict[str, Any]:
- return {
- "record_schema_version": RUNTIME_RECORD_SCHEMA_VERSION,
- "run_id": run_id,
- "policy_run_id": policy_run_id,
- "search_query_id": search_query_id,
- "search_query": search_query,
- "search_query_generation_method": generation_method,
- "discovery_start_source": "pattern_itemset",
- "previous_discovery_step": "pattern_search_query",
- "search_query_effect_status": "pending",
- "query_source_terms": [seed_term],
- "query_source_fields": ["seed_terms"],
- "pattern_seed_ref": pattern_seed_ref,
- "created_at": created_at,
- }
- def _generate_variant(
- query_variant_client: QueryVariantClient,
- seed_term: str,
- evidence_context: dict[str, Any],
- ) -> QueryVariantResult:
- try:
- result = query_variant_client.generate_variant(
- seed_term=seed_term,
- evidence_context=evidence_context,
- )
- except ContentAgentError:
- raise
- except Exception as exc:
- raise _query_generation_error(
- "llm_variant_exception",
- {
- "seed_term": seed_term,
- "exception_type": type(exc).__name__,
- },
- ) from exc
- if not isinstance(result, QueryVariantResult):
- raise _query_generation_error(
- "llm_variant_result_invalid",
- {
- "seed_term": seed_term,
- "result_type": type(result).__name__,
- },
- )
- return result
- def _validate_variant_query(
- query: str,
- seed_term: str,
- seen_queries: set[str],
- *,
- generic_filter: dict[str, Any] | None = None,
- ) -> None:
- if not query:
- raise _query_generation_error("llm_variant_empty", {"seed_term": seed_term})
- if query == seed_term:
- raise _query_generation_error("llm_variant_same_as_seed", {"seed_term": seed_term})
- if query in seen_queries:
- raise _query_generation_error(
- "llm_variant_duplicate",
- {
- "seed_term": seed_term,
- "search_query": query,
- },
- )
- if _is_generic_query(query, generic_filter=generic_filter):
- raise _query_generation_error(
- "llm_variant_generic",
- {
- "seed_term": seed_term,
- "search_query": query,
- },
- )
- def _reserve_query(
- row: dict[str, Any],
- seen_queries: set[str],
- *,
- seed_term: str,
- method: str,
- ) -> None:
- query = _normalize_query(row.get("search_query", ""))
- if not query:
- raise _query_generation_error(
- "search_query_empty",
- {
- "seed_term": seed_term,
- "search_query_generation_method": method,
- },
- )
- if query in seen_queries:
- raise _query_generation_error(
- "search_query_duplicate",
- {
- "seed_term": seed_term,
- "search_query": query,
- "search_query_generation_method": method,
- },
- )
- row["search_query"] = query
- seen_queries.add(query)
- def _pattern_seed_ref(
- pattern_seed_pack: dict[str, Any],
- seed_term: str,
- seed_index: int,
- ) -> dict[str, Any]:
- return {
- "source_field": "seed_terms",
- "source_index": seed_index,
- "seed_term": seed_term,
- "pattern_execution_id": pattern_seed_pack.get("pattern_execution_id"),
- "mining_config_id": pattern_seed_pack.get("mining_config_id"),
- "source_post_id": pattern_seed_pack.get("source_post_id"),
- "matched_post_ids": pattern_seed_pack.get("matched_post_ids") or [],
- "itemset_ids": _itemset_ids(pattern_seed_pack),
- }
- def _llm_input_evidence(
- pattern_seed_pack: dict[str, Any],
- seed_terms: list[str],
- seed_term: str,
- seed_index: int,
- existing_search_queries: list[str],
- evidence_fields: list[str] | None = None,
- ) -> dict[str, Any]:
- evidence = {
- "seed_term": seed_term,
- "seed_terms": seed_terms,
- "existing_search_queries": existing_search_queries,
- "source_field": "seed_terms",
- "source_index": seed_index,
- "itemset_items": pattern_seed_pack.get("itemset_items")
- or pattern_seed_pack.get("itemsets")
- or [],
- "category_bindings": pattern_seed_pack.get("category_bindings") or [],
- "element_bindings": pattern_seed_pack.get("element_bindings") or [],
- "pattern_source_system": pattern_seed_pack.get("pattern_source_system"),
- "pattern_execution_id": pattern_seed_pack.get("pattern_execution_id"),
- "mining_config_id": pattern_seed_pack.get("mining_config_id"),
- "source_post_id": pattern_seed_pack.get("source_post_id"),
- "matched_post_ids": pattern_seed_pack.get("matched_post_ids") or [],
- "itemset_ids": _itemset_ids(pattern_seed_pack),
- "support": pattern_seed_pack.get("support"),
- "absolute_support": pattern_seed_pack.get("absolute_support"),
- "confidence": pattern_seed_pack.get("confidence"),
- }
- if evidence_fields is None:
- return evidence
- return {field: evidence[field] for field in evidence_fields if field in evidence}
- def _itemset_ids(pattern_seed_pack: dict[str, Any]) -> list[Any]:
- direct = pattern_seed_pack.get("itemset_ids")
- if isinstance(direct, list):
- return direct
- itemsets = pattern_seed_pack.get("itemsets")
- if not isinstance(itemsets, list):
- return []
- ids: list[Any] = []
- for itemset in itemsets:
- if not isinstance(itemset, dict):
- continue
- itemset_id = itemset.get("itemset_id")
- if itemset_id is not None:
- ids.append(itemset_id)
- return ids
- def _normalize_query(value: Any) -> str:
- if not isinstance(value, str):
- return ""
- return " ".join(value.split()).strip()
- def _is_generic_query(query: str, generic_filter: dict[str, Any] | None = None) -> bool:
- generic_queries = set((generic_filter or {}).get("queries") or GENERIC_QUERIES)
- generic_tokens = tuple((generic_filter or {}).get("tokens") or GENERIC_QUERY_TOKENS)
- compact = "".join(query.split())
- if not compact or len(compact) <= 1:
- return True
- if not any(char.isalnum() for char in compact):
- return True
- if compact in generic_queries:
- return True
- remainder = compact
- for token in generic_tokens:
- remainder = remainder.replace(token, "")
- return not remainder
- def _query_generation_error(
- reason: str,
- detail: dict[str, Any] | None = None,
- ) -> ContentAgentError:
- return ContentAgentError(
- ErrorCode.QUERY_GENERATION_FAILED,
- "query generation failed",
- {
- "reason": reason,
- **(detail or {}),
- },
- )
|