""" Librarian Agent — KnowHub 的知识管理 Agent 同一个 Agent 处理多种任务模式(查询 / 上传等),调用方通过 skills 参数选择策略: - skills=["ask_strategy"] → 知识库查询整合 - skills=["upload_strategy"] → 知识上传图谱编排 通过 HTTP API 被 FastAPI server 调用,每次请求一次 AgentRunner.run()。 续跑由 caller 显式传入 continue_from 指定。 """ import json import logging import sys from pathlib import Path from typing import Optional, Dict, Any, List # 确保项目路径可用 sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from agent.core.runner import AgentRunner, RunConfig from agent.trace import FileSystemTraceStore from agent.llm import create_qwen_llm_call from agent.llm.prompts import SimplePrompt from agent.tools.builtin.knowledge import KnowledgeConfig logger = logging.getLogger("agents.librarian") # ===== 配置 ===== ENABLE_DATABASE_COMMIT = False # Librarian 允许调用方注入的 skill 白名单。不在列表中的 skill 会被过滤掉(只打 warning,不报错)。 ALLOWED_SKILLS = ["ask_strategy", "upload_strategy"] DEFAULT_SKILLS = ["ask_strategy"] # 调用方不传时的默认值 def get_librarian_config(enable_db_commit: bool = ENABLE_DATABASE_COMMIT) -> RunConfig: """获取 Librarian Agent 配置""" tools = [ "knowledge_search", "tool_search", "capability_search", "requirement_search", "relation_search", "read_file", "write_file", "list_cache_status", "match_tree_nodes", "sync_atomic_capabilities", "skill", ] if enable_db_commit: tools.extend(["commit_to_database", "organize_cached_data", "cache_research_data"]) else: tools.extend(["organize_cached_data", "cache_research_data"]) return RunConfig( model="qwen3.5-plus", temperature=0.2, max_iterations=30, agent_type="default", name="Librarian Agent", goal_compression="on_complete", skills=[], # 不注入通用 skills(planning/research/browser),使用指定注入 knowledge=KnowledgeConfig( enable_extraction=False, enable_completion_extraction=False, enable_injection=False, ), tools=tools, tool_groups=[], # 精确指定工具,不从分组加载 ) def _register_internal_tools(): """注册内部工具""" try: sys.path.insert(0, str(Path(__file__).parent.parent)) from internal_tools.cache_manager import ( cache_research_data, organize_cached_data, commit_to_database, list_cache_status, ) from internal_tools.tree_matcher import match_tree_nodes from internal_tools.capability_extractor import sync_atomic_capabilities from agent.tools import get_tool_registry registry = get_tool_registry() registry.register(cache_research_data) registry.register(organize_cached_data) registry.register(commit_to_database) registry.register(list_cache_status) registry.register(match_tree_nodes) registry.register(sync_atomic_capabilities) logger.info("✓ 已注册 Librarian 内部工具") except Exception as e: logger.error(f"✗ 注册内部工具失败: {e}") # ===== 单例 Runner ===== _runner: Optional[AgentRunner] = None _prompt_messages = None _initialized = False def _ensure_initialized(): """延迟初始化 Runner 和 Prompt(首次调用时执行)""" global _runner, _prompt_messages, _initialized if _initialized: return _initialized = True _register_internal_tools() _runner = AgentRunner( trace_store=FileSystemTraceStore(base_path=".trace"), llm_call=create_qwen_llm_call(model="qwen3.5-plus"), skills_dir=str(Path(__file__).parent / "skills"), debug=True, logger_name="agents.librarian", ) prompt_path = Path(__file__).parent / "librarian_agent.prompt" if prompt_path.exists(): prompt = SimplePrompt(prompt_path) _prompt_messages = prompt.build_messages() else: _prompt_messages = [] logger.warning(f"Librarian prompt 文件不存在: {prompt_path}") logger.info("✓ Librarian Agent 已初始化") # ===== 核心方法 ===== async def run_librarian( query: str, continue_from: Optional[str] = None, skills: Optional[List[str]] = None, ) -> Dict[str, Any]: """ 同步运行 Librarian Agent。由 skills 参数决定当前是什么模式: - skills=["ask_strategy"] → 查询模式(默认) - skills=["upload_strategy"] → 上传模式(query 应为 JSON 字符串) Args: query: 任务内容。查询模式为自然语言问题;上传模式为 JSON 字符串 {knowledge, tools, resources} continue_from: 已有 sub_trace_id,传入则续跑该 trace skills: 调用方指定的 skill 列表,会被 ALLOWED_SKILLS 过滤 Returns: {"status", "sub_trace_id", "summary", "stats", "error"?} """ _ensure_initialized() # Skill 过滤 skills = _filter_skills(skills) # 判断模式:upload_strategy 走上传路径(会写 buffer + 解析 JSON) is_upload = "upload_strategy" in skills buffer_file = None if is_upload: try: data = json.loads(query) if isinstance(query, str) else query except json.JSONDecodeError as e: return _fail(f"upload 模式下 query 不是合法 JSON: {e}") if not isinstance(data, dict): return _fail("upload 模式下 query 应为 JSON 对象") if not (data.get("knowledge") or data.get("tools") or data.get("resources")): return _fail("upload 模式下 data 中无有效条目") buffer_file = _write_upload_buffer(data) content = f"[UPLOAD:BATCH] 收到上传请求,请处理:\n{json.dumps(data, ensure_ascii=False)}" else: content = f"[ASK] {query}" # 运行 Librarian config = get_librarian_config() config.trace_id = continue_from if continue_from is None: messages = _prompt_messages + [{"role": "user", "content": content}] else: messages = [{"role": "user", "content": content}] try: result = await _runner.run_result( messages=messages, config=config, inject_skills=skills, ) actual_trace_id = result.get("trace_id") if buffer_file: _update_buffer_status(buffer_file, "completed", trace_id=actual_trace_id) return { "status": result.get("status", "completed"), "sub_trace_id": actual_trace_id, "summary": result.get("summary", ""), "stats": result.get("stats", {}), "error": result.get("error"), } except Exception as e: if buffer_file: _update_buffer_status(buffer_file, "failed", error=str(e)) logger.error(f"[Librarian] run 失败: {e}") return _fail(str(e)) def _filter_skills(skills: Optional[List[str]]) -> List[str]: """把调用方传的 skills 按 ALLOWED_SKILLS 过滤;空列表退化到 DEFAULT_SKILLS。""" if not skills: return list(DEFAULT_SKILLS) allowed = [s for s in skills if s in ALLOWED_SKILLS] dropped = [s for s in skills if s not in ALLOWED_SKILLS] if dropped: logger.warning(f"[Librarian] 忽略不在白名单的 skills: {dropped}(允许: {ALLOWED_SKILLS})") return allowed or list(DEFAULT_SKILLS) def _fail(error: str) -> Dict[str, Any]: return {"status": "failed", "sub_trace_id": None, "summary": "", "stats": {}, "error": error} def _write_upload_buffer(data: Dict[str, Any]) -> Optional[str]: """把 upload 数据写到 buffer 目录,便于审计和失败重跑。""" try: from datetime import datetime as dt buffer_dir = Path(".cache/.knowledge/buffer") buffer_dir.mkdir(parents=True, exist_ok=True) timestamp = dt.now().strftime("%Y%m%d_%H%M%S") path = buffer_dir / f"upload_{timestamp}.json" path.write_text(json.dumps({ "data": data, "received_at": dt.now().isoformat(), }, ensure_ascii=False, indent=2), encoding="utf-8") return str(path) except Exception as e: logger.warning(f"写 buffer 失败: {e}") return None def _update_buffer_status(buffer_file: Optional[str], status: str, trace_id: str = None, error: str = None): """更新 buffer 文件中的处理状态""" if not buffer_file: return try: from datetime import datetime as dt path = Path(buffer_file) if not path.exists(): return data = json.loads(path.read_text(encoding="utf-8")) data["status"] = status data["processed_at"] = dt.now().isoformat() if trace_id: data["librarian_trace_id"] = trace_id if error: data["error"] = error path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") except Exception as e: logger.warning(f"更新 buffer 状态失败: {e}") def list_pending_uploads() -> list: """列出所有未处理或失败的 upload buffer 文件""" buffer_dir = Path(".cache/.knowledge/buffer") if not buffer_dir.exists(): return [] pending = [] for f in sorted(buffer_dir.glob("upload_*.json")): try: data = json.loads(f.read_text(encoding="utf-8")) status = data.get("status", "pending") if status in ("pending", "failed"): pending.append({ "file": str(f), "status": status, "received_at": data.get("received_at", ""), "error": data.get("error", ""), "trace_id": data.get("trace_id", ""), }) except Exception: pass return pending