| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- from __future__ import annotations
- import copy
- import os
- from pathlib import Path
- from typing import Any, Mapping
- import httpx
- from content_agent.errors import ContentAgentError, ErrorCode
- from content_agent.integrations.query_prompt_config import DEFAULT_PROFILE, load_profile
- from content_agent.interfaces import QueryVariantClient, QueryVariantResult
- DEFAULT_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
- DEFAULT_QUERY_PROMPT_VERSION = "query_variant.v1"
- DEFAULT_QUERY_TIMEOUT_SECONDS = 60.0
- class MissingQueryVariantClient:
- def __init__(self, reason: str, detail: dict[str, Any] | None = None) -> None:
- self.reason = reason
- self.detail = detail or {}
- def generate_variant(
- self,
- *,
- seed_term: str,
- evidence_context: dict[str, Any],
- ) -> QueryVariantResult:
- raise ContentAgentError(
- ErrorCode.QUERY_GENERATION_FAILED,
- "query generation failed",
- {
- "reason": self.reason,
- "seed_term": seed_term,
- **self.detail,
- },
- )
- class OpenRouterQueryVariantClient:
- def __init__(
- self,
- *,
- api_key: str,
- model: str,
- base_url: str = DEFAULT_OPENROUTER_BASE_URL,
- timeout_seconds: float = DEFAULT_QUERY_TIMEOUT_SECONDS,
- prompt_version: str = DEFAULT_QUERY_PROMPT_VERSION,
- profile: dict[str, Any] | None = None,
- ) -> None:
- self.api_key = api_key
- self.model = model
- self.base_url = base_url.rstrip("/")
- self.timeout_seconds = timeout_seconds
- self.profile = copy.deepcopy(profile or DEFAULT_PROFILE)
- self.prompt_version = str(self.profile.get("prompt_version") or prompt_version)
- def generate_variant(
- self,
- *,
- seed_term: str,
- evidence_context: dict[str, Any],
- ) -> QueryVariantResult:
- try:
- response = httpx.post(
- f"{self.base_url}/chat/completions",
- headers={
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json",
- },
- json={
- "model": self.model,
- "messages": _render_messages(self.profile, seed_term, evidence_context),
- "temperature": self.profile["temperature"],
- "max_tokens": self.profile["max_tokens"],
- },
- timeout=self.timeout_seconds,
- )
- response.raise_for_status()
- query = _extract_query(response.json())
- except ContentAgentError:
- raise
- except httpx.HTTPStatusError as exc:
- raise _generation_error(
- "openrouter_http_status",
- seed_term,
- {"status_code": exc.response.status_code},
- ) from exc
- except httpx.HTTPError as exc:
- raise _generation_error(
- "openrouter_http_error",
- seed_term,
- {"exception_type": type(exc).__name__},
- ) from exc
- except (KeyError, TypeError, ValueError) as exc:
- raise _generation_error(
- "openrouter_response_invalid",
- seed_term,
- {"exception_type": type(exc).__name__},
- ) from exc
- return QueryVariantResult(
- query=query,
- model=self.model,
- prompt_version=self.prompt_version,
- input_evidence=evidence_context,
- )
- def query_variant_client_from_env(
- env: Mapping[str, str] | None = None,
- *,
- platform: str = "douyin",
- strategy_version: str = "V1",
- root_dir: Path | str = Path("."),
- ) -> QueryVariantClient:
- source = os.environ if env is None else env
- api_key = _env_value(source, "OPENROUTER_API_KEY") or _env_value(
- source,
- "OPEN_ROUTER_API_KEY",
- )
- model = _env_value(source, "CONTENT_AGENT_QUERY_LLM_MODEL") or _env_value(source, "MODEL")
- base_url = _env_value(source, "OPENROUTER_BASE_URL") or DEFAULT_OPENROUTER_BASE_URL
- prompt_version = _env_value(source, "CONTENT_AGENT_QUERY_LLM_PROMPT_VERSION") or DEFAULT_QUERY_PROMPT_VERSION
- timeout_seconds = _float_env(
- source,
- "CONTENT_AGENT_QUERY_LLM_TIMEOUT_SECONDS",
- DEFAULT_QUERY_TIMEOUT_SECONDS,
- )
- missing = []
- if not api_key:
- missing.append("OPENROUTER_API_KEY")
- if not model:
- missing.append("CONTENT_AGENT_QUERY_LLM_MODEL")
- if missing:
- return MissingQueryVariantClient(
- "query variant LLM config missing",
- {"missing_env_keys": missing},
- )
- return OpenRouterQueryVariantClient(
- api_key=api_key,
- model=model,
- base_url=base_url,
- timeout_seconds=timeout_seconds,
- prompt_version=prompt_version,
- profile=load_profile(platform, strategy_version, root_dir=root_dir),
- )
- def _messages(seed_term: str, evidence_context: dict[str, Any]) -> list[dict[str, str]]:
- return _render_messages(DEFAULT_PROFILE, seed_term, evidence_context)
- def _render_messages(
- profile: dict[str, Any],
- seed_term: str,
- evidence_context: dict[str, Any],
- ) -> list[dict[str, str]]:
- return [
- {
- "role": "system",
- "content": str(profile["system"]),
- },
- {
- "role": "user",
- "content": str(profile["user"])
- .replace("{seed_term}", seed_term)
- .replace("{evidence_context}", str(evidence_context)),
- },
- ]
- def _extract_query(payload: dict[str, Any]) -> str:
- content = payload["choices"][0]["message"]["content"]
- if not isinstance(content, str):
- raise ValueError("OpenRouter content is not a string")
- query = _normalize_query(content)
- if not query:
- raise ValueError("OpenRouter content is empty")
- if "\n" in query:
- raise ValueError("OpenRouter content has multiple lines")
- return query
- def _normalize_query(value: str) -> str:
- query = " ".join(value.split()).strip()
- return query.strip("`'\"“”‘’")
- def _generation_error(
- reason: str,
- seed_term: str,
- detail: dict[str, Any] | None = None,
- ) -> ContentAgentError:
- return ContentAgentError(
- ErrorCode.QUERY_GENERATION_FAILED,
- "query generation failed",
- {
- "reason": reason,
- "seed_term": seed_term,
- **(detail or {}),
- },
- )
- def _env_value(env: Mapping[str, str], key: str) -> str:
- return str(env.get(key, "")).strip()
- def _float_env(env: Mapping[str, str], key: str, default: float) -> float:
- value = _env_value(env, key)
- if not value:
- return default
- try:
- return float(value)
- except ValueError:
- return default
|