query_variant.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from __future__ import annotations
  2. import copy
  3. import os
  4. from pathlib import Path
  5. from typing import Any, Mapping
  6. import httpx
  7. from content_agent.errors import ContentAgentError, ErrorCode
  8. from content_agent.integrations.query_prompt_config import DEFAULT_PROFILE, load_profile
  9. from content_agent.interfaces import QueryVariantClient, QueryVariantResult
  10. DEFAULT_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
  11. DEFAULT_QUERY_PROMPT_VERSION = "query_variant.v1"
  12. DEFAULT_QUERY_TIMEOUT_SECONDS = 60.0
  13. class MissingQueryVariantClient:
  14. def __init__(self, reason: str, detail: dict[str, Any] | None = None) -> None:
  15. self.reason = reason
  16. self.detail = detail or {}
  17. def generate_variant(
  18. self,
  19. *,
  20. seed_term: str,
  21. evidence_context: dict[str, Any],
  22. ) -> QueryVariantResult:
  23. raise ContentAgentError(
  24. ErrorCode.QUERY_GENERATION_FAILED,
  25. "query generation failed",
  26. {
  27. "reason": self.reason,
  28. "seed_term": seed_term,
  29. **self.detail,
  30. },
  31. )
  32. class OpenRouterQueryVariantClient:
  33. def __init__(
  34. self,
  35. *,
  36. api_key: str,
  37. model: str,
  38. base_url: str = DEFAULT_OPENROUTER_BASE_URL,
  39. timeout_seconds: float = DEFAULT_QUERY_TIMEOUT_SECONDS,
  40. prompt_version: str = DEFAULT_QUERY_PROMPT_VERSION,
  41. profile: dict[str, Any] | None = None,
  42. ) -> None:
  43. self.api_key = api_key
  44. self.model = model
  45. self.base_url = base_url.rstrip("/")
  46. self.timeout_seconds = timeout_seconds
  47. self.profile = copy.deepcopy(profile or DEFAULT_PROFILE)
  48. self.prompt_version = str(self.profile.get("prompt_version") or prompt_version)
  49. def generate_variant(
  50. self,
  51. *,
  52. seed_term: str,
  53. evidence_context: dict[str, Any],
  54. ) -> QueryVariantResult:
  55. try:
  56. response = httpx.post(
  57. f"{self.base_url}/chat/completions",
  58. headers={
  59. "Authorization": f"Bearer {self.api_key}",
  60. "Content-Type": "application/json",
  61. },
  62. json={
  63. "model": self.model,
  64. "messages": _render_messages(self.profile, seed_term, evidence_context),
  65. "temperature": self.profile["temperature"],
  66. "max_tokens": self.profile["max_tokens"],
  67. },
  68. timeout=self.timeout_seconds,
  69. )
  70. response.raise_for_status()
  71. query = _extract_query(response.json())
  72. except ContentAgentError:
  73. raise
  74. except httpx.HTTPStatusError as exc:
  75. raise _generation_error(
  76. "openrouter_http_status",
  77. seed_term,
  78. {"status_code": exc.response.status_code},
  79. ) from exc
  80. except httpx.HTTPError as exc:
  81. raise _generation_error(
  82. "openrouter_http_error",
  83. seed_term,
  84. {"exception_type": type(exc).__name__},
  85. ) from exc
  86. except (KeyError, TypeError, ValueError) as exc:
  87. raise _generation_error(
  88. "openrouter_response_invalid",
  89. seed_term,
  90. {"exception_type": type(exc).__name__},
  91. ) from exc
  92. return QueryVariantResult(
  93. query=query,
  94. model=self.model,
  95. prompt_version=self.prompt_version,
  96. input_evidence=evidence_context,
  97. )
  98. def query_variant_client_from_env(
  99. env: Mapping[str, str] | None = None,
  100. *,
  101. platform: str = "douyin",
  102. strategy_version: str = "V1",
  103. root_dir: Path | str = Path("."),
  104. ) -> QueryVariantClient:
  105. source = os.environ if env is None else env
  106. api_key = _env_value(source, "OPENROUTER_API_KEY") or _env_value(
  107. source,
  108. "OPEN_ROUTER_API_KEY",
  109. )
  110. model = _env_value(source, "CONTENT_AGENT_QUERY_LLM_MODEL") or _env_value(source, "MODEL")
  111. base_url = _env_value(source, "OPENROUTER_BASE_URL") or DEFAULT_OPENROUTER_BASE_URL
  112. prompt_version = _env_value(source, "CONTENT_AGENT_QUERY_LLM_PROMPT_VERSION") or DEFAULT_QUERY_PROMPT_VERSION
  113. timeout_seconds = _float_env(
  114. source,
  115. "CONTENT_AGENT_QUERY_LLM_TIMEOUT_SECONDS",
  116. DEFAULT_QUERY_TIMEOUT_SECONDS,
  117. )
  118. missing = []
  119. if not api_key:
  120. missing.append("OPENROUTER_API_KEY")
  121. if not model:
  122. missing.append("CONTENT_AGENT_QUERY_LLM_MODEL")
  123. if missing:
  124. return MissingQueryVariantClient(
  125. "query variant LLM config missing",
  126. {"missing_env_keys": missing},
  127. )
  128. return OpenRouterQueryVariantClient(
  129. api_key=api_key,
  130. model=model,
  131. base_url=base_url,
  132. timeout_seconds=timeout_seconds,
  133. prompt_version=prompt_version,
  134. profile=load_profile(platform, strategy_version, root_dir=root_dir),
  135. )
  136. def _messages(seed_term: str, evidence_context: dict[str, Any]) -> list[dict[str, str]]:
  137. return _render_messages(DEFAULT_PROFILE, seed_term, evidence_context)
  138. def _render_messages(
  139. profile: dict[str, Any],
  140. seed_term: str,
  141. evidence_context: dict[str, Any],
  142. ) -> list[dict[str, str]]:
  143. return [
  144. {
  145. "role": "system",
  146. "content": str(profile["system"]),
  147. },
  148. {
  149. "role": "user",
  150. "content": str(profile["user"])
  151. .replace("{seed_term}", seed_term)
  152. .replace("{evidence_context}", str(evidence_context)),
  153. },
  154. ]
  155. def _extract_query(payload: dict[str, Any]) -> str:
  156. content = payload["choices"][0]["message"]["content"]
  157. if not isinstance(content, str):
  158. raise ValueError("OpenRouter content is not a string")
  159. query = _normalize_query(content)
  160. if not query:
  161. raise ValueError("OpenRouter content is empty")
  162. if "\n" in query:
  163. raise ValueError("OpenRouter content has multiple lines")
  164. return query
  165. def _normalize_query(value: str) -> str:
  166. query = " ".join(value.split()).strip()
  167. return query.strip("`'\"“”‘’")
  168. def _generation_error(
  169. reason: str,
  170. seed_term: str,
  171. detail: dict[str, Any] | None = None,
  172. ) -> ContentAgentError:
  173. return ContentAgentError(
  174. ErrorCode.QUERY_GENERATION_FAILED,
  175. "query generation failed",
  176. {
  177. "reason": reason,
  178. "seed_term": seed_term,
  179. **(detail or {}),
  180. },
  181. )
  182. def _env_value(env: Mapping[str, str], key: str) -> str:
  183. return str(env.get(key, "")).strip()
  184. def _float_env(env: Mapping[str, str], key: str, default: float) -> float:
  185. value = _env_value(env, key)
  186. if not value:
  187. return default
  188. try:
  189. return float(value)
  190. except ValueError:
  191. return default