librarian.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. """
  2. Librarian Agent — KnowHub 的知识管理 Agent
  3. 同一个 Agent 处理多种任务模式(查询 / 上传等),调用方通过 skills 参数选择策略:
  4. - skills=["ask_strategy"] → 知识库查询整合
  5. - skills=["upload_strategy"] → 知识上传图谱编排
  6. 通过 HTTP API 被 FastAPI server 调用,每次请求一次 AgentRunner.run()。
  7. 续跑由 caller 显式传入 continue_from 指定。
  8. """
  9. import json
  10. import logging
  11. import sys
  12. from pathlib import Path
  13. from typing import Optional, Dict, Any, List
  14. # 确保项目路径可用
  15. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  16. from agent.core.runner import AgentRunner, RunConfig
  17. from agent.trace import FileSystemTraceStore
  18. from agent.llm import create_qwen_llm_call
  19. from agent.llm.prompts import SimplePrompt
  20. from agent.tools.builtin.knowledge import KnowledgeConfig
  21. logger = logging.getLogger("agents.librarian")
  22. # ===== 配置 =====
  23. ENABLE_DATABASE_COMMIT = False
  24. # Librarian 允许调用方注入的 skill 白名单。不在列表中的 skill 会被过滤掉(只打 warning,不报错)。
  25. ALLOWED_SKILLS = ["ask_strategy", "upload_strategy"]
  26. DEFAULT_SKILLS = ["ask_strategy"] # 调用方不传时的默认值
  27. def get_librarian_config(enable_db_commit: bool = ENABLE_DATABASE_COMMIT) -> RunConfig:
  28. """获取 Librarian Agent 配置"""
  29. tools = [
  30. "knowledge_search",
  31. "tool_search",
  32. "capability_search",
  33. "requirement_search",
  34. "relation_search",
  35. "read_file", "write_file",
  36. "list_cache_status",
  37. "match_tree_nodes",
  38. "sync_atomic_capabilities",
  39. "skill",
  40. ]
  41. if enable_db_commit:
  42. tools.extend(["commit_to_database", "organize_cached_data", "cache_research_data"])
  43. else:
  44. tools.extend(["organize_cached_data", "cache_research_data"])
  45. return RunConfig(
  46. model="qwen3.5-plus",
  47. temperature=0.2,
  48. max_iterations=30,
  49. agent_type="default",
  50. name="Librarian Agent",
  51. goal_compression="on_complete",
  52. skills=[], # 不注入通用 skills(planning/research/browser),使用指定注入
  53. knowledge=KnowledgeConfig(
  54. enable_extraction=False,
  55. enable_completion_extraction=False,
  56. enable_injection=False,
  57. ),
  58. tools=tools,
  59. tool_groups=[], # 精确指定工具,不从分组加载
  60. )
  61. def _register_internal_tools():
  62. """注册内部工具"""
  63. try:
  64. sys.path.insert(0, str(Path(__file__).parent.parent))
  65. from internal_tools.cache_manager import (
  66. cache_research_data,
  67. organize_cached_data,
  68. commit_to_database,
  69. list_cache_status,
  70. )
  71. from internal_tools.tree_matcher import match_tree_nodes
  72. from internal_tools.capability_extractor import sync_atomic_capabilities
  73. from agent.tools import get_tool_registry
  74. registry = get_tool_registry()
  75. registry.register(cache_research_data)
  76. registry.register(organize_cached_data)
  77. registry.register(commit_to_database)
  78. registry.register(list_cache_status)
  79. registry.register(match_tree_nodes)
  80. registry.register(sync_atomic_capabilities)
  81. logger.info("✓ 已注册 Librarian 内部工具")
  82. except Exception as e:
  83. logger.error(f"✗ 注册内部工具失败: {e}")
  84. # ===== 单例 Runner =====
  85. _runner: Optional[AgentRunner] = None
  86. _prompt_messages = None
  87. _initialized = False
  88. def _ensure_initialized():
  89. """延迟初始化 Runner 和 Prompt(首次调用时执行)"""
  90. global _runner, _prompt_messages, _initialized
  91. if _initialized:
  92. return
  93. _initialized = True
  94. _register_internal_tools()
  95. _runner = AgentRunner(
  96. trace_store=FileSystemTraceStore(base_path=".trace"),
  97. llm_call=create_qwen_llm_call(model="qwen3.5-plus"),
  98. skills_dir=str(Path(__file__).parent / "skills"),
  99. debug=True,
  100. logger_name="agents.librarian",
  101. )
  102. prompt_path = Path(__file__).parent / "librarian_agent.prompt"
  103. if prompt_path.exists():
  104. prompt = SimplePrompt(prompt_path)
  105. _prompt_messages = prompt.build_messages()
  106. else:
  107. _prompt_messages = []
  108. logger.warning(f"Librarian prompt 文件不存在: {prompt_path}")
  109. logger.info("✓ Librarian Agent 已初始化")
  110. # ===== 核心方法 =====
  111. async def run_librarian(
  112. query: str,
  113. continue_from: Optional[str] = None,
  114. skills: Optional[List[str]] = None,
  115. ) -> Dict[str, Any]:
  116. """
  117. 同步运行 Librarian Agent。由 skills 参数决定当前是什么模式:
  118. - skills=["ask_strategy"] → 查询模式(默认)
  119. - skills=["upload_strategy"] → 上传模式(query 应为 JSON 字符串)
  120. Args:
  121. query: 任务内容。查询模式为自然语言问题;上传模式为 JSON 字符串 {knowledge, tools, resources}
  122. continue_from: 已有 sub_trace_id,传入则续跑该 trace
  123. skills: 调用方指定的 skill 列表,会被 ALLOWED_SKILLS 过滤
  124. Returns:
  125. {"status", "sub_trace_id", "summary", "stats", "error"?}
  126. """
  127. _ensure_initialized()
  128. # Skill 过滤
  129. skills = _filter_skills(skills)
  130. # 判断模式:upload_strategy 走上传路径(会写 buffer + 解析 JSON)
  131. is_upload = "upload_strategy" in skills
  132. buffer_file = None
  133. if is_upload:
  134. try:
  135. data = json.loads(query) if isinstance(query, str) else query
  136. except json.JSONDecodeError as e:
  137. return _fail(f"upload 模式下 query 不是合法 JSON: {e}")
  138. if not isinstance(data, dict):
  139. return _fail("upload 模式下 query 应为 JSON 对象")
  140. if not (data.get("knowledge") or data.get("tools") or data.get("resources")):
  141. return _fail("upload 模式下 data 中无有效条目")
  142. buffer_file = _write_upload_buffer(data)
  143. content = f"[UPLOAD:BATCH] 收到上传请求,请处理:\n{json.dumps(data, ensure_ascii=False)}"
  144. else:
  145. content = f"[ASK] {query}"
  146. # 运行 Librarian
  147. config = get_librarian_config()
  148. config.trace_id = continue_from
  149. if continue_from is None:
  150. messages = _prompt_messages + [{"role": "user", "content": content}]
  151. else:
  152. messages = [{"role": "user", "content": content}]
  153. try:
  154. result = await _runner.run_result(
  155. messages=messages,
  156. config=config,
  157. inject_skills=skills,
  158. )
  159. actual_trace_id = result.get("trace_id")
  160. if buffer_file:
  161. _update_buffer_status(buffer_file, "completed", trace_id=actual_trace_id)
  162. return {
  163. "status": result.get("status", "completed"),
  164. "sub_trace_id": actual_trace_id,
  165. "summary": result.get("summary", ""),
  166. "stats": result.get("stats", {}),
  167. "error": result.get("error"),
  168. }
  169. except Exception as e:
  170. if buffer_file:
  171. _update_buffer_status(buffer_file, "failed", error=str(e))
  172. logger.error(f"[Librarian] run 失败: {e}")
  173. return _fail(str(e))
  174. def _filter_skills(skills: Optional[List[str]]) -> List[str]:
  175. """把调用方传的 skills 按 ALLOWED_SKILLS 过滤;空列表退化到 DEFAULT_SKILLS。"""
  176. if not skills:
  177. return list(DEFAULT_SKILLS)
  178. allowed = [s for s in skills if s in ALLOWED_SKILLS]
  179. dropped = [s for s in skills if s not in ALLOWED_SKILLS]
  180. if dropped:
  181. logger.warning(f"[Librarian] 忽略不在白名单的 skills: {dropped}(允许: {ALLOWED_SKILLS})")
  182. return allowed or list(DEFAULT_SKILLS)
  183. def _fail(error: str) -> Dict[str, Any]:
  184. return {"status": "failed", "sub_trace_id": None, "summary": "", "stats": {}, "error": error}
  185. def _write_upload_buffer(data: Dict[str, Any]) -> Optional[str]:
  186. """把 upload 数据写到 buffer 目录,便于审计和失败重跑。"""
  187. try:
  188. from datetime import datetime as dt
  189. buffer_dir = Path(".cache/.knowledge/buffer")
  190. buffer_dir.mkdir(parents=True, exist_ok=True)
  191. timestamp = dt.now().strftime("%Y%m%d_%H%M%S")
  192. path = buffer_dir / f"upload_{timestamp}.json"
  193. path.write_text(json.dumps({
  194. "data": data, "received_at": dt.now().isoformat(),
  195. }, ensure_ascii=False, indent=2), encoding="utf-8")
  196. return str(path)
  197. except Exception as e:
  198. logger.warning(f"写 buffer 失败: {e}")
  199. return None
  200. def _update_buffer_status(buffer_file: Optional[str], status: str, trace_id: str = None, error: str = None):
  201. """更新 buffer 文件中的处理状态"""
  202. if not buffer_file:
  203. return
  204. try:
  205. from datetime import datetime as dt
  206. path = Path(buffer_file)
  207. if not path.exists():
  208. return
  209. data = json.loads(path.read_text(encoding="utf-8"))
  210. data["status"] = status
  211. data["processed_at"] = dt.now().isoformat()
  212. if trace_id:
  213. data["librarian_trace_id"] = trace_id
  214. if error:
  215. data["error"] = error
  216. path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
  217. except Exception as e:
  218. logger.warning(f"更新 buffer 状态失败: {e}")
  219. def list_pending_uploads() -> list:
  220. """列出所有未处理或失败的 upload buffer 文件"""
  221. buffer_dir = Path(".cache/.knowledge/buffer")
  222. if not buffer_dir.exists():
  223. return []
  224. pending = []
  225. for f in sorted(buffer_dir.glob("upload_*.json")):
  226. try:
  227. data = json.loads(f.read_text(encoding="utf-8"))
  228. status = data.get("status", "pending")
  229. if status in ("pending", "failed"):
  230. pending.append({
  231. "file": str(f),
  232. "status": status,
  233. "received_at": data.get("received_at", ""),
  234. "error": data.get("error", ""),
  235. "trace_id": data.get("trace_id", ""),
  236. })
  237. except Exception:
  238. pass
  239. return pending