platform_access.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from __future__ import annotations
  2. from typing import Any
  3. from content_agent.errors import ContentAgentError, ErrorCode
  4. from content_agent.interfaces import PlatformSearchClient
  5. def run(
  6. search_queries: list[dict[str, Any]], platform_client: PlatformSearchClient
  7. ) -> dict[str, list[dict[str, Any]]]:
  8. results: list[dict[str, Any]] = []
  9. by_platform_content_id: dict[str, dict[str, Any]] = {}
  10. query_failures: list[dict[str, Any]] = []
  11. for search_query in search_queries:
  12. try:
  13. query_results = platform_client.search(search_query)
  14. except Exception as exc:
  15. query_failures.append(_query_failure(search_query, exc))
  16. continue
  17. for result in query_results:
  18. platform_content_id = result.get("platform_content_id")
  19. enriched = _with_query_source(result, search_query)
  20. if not platform_content_id:
  21. results.append(enriched)
  22. continue
  23. existing = by_platform_content_id.get(platform_content_id)
  24. if existing:
  25. _append_query_source(existing, search_query)
  26. continue
  27. by_platform_content_id[platform_content_id] = enriched
  28. results.append(enriched)
  29. if search_queries and query_failures and len(query_failures) == len(search_queries) and not results:
  30. raise ContentAgentError(
  31. ErrorCode.PLATFORM_REQUEST_FAILED,
  32. "all platform queries failed",
  33. {"query_failures": query_failures},
  34. )
  35. return {"platform_results": results, "query_failures": query_failures}
  36. def _with_query_source(
  37. result: dict[str, Any], search_query: dict[str, Any]
  38. ) -> dict[str, Any]:
  39. enriched = {
  40. **result,
  41. "search_query": search_query["search_query"],
  42. "search_query_generation_method": search_query["search_query_generation_method"],
  43. }
  44. _append_query_source(enriched, search_query)
  45. return enriched
  46. def _append_query_source(result: dict[str, Any], search_query: dict[str, Any]) -> None:
  47. source = {
  48. "search_query_id": search_query["search_query_id"],
  49. "search_query": search_query["search_query"],
  50. "search_query_generation_method": search_query["search_query_generation_method"],
  51. }
  52. if search_query.get("llm_variant_of"):
  53. source["llm_variant_of"] = search_query["llm_variant_of"]
  54. query_sources = result.setdefault("query_sources", [])
  55. if any(
  56. item.get("search_query_id") == source["search_query_id"]
  57. for item in query_sources
  58. ):
  59. return
  60. query_sources.append(source)
  61. result["matched_search_query_ids"] = [
  62. item["search_query_id"] for item in query_sources
  63. ]
  64. result["matched_search_queries"] = [item["search_query"] for item in query_sources]
  65. result["matched_search_query_generation_methods"] = [
  66. item["search_query_generation_method"] for item in query_sources
  67. ]
  68. def _query_failure(search_query: dict[str, Any], exc: Exception) -> dict[str, Any]:
  69. if isinstance(exc, ContentAgentError):
  70. error_code = exc.error_code.value
  71. message = exc.message
  72. detail = exc.detail
  73. else:
  74. error_code = ErrorCode.PLATFORM_REQUEST_FAILED.value
  75. message = "platform query failed"
  76. detail = {"exception_type": type(exc).__name__}
  77. return {
  78. "search_query_id": search_query["search_query_id"],
  79. "search_query": search_query["search_query"],
  80. "search_query_generation_method": search_query.get(
  81. "search_query_generation_method"
  82. ),
  83. "status": "failed",
  84. "error_code": error_code,
  85. "message": message,
  86. "error_detail": detail,
  87. }