topic_search.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. """
  2. 选题检索工具 - 根据关键词在数据库中匹配已有帖子的选题
  3. 用于 Agent 执行时自主调取参考数据,并选择与当前人设最匹配的内容输出。
  4. """
  5. import json
  6. import os
  7. from typing import Any, Dict, List, Optional
  8. import httpx
  9. from agent.tools import tool, ToolResult
  10. # 选题检索 API 配置
  11. TOPIC_SEARCH_BASE_URL = os.getenv("TOPIC_SEARCH_BASE_URL", "http://192.168.81.89:8000")
  12. DEFAULT_TIMEOUT = 30.0
  13. async def _call_search_api(keywords: List[str]) -> Optional[List[Dict[str, Any]]]:
  14. """调用选题检索 API,返回结果列表。"""
  15. url = f"{TOPIC_SEARCH_BASE_URL.rstrip('/')}/search"
  16. payload = {"keywords": keywords}
  17. try:
  18. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  19. resp = await client.post(url, json=payload)
  20. resp.raise_for_status()
  21. data = resp.json()
  22. except httpx.HTTPStatusError as e:
  23. raise RuntimeError(f"API 请求失败: {e.response.status_code} - {e.response.text[:200]}")
  24. except Exception as e:
  25. raise RuntimeError(f"请求异常: {str(e)}")
  26. # 兼容多种响应格式
  27. if isinstance(data, list):
  28. return data[:5]
  29. if isinstance(data, dict):
  30. items = data.get("data") or data.get("results") or data.get("items") or []
  31. return list(items)[:5] if isinstance(items, (list, tuple)) else []
  32. return []
  33. def _extract_text(obj: Any) -> str:
  34. """从结果对象中提取可比较的文本。"""
  35. if obj is None:
  36. return ""
  37. if isinstance(obj, str):
  38. return obj
  39. if isinstance(obj, dict):
  40. text_parts = []
  41. for k in ("title", "content", "主题", "选题", "描述", "description", "摘要"):
  42. v = obj.get(k)
  43. if v and isinstance(v, str):
  44. text_parts.append(v)
  45. if not text_parts:
  46. text_parts = [str(v) for v in obj.values() if isinstance(v, str)]
  47. return " ".join(text_parts)
  48. return str(obj)
  49. def _score_match(result: Dict[str, Any], persona_summary: str) -> float:
  50. """
  51. 计算单条结果与人设摘要的匹配度(简单关键词重叠)。
  52. 返回 0~1 之间的分数,越高表示越匹配。
  53. """
  54. if not persona_summary or not persona_summary.strip():
  55. return 1.0
  56. result_text = _extract_text(result).lower()
  57. persona_words = set(
  58. w for w in persona_summary.lower().replace(",", " ").replace(",", " ").split()
  59. if len(w) > 1
  60. )
  61. if not persona_words:
  62. return 1.0
  63. hits = sum(1 for w in persona_words if w in result_text)
  64. return hits / len(persona_words)
  65. def _pick_best_match(results: List[Dict[str, Any]], persona_summary: Optional[str]) -> Dict[str, Any]:
  66. """从结果中选出与人设最匹配的一条。"""
  67. if not results:
  68. raise ValueError("无可用结果")
  69. if not persona_summary or len(results) == 1:
  70. return results[0]
  71. best = max(results, key=lambda r: _score_match(r, persona_summary))
  72. return best
  73. @tool(
  74. description="根据关键词在数据库中检索已有帖子的选题,用于创作参考。最多返回5条,自动选择与当前人设最匹配的一条输出。",
  75. display={
  76. "zh": {
  77. "name": "爆款选题检索",
  78. "params": {
  79. "keywords": "关键词列表",
  80. },
  81. },
  82. },
  83. )
  84. async def topic_search(
  85. keywords: List[str],
  86. persona_summary: Optional[str] = None,
  87. ) -> ToolResult:
  88. """
  89. 根据关键词检索数据库中已有帖子的选题,选择与人设最匹配的一条作为参考。
  90. Args:
  91. keywords: 关键词列表,如 ["中老年健康养生", "爆款", "知识科普"]
  92. persona_summary: 当前人设摘要,用于从多条结果中筛选最匹配的(可选)
  93. Returns:
  94. ToolResult: 最匹配的选题参考内容
  95. """
  96. if not keywords:
  97. return ToolResult(
  98. title="选题检索失败",
  99. output="",
  100. error="请提供至少一个关键词",
  101. )
  102. try:
  103. results = await _call_search_api(keywords)
  104. except RuntimeError as e:
  105. return ToolResult(
  106. title="选题检索失败",
  107. output="",
  108. error=str(e),
  109. )
  110. if not results:
  111. return ToolResult(
  112. title="选题检索",
  113. output=json.dumps({"message": "未找到匹配的选题", "keywords": keywords}, ensure_ascii=False, indent=2),
  114. )
  115. try:
  116. best = _pick_best_match(results, persona_summary)
  117. except ValueError:
  118. return ToolResult(
  119. title="选题检索",
  120. output=json.dumps({"message": "无可用结果", "keywords": keywords}, ensure_ascii=False, indent=2),
  121. )
  122. output = json.dumps(best, ensure_ascii=False, indent=2)
  123. return ToolResult(
  124. title="选题检索 - 参考数据",
  125. output=output,
  126. long_term_memory=f"检索到与人设匹配的选题参考,关键词: {', '.join(keywords)}",
  127. )