function_knowledge.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. '''
  2. 方法知识获取模块
  3. 1. 输入:问题 + 帖子信息 + 账号人设信息
  4. 2. 将输入的问题转化成query,调用大模型,prompt在 function_knowledge_generate_query_prompt.md 中
  5. 3. 从已有方法工具库中尝试选择合适的方法工具(调用大模型执行,prompt在 function_knowledge_select_tools_prompt.md 中),如果有,则返回选择的方法工具,否则:
  6. - 调用 multi_search_knowledge.py 获取知识
  7. - 返回新的方法工具知识
  8. - 异步从新方法知识中获取新工具(调用大模型执行,prompt在 function_knowledge_generate_new_tool_prompt.md 中),调用工具库系统,接入新的工具
  9. 4. 调用选择的方法工具执行验证,返回工具执行结果
  10. '''
  11. import os
  12. import sys
  13. import json
  14. import threading
  15. from loguru import logger
  16. # 设置路径以便导入工具类
  17. current_dir = os.path.dirname(os.path.abspath(__file__))
  18. root_dir = os.path.dirname(current_dir)
  19. sys.path.insert(0, root_dir)
  20. from utils.gemini_client import generate_text
  21. from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info
  22. from knowledge_v2.multi_search_knowledge import get_knowledge as get_multi_search_knowledge
  23. from knowledge_v2.cache_manager import CacheManager
  24. class FunctionKnowledge:
  25. """方法知识获取类"""
  26. def __init__(self, use_cache: bool = True):
  27. """
  28. 初始化
  29. Args:
  30. use_cache: 是否启用缓存,默认启用
  31. """
  32. logger.info("=" * 80)
  33. logger.info("初始化 FunctionKnowledge - 方法知识获取入口")
  34. self.prompt_dir = os.path.join(current_dir, "prompt")
  35. self.use_cache = use_cache
  36. self.cache = CacheManager() if use_cache else None
  37. # 执行详情收集
  38. self.execution_detail = {
  39. "generate_query": {},
  40. "select_tool": {},
  41. "extract_params": {},
  42. "execution_time": 0,
  43. "cache_hits": []
  44. }
  45. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  46. logger.info("=" * 80)
  47. def _save_execution_detail(self, cache_key: str):
  48. """保存执行详情到缓存(支持合并旧记录)"""
  49. if not self.use_cache or not self.cache:
  50. return
  51. try:
  52. import hashlib
  53. question_hash = hashlib.md5(cache_key.encode('utf-8')).hexdigest()[:12]
  54. detail_dir = os.path.join(
  55. self.cache.base_cache_dir,
  56. question_hash,
  57. 'function_knowledge'
  58. )
  59. os.makedirs(detail_dir, exist_ok=True)
  60. detail_file = os.path.join(detail_dir, 'execution_detail.json')
  61. # 准备最终要保存的数据,默认为当前内存中的数据
  62. final_detail = self.execution_detail.copy()
  63. # 尝试读取旧文件进行合并
  64. if os.path.exists(detail_file):
  65. try:
  66. with open(detail_file, 'r', encoding='utf-8') as f:
  67. old_detail = json.load(f)
  68. # 智能合并逻辑:保留更有价值的历史信息
  69. for key, new_val in self.execution_detail.items():
  70. # 跳过非字典字段或旧文件中不存在的字段
  71. if not isinstance(new_val, dict) or key not in old_detail:
  72. continue
  73. old_val = old_detail[key]
  74. if not isinstance(old_val, dict):
  75. continue
  76. # 核心逻辑:如果新记录是缓存命中(cached=True),而旧记录包含prompt(说明是当初生成的)
  77. # 则保留旧记录,防止被简略信息覆盖
  78. if new_val.get("cached", False) is True and "prompt" in old_val:
  79. # logger.debug(f" 保留 {key} 的历史详细记录")
  80. final_detail[key] = old_val
  81. except Exception as e:
  82. logger.warning(f" ⚠ 读取旧详情失败,将使用新记录: {e}")
  83. with open(detail_file, 'w', encoding='utf-8') as f:
  84. json.dump(final_detail, f, ensure_ascii=False, indent=2)
  85. logger.info(f"✓ 执行详情已保存: {detail_file}")
  86. except Exception as e:
  87. logger.error(f"✗ 保存执行详情失败: {e}")
  88. def _load_prompt(self, filename: str) -> str:
  89. """加载prompt文件内容"""
  90. prompt_path = os.path.join(self.prompt_dir, filename)
  91. if not os.path.exists(prompt_path):
  92. raise FileNotFoundError(f"Prompt文件不存在: {prompt_path}")
  93. with open(prompt_path, 'r', encoding='utf-8') as f:
  94. return f.read().strip()
  95. def generate_query(self, question: str, post_info: str, persona_info: str) -> str:
  96. """
  97. 生成查询语句
  98. Returns:
  99. str: 生成的查询语句
  100. """
  101. logger.info(f"[步骤1] 生成Query...")
  102. # 组合问题的唯一标识
  103. combined_question = f"{question}||{post_info}||{persona_info}"
  104. # 尝试从缓存读取
  105. if self.use_cache:
  106. cached_query = self.cache.get(combined_question, 'function_knowledge', 'generated_query.txt')
  107. if cached_query:
  108. logger.info(f"✓ 使用缓存的Query: {cached_query}")
  109. # 记录缓存命中
  110. self.execution_detail["generate_query"].update({"cached": True, "query": cached_query})
  111. return cached_query
  112. try:
  113. prompt_template = self._load_prompt("function_generate_query_prompt.md")
  114. prompt = prompt_template.replace("{question}", question)
  115. logger.info("→ 调用Gemini生成Query...")
  116. query = generate_text(prompt=prompt)
  117. query = query.strip()
  118. logger.info(f"✓ 生成Query: {query}")
  119. # 写入缓存
  120. if self.use_cache:
  121. self.cache.set(combined_question, 'function_knowledge', 'generated_query.txt', query)
  122. # 记录详情
  123. self.execution_detail["generate_query"] = {
  124. "cached": False,
  125. "prompt": prompt,
  126. "response": query,
  127. "query": query
  128. }
  129. return query
  130. except Exception as e:
  131. logger.error(f"✗ 生成Query失败: {e}")
  132. return question # 降级使用原问题
  133. def select_tool(self, combined_question: str, query: str) -> str:
  134. """
  135. 选择合适的工具
  136. Returns:
  137. str: 工具名称,如果没有合适的工具则返回"None"
  138. """
  139. logger.info(f"[步骤2] 选择工具...")
  140. # 尝试从缓存读取
  141. if self.use_cache:
  142. cached_tool = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.txt')
  143. if cached_tool:
  144. logger.info(f"✓ 使用缓存的工具: {cached_tool}")
  145. # 记录缓存命中
  146. self.execution_detail["select_tool"].update({
  147. "cached": True,
  148. "tool_name": cached_tool
  149. })
  150. return cached_tool
  151. try:
  152. all_tool_infos = self._load_prompt("all_tools_infos.md")
  153. if not all_tool_infos:
  154. logger.info(" 工具库为空,无可用工具")
  155. return "None"
  156. tool_count = len(all_tool_infos.split('--- Tool:')) - 1
  157. logger.info(f" 当前可用工具数: {tool_count}")
  158. prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md")
  159. prompt = prompt_template.replace("{all_tool_infos}", all_tool_infos)
  160. logger.info("→ 调用Gemini选择工具...")
  161. tool_name = generate_text(prompt=prompt)
  162. tool_name = tool_name.strip()
  163. logger.info(f"✓ 选择结果: {tool_name}")
  164. # 写入缓存
  165. if self.use_cache:
  166. self.cache.set(combined_question, 'function_knowledge', 'selected_tool.txt', tool_name)
  167. # 记录详情
  168. self.execution_detail["select_tool"] = {
  169. "cached": False,
  170. "prompt": prompt,
  171. "response": tool_name,
  172. "tool_name": tool_name,
  173. "available_tools_count": tool_count
  174. }
  175. return tool_name
  176. except Exception as e:
  177. logger.error(f"✗ 选择工具失败: {e}")
  178. return "None"
  179. def extract_tool_params(self, combined_question: str, tool_name: str, query: str) -> dict:
  180. """
  181. 根据工具信息和查询提取调用参数
  182. Args:
  183. combined_question: 组合问题(用于缓存)
  184. tool_name: 工具名称
  185. query: 查询内容
  186. Returns:
  187. dict: 提取的参数字典
  188. """
  189. logger.info(f"[步骤3] 提取工具参数...")
  190. # 尝试从缓存读取
  191. if self.use_cache:
  192. cached_params = self.cache.get(combined_question, 'function_knowledge', 'tool_params.json')
  193. if cached_params:
  194. logger.info(f"✓ 使用缓存的参数: {cached_params}")
  195. # 记录缓存命中
  196. self.execution_detail["extract_params"].update({
  197. "cached": True,
  198. "params": cached_params
  199. })
  200. return cached_params
  201. try:
  202. # 获取工具信息
  203. tool_info = get_tool_info(tool_name)
  204. if not tool_info:
  205. logger.warning(f" ⚠ 未找到工具 {tool_name} 的信息,使用默认参数")
  206. return {"keyword": query}
  207. logger.info(f" 工具 {tool_name} 信息长度: {len(tool_info)}")
  208. # 加载prompt
  209. prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md")
  210. prompt = prompt_template.format(
  211. query=query,
  212. tool_info=tool_info
  213. )
  214. # 调用LLM提取参数
  215. logger.info(" → 调用Gemini提取参数...")
  216. response_text = generate_text(prompt=prompt)
  217. # 解析JSON
  218. logger.info(" → 解析参数JSON...")
  219. try:
  220. # 清理可能的markdown标记
  221. response_text = response_text.strip()
  222. if response_text.startswith("```json"):
  223. response_text = response_text[7:]
  224. if response_text.startswith("```"):
  225. response_text = response_text[3:]
  226. if response_text.endswith("```"):
  227. response_text = response_text[:-3]
  228. response_text = response_text.strip()
  229. params = json.loads(response_text)
  230. logger.info(f"✓ 提取参数成功: {params}")
  231. # 写入缓存
  232. if self.use_cache:
  233. self.cache.set(combined_question, 'function_knowledge', 'tool_params.json', params)
  234. # 记录详情
  235. self.execution_detail["extract_params"].update({
  236. "cached": False,
  237. "prompt": prompt,
  238. "response": response_text,
  239. "params": params
  240. })
  241. return params
  242. except json.JSONDecodeError as e:
  243. logger.error(f" ✗ 解析JSON失败: {e}")
  244. logger.error(f" 响应内容: {response_text}")
  245. # 降级:使用query作为keyword
  246. default_params = {"keyword": query}
  247. logger.warning(f" 使用默认参数: {default_params}")
  248. return default_params
  249. except Exception as e:
  250. logger.error(f"✗ 提取工具参数失败: {e}")
  251. # 降级:使用query作为keyword
  252. return {"keyword": query}
  253. def save_knowledge_to_file(self, knowledge: str, combined_question: str):
  254. """保存获取到的知识到文件"""
  255. try:
  256. logger.info("[保存知识] 开始保存知识到文件...")
  257. # 获取问题hash
  258. import hashlib
  259. question_hash = hashlib.md5(combined_question.encode('utf-8')).hexdigest()[:12]
  260. # 获取缓存目录(和execution_record.json同级)
  261. if self.use_cache and self.cache:
  262. cache_dir = os.path.join(self.cache.base_cache_dir, question_hash)
  263. else:
  264. cache_dir = os.path.join(os.path.dirname(__file__), '.cache', question_hash)
  265. os.makedirs(cache_dir, exist_ok=True)
  266. # 保存到knowledge.txt
  267. knowledge_file = os.path.join(cache_dir, 'knowledge.txt')
  268. with open(knowledge_file, 'w', encoding='utf-8') as f:
  269. f.write(knowledge)
  270. logger.info(f"✓ 知识已保存到: {knowledge_file}")
  271. logger.info(f" 知识长度: {len(knowledge)} 字符")
  272. except Exception as e:
  273. logger.error(f"✗ 保存知识失败: {e}")
  274. def get_knowledge(self, question: str, post_info: str, persona_info: str) -> dict:
  275. """
  276. 获取方法知识的主流程(重构后)
  277. Returns:
  278. dict: 完整的执行记录
  279. """
  280. import time
  281. timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
  282. start_time = time.time()
  283. logger.info("=" * 80)
  284. logger.info(f"Function Knowledge - 开始处理")
  285. logger.info(f"问题: {question}")
  286. logger.info(f"帖子信息: {post_info}")
  287. logger.info(f"人设信息: {persona_info}")
  288. logger.info("=" * 80)
  289. # 组合问题的唯一标识
  290. combined_question = f"{question}||{post_info}||{persona_info}"
  291. try:
  292. # 步骤1: 生成Query
  293. query = self.generate_query(question, post_info, persona_info)
  294. # 步骤2: 选择工具
  295. tool_name = self.select_tool(combined_question, query)
  296. if tool_name and tool_name != "None":
  297. # 路径A: 使用工具
  298. # 步骤3: 提取参数
  299. arguments = self.extract_tool_params(combined_question, tool_name, query)
  300. # 步骤4: 调用工具
  301. logger.info(f"[步骤4] 调用工具: {tool_name}")
  302. # 检查工具调用缓存
  303. if self.use_cache:
  304. cached_tool_result = self.cache.get(combined_question, 'function_knowledge', 'tool_result.json')
  305. if cached_tool_result:
  306. logger.info(f"✓ 使用缓存的工具调用结果")
  307. tool_result = cached_tool_result
  308. else:
  309. logger.info(f" → 调用工具,参数: {arguments}")
  310. tool_result = call_tool(tool_name, arguments)
  311. # 缓存工具调用结果
  312. self.cache.set(combined_question, 'function_knowledge', 'tool_result.json', tool_result)
  313. else:
  314. logger.info(f" → 调用工具,参数: {arguments}")
  315. tool_result = call_tool(tool_name, arguments)
  316. logger.info(f"✓ 工具调用完成")
  317. else:
  318. # 路径B: 知识搜索
  319. logger.info("[步骤4] 未找到合适工具,调用 MultiSearch...")
  320. knowledge = get_multi_search_knowledge(query, cache_key=combined_question)
  321. # 异步保存知识到文件
  322. logger.info("[后台任务] 保存知识到文件...")
  323. threading.Thread(target=self.save_knowledge_to_file, args=(knowledge, combined_question)).start()
  324. # 计算执行时间并保存详情
  325. self.execution_detail["execution_time"] = time.time() - start_time
  326. self._save_execution_detail(combined_question)
  327. # 收集所有执行记录
  328. logger.info("=" * 80)
  329. logger.info("收集执行记录...")
  330. logger.info("=" * 80)
  331. from knowledge_v2.execution_collector import collect_and_save_execution_record
  332. execution_record = collect_and_save_execution_record(
  333. combined_question,
  334. {
  335. "question": question,
  336. "post_info": post_info,
  337. "persona_info": persona_info,
  338. "timestamp": timestamp
  339. }
  340. )
  341. logger.info("=" * 80)
  342. logger.info(f"✓ Function Knowledge 完成")
  343. logger.info(f" 执行时间: {execution_record.get('metadata', {}).get('execution_time', 0):.2f}秒")
  344. logger.info("=" * 80 + "\n")
  345. return execution_record
  346. except Exception as e:
  347. logger.error(f"✗ 执行失败: {e}")
  348. import traceback
  349. logger.error(traceback.format_exc())
  350. # 即使失败也尝试保存详情和收集记录
  351. try:
  352. self.execution_detail["execution_time"] = time.time() - start_time
  353. self._save_execution_detail(combined_question)
  354. from knowledge_v2.execution_collector import collect_and_save_execution_record
  355. execution_record = collect_and_save_execution_record(
  356. combined_question,
  357. {
  358. "question": question,
  359. "post_info": post_info,
  360. "persona_info": persona_info,
  361. "timestamp": timestamp
  362. }
  363. )
  364. return execution_record
  365. except Exception as collect_error:
  366. logger.error(f"收集执行记录也失败: {collect_error}")
  367. # 返回基本错误信息
  368. return {
  369. "input": {
  370. "question": question,
  371. "post_info": post_info,
  372. "persona_info": persona_info,
  373. "timestamp": timestamp
  374. },
  375. "result": {
  376. "type": "error",
  377. "content": f"执行失败: {str(e)}"
  378. },
  379. "metadata": {
  380. "errors": [str(e)]
  381. }
  382. }
  383. if __name__ == "__main__":
  384. # 测试代码
  385. question = "教资查分这个信息怎么来的"
  386. post_info = "发帖时间:2025.11.07"
  387. persona_info = ""
  388. try:
  389. agent = FunctionKnowledge()
  390. execution_result = agent.get_knowledge(question, post_info, persona_info)
  391. print("=" * 50)
  392. print("执行结果:")
  393. print("=" * 50)
  394. print(json.dumps(execution_result, ensure_ascii=False, indent=2))
  395. print(f"\n完整JSON已保存到缓存目录")
  396. except Exception as e:
  397. logger.error(f"测试失败: {e}")