function_knowledge.py 24 KB


  1. '''
  2. 方法知识获取模块
  3. 1. 输入:问题 + 帖子信息 + 账号人设信息
  4. 2. 将输入的问题转化成query,调用大模型,prompt在 function_knowledge_generate_query_prompt.md 中
  5. 3. 从已有方法工具库中尝试选择合适的方法工具(调用大模型执行,prompt在 function_knowledge_select_tools_prompt.md 中),如果有,则返回选择的方法工具,否则:
  6. - 调用 multi_search_knowledge.py 获取知识
  7. - 返回新的方法工具知识
  8. - 异步从新方法知识中获取新工具(调用大模型执行,prompt在 function_knowledge_generate_new_tool_prompt.md 中),调用工具库系统,接入新的工具
  9. 4. 调用选择的方法工具执行验证,返回工具执行结果
  10. '''
  11. import os
  12. import sys
  13. import json
  14. import threading
  15. from loguru import logger
  16. import re
  17. # 设置路径以便导入工具类
  18. current_dir = os.path.dirname(os.path.abspath(__file__))
  19. root_dir = os.path.dirname(current_dir)
  20. sys.path.insert(0, root_dir)
  21. from utils.qwen_client import QwenClient
  22. from utils.gemini_client import generate_text
  23. from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info, get_tool_params
  24. from knowledge_v2.multi_search_knowledge import get_knowledge as get_multi_search_knowledge
  25. from knowledge_v2.cache_manager import CacheManager
  26. class FunctionKnowledge:
  27. """方法知识获取类"""
  28. def __init__(self, use_cache: bool = True):
  29. """
  30. 初始化
  31. Args:
  32. use_cache: 是否启用缓存,默认启用
  33. """
  34. logger.info("=" * 80)
  35. logger.info("初始化 FunctionKnowledge - 方法知识获取入口")
  36. self.prompt_dir = os.path.join(current_dir, "prompt")
  37. self.use_cache = use_cache
  38. self.cache = CacheManager() if use_cache else None
  39. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  40. logger.info("=" * 80)
  41. def _load_prompt(self, filename: str) -> str:
  42. """加载prompt文件内容"""
  43. prompt_path = os.path.join(self.prompt_dir, filename)
  44. if not os.path.exists(prompt_path):
  45. raise FileNotFoundError(f"Prompt文件不存在: {prompt_path}")
  46. with open(prompt_path, 'r', encoding='utf-8') as f:
  47. return f.read().strip()
  48. def generate_query(self, question: str, post_info: str, persona_info: str) -> str:
  49. """
  50. 生成查询语句
  51. Returns:
  52. str: 生成的查询语句
  53. """
  54. logger.info(f"[步骤1] 生成Query...")
  55. # 组合问题的唯一标识
  56. combined_question = f"{question}||{post_info}||{persona_info}"
  57. try:
  58. prompt_template = self._load_prompt("function_generate_query_prompt.md")
  59. prompt = prompt_template.format(
  60. question=question,
  61. post_info=post_info,
  62. persona_info=persona_info
  63. )
  64. # 尝试从缓存读取
  65. if self.use_cache:
  66. cached_data = self.cache.get(combined_question, 'function_knowledge', 'generated_query.json')
  67. if cached_data:
  68. query = cached_data.get('query', cached_data.get('response', ''))
  69. logger.info(f"✓ 使用缓存的Query: {query}")
  70. return query
  71. logger.info("→ 调用Gemini生成Query...")
  72. query = generate_text(prompt=prompt)
  73. query = query.strip()
  74. logger.info(f"✓ 生成Query: {query}")
  75. # 保存到缓存(包含完整的prompt和response)
  76. if self.use_cache:
  77. query_data = {
  78. "prompt": prompt,
  79. "response": query,
  80. "query": query
  81. }
  82. self.cache.set(combined_question, 'function_knowledge', 'generated_query.json', query_data)
  83. return query
  84. except Exception as e:
  85. logger.error(f"✗ 生成Query失败: {e}")
  86. return question # 降级使用原问题
  87. def select_tool(self, combined_question: str, input_info: str) -> str:
  88. """
  89. 选择合适的工具
  90. Returns:
  91. str: 工具名称,如果没有合适的工具则返回"None"
  92. """
  93. logger.info(f"[步骤2] 选择工具...")
  94. try:
  95. all_tool_infos = self._load_prompt("all_tools_infos.md")
  96. if not all_tool_infos:
  97. logger.info(" 工具库为空,无可用工具")
  98. return "None"
  99. prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md")
  100. prompt = prompt_template.replace("{all_tool_infos}", all_tool_infos).replace("input_info", input_info)
  101. # 尝试从缓存读取
  102. if self.use_cache:
  103. cached_data = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.json')
  104. if cached_data:
  105. result_json = cached_data.get('response', {})
  106. logger.info(f"✓ 使用缓存的工具: {result_json}")
  107. return result_json
  108. logger.info("→ 调用Gemini选择工具...")
  109. result = generate_text(prompt=prompt)
  110. result = self.extract_and_validate_json(result)
  111. if not result:
  112. logger.error("✗ 选择工具失败: 无法提取有效JSON")
  113. return "None"
  114. result_json = json.loads(result)
  115. logger.info(f"✓ 选择结果: {result_json.get('工具名', 'None')}")
  116. # 保存到缓存(包含完整的prompt和response)
  117. if self.use_cache:
  118. tool_data = {
  119. "prompt": prompt,
  120. "response": result_json
  121. }
  122. self.cache.set(combined_question, 'function_knowledge', 'selected_tool.json', tool_data)
  123. return result_json
  124. except Exception as e:
  125. logger.error(f"✗ 选择工具失败: {e}")
  126. return "None"
  127. def extract_and_validate_json(self, text: str):
  128. """
  129. 从字符串中提取 JSON 部分,并返回标准的 JSON 字符串。
  130. 如果无法提取或解析失败,返回 None (或者你可以改为抛出异常)。
  131. """
  132. # 1. 使用正则表达式寻找最大的 JSON 块
  133. # r"(\{[\s\S]*\}|\[[\s\S]*\])" 的含义:
  134. # - \{[\s\S]*\} : 匹配以 { 开头,} 结尾的最长字符串([\s\S] 包含换行符)
  135. # - | : 或者
  136. # - \[[\s\S]*\] : 匹配以 [ 开头,] 结尾的最长字符串(处理 JSON 数组)
  137. match = re.search(r"(\{[\s\S]*\}|\[[\s\S]*\])", text)
  138. if match:
  139. json_str = match.group(0)
  140. try:
  141. # 2. 尝试解析提取出的字符串,验证是否为合法 JSON
  142. parsed_json = json.loads(json_str)
  143. # 3. 重新转储为标准字符串 (去除原本可能存在的缩进、多余空格等)
  144. # ensure_ascii=False 保证中文不会变成 \uXXXX
  145. return json.dumps(parsed_json, ensure_ascii=False)
  146. except json.JSONDecodeError as e:
  147. print(f"提取到了类似JSON的片段,但解析失败: {e}")
  148. return None
  149. else:
  150. print("未在文本中发现 JSON 结构")
  151. return None
  152. def extract_tool_params(self, combined_question: str, input_info: str, tool_id: str, tool_instructions: str) -> dict:
  153. """
  154. 根据工具信息和查询提取调用参数
  155. Args:
  156. combined_question: 组合问题(用于缓存)
  157. tool_name: 工具名称
  158. query: 查询内容
  159. Returns:
  160. dict: 提取的参数字典
  161. """
  162. logger.info(f"[步骤3] 提取工具参数...")
  163. try:
  164. # 获取工具信息
  165. tool_params = get_tool_params(tool_id)
  166. if not tool_params:
  167. logger.warning(f" ⚠ 未找到工具 {tool_id} 的信息,使用默认参数")
  168. return {"keyword": input_info}
  169. # 加载prompt
  170. prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md")
  171. prompt = prompt_template.format(
  172. tool_mcp_name=tool_id,
  173. input_info=input_info,
  174. all_tool_params=tool_params
  175. )
  176. # 尝试从缓存读取
  177. if self.use_cache:
  178. cached_data = self.cache.get(combined_question, 'function_knowledge', 'extracted_params.json')
  179. if cached_data:
  180. params = cached_data.get('params', {})
  181. logger.info(f"✓ 使用缓存的参数: {params}")
  182. return params
  183. # 调用LLM提取参数
  184. logger.info(" → 调用Gemini提取参数...")
  185. response_text = generate_text(prompt=prompt)
  186. # 解析JSON
  187. logger.info(" → 解析参数JSON...")
  188. try:
  189. # 清理可能的markdown标记
  190. response_text = response_text.strip()
  191. if response_text.startswith("```json"):
  192. response_text = response_text[7:]
  193. if response_text.startswith("```"):
  194. response_text = response_text[3:]
  195. if response_text.endswith("```"):
  196. response_text = response_text[:-3]
  197. response_text = response_text.strip()
  198. params = json.loads(response_text)
  199. logger.info(f"✓ 提取参数成功: {params}")
  200. # 保存到缓存(包含完整的prompt和response)
  201. if self.use_cache:
  202. params_data = {
  203. "prompt": prompt,
  204. "response": response_text,
  205. "params": params
  206. }
  207. self.cache.set(combined_question, 'function_knowledge', 'extracted_params.json', params_data)
  208. return params
  209. except json.JSONDecodeError as e:
  210. logger.error(f" ✗ 解析JSON失败: {e}")
  211. logger.error(f" 响应内容: {response_text}")
  212. # 降级:使用input_info作为keyword
  213. default_params = {"keyword": input_info}
  214. logger.warning(f" 使用默认参数: {default_params}")
  215. return default_params
  216. except Exception as e:
  217. logger.error(f"✗ 提取工具参数失败: {e}")
  218. # 降级:使用input_info作为keyword
  219. return {"keyword": input_info}
  220. def save_knowledge_to_file(self, knowledge: str, combined_question: str):
  221. """保存获取到的知识到文件"""
  222. try:
  223. logger.info("[保存知识] 开始保存知识到文件...")
  224. # 获取问题hash
  225. import hashlib
  226. question_hash = hashlib.md5(combined_question.encode('utf-8')).hexdigest()[:12]
  227. # 获取缓存目录(和execution_record.json同级)
  228. if self.use_cache and self.cache:
  229. cache_dir = os.path.join(self.cache.base_cache_dir, question_hash)
  230. else:
  231. cache_dir = os.path.join(os.path.dirname(__file__), '.cache', question_hash)
  232. os.makedirs(cache_dir, exist_ok=True)
  233. # 保存到knowledge.txt
  234. knowledge_file = os.path.join(cache_dir, 'knowledge.txt')
  235. with open(knowledge_file, 'w', encoding='utf-8') as f:
  236. f.write(knowledge)
  237. logger.info(f"✓ 知识已保存到: {knowledge_file}")
  238. logger.info(f" 知识长度: {len(knowledge)} 字符")
  239. except Exception as e:
  240. logger.error(f"✗ 保存知识失败: {e}")
  241. def organize_tool_result(self, tool_result: dict) -> dict:
  242. """
  243. 组织工具调用结果,确保包含必要字段
  244. Args:
  245. tool_result: 原始工具调用结果
  246. Returns:
  247. dict: 组织后的工具调用结果
  248. """
  249. prompt_template = self._load_prompt("tool_result_prettify_prompt.md")
  250. prompt = prompt_template.format(
  251. input=tool_result,
  252. )
  253. # qwen_client = QwenClient()
  254. # organized_result = qwen_client.chat(user_prompt=prompt)
  255. # organized_result = generate_text(prompt=prompt)
  256. # organized_result = organized_result.strip()
  257. # return organized_result
  258. try:
  259. result = tool_result.get('result')
  260. if not result:
  261. return tool_result
  262. else:
  263. return result
  264. except Exception as e:
  265. logger.error(f"✗ 组织工具调用结果失败: {e}")
  266. return tool_result
  267. def evaluate_tool_result(self, combined_question: str, input_info: str, tool_result) -> dict:
  268. """
  269. 评估工具执行结果是否可以回答输入的需求
  270. Args:
  271. combined_question: 组合问题(用于缓存)
  272. input_info: 输入的需求信息
  273. tool_result: 工具执行结果(可以是dict、list、str等任意类型)
  274. Returns:
  275. dict: 评估结果,包含"是否可以回答"和"理由"
  276. """
  277. logger.info(f"[步骤5] 评估工具执行结果...")
  278. try:
  279. # 加载prompt
  280. prompt_template = self._load_prompt("function_knowledge_tool_result_eval_prompt.md")
  281. # 将tool_result转换为字符串格式,便于在prompt中使用
  282. if isinstance(tool_result, (dict, list)):
  283. tool_result_str = json.dumps(tool_result, ensure_ascii=False, indent=2)
  284. else:
  285. tool_result_str = str(tool_result)
  286. prompt = prompt_template.replace('{tool_call_result}', tool_result_str).replace('{input_info}', input_info)
  287. # 尝试从缓存读取
  288. if self.use_cache:
  289. cached_data = self.cache.get(combined_question, 'function_knowledge', 'tool_result_eval.json')
  290. if cached_data:
  291. eval_result = cached_data.get('eval_result', {})
  292. logger.info(f"✓ 使用缓存的评估结果: {eval_result}")
  293. return eval_result
  294. # 调用LLM进行评估
  295. logger.info(" → 调用Gemini评估工具执行结果...")
  296. response_text = generate_text(prompt=prompt)
  297. # 解析JSON
  298. logger.info(" → 解析评估结果JSON...")
  299. try:
  300. # 清理可能的markdown标记
  301. response_text = response_text.strip()
  302. if response_text.startswith("```json"):
  303. response_text = response_text[7:]
  304. if response_text.startswith("```"):
  305. response_text = response_text[3:]
  306. if response_text.endswith("```"):
  307. response_text = response_text[:-3]
  308. response_text = response_text.strip()
  309. # 使用extract_and_validate_json提取JSON
  310. json_str = self.extract_and_validate_json(response_text)
  311. if json_str:
  312. eval_result = json.loads(json_str)
  313. else:
  314. # 如果提取失败,尝试直接解析
  315. eval_result = json.loads(response_text)
  316. logger.info(f"✓ 评估完成: {eval_result.get('是否可以回答', '未知')}")
  317. # 保存到缓存(包含完整的prompt和response)
  318. if self.use_cache:
  319. eval_data = {
  320. "prompt": prompt,
  321. "response": response_text,
  322. "eval_result": eval_result
  323. }
  324. self.cache.set(combined_question, 'function_knowledge', 'tool_result_eval.json', eval_data)
  325. return eval_result
  326. except json.JSONDecodeError as e:
  327. logger.error(f" ✗ 解析JSON失败: {e}")
  328. logger.error(f" 响应内容: {response_text}")
  329. # 降级:返回默认评估结果
  330. default_eval = {
  331. "是否可以回答": "未知",
  332. "理由": f"评估失败,无法解析LLM响应: {str(e)}"
  333. }
  334. logger.warning(f" 使用默认评估结果: {default_eval}")
  335. return default_eval
  336. except Exception as e:
  337. logger.error(f"✗ 评估工具执行结果失败: {e}")
  338. # 降级:返回默认评估结果
  339. return {
  340. "是否可以回答": "未知",
  341. "理由": f"评估过程出错: {str(e)}"
  342. }
  343. def get_knowledge(self, input_info: str) -> dict:
  344. """
  345. 获取方法知识的主流程(重构后)
  346. Returns:
  347. dict: 完整的执行记录
  348. """
  349. import time
  350. timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
  351. start_time = time.time()
  352. logger.info("=" * 80)
  353. logger.info(f"Function Knowledge - 开始处理")
  354. logger.info(f"输入: {input_info}")
  355. logger.info("=" * 80)
  356. # 组合问题的唯一标识
  357. combined_question = input_info
  358. try:
  359. # 步骤1: 生成Query
  360. # query = self.generate_query(question, post_info, persona_info)
  361. # 步骤2: 选择工具
  362. tool_info = self.select_tool(combined_question, input_info)
  363. # tool_name = tool_info.get("工具名")
  364. tool_id = tool_info.get("工具调用ID")
  365. # tool_instructions = tool_info.get("使用方法")
  366. if tool_id and len(tool_id) > 0:
  367. # 路径A: 使用工具
  368. # 步骤3: 提取参数
  369. arguments = self.extract_tool_params(combined_question, input_info, tool_id, None)
  370. # 步骤4: 调用工具
  371. logger.info(f"[步骤4] 调用工具: {tool_id}")
  372. # 检查工具调用缓存
  373. if self.use_cache:
  374. cached_tool_call = self.cache.get(combined_question, 'function_knowledge', 'tool_call.json')
  375. if cached_tool_call:
  376. logger.info(f"✓ 使用缓存的工具调用结果")
  377. response = cached_tool_call.get('response', {})
  378. tool_result = self.organize_tool_result(response)
  379. # 保存工具调用信息(包含工具名、入参、结果)
  380. tool_call_data = {
  381. "tool_name": tool_id,
  382. "arguments": arguments,
  383. "result": tool_result,
  384. "response": response
  385. }
  386. self.cache.set(combined_question, 'function_knowledge', 'tool_call.json', tool_call_data)
  387. else:
  388. logger.info(f" → 调用工具,参数: {arguments}")
  389. rs = call_tool(tool_id, arguments)
  390. tool_result = self.organize_tool_result(rs)
  391. # 保存工具调用信息(包含工具名、入参、结果)
  392. tool_call_data = {
  393. "tool_name": tool_id,
  394. "arguments": arguments,
  395. "result": tool_result,
  396. "response": rs
  397. }
  398. self.cache.set(combined_question, 'function_knowledge', 'tool_call.json', tool_call_data)
  399. else:
  400. logger.info(f" → 调用工具,参数: {arguments}")
  401. rs = call_tool(tool_id, arguments)
  402. tool_result = self.organize_tool_result(rs)
  403. logger.info(f"✓ 工具调用完成")
  404. # 步骤5: 评估工具执行结果
  405. eval_result = self.evaluate_tool_result(combined_question, input_info, tool_result)
  406. logger.info(f" 评估结果: {eval_result.get('是否可以回答', '未知')}")
  407. if eval_result.get('理由'):
  408. logger.info(f" 评估理由: {eval_result.get('理由')}")
  409. else:
  410. # 路径B: 知识搜索
  411. logger.info("[步骤4] 未找到合适工具,调用 MultiSearch...")
  412. knowledge = get_multi_search_knowledge(input_info, cache_key=combined_question)
  413. # 异步保存知识到文件
  414. logger.info("[后台任务] 保存知识到文件...")
  415. threading.Thread(target=self.save_knowledge_to_file, args=(knowledge, combined_question)).start()
  416. # 计算执行时间
  417. execution_time = time.time() - start_time
  418. # 收集所有执行记录
  419. logger.info("=" * 80)
  420. logger.info("收集执行记录...")
  421. logger.info("=" * 80)
  422. from knowledge_v2.execution_collector import collect_and_save_execution_record
  423. execution_record = collect_and_save_execution_record(
  424. combined_question,
  425. input_info
  426. )
  427. logger.info("=" * 80)
  428. logger.info(f"✓ Function Knowledge 完成")
  429. logger.info(f" 执行时间: {execution_record.get('metadata', {}).get('execution_time', 0):.2f}秒")
  430. logger.info("=" * 80 + "\n")
  431. return execution_record
  432. except Exception as e:
  433. logger.error(f"✗ 执行失败: {e}")
  434. import traceback
  435. logger.error(traceback.format_exc())
  436. # 即使失败也尝试收集记录
  437. try:
  438. execution_time = time.time() - start_time
  439. from knowledge_v2.execution_collector import collect_and_save_execution_record
  440. execution_record = collect_and_save_execution_record(
  441. combined_question,
  442. input_info
  443. )
  444. return execution_record
  445. except Exception as collect_error:
  446. logger.error(f"收集执行记录也失败: {collect_error}")
  447. # 返回基本错误信息
  448. return {
  449. "input": f"{input_info}",
  450. "result": {
  451. "type": "error",
  452. "content": f"执行失败: {str(e)}"
  453. },
  454. "metadata": {
  455. "errors": [str(e)]
  456. }
  457. }
  458. if __name__ == "__main__":
  459. # 测试代码
  460. input_info = """1.已知信息账号人设:
  461. -账号的品类:宠物表情包账号
  462. -人设里能和该贴匹配的点:
  463. 鼓励式猫咪表情包-猫咪考试祝福
  464. 推广饮品品牌-推广餐饮品牌
  465. 互动粉丝-互动特点人群
  466. 拟人化猫咪形象-拟人化猫咪形象
  467. 表情包式图文-表情包式视觉风格
  468. 情景化植入-强关联场景植入
  469. -账号聚出来的pattern模式:
  470. 模式1: 拟人化穿搭+趣味分享意图, 萌宠主题内容+拟人化主体,视觉构图版式....
  471. 模式2:校园学生人设+商业推广意图,商业产品推厂+场景化产品植入
  472. 模式3:日常生活演绎+萌宠主题内容+图文叙事结构
  473. 模式4:视觉隐喻+趣味分享意图+视觉构图版式
  474. 2.待寻找点:
  475. -社交媒体解构贴中未与账号人设匹配的信息
  476. 考试祝福
  477. 3.帖子创作日期:2025-11-07
  478. """
  479. try:
  480. agent = FunctionKnowledge()
  481. execution_result = agent.get_knowledge(input_info=input_info)
  482. print("=" * 50)
  483. print("执行结果:")
  484. print("=" * 50)
  485. print(json.dumps(execution_result, ensure_ascii=False, indent=2))
  486. print(f"\n完整JSON已保存到缓存目录")
  487. except Exception as e:
  488. logger.error(f"测试失败: {e}")