function_knowledge.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. '''
  2. 方法知识获取模块
  3. 1. 输入:问题 + 帖子信息 + 账号人设信息
  4. 2. 将输入的问题转化成query,调用大模型,prompt在 function_knowledge_generate_query_prompt.md 中
  5. 3. 从已有方法工具库中尝试选择合适的方法工具(调用大模型执行,prompt在 function_knowledge_select_tools_prompt.md 中),如果有,则返回选择的方法工具,否则:
  6. - 调用 multi_search_knowledge.py 获取知识
  7. - 返回新的方法工具知识
  8. - 异步从新方法知识中获取新工具(调用大模型执行,prompt在 function_knowledge_generate_new_tool_prompt.md 中),调用工具库系统,接入新的工具
  9. 4. 调用选择的方法工具执行验证,返回工具执行结果
  10. '''
  11. import os
  12. import sys
  13. import json
  14. import threading
  15. from loguru import logger
  16. import re
  17. # 设置路径以便导入工具类
  18. current_dir = os.path.dirname(os.path.abspath(__file__))
  19. root_dir = os.path.dirname(current_dir)
  20. sys.path.insert(0, root_dir)
  21. from utils.gemini_client import generate_text
  22. from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info, get_tool_params
  23. from knowledge_v2.multi_search_knowledge import get_knowledge as get_multi_search_knowledge
  24. from knowledge_v2.cache_manager import CacheManager
  25. class FunctionKnowledge:
  26. """方法知识获取类"""
  27. def __init__(self, use_cache: bool = True):
  28. """
  29. 初始化
  30. Args:
  31. use_cache: 是否启用缓存,默认启用
  32. """
  33. logger.info("=" * 80)
  34. logger.info("初始化 FunctionKnowledge - 方法知识获取入口")
  35. self.prompt_dir = os.path.join(current_dir, "prompt")
  36. self.use_cache = use_cache
  37. self.cache = CacheManager() if use_cache else None
  38. # 执行详情收集
  39. self.execution_detail = {
  40. "generate_query": {},
  41. "select_tool": {},
  42. "extract_params": {},
  43. "execution_time": 0,
  44. "cache_hits": []
  45. }
  46. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  47. logger.info("=" * 80)
  48. def _save_execution_detail(self, cache_key: str):
  49. """保存执行详情到缓存"""
  50. if not self.use_cache or not self.cache:
  51. return
  52. try:
  53. import hashlib
  54. question_hash = hashlib.md5(cache_key.encode('utf-8')).hexdigest()[:12]
  55. detail_dir = os.path.join(
  56. self.cache.base_cache_dir,
  57. question_hash,
  58. 'function_knowledge'
  59. )
  60. os.makedirs(detail_dir, exist_ok=True)
  61. detail_file = os.path.join(detail_dir, 'execution_detail.json')
  62. with open(detail_file, 'w', encoding='utf-8') as f:
  63. json.dump(self.execution_detail, f, ensure_ascii=False, indent=2)
  64. logger.info(f"✓ 执行详情已保存: {detail_file}")
  65. except Exception as e:
  66. logger.error(f"✗ 保存执行详情失败: {e}")
  67. def _load_prompt(self, filename: str) -> str:
  68. """加载prompt文件内容"""
  69. prompt_path = os.path.join(self.prompt_dir, filename)
  70. if not os.path.exists(prompt_path):
  71. raise FileNotFoundError(f"Prompt文件不存在: {prompt_path}")
  72. with open(prompt_path, 'r', encoding='utf-8') as f:
  73. return f.read().strip()
  74. def generate_query(self, question: str, post_info: str, persona_info: str) -> str:
  75. """
  76. 生成查询语句
  77. Returns:
  78. str: 生成的查询语句
  79. """
  80. logger.info(f"[步骤1] 生成Query...")
  81. # 组合问题的唯一标识
  82. combined_question = f"{question}||{post_info}||{persona_info}"
  83. try:
  84. prompt_template = self._load_prompt("function_generate_query_prompt.md")
  85. prompt = prompt_template.format(
  86. question=question,
  87. post_info=post_info,
  88. persona_info=persona_info
  89. )
  90. # 尝试从缓存读取
  91. if self.use_cache:
  92. cached_query = self.cache.get(combined_question, 'function_knowledge', 'generated_query.txt')
  93. if cached_query:
  94. logger.info(f"✓ 使用缓存的Query: {cached_query}")
  95. # 记录缓存命中
  96. self.execution_detail["generate_query"].update({"cached": True, "query": cached_query, "prompt": prompt})
  97. return cached_query
  98. logger.info("→ 调用Gemini生成Query...")
  99. query = generate_text(prompt=prompt)
  100. query = query.strip()
  101. logger.info(f"✓ 生成Query: {query}")
  102. # 写入缓存
  103. if self.use_cache:
  104. self.cache.set(combined_question, 'function_knowledge', 'generated_query.txt', query)
  105. # 记录详情
  106. self.execution_detail["generate_query"] = {
  107. "cached": False,
  108. "prompt": prompt,
  109. "response": query,
  110. "query": query
  111. }
  112. return query
  113. except Exception as e:
  114. logger.error(f"✗ 生成Query失败: {e}")
  115. return question # 降级使用原问题
  116. def select_tool(self, combined_question: str, query: str) -> str:
  117. """
  118. 选择合适的工具
  119. Returns:
  120. str: 工具名称,如果没有合适的工具则返回"None"
  121. """
  122. logger.info(f"[步骤2] 选择工具...")
  123. try:
  124. all_tool_infos = self._load_prompt("all_tools_infos.md")
  125. if not all_tool_infos:
  126. logger.info(" 工具库为空,无可用工具")
  127. return "None"
  128. prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md")
  129. prompt = prompt_template.replace("{all_tool_infos}", all_tool_infos).replace("query", query)
  130. # 尝试从缓存读取
  131. if self.use_cache:
  132. cached_tool = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.txt')
  133. if cached_tool:
  134. logger.info(f"✓ 使用缓存的工具: {cached_tool}")
  135. # 记录缓存命中
  136. self.execution_detail["select_tool"].update({
  137. "cached": True,
  138. "response": json.loads(cached_tool),
  139. "prompt": prompt,
  140. })
  141. return json.loads(cached_tool)
  142. logger.info("→ 调用Gemini选择工具...")
  143. result = generate_text(prompt=prompt)
  144. result = self.extract_and_validate_json(result)
  145. if not result:
  146. logger.error("✗ 选择工具失败: 无法提取有效JSON")
  147. return "None"
  148. result_json = json.loads(result)
  149. logger.info(f"✓ 选择结果: {result_json.get('工具名', 'None')}")
  150. # 写入缓存
  151. if self.use_cache:
  152. self.cache.set(combined_question, 'function_knowledge', 'selected_tool.txt', result)
  153. # 记录详情
  154. self.execution_detail["select_tool"] = {
  155. "cached": False,
  156. "prompt": prompt,
  157. "response": result_json,
  158. }
  159. return result_json
  160. except Exception as e:
  161. logger.error(f"✗ 选择工具失败: {e}")
  162. return "None"
  163. def extract_and_validate_json(self, text: str):
  164. """
  165. 从字符串中提取 JSON 部分,并返回标准的 JSON 字符串。
  166. 如果无法提取或解析失败,返回 None (或者你可以改为抛出异常)。
  167. """
  168. # 1. 使用正则表达式寻找最大的 JSON 块
  169. # r"(\{[\s\S]*\}|\[[\s\S]*\])" 的含义:
  170. # - \{[\s\S]*\} : 匹配以 { 开头,} 结尾的最长字符串([\s\S] 包含换行符)
  171. # - | : 或者
  172. # - \[[\s\S]*\] : 匹配以 [ 开头,] 结尾的最长字符串(处理 JSON 数组)
  173. match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
  174. if match:
  175. json_str = match.group(0)
  176. try:
  177. # 2. 尝试解析提取出的字符串,验证是否为合法 JSON
  178. parsed_json = json.loads(json_str)
  179. # 3. 重新转储为标准字符串 (去除原本可能存在的缩进、多余空格等)
  180. # ensure_ascii=False 保证中文不会变成 \uXXXX
  181. return json.dumps(parsed_json, ensure_ascii=False)
  182. except json.JSONDecodeError as e:
  183. print(f"提取到了类似JSON的片段,但解析失败: {e}")
  184. return None
  185. else:
  186. print("未在文本中发现 JSON 结构")
  187. return None
  188. def extract_tool_params(self, combined_question: str, query: str, tool_id: str, tool_instructions: str) -> dict:
  189. """
  190. 根据工具信息和查询提取调用参数
  191. Args:
  192. combined_question: 组合问题(用于缓存)
  193. tool_name: 工具名称
  194. query: 查询内容
  195. Returns:
  196. dict: 提取的参数字典
  197. """
  198. logger.info(f"[步骤3] 提取工具参数...")
  199. try:
  200. # 获取工具信息
  201. tool_params = get_tool_params(tool_id)
  202. if not tool_params:
  203. logger.warning(f" ⚠ 未找到工具 {tool_id} 的信息,使用默认参数")
  204. return {"keyword": query}
  205. # 加载prompt
  206. prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md")
  207. prompt = prompt_template.format(
  208. query=query,
  209. all_tool_params=tool_params
  210. )
  211. # 尝试从缓存读取
  212. if self.use_cache:
  213. cached_params = self.cache.get(combined_question, 'function_knowledge', 'tool_params.json')
  214. if cached_params:
  215. logger.info(f"✓ 使用缓存的参数: {cached_params}")
  216. # 记录缓存命中
  217. self.execution_detail["extract_params"].update({
  218. "cached": True,
  219. "params": cached_params,
  220. "prompt": prompt,
  221. })
  222. return cached_params
  223. # 调用LLM提取参数
  224. logger.info(" → 调用Gemini提取参数...")
  225. response_text = generate_text(prompt=prompt)
  226. # 解析JSON
  227. logger.info(" → 解析参数JSON...")
  228. try:
  229. # 清理可能的markdown标记
  230. response_text = response_text.strip()
  231. if response_text.startswith("```json"):
  232. response_text = response_text[7:]
  233. if response_text.startswith("```"):
  234. response_text = response_text[3:]
  235. if response_text.endswith("```"):
  236. response_text = response_text[:-3]
  237. response_text = response_text.strip()
  238. params = json.loads(response_text)
  239. logger.info(f"✓ 提取参数成功: {params}")
  240. # 写入缓存
  241. if self.use_cache:
  242. self.cache.set(combined_question, 'function_knowledge', 'tool_params.json', params)
  243. # 记录详情
  244. self.execution_detail["extract_params"].update({
  245. "cached": False,
  246. "prompt": prompt,
  247. "response": response_text,
  248. "params": params
  249. })
  250. return params
  251. except json.JSONDecodeError as e:
  252. logger.error(f" ✗ 解析JSON失败: {e}")
  253. logger.error(f" 响应内容: {response_text}")
  254. # 降级:使用query作为keyword
  255. default_params = {"keyword": query}
  256. logger.warning(f" 使用默认参数: {default_params}")
  257. return default_params
  258. except Exception as e:
  259. logger.error(f"✗ 提取工具参数失败: {e}")
  260. # 降级:使用query作为keyword
  261. return {"keyword": query}
  262. def save_knowledge_to_file(self, knowledge: str, combined_question: str):
  263. """保存获取到的知识到文件"""
  264. try:
  265. logger.info("[保存知识] 开始保存知识到文件...")
  266. # 获取问题hash
  267. import hashlib
  268. question_hash = hashlib.md5(combined_question.encode('utf-8')).hexdigest()[:12]
  269. # 获取缓存目录(和execution_record.json同级)
  270. if self.use_cache and self.cache:
  271. cache_dir = os.path.join(self.cache.base_cache_dir, question_hash)
  272. else:
  273. cache_dir = os.path.join(os.path.dirname(__file__), '.cache', question_hash)
  274. os.makedirs(cache_dir, exist_ok=True)
  275. # 保存到knowledge.txt
  276. knowledge_file = os.path.join(cache_dir, 'knowledge.txt')
  277. with open(knowledge_file, 'w', encoding='utf-8') as f:
  278. f.write(knowledge)
  279. logger.info(f"✓ 知识已保存到: {knowledge_file}")
  280. logger.info(f" 知识长度: {len(knowledge)} 字符")
  281. except Exception as e:
  282. logger.error(f"✗ 保存知识失败: {e}")
  283. def get_knowledge(self, question: str, post_info: str, persona_info: str) -> dict:
  284. """
  285. 获取方法知识的主流程(重构后)
  286. Returns:
  287. dict: 完整的执行记录
  288. """
  289. import time
  290. timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
  291. start_time = time.time()
  292. logger.info("=" * 80)
  293. logger.info(f"Function Knowledge - 开始处理")
  294. logger.info(f"问题: {question}")
  295. logger.info(f"帖子信息: {post_info}")
  296. logger.info(f"人设信息: {persona_info}")
  297. logger.info("=" * 80)
  298. # 组合问题的唯一标识
  299. combined_question = f"{question}||{post_info}||{persona_info}"
  300. try:
  301. # 步骤1: 生成Query
  302. query = self.generate_query(question, post_info, persona_info)
  303. # 步骤2: 选择工具
  304. tool_info = self.select_tool(combined_question, query)
  305. # tool_name = tool_info.get("工具名")
  306. tool_id = tool_info.get("工具调用ID")
  307. tool_instructions = tool_info.get("使用方法")
  308. if tool_id and tool_instructions:
  309. # 路径A: 使用工具
  310. # 步骤3: 提取参数
  311. arguments = self.extract_tool_params(combined_question, query, tool_id, tool_instructions)
  312. # 步骤4: 调用工具
  313. logger.info(f"[步骤4] 调用工具: {tool_id}")
  314. # 检查工具调用缓存
  315. if self.use_cache:
  316. cached_tool_result = self.cache.get(combined_question, 'function_knowledge', 'tool_result.json')
  317. if cached_tool_result:
  318. logger.info(f"✓ 使用缓存的工具调用结果")
  319. tool_result = cached_tool_result
  320. else:
  321. logger.info(f" → 调用工具,参数: {arguments}")
  322. tool_result = call_tool(tool_id, arguments)
  323. # 缓存工具调用结果
  324. self.cache.set(combined_question, 'function_knowledge', 'tool_result.json', tool_result)
  325. else:
  326. logger.info(f" → 调用工具,参数: {arguments}")
  327. tool_result = call_tool(tool_id, arguments)
  328. logger.info(f"✓ 工具调用完成")
  329. else:
  330. # 路径B: 知识搜索
  331. logger.info("[步骤4] 未找到合适工具,调用 MultiSearch...")
  332. knowledge = get_multi_search_knowledge(query, cache_key=combined_question)
  333. # 异步保存知识到文件
  334. logger.info("[后台任务] 保存知识到文件...")
  335. threading.Thread(target=self.save_knowledge_to_file, args=(knowledge, combined_question)).start()
  336. # 计算执行时间并保存详情
  337. self.execution_detail["execution_time"] = time.time() - start_time
  338. self._save_execution_detail(combined_question)
  339. # 收集所有执行记录
  340. logger.info("=" * 80)
  341. logger.info("收集执行记录...")
  342. logger.info("=" * 80)
  343. from knowledge_v2.execution_collector import collect_and_save_execution_record
  344. execution_record = collect_and_save_execution_record(
  345. combined_question,
  346. {
  347. "question": question,
  348. "post_info": post_info,
  349. "persona_info": persona_info,
  350. "timestamp": timestamp
  351. }
  352. )
  353. logger.info("=" * 80)
  354. logger.info(f"✓ Function Knowledge 完成")
  355. logger.info(f" 执行时间: {execution_record.get('metadata', {}).get('execution_time', 0):.2f}秒")
  356. logger.info("=" * 80 + "\n")
  357. return execution_record
  358. except Exception as e:
  359. logger.error(f"✗ 执行失败: {e}")
  360. import traceback
  361. logger.error(traceback.format_exc())
  362. # 即使失败也尝试保存详情和收集记录
  363. try:
  364. self.execution_detail["execution_time"] = time.time() - start_time
  365. self._save_execution_detail(combined_question)
  366. from knowledge_v2.execution_collector import collect_and_save_execution_record
  367. execution_record = collect_and_save_execution_record(
  368. combined_question,
  369. {
  370. "question": question,
  371. "post_info": post_info,
  372. "persona_info": persona_info,
  373. "timestamp": timestamp
  374. }
  375. )
  376. return execution_record
  377. except Exception as collect_error:
  378. logger.error(f"收集执行记录也失败: {collect_error}")
  379. # 返回基本错误信息
  380. return {
  381. "input": {
  382. "question": question,
  383. "post_info": post_info,
  384. "persona_info": persona_info,
  385. "timestamp": timestamp
  386. },
  387. "result": {
  388. "type": "error",
  389. "content": f"执行失败: {str(e)}"
  390. },
  391. "metadata": {
  392. "errors": [str(e)]
  393. }
  394. }
  395. if __name__ == "__main__":
  396. # 测试代码
  397. question = "教资查分这个选题点怎么来的"
  398. post_info = "发帖时间:2025.11.07"
  399. persona_info = ""
  400. try:
  401. agent = FunctionKnowledge()
  402. execution_result = agent.get_knowledge(question, post_info, persona_info)
  403. print("=" * 50)
  404. print("执行结果:")
  405. print("=" * 50)
  406. print(json.dumps(execution_result, ensure_ascii=False, indent=2))
  407. print(f"\n完整JSON已保存到缓存目录")
  408. except Exception as e:
  409. logger.error(f"测试失败: {e}")