| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- from __future__ import annotations
- from typing import Any
- from content_agent.errors import ContentAgentError, ErrorCode
- from content_agent.interfaces import PlatformSearchClient
- def run(
- search_queries: list[dict[str, Any]], platform_client: PlatformSearchClient
- ) -> dict[str, list[dict[str, Any]]]:
- results: list[dict[str, Any]] = []
- by_platform_content_id: dict[str, dict[str, Any]] = {}
- query_failures: list[dict[str, Any]] = []
- for search_query in search_queries:
- try:
- query_results = platform_client.search(search_query)
- except Exception as exc:
- query_failures.append(_query_failure(search_query, exc))
- continue
- for result in query_results:
- platform_content_id = result.get("platform_content_id")
- enriched = _with_query_source(result, search_query)
- if not platform_content_id:
- results.append(enriched)
- continue
- existing = by_platform_content_id.get(platform_content_id)
- if existing:
- _append_query_source(existing, search_query)
- continue
- by_platform_content_id[platform_content_id] = enriched
- results.append(enriched)
- if search_queries and query_failures and len(query_failures) == len(search_queries) and not results:
- raise ContentAgentError(
- ErrorCode.PLATFORM_REQUEST_FAILED,
- "all platform queries failed",
- {"query_failures": query_failures},
- )
- return {"platform_results": results, "query_failures": query_failures}
- def _with_query_source(
- result: dict[str, Any], search_query: dict[str, Any]
- ) -> dict[str, Any]:
- enriched = {
- **result,
- "search_query": search_query["search_query"],
- "search_query_generation_method": search_query["search_query_generation_method"],
- }
- _append_query_source(enriched, search_query)
- return enriched
- def _append_query_source(result: dict[str, Any], search_query: dict[str, Any]) -> None:
- source = {
- "search_query_id": search_query["search_query_id"],
- "search_query": search_query["search_query"],
- "search_query_generation_method": search_query["search_query_generation_method"],
- }
- if search_query.get("llm_variant_of"):
- source["llm_variant_of"] = search_query["llm_variant_of"]
- query_sources = result.setdefault("query_sources", [])
- if any(
- item.get("search_query_id") == source["search_query_id"]
- for item in query_sources
- ):
- return
- query_sources.append(source)
- result["matched_search_query_ids"] = [
- item["search_query_id"] for item in query_sources
- ]
- result["matched_search_queries"] = [item["search_query"] for item in query_sources]
- result["matched_search_query_generation_methods"] = [
- item["search_query_generation_method"] for item in query_sources
- ]
- def _query_failure(search_query: dict[str, Any], exc: Exception) -> dict[str, Any]:
- if isinstance(exc, ContentAgentError):
- error_code = exc.error_code.value
- message = exc.message
- detail = exc.detail
- else:
- error_code = ErrorCode.PLATFORM_REQUEST_FAILED.value
- message = "platform query failed"
- detail = {"exception_type": type(exc).__name__}
- return {
- "search_query_id": search_query["search_query_id"],
- "search_query": search_query["search_query"],
- "search_query_generation_method": search_query.get(
- "search_query_generation_method"
- ),
- "status": "failed",
- "error_code": error_code,
- "message": message,
- "error_detail": detail,
- }
|