from __future__ import annotations import copy import os from pathlib import Path from typing import Any, Mapping import httpx from content_agent.errors import ContentAgentError, ErrorCode from content_agent.integrations.query_prompt_config import DEFAULT_PROFILE, load_profile from content_agent.interfaces import QueryVariantClient, QueryVariantResult DEFAULT_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" DEFAULT_QUERY_PROMPT_VERSION = "query_variant.v1" DEFAULT_QUERY_TIMEOUT_SECONDS = 60.0 class MissingQueryVariantClient: def __init__(self, reason: str, detail: dict[str, Any] | None = None) -> None: self.reason = reason self.detail = detail or {} def generate_variant( self, *, seed_term: str, evidence_context: dict[str, Any], ) -> QueryVariantResult: raise ContentAgentError( ErrorCode.QUERY_GENERATION_FAILED, "query generation failed", { "reason": self.reason, "seed_term": seed_term, **self.detail, }, ) class OpenRouterQueryVariantClient: def __init__( self, *, api_key: str, model: str, base_url: str = DEFAULT_OPENROUTER_BASE_URL, timeout_seconds: float = DEFAULT_QUERY_TIMEOUT_SECONDS, prompt_version: str = DEFAULT_QUERY_PROMPT_VERSION, profile: dict[str, Any] | None = None, ) -> None: self.api_key = api_key self.model = model self.base_url = base_url.rstrip("/") self.timeout_seconds = timeout_seconds self.profile = copy.deepcopy(profile or DEFAULT_PROFILE) self.prompt_version = str(self.profile.get("prompt_version") or prompt_version) def generate_variant( self, *, seed_term: str, evidence_context: dict[str, Any], ) -> QueryVariantResult: try: response = httpx.post( f"{self.base_url}/chat/completions", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={ "model": self.model, "messages": _render_messages(self.profile, seed_term, evidence_context), "temperature": self.profile["temperature"], "max_tokens": self.profile["max_tokens"], }, timeout=self.timeout_seconds, ) response.raise_for_status() query = _extract_query(response.json()) except ContentAgentError: raise except httpx.HTTPStatusError as exc: raise _generation_error( "openrouter_http_status", seed_term, {"status_code": exc.response.status_code}, ) from exc except httpx.HTTPError as exc: raise _generation_error( "openrouter_http_error", seed_term, {"exception_type": type(exc).__name__}, ) from exc except (KeyError, TypeError, ValueError) as exc: raise _generation_error( "openrouter_response_invalid", seed_term, {"exception_type": type(exc).__name__}, ) from exc return QueryVariantResult( query=query, model=self.model, prompt_version=self.prompt_version, input_evidence=evidence_context, ) def query_variant_client_from_env( env: Mapping[str, str] | None = None, *, platform: str = "douyin", strategy_version: str = "V1", root_dir: Path | str = Path("."), ) -> QueryVariantClient: source = os.environ if env is None else env api_key = _env_value(source, "OPENROUTER_API_KEY") or _env_value( source, "OPEN_ROUTER_API_KEY", ) model = _env_value(source, "CONTENT_AGENT_QUERY_LLM_MODEL") or _env_value(source, "MODEL") base_url = _env_value(source, "OPENROUTER_BASE_URL") or DEFAULT_OPENROUTER_BASE_URL prompt_version = _env_value(source, "CONTENT_AGENT_QUERY_LLM_PROMPT_VERSION") or DEFAULT_QUERY_PROMPT_VERSION timeout_seconds = _float_env( source, "CONTENT_AGENT_QUERY_LLM_TIMEOUT_SECONDS", DEFAULT_QUERY_TIMEOUT_SECONDS, ) missing = [] if not api_key: missing.append("OPENROUTER_API_KEY") if not model: missing.append("CONTENT_AGENT_QUERY_LLM_MODEL") if missing: return MissingQueryVariantClient( "query variant LLM config missing", {"missing_env_keys": missing}, ) return OpenRouterQueryVariantClient( api_key=api_key, model=model, base_url=base_url, timeout_seconds=timeout_seconds, prompt_version=prompt_version, profile=load_profile(platform, strategy_version, root_dir=root_dir), ) def _messages(seed_term: str, evidence_context: dict[str, Any]) -> list[dict[str, str]]: return _render_messages(DEFAULT_PROFILE, seed_term, evidence_context) def _render_messages( profile: dict[str, Any], seed_term: str, evidence_context: dict[str, Any], ) -> list[dict[str, str]]: return [ { "role": "system", "content": str(profile["system"]), }, { "role": "user", "content": str(profile["user"]) .replace("{seed_term}", seed_term) .replace("{evidence_context}", str(evidence_context)), }, ] def _extract_query(payload: dict[str, Any]) -> str: content = payload["choices"][0]["message"]["content"] if not isinstance(content, str): raise ValueError("OpenRouter content is not a string") query = _normalize_query(content) if not query: raise ValueError("OpenRouter content is empty") if "\n" in query: raise ValueError("OpenRouter content has multiple lines") return query def _normalize_query(value: str) -> str: query = " ".join(value.split()).strip() return query.strip("`'\"“”‘’") def _generation_error( reason: str, seed_term: str, detail: dict[str, Any] | None = None, ) -> ContentAgentError: return ContentAgentError( ErrorCode.QUERY_GENERATION_FAILED, "query generation failed", { "reason": reason, "seed_term": seed_term, **(detail or {}), }, ) def _env_value(env: Mapping[str, str], key: str) -> str: return str(env.get(key, "")).strip() def _float_env(env: Mapping[str, str], key: str, default: float) -> float: value = _env_value(env, key) if not value: return default try: return float(value) except ValueError: return default