function_knowledge.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. '''
  2. 方法知识获取模块
  3. 1. 输入:问题 + 帖子信息 + 账号人设信息
  4. 2. 将输入的问题转化成query,调用大模型,prompt在 function_knowledge_generate_query_prompt.md 中
  5. 3. 从已有方法工具库中尝试选择合适的方法工具(调用大模型执行,prompt在 function_knowledge_select_tools_prompt.md 中),如果有,则返回选择的方法工具,否则:
  6. - 调用 multi_search_knowledge.py 获取知识
  7. - 返回新的方法工具知识
  8. - 异步从新方法知识中获取新工具(调用大模型执行,prompt在 function_knowledge_generate_new_tool_prompt.md 中),调用工具库系统,接入新的工具
  9. 4. 调用选择的方法工具执行验证,返回工具执行结果
  10. '''
  11. import os
  12. import sys
  13. import json
  14. import threading
  15. from loguru import logger
  16. # 设置路径以便导入工具类
  17. current_dir = os.path.dirname(os.path.abspath(__file__))
  18. root_dir = os.path.dirname(current_dir)
  19. sys.path.insert(0, root_dir)
  20. from utils.gemini_client import generate_text
  21. from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info
  22. from knowledge_v2.multi_search_knowledge import get_knowledge as get_multi_search_knowledge
  23. from knowledge_v2.cache_manager import CacheManager
  24. class FunctionKnowledge:
  25. """方法知识获取类"""
  26. def __init__(self, use_cache: bool = True):
  27. """
  28. 初始化
  29. Args:
  30. use_cache: 是否启用缓存,默认启用
  31. """
  32. logger.info("=" * 80)
  33. logger.info("初始化 FunctionKnowledge - 方法知识获取入口")
  34. self.prompt_dir = os.path.join(current_dir, "prompt")
  35. self.use_cache = use_cache
  36. self.cache = CacheManager() if use_cache else None
  37. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  38. logger.info("=" * 80)
  39. def _load_prompt(self, filename: str) -> str:
  40. """加载prompt文件内容"""
  41. prompt_path = os.path.join(self.prompt_dir, filename)
  42. if not os.path.exists(prompt_path):
  43. raise FileNotFoundError(f"Prompt文件不存在: {prompt_path}")
  44. with open(prompt_path, 'r', encoding='utf-8') as f:
  45. return f.read().strip()
  46. def generate_query(self, question: str, post_info: str, persona_info: str) -> tuple:
  47. """
  48. 生成查询语句
  49. Returns:
  50. tuple: (query, detail_info)
  51. - query: 生成的查询语句
  52. - detail_info: 详细信息dict,包含prompt和response
  53. """
  54. logger.info(f"[步骤1] 生成Query...")
  55. # 组合问题的唯一标识
  56. combined_question = f"{question}||{post_info}||{persona_info}"
  57. detail_info = {"cached": False, "prompt": None, "response": None}
  58. # 尝试从缓存读取
  59. if self.use_cache:
  60. cached_query = self.cache.get(combined_question, 'function_knowledge', 'generated_query.txt')
  61. if cached_query:
  62. logger.info(f"✓ 使用缓存的Query: {cached_query}")
  63. detail_info["cached"] = True
  64. return cached_query, detail_info
  65. try:
  66. prompt_template = self._load_prompt("function_generate_query_prompt.md")
  67. prompt = prompt_template.format(
  68. question=question,
  69. post_info=post_info,
  70. persona_info=persona_info
  71. )
  72. detail_info["prompt"] = prompt
  73. logger.info("→ 调用Gemini生成Query...")
  74. query = generate_text(prompt=prompt)
  75. query = query.strip()
  76. detail_info["response"] = query
  77. logger.info(f"✓ 生成Query: {query}")
  78. # 写入缓存
  79. if self.use_cache:
  80. self.cache.set(combined_question, 'function_knowledge', 'generated_query.txt', query)
  81. return query, detail_info
  82. except Exception as e:
  83. logger.error(f"✗ 生成Query失败: {e}")
  84. detail_info["error"] = str(e)
  85. return question, detail_info # 降级使用原问题
  86. def select_tool(self, combined_question: str, query: str) -> tuple:
  87. """
  88. 选择合适的工具
  89. Returns:
  90. tuple: (tool_name, detail_info)
  91. """
  92. logger.info(f"[步骤2] 选择工具...")
  93. detail_info = {"cached": False, "prompt": None, "response": None, "available_tools_count": 0}
  94. # 尝试从缓存读取
  95. if self.use_cache:
  96. cached_tool = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.txt')
  97. if cached_tool:
  98. logger.info(f"✓ 使用缓存的工具: {cached_tool}")
  99. detail_info["cached"] = True
  100. return cached_tool, detail_info
  101. try:
  102. all_tool_infos = get_all_tool_infos()
  103. if not all_tool_infos:
  104. logger.info(" 工具库为空,无可用工具")
  105. return "None", detail_info
  106. tool_count = len(all_tool_infos.split('--- Tool:')) - 1
  107. detail_info["available_tools_count"] = tool_count
  108. logger.info(f" 当前可用工具数: {tool_count}")
  109. prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md")
  110. prompt = prompt_template.format(
  111. query=query,
  112. tool_infos=all_tool_infos
  113. )
  114. detail_info["prompt"] = prompt
  115. detail_info["tool_infos"] = all_tool_infos
  116. logger.info("→ 调用Gemini选择工具...")
  117. tool_name = generate_text(prompt=prompt)
  118. tool_name = tool_name.strip()
  119. detail_info["response"] = tool_name
  120. logger.info(f"✓ 选择结果: {tool_name}")
  121. # 写入缓存
  122. if self.use_cache:
  123. self.cache.set(combined_question, 'function_knowledge', 'selected_tool.txt', tool_name)
  124. return tool_name, detail_info
  125. except Exception as e:
  126. logger.error(f"✗ 选择工具失败: {e}")
  127. detail_info["error"] = str(e)
  128. return "None", detail_info
  129. def extract_tool_params(self, combined_question: str, tool_name: str, query: str) -> tuple:
  130. """
  131. 根据工具信息和查询提取调用参数
  132. Args:
  133. combined_question: 组合问题(用于缓存)
  134. tool_name: 工具名称
  135. query: 查询内容
  136. Returns:
  137. tuple: (params, detail_info)
  138. """
  139. logger.info(f"[步骤3] 提取工具参数...")
  140. # 初始化detail_info
  141. detail_info = {"cached": False, "prompt": None, "response": None, "tool_info": None}
  142. # 尝试从缓存读取
  143. if self.use_cache:
  144. cached_params = self.cache.get(combined_question, 'function_knowledge', 'tool_params.json')
  145. if cached_params:
  146. logger.info(f"✓ 使用缓存的参数: {cached_params}")
  147. detail_info["cached"] = True
  148. return cached_params, detail_info
  149. try:
  150. # 获取工具信息
  151. tool_info = get_tool_info(tool_name)
  152. if not tool_info:
  153. logger.warning(f" ⚠ 未找到工具 {tool_name} 的信息,使用默认参数")
  154. # 降级:使用query作为keyword
  155. default_params = {"keyword": query}
  156. detail_info["fallback"] = "tool_info_not_found"
  157. return default_params, detail_info
  158. detail_info["tool_info"] = tool_info
  159. logger.info(f" 工具 {tool_name} 信息长度: {len(tool_info)}")
  160. # 加载prompt
  161. prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md")
  162. prompt = prompt_template.format(
  163. query=query,
  164. tool_info=tool_info
  165. )
  166. detail_info["prompt"] = prompt
  167. # 调用LLM提取参数
  168. logger.info(" → 调用Gemini提取参数...")
  169. response_text = generate_text(prompt=prompt)
  170. detail_info["response"] = response_text
  171. # 解析JSON
  172. logger.info(" → 解析参数JSON...")
  173. try:
  174. # 清理可能的markdown标记
  175. response_text = response_text.strip()
  176. if response_text.startswith("```json"):
  177. response_text = response_text[7:]
  178. if response_text.startswith("```"):
  179. response_text = response_text[3:]
  180. if response_text.endswith("```"):
  181. response_text = response_text[:-3]
  182. response_text = response_text.strip()
  183. params = json.loads(response_text)
  184. logger.info(f"✓ 提取参数成功: {params}")
  185. # 写入缓存
  186. if self.use_cache:
  187. self.cache.set(combined_question, 'function_knowledge', 'tool_params.json', params)
  188. return params, detail_info
  189. except json.JSONDecodeError as e:
  190. logger.error(f" ✗ 解析JSON失败: {e}")
  191. logger.error(f" 响应内容: {response_text}")
  192. # 降级:使用query作为keyword
  193. default_params = {"keyword": query}
  194. logger.warning(f" 使用默认参数: {default_params}")
  195. detail_info["fallback"] = "json_decode_error"
  196. return default_params, detail_info
  197. except Exception as e:
  198. logger.error(f"✗ 提取工具参数失败: {e}")
  199. # 降级:使用query作为keyword
  200. default_params = {"keyword": query}
  201. detail_info["error"] = str(e)
  202. detail_info["fallback"] = "exception"
  203. return default_params, detail_info
  204. def save_knowledge_to_file(self, knowledge: str, combined_question: str):
  205. """保存获取到的知识到文件"""
  206. try:
  207. logger.info("[保存知识] 开始保存知识到文件...")
  208. # 获取问题hash
  209. import hashlib
  210. question_hash = hashlib.md5(combined_question.encode('utf-8')).hexdigest()[:12]
  211. # 获取缓存目录(和execution_record.json同级)
  212. if self.use_cache and self.cache:
  213. cache_dir = os.path.join(self.cache.base_cache_dir, question_hash)
  214. else:
  215. cache_dir = os.path.join(os.path.dirname(__file__), '.cache', question_hash)
  216. os.makedirs(cache_dir, exist_ok=True)
  217. # 保存到knowledge.txt
  218. knowledge_file = os.path.join(cache_dir, 'knowledge.txt')
  219. with open(knowledge_file, 'w', encoding='utf-8') as f:
  220. f.write(knowledge)
  221. logger.info(f"✓ 知识已保存到: {knowledge_file}")
  222. logger.info(f" 知识长度: {len(knowledge)} 字符")
  223. except Exception as e:
  224. logger.error(f"✗ 保存知识失败: {e}")
  225. def get_knowledge(self, question: str, post_info: str, persona_info: str) -> dict:
  226. """
  227. 获取方法知识的主流程
  228. Returns:
  229. dict: 包含完整执行信息的字典
  230. {
  231. "input": {...}, # 原始输入
  232. "execution": {...}, # 执行过程信息
  233. "result": {...}, # 最终结果
  234. "metadata": {...} # 元数据
  235. }
  236. """
  237. import time
  238. start_time = time.time()
  239. timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
  240. logger.info("=" * 80)
  241. logger.info(f"Function Knowledge - 开始处理")
  242. logger.info(f"问题: {question}")
  243. logger.info(f"帖子信息: {post_info}")
  244. logger.info(f"人设信息: {persona_info}")
  245. logger.info("=" * 80)
  246. # 组合问题的唯一标识
  247. combined_question = f"{question}||{post_info}||{persona_info}"
  248. # 初始化执行记录
  249. execution_record = {
  250. "input": {
  251. "question": question,
  252. "post_info": post_info,
  253. "persona_info": persona_info,
  254. "timestamp": timestamp
  255. },
  256. "execution": {
  257. "steps": [],
  258. "tool_info": None,
  259. "knowledge_search_info": None
  260. },
  261. "result": {
  262. "type": None, # "tool" 或 "knowledge"
  263. "content": None,
  264. "raw_data": None
  265. },
  266. "metadata": {
  267. "execution_time": None,
  268. "cache_hits": [],
  269. "errors": []
  270. }
  271. }
  272. # 检查最终结果缓存
  273. if self.use_cache:
  274. cached_final = self.cache.get(combined_question, 'function_knowledge', 'final_result.json')
  275. if cached_final:
  276. logger.info(f"✓ 使用缓存的最终结果")
  277. logger.info("=" * 80 + "\n")
  278. # 如果是完整的执行记录,直接返回
  279. if isinstance(cached_final, dict) and "execution" in cached_final:
  280. return cached_final
  281. # 否则构造一个简单的返回
  282. return {
  283. "input": execution_record["input"],
  284. "execution": {"cached": True},
  285. "result": {"type": "cached", "content": cached_final},
  286. "metadata": {"cache_hit": True}
  287. }
  288. try:
  289. # 步骤1: 生成Query
  290. step1_start = time.time()
  291. query, query_detail = self.generate_query(question, post_info, persona_info)
  292. execution_record["execution"]["steps"].append({
  293. "step": 1,
  294. "name": "generate_query",
  295. "duration": time.time() - step1_start,
  296. "output": query,
  297. "detail": query_detail # 包含prompt和response
  298. })
  299. # 步骤2: 选择工具
  300. step2_start = time.time()
  301. tool_name, tool_select_detail = self.select_tool(combined_question, query)
  302. execution_record["execution"]["steps"].append({
  303. "step": 2,
  304. "name": "select_tool",
  305. "duration": time.time() - step2_start,
  306. "output": tool_name,
  307. "detail": tool_select_detail # 包含prompt、response和可用工具列表
  308. })
  309. result_content = None
  310. if tool_name and tool_name != "None":
  311. # 路径A: 使用工具
  312. execution_record["result"]["type"] = "tool"
  313. # 步骤3: 提取参数
  314. step3_start = time.time()
  315. arguments, params_detail = self.extract_tool_params(combined_question, tool_name, query)
  316. execution_record["execution"]["steps"].append({
  317. "step": 3,
  318. "name": "extract_tool_params",
  319. "duration": time.time() - step3_start,
  320. "output": arguments,
  321. "detail": params_detail # 包含prompt、response和工具信息
  322. })
  323. # 步骤4: 调用工具
  324. logger.info(f"[步骤4] 调用工具: {tool_name}")
  325. # 检查工具调用缓存
  326. if self.use_cache:
  327. cached_tool_result = self.cache.get(combined_question, 'function_knowledge', 'tool_result.json')
  328. if cached_tool_result:
  329. logger.info(f"✓ 使用缓存的工具调用结果")
  330. execution_record["metadata"]["cache_hits"].append("tool_result")
  331. tool_result = cached_tool_result
  332. else:
  333. step4_start = time.time()
  334. logger.info(f" → 调用工具,参数: {arguments}")
  335. tool_result = call_tool(tool_name, arguments)
  336. # 缓存工具调用结果
  337. self.cache.set(combined_question, 'function_knowledge', 'tool_result.json', tool_result)
  338. execution_record["execution"]["steps"].append({
  339. "step": 4,
  340. "name": "call_tool",
  341. "duration": time.time() - step4_start,
  342. "output": "success"
  343. })
  344. else:
  345. step4_start = time.time()
  346. logger.info(f" → 调用工具,参数: {arguments}")
  347. tool_result = call_tool(tool_name, arguments)
  348. execution_record["execution"]["steps"].append({
  349. "step": 4,
  350. "name": "call_tool",
  351. "duration": time.time() - step4_start,
  352. "output": "success"
  353. })
  354. # 记录工具调用信息
  355. execution_record["execution"]["tool_info"] = {
  356. "tool_name": tool_name,
  357. "parameters": arguments,
  358. "result": tool_result
  359. }
  360. result_content = f"工具 {tool_name} 执行结果: {json.dumps(tool_result, ensure_ascii=False)}"
  361. execution_record["result"]["content"] = result_content
  362. execution_record["result"]["raw_data"] = tool_result
  363. logger.info(f"✓ 工具调用完成")
  364. else:
  365. # 路径B: 知识搜索
  366. execution_record["result"]["type"] = "knowledge_search"
  367. logger.info("[步骤4] 未找到合适工具,调用 MultiSearch...")
  368. step4_start = time.time()
  369. knowledge = get_multi_search_knowledge(query, cache_key=combined_question)
  370. execution_record["execution"]["steps"].append({
  371. "step": 4,
  372. "name": "multi_search_knowledge",
  373. "duration": time.time() - step4_start,
  374. "output": f"knowledge_length: {len(knowledge)}"
  375. })
  376. # 记录知识搜索信息
  377. execution_record["execution"]["knowledge_search_info"] = {
  378. "query": query,
  379. "knowledge_length": len(knowledge),
  380. "source": "multi_search"
  381. }
  382. result_content = knowledge
  383. execution_record["result"]["content"] = knowledge
  384. execution_record["result"]["raw_data"] = {"knowledge": knowledge, "query": query}
  385. # 异步生成新工具
  386. logger.info("[后台任务] 启动新工具生成线程...")
  387. threading.Thread(target=self.save_knowledge_to_file, args=(knowledge, combined_question)).start()
  388. # 计算总执行时间
  389. execution_record["metadata"]["execution_time"] = time.time() - start_time
  390. # 保存完整的执行记录到JSON文件
  391. if self.use_cache:
  392. self.cache.set(combined_question, 'function_knowledge', 'final_result.json', execution_record)
  393. # 同时保存一个格式化的JSON文件供人类阅读
  394. from knowledge_v2.cache_manager import CacheManager
  395. cache = CacheManager()
  396. import hashlib
  397. question_hash = hashlib.md5(combined_question.encode('utf-8')).hexdigest()[:12]
  398. output_file = os.path.join(cache.base_cache_dir, question_hash, 'execution_record.json')
  399. try:
  400. with open(output_file, 'w', encoding='utf-8') as f:
  401. json.dump(execution_record, f, ensure_ascii=False, indent=2)
  402. logger.info(f"✓ 完整执行记录已保存: {output_file}")
  403. except Exception as e:
  404. logger.error(f"保存执行记录失败: {e}")
  405. logger.info("=" * 80)
  406. logger.info(f"✓ Function Knowledge 完成")
  407. logger.info(f" 类型: {execution_record['result']['type']}")
  408. logger.info(f" 结果长度: {len(result_content) if result_content else 0}")
  409. logger.info(f" 执行时间: {execution_record['metadata']['execution_time']:.2f}秒")
  410. logger.info("=" * 80 + "\n")
  411. return execution_record
  412. except Exception as e:
  413. logger.error(f"✗ 执行失败: {e}")
  414. import traceback
  415. error_trace = traceback.format_exc()
  416. execution_record["metadata"]["errors"].append({
  417. "error": str(e),
  418. "traceback": error_trace
  419. })
  420. execution_record["result"]["type"] = "error"
  421. execution_record["result"]["content"] = f"执行失败: {str(e)}"
  422. execution_record["metadata"]["execution_time"] = time.time() - start_time
  423. return execution_record
  424. def get_knowledge(question: str, post_info: str, persona_info: str) -> dict:
  425. """
  426. 便捷调用函数
  427. Returns:
  428. dict: 完整的执行记录,包含输入、执行过程、结果和元数据
  429. """
  430. agent = FunctionKnowledge()
  431. return agent.get_knowledge(question, post_info, persona_info)
  432. if __name__ == "__main__":
  433. # 测试代码
  434. question = "教资查分这个信息怎么来的"
  435. post_info = "发帖时间:2025.11.07"
  436. persona_info = ""
  437. try:
  438. agent = FunctionKnowledge()
  439. execution_result = agent.get_knowledge(question, post_info, persona_info)
  440. print("=" * 50)
  441. print("执行结果:")
  442. print("=" * 50)
  443. print(f"类型: {execution_result['result']['type']}")
  444. print(f"内容预览: {execution_result['result']['content'][:200]}...")
  445. print(f"执行时间: {execution_result['metadata']['execution_time']:.2f}秒")
  446. print(f"\n完整JSON已保存到缓存目录")
  447. except Exception as e:
  448. logger.error(f"测试失败: {e}")