| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- """
- 选题检索工具 - 根据关键词在数据库中匹配已有帖子的选题
- 用于 Agent 执行时自主调取参考数据,并选择与当前人设最匹配的内容输出。
- """
- import json
- import os
- from typing import Any, Dict, List, Optional
- import httpx
- from agent.tools import tool, ToolResult
- # 选题检索 API 配置
- TOPIC_SEARCH_BASE_URL = os.getenv("TOPIC_SEARCH_BASE_URL", "http://192.168.81.89:8000")
- DEFAULT_TIMEOUT = 30.0
- async def _call_search_api(keywords: List[str]) -> Optional[List[Dict[str, Any]]]:
- """调用选题检索 API,返回结果列表。"""
- url = f"{TOPIC_SEARCH_BASE_URL.rstrip('/')}/search"
- payload = {"keywords": keywords}
- try:
- async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
- resp = await client.post(url, json=payload)
- resp.raise_for_status()
- data = resp.json()
- except httpx.HTTPStatusError as e:
- raise RuntimeError(f"API 请求失败: {e.response.status_code} - {e.response.text[:200]}")
- except Exception as e:
- raise RuntimeError(f"请求异常: {str(e)}")
- # 兼容多种响应格式
- if isinstance(data, list):
- return data[:5]
- if isinstance(data, dict):
- items = data.get("data") or data.get("results") or data.get("items") or []
- return list(items)[:5] if isinstance(items, (list, tuple)) else []
- return []
- def _extract_text(obj: Any) -> str:
- """从结果对象中提取可比较的文本。"""
- if obj is None:
- return ""
- if isinstance(obj, str):
- return obj
- if isinstance(obj, dict):
- text_parts = []
- for k in ("title", "content", "主题", "选题", "描述", "description", "摘要"):
- v = obj.get(k)
- if v and isinstance(v, str):
- text_parts.append(v)
- if not text_parts:
- text_parts = [str(v) for v in obj.values() if isinstance(v, str)]
- return " ".join(text_parts)
- return str(obj)
- def _score_match(result: Dict[str, Any], persona_summary: str) -> float:
- """
- 计算单条结果与人设摘要的匹配度(简单关键词重叠)。
- 返回 0~1 之间的分数,越高表示越匹配。
- """
- if not persona_summary or not persona_summary.strip():
- return 1.0
- result_text = _extract_text(result).lower()
- persona_words = set(
- w for w in persona_summary.lower().replace(",", " ").replace(",", " ").split()
- if len(w) > 1
- )
- if not persona_words:
- return 1.0
- hits = sum(1 for w in persona_words if w in result_text)
- return hits / len(persona_words)
- def _pick_best_match(results: List[Dict[str, Any]], persona_summary: Optional[str]) -> Dict[str, Any]:
- """从结果中选出与人设最匹配的一条。"""
- if not results:
- raise ValueError("无可用结果")
- if not persona_summary or len(results) == 1:
- return results[0]
- best = max(results, key=lambda r: _score_match(r, persona_summary))
- return best
- @tool(
- description="根据关键词在数据库中检索已有帖子的选题,用于创作参考。最多返回5条,自动选择与当前人设最匹配的一条输出。",
- display={
- "zh": {
- "name": "爆款选题检索",
- "params": {
- "keywords": "关键词列表",
- },
- },
- },
- )
- async def topic_search(
- keywords: List[str],
- persona_summary: Optional[str] = None,
- ) -> ToolResult:
- """
- 根据关键词检索数据库中已有帖子的选题,选择与人设最匹配的一条作为参考。
- Args:
- keywords: 关键词列表,如 ["中老年健康养生", "爆款", "知识科普"]
- persona_summary: 当前人设摘要,用于从多条结果中筛选最匹配的(可选)
- Returns:
- ToolResult: 最匹配的选题参考内容
- """
- if not keywords:
- return ToolResult(
- title="选题检索失败",
- output="",
- error="请提供至少一个关键词",
- )
- try:
- results = await _call_search_api(keywords)
- except RuntimeError as e:
- return ToolResult(
- title="选题检索失败",
- output="",
- error=str(e),
- )
- if not results:
- return ToolResult(
- title="选题检索",
- output=json.dumps({"message": "未找到匹配的选题", "keywords": keywords}, ensure_ascii=False, indent=2),
- )
- try:
- best = _pick_best_match(results, persona_summary)
- except ValueError:
- return ToolResult(
- title="选题检索",
- output=json.dumps({"message": "无可用结果", "keywords": keywords}, ensure_ascii=False, indent=2),
- )
- output = json.dumps(best, ensure_ascii=False, indent=2)
- return ToolResult(
- title="选题检索 - 参考数据",
- output=output,
- long_term_memory=f"检索到与人设匹配的选题参考,关键词: {', '.join(keywords)}",
- )
|