librarian.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """
  2. Librarian Agent — KnowHub 的知识管理 Agent
  3. 通过 HTTP API 被 FastAPI server 调用,每次请求是一次 AgentRunner.run()。
  4. 状态全部持久化在 trace 中,通过 trace_id 续跑实现跨请求上下文积累。
  5. 两种调用模式:
  6. - ask: 同步,运行 Agent 处理查询,等待完成后返回结果
  7. - upload: 异步,存 buffer 后由后台任务运行 Agent 处理
  8. """
  9. import json
  10. import logging
  11. import sys
  12. from pathlib import Path
  13. from typing import Optional, Dict, Any
  14. # 确保项目路径可用
  15. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  16. from agent.core.runner import AgentRunner, RunConfig
  17. from agent.trace import FileSystemTraceStore, Trace, Message
  18. from agent.llm import create_qwen_llm_call
  19. from agent.llm.prompts import SimplePrompt
  20. from agent.tools.builtin.knowledge import KnowledgeConfig
  21. logger = logging.getLogger("agents.librarian")
  22. # ===== 配置 =====
  23. ENABLE_DATABASE_COMMIT = False
  24. # caller trace_id → librarian trace_id 的映射持久化文件
  25. TRACE_MAP_FILE = Path(".cache/.knowledge/trace_map.json")
  26. def get_librarian_config(enable_db_commit: bool = ENABLE_DATABASE_COMMIT) -> RunConfig:
  27. """获取 Librarian Agent 配置"""
  28. tools = [
  29. "knowledge_search",
  30. "tool_search",
  31. "capability_search",
  32. "requirement_search",
  33. "relation_search",
  34. "read_file", "write_file",
  35. "list_cache_status",
  36. "match_tree_nodes",
  37. "sync_atomic_capabilities",
  38. "skill",
  39. ]
  40. if enable_db_commit:
  41. tools.extend(["commit_to_database", "organize_cached_data", "cache_research_data"])
  42. else:
  43. tools.extend(["organize_cached_data", "cache_research_data"])
  44. return RunConfig(
  45. model="qwen3.5-plus",
  46. temperature=0.2,
  47. max_iterations=30,
  48. agent_type="default",
  49. name="Librarian Agent",
  50. goal_compression="on_complete",
  51. skills=[], # 不注入通用 skills(planning/research/browser),使用指定注入
  52. knowledge=KnowledgeConfig(
  53. enable_extraction=False,
  54. enable_completion_extraction=False,
  55. enable_injection=False,
  56. ),
  57. tools=tools,
  58. )
  59. def _register_internal_tools():
  60. """注册内部工具"""
  61. try:
  62. sys.path.insert(0, str(Path(__file__).parent.parent))
  63. from internal_tools.cache_manager import (
  64. cache_research_data,
  65. organize_cached_data,
  66. commit_to_database,
  67. list_cache_status,
  68. )
  69. from internal_tools.tree_matcher import match_tree_nodes
  70. from internal_tools.capability_extractor import sync_atomic_capabilities
  71. from agent.tools import get_tool_registry
  72. registry = get_tool_registry()
  73. registry.register(cache_research_data)
  74. registry.register(organize_cached_data)
  75. registry.register(commit_to_database)
  76. registry.register(list_cache_status)
  77. registry.register(match_tree_nodes)
  78. registry.register(sync_atomic_capabilities)
  79. logger.info("✓ 已注册 Librarian 内部工具")
  80. except Exception as e:
  81. logger.error(f"✗ 注册内部工具失败: {e}")
  82. # ===== trace_id 映射 =====
  83. def _load_trace_map() -> Dict[str, str]:
  84. if TRACE_MAP_FILE.exists():
  85. return json.loads(TRACE_MAP_FILE.read_text(encoding="utf-8"))
  86. return {}
  87. def _save_trace_map(mapping: Dict[str, str]):
  88. TRACE_MAP_FILE.parent.mkdir(parents=True, exist_ok=True)
  89. TRACE_MAP_FILE.write_text(json.dumps(mapping, indent=2, ensure_ascii=False), encoding="utf-8")
  90. def get_librarian_trace_id(caller_trace_id: str) -> Optional[str]:
  91. """根据调用方 trace_id 查找对应的 Librarian trace_id"""
  92. if not caller_trace_id:
  93. return None
  94. mapping = _load_trace_map()
  95. return mapping.get(caller_trace_id)
  96. def set_librarian_trace_id(caller_trace_id: str, librarian_trace_id: str):
  97. """记录映射"""
  98. if not caller_trace_id:
  99. return
  100. mapping = _load_trace_map()
  101. mapping[caller_trace_id] = librarian_trace_id
  102. _save_trace_map(mapping)
  103. # ===== 单例 Runner =====
  104. _runner: Optional[AgentRunner] = None
  105. _prompt_messages = None
  106. _initialized = False
  107. def _ensure_initialized():
  108. """延迟初始化 Runner 和 Prompt(首次调用时执行)"""
  109. global _runner, _prompt_messages, _initialized
  110. if _initialized:
  111. return
  112. _initialized = True
  113. _register_internal_tools()
  114. _runner = AgentRunner(
  115. trace_store=FileSystemTraceStore(base_path=".trace"),
  116. llm_call=create_qwen_llm_call(model="qwen3.5-plus"),
  117. skills_dir=str(Path(__file__).parent / "skills"),
  118. debug=True,
  119. logger_name="agents.librarian",
  120. )
  121. prompt_path = Path(__file__).parent / "librarian_agent.prompt"
  122. if prompt_path.exists():
  123. prompt = SimplePrompt(prompt_path)
  124. _prompt_messages = prompt.build_messages()
  125. else:
  126. _prompt_messages = []
  127. logger.warning(f"Librarian prompt 文件不存在: {prompt_path}")
  128. logger.info("✓ Librarian Agent 已初始化")
  129. # ===== 核心方法 =====
  130. async def ask(query: str, caller_trace_id: str = "") -> Dict[str, Any]:
  131. """
  132. 同步查询知识库。运行 Librarian Agent 处理查询,返回整合结果。
  133. Args:
  134. query: 查询内容
  135. caller_trace_id: 调用方 trace_id,用于续跑
  136. Returns:
  137. {"response": str, "source_ids": [str], "sources": [dict]}
  138. """
  139. _ensure_initialized()
  140. # 查找或创建 trace
  141. librarian_trace_id = get_librarian_trace_id(caller_trace_id)
  142. config = get_librarian_config()
  143. config.trace_id = librarian_trace_id # None = 新建, 有值 = 续跑
  144. # 构建消息
  145. content = f"[ASK] {query}"
  146. if librarian_trace_id is None:
  147. messages = _prompt_messages + [{"role": "user", "content": content}]
  148. else:
  149. messages = [{"role": "user", "content": content}]
  150. # 运行 Agent(指定注入 ask_strategy skill)
  151. response_text = ""
  152. actual_trace_id = None
  153. async for item in _runner.run(
  154. messages=messages, config=config,
  155. inject_skills=["ask_strategy"],
  156. skill_recency_threshold=20,
  157. ):
  158. if isinstance(item, Trace):
  159. actual_trace_id = item.trace_id
  160. elif isinstance(item, Message):
  161. if item.role == "assistant":
  162. msg_content = item.content
  163. if isinstance(msg_content, dict):
  164. text = msg_content.get("text", "")
  165. if text:
  166. response_text = text
  167. elif isinstance(msg_content, str) and msg_content:
  168. response_text = msg_content
  169. # 记录 trace 映射
  170. if actual_trace_id and caller_trace_id:
  171. set_librarian_trace_id(caller_trace_id, actual_trace_id)
  172. # 解析 source_ids(从 Agent 回复中提取,或从工具调用结果中提取)
  173. # Agent 回复中会引用 knowledge ID,格式如 [knowledge-xxx]
  174. import re
  175. source_ids = re.findall(r'\[?(knowledge-[a-zA-Z0-9_-]+)\]?', response_text)
  176. source_ids = list(dict.fromkeys(source_ids)) # 去重保序
  177. return {
  178. "response": response_text,
  179. "source_ids": source_ids,
  180. "sources": [], # TODO: 从 trace 的工具调用结果中提取 source 详情
  181. }
  182. async def process_upload(
  183. data: Dict[str, Any],
  184. caller_trace_id: str = "",
  185. buffer_file: Optional[str] = None,
  186. max_retries: int = 2,
  187. ):
  188. """
  189. 处理上传数据。运行 Librarian Agent 做图谱编排。
  190. 失败时重试,最终失败记录到 buffer 文件的状态中。
  191. Args:
  192. data: 上传数据 {knowledge, tools, resources}
  193. caller_trace_id: 调用方 trace_id
  194. buffer_file: 对应的 buffer 文件路径(用于更新状态)
  195. max_retries: 最大重试次数
  196. """
  197. _ensure_initialized()
  198. librarian_trace_id = get_librarian_trace_id(caller_trace_id)
  199. config = get_librarian_config()
  200. config.trace_id = librarian_trace_id
  201. content = f"[UPLOAD:BATCH] 收到上传请求,请处理:\n{json.dumps(data, ensure_ascii=False)}"
  202. if librarian_trace_id is None:
  203. messages = _prompt_messages + [{"role": "user", "content": content}]
  204. else:
  205. messages = [{"role": "user", "content": content}]
  206. last_error = None
  207. for attempt in range(max_retries + 1):
  208. try:
  209. actual_trace_id = None
  210. async for item in _runner.run(
  211. messages=messages, config=config,
  212. inject_skills=["upload_strategy"],
  213. skill_recency_threshold=10,
  214. ):
  215. if isinstance(item, Trace):
  216. actual_trace_id = item.trace_id
  217. if actual_trace_id and caller_trace_id:
  218. set_librarian_trace_id(caller_trace_id, actual_trace_id)
  219. # 成功:更新 buffer 文件状态
  220. _update_buffer_status(buffer_file, "completed", trace_id=actual_trace_id)
  221. logger.info(f"[Librarian] upload 处理完成,trace: {actual_trace_id}")
  222. return
  223. except Exception as e:
  224. last_error = str(e)
  225. logger.warning(f"[Librarian] upload 处理失败 (attempt {attempt + 1}/{max_retries + 1}): {e}")
  226. if attempt < max_retries:
  227. import asyncio
  228. await asyncio.sleep(2 ** attempt) # 1s, 2s 指数退避
  229. # 所有重试都失败
  230. _update_buffer_status(buffer_file, "failed", error=last_error)
  231. logger.error(f"[Librarian] upload 处理最终失败: {last_error}")
  232. def _update_buffer_status(buffer_file: Optional[str], status: str, trace_id: str = None, error: str = None):
  233. """更新 buffer 文件中的处理状态"""
  234. if not buffer_file:
  235. return
  236. try:
  237. from datetime import datetime as dt
  238. path = Path(buffer_file)
  239. if not path.exists():
  240. return
  241. data = json.loads(path.read_text(encoding="utf-8"))
  242. data["status"] = status
  243. data["processed_at"] = dt.now().isoformat()
  244. if trace_id:
  245. data["librarian_trace_id"] = trace_id
  246. if error:
  247. data["error"] = error
  248. path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
  249. except Exception as e:
  250. logger.warning(f"更新 buffer 状态失败: {e}")
  251. def list_pending_uploads() -> list:
  252. """列出所有未处理或失败的 upload buffer 文件"""
  253. buffer_dir = Path(".cache/.knowledge/buffer")
  254. if not buffer_dir.exists():
  255. return []
  256. pending = []
  257. for f in sorted(buffer_dir.glob("upload_*.json")):
  258. try:
  259. data = json.loads(f.read_text(encoding="utf-8"))
  260. status = data.get("status", "pending")
  261. if status in ("pending", "failed"):
  262. pending.append({
  263. "file": str(f),
  264. "status": status,
  265. "received_at": data.get("received_at", ""),
  266. "error": data.get("error", ""),
  267. "trace_id": data.get("trace_id", ""),
  268. })
  269. except Exception:
  270. pass
  271. return pending