function_knowledge.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. '''
  2. 方法知识获取模块
  3. 1. 输入:问题 + 帖子信息 + 账号人设信息
  4. 2. 将输入的问题转化成query,调用大模型,prompt在 function_knowledge_generate_query_prompt.md 中
  5. 3. 从已有方法工具库中尝试选择合适的方法工具(调用大模型执行,prompt在 function_knowledge_select_tools_prompt.md 中),如果有,则返回选择的方法工具,否则:
  6. - 调用 multi_search_knowledge.py 获取知识
  7. - 返回新的方法工具知识
  8. - 异步从新方法知识中获取新工具(调用大模型执行,prompt在 function_knowledge_generate_new_tool_prompt.md 中),调用工具库系统,接入新的工具
  9. 4. 调用选择的方法工具执行验证,返回工具执行结果
  10. '''
  11. import os
  12. import sys
  13. import json
  14. import threading
  15. from loguru import logger
  16. # 设置路径以便导入工具类
  17. current_dir = os.path.dirname(os.path.abspath(__file__))
  18. root_dir = os.path.dirname(current_dir)
  19. sys.path.insert(0, root_dir)
  20. from utils.gemini_client import generate_text
  21. from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info
  22. from knowledge_v2.multi_search_knowledge import get_knowledge as get_multi_search_knowledge
  23. from knowledge_v2.cache_manager import CacheManager
  24. class FunctionKnowledge:
  25. """方法知识获取类"""
  26. def __init__(self, use_cache: bool = True):
  27. """
  28. 初始化
  29. Args:
  30. use_cache: 是否启用缓存,默认启用
  31. """
  32. logger.info("=" * 80)
  33. logger.info("初始化 FunctionKnowledge - 方法知识获取入口")
  34. self.prompt_dir = os.path.join(current_dir, "prompt")
  35. self.use_cache = use_cache
  36. self.cache = CacheManager() if use_cache else None
  37. logger.info(f"缓存状态: {'启用' if use_cache else '禁用'}")
  38. logger.info("=" * 80)
  39. def _load_prompt(self, filename: str) -> str:
  40. """加载prompt文件内容"""
  41. prompt_path = os.path.join(self.prompt_dir, filename)
  42. if not os.path.exists(prompt_path):
  43. raise FileNotFoundError(f"Prompt文件不存在: {prompt_path}")
  44. with open(prompt_path, 'r', encoding='utf-8') as f:
  45. return f.read().strip()
  46. def generate_query(self, question: str, post_info: str, persona_info: str) -> tuple:
  47. """
  48. 生成查询语句
  49. Returns:
  50. tuple: (query, detail_info)
  51. - query: 生成的查询语句
  52. - detail_info: 详细信息dict,包含prompt和response
  53. """
  54. logger.info(f"[步骤1] 生成Query...")
  55. # 组合问题的唯一标识
  56. combined_question = f"{question}||{post_info}||{persona_info}"
  57. detail_info = {"cached": False, "prompt": None, "response": None}
  58. # 尝试从缓存读取
  59. if self.use_cache:
  60. cached_query = self.cache.get(combined_question, 'function_knowledge', 'generated_query.txt')
  61. if cached_query:
  62. logger.info(f"✓ 使用缓存的Query: {cached_query}")
  63. detail_info["cached"] = True
  64. return cached_query, detail_info
  65. try:
  66. prompt_template = self._load_prompt("function_generate_query_prompt.md")
  67. prompt = prompt_template.format(
  68. question=question,
  69. post_info=post_info,
  70. persona_info=persona_info
  71. )
  72. detail_info["prompt"] = prompt
  73. logger.info("→ 调用Gemini生成Query...")
  74. query = generate_text(prompt=prompt)
  75. query = query.strip()
  76. detail_info["response"] = query
  77. logger.info(f"✓ 生成Query: {query}")
  78. # 写入缓存
  79. if self.use_cache:
  80. self.cache.set(combined_question, 'function_knowledge', 'generated_query.txt', query)
  81. return query, detail_info
  82. except Exception as e:
  83. logger.error(f"✗ 生成Query失败: {e}")
  84. detail_info["error"] = str(e)
  85. return question, detail_info # 降级使用原问题
  86. def select_tool(self, combined_question: str, query: str) -> tuple:
  87. """
  88. 选择合适的工具
  89. Returns:
  90. tuple: (tool_name, detail_info)
  91. """
  92. logger.info(f"[步骤2] 选择工具...")
  93. detail_info = {"cached": False, "prompt": None, "response": None, "available_tools_count": 0}
  94. # 尝试从缓存读取
  95. if self.use_cache:
  96. cached_tool = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.txt')
  97. if cached_tool:
  98. logger.info(f"✓ 使用缓存的工具: {cached_tool}")
  99. detail_info["cached"] = True
  100. return cached_tool, detail_info
  101. try:
  102. all_tool_infos = get_all_tool_infos()
  103. if not all_tool_infos:
  104. logger.info(" 工具库为空,无可用工具")
  105. return "None", detail_info
  106. tool_count = len(all_tool_infos.split('--- Tool:')) - 1
  107. detail_info["available_tools_count"] = tool_count
  108. logger.info(f" 当前可用工具数: {tool_count}")
  109. prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md")
  110. prompt = prompt_template.format(
  111. query=query,
  112. tool_infos=all_tool_infos
  113. )
  114. detail_info["prompt"] = prompt
  115. detail_info["tool_infos"] = all_tool_infos
  116. logger.info("→ 调用Gemini选择工具...")
  117. tool_name = generate_text(prompt=prompt)
  118. tool_name = tool_name.strip()
  119. detail_info["response"] = tool_name
  120. logger.info(f"✓ 选择结果: {tool_name}")
  121. # 写入缓存
  122. if self.use_cache:
  123. self.cache.set(combined_question, 'function_knowledge', 'selected_tool.txt', tool_name)
  124. return tool_name, detail_info
  125. except Exception as e:
  126. logger.error(f"✗ 选择工具失败: {e}")
  127. detail_info["error"] = str(e)
  128. return "None", detail_info
  129. def extract_tool_params(self, combined_question: str, tool_name: str, query: str) -> tuple:
  130. """
  131. 根据工具信息和查询提取调用参数
  132. Args:
  133. combined_question: 组合问题(用于缓存)
  134. tool_name: 工具名称
  135. query: 查询内容
  136. Returns:
  137. tuple: (params, detail_info)
  138. """
  139. logger.info(f"[步骤3] 提取工具参数...")
  140. # 初始化detail_info
  141. detail_info = {"cached": False, "prompt": None, "response": None, "tool_info": None}
  142. # 尝试从缓存读取
  143. if self.use_cache:
  144. cached_params = self.cache.get(combined_question, 'function_knowledge', 'tool_params.json')
  145. if cached_params:
  146. logger.info(f"✓ 使用缓存的参数: {cached_params}")
  147. detail_info["cached"] = True
  148. return cached_params, detail_info
  149. try:
  150. # 获取工具信息
  151. tool_info = get_tool_info(tool_name)
  152. if not tool_info:
  153. logger.warning(f" ⚠ 未找到工具 {tool_name} 的信息,使用默认参数")
  154. # 降级:使用query作为keyword
  155. default_params = {"keyword": query}
  156. detail_info["fallback"] = "tool_info_not_found"
  157. return default_params, detail_info
  158. detail_info["tool_info"] = tool_info
  159. logger.info(f" 工具 {tool_name} 信息长度: {len(tool_info)}")
  160. # 加载prompt
  161. prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md")
  162. prompt = prompt_template.format(
  163. query=query,
  164. tool_info=tool_info
  165. )
  166. detail_info["prompt"] = prompt
  167. # 调用LLM提取参数
  168. logger.info(" → 调用Gemini提取参数...")
  169. response_text = generate_text(prompt=prompt)
  170. detail_info["response"] = response_text
  171. # 解析JSON
  172. logger.info(" → 解析参数JSON...")
  173. try:
  174. # 清理可能的markdown标记
  175. response_text = response_text.strip()
  176. if response_text.startswith("```json"):
  177. response_text = response_text[7:]
  178. if response_text.startswith("```"):
  179. response_text = response_text[3:]
  180. if response_text.endswith("```"):
  181. response_text = response_text[:-3]
  182. response_text = response_text.strip()
  183. params = json.loads(response_text)
  184. logger.info(f"✓ 提取参数成功: {params}")
  185. # 写入缓存
  186. if self.use_cache:
  187. self.cache.set(combined_question, 'function_knowledge', 'tool_params.json', params)
  188. return params, detail_info
  189. except json.JSONDecodeError as e:
  190. logger.error(f" ✗ 解析JSON失败: {e}")
  191. logger.error(f" 响应内容: {response_text}")
  192. # 降级:使用query作为keyword
  193. default_params = {"keyword": query}
  194. logger.warning(f" 使用默认参数: {default_params}")
  195. detail_info["fallback"] = "json_decode_error"
  196. return default_params, detail_info
  197. except Exception as e:
  198. logger.error(f"✗ 提取工具参数失败: {e}")
  199. # 降级:使用query作为keyword
  200. default_params = {"keyword": query}
  201. detail_info["error"] = str(e)
  202. detail_info["fallback"] = "exception"
  203. return default_params, detail_info
  204. def generate_and_save_new_tool(self, knowledge: str):
  205. """异步生成并保存新工具"""
  206. try:
  207. logger.info("开始生成新工具...")
  208. prompt_template = self._load_prompt("function_knowledge_generate_new_tool_prompt.md")
  209. prompt = prompt_template.format(knowledge=knowledge)
  210. tool_code = generate_text(prompt=prompt)
  211. # 简单解析工具名(假设工具名在 def xxx( 中)
  212. # 这里做一个简单的提取,实际可能需要更复杂的解析
  213. import re
  214. match = re.search(r"def\s+([a-zA-Z_][a-zA-Z0-9_]*)", tool_code)
  215. if match:
  216. tool_name = match.group(1)
  217. save_path = save_tool_info(tool_name, tool_code)
  218. logger.info(f"新工具已保存: {save_path}")
  219. else:
  220. logger.warning("无法从生成的代码中提取工具名")
  221. except Exception as e:
  222. logger.error(f"生成新工具失败: {e}")
  223. def get_knowledge(self, question: str, post_info: str, persona_info: str) -> dict:
  224. """
  225. 获取方法知识的主流程
  226. Returns:
  227. dict: 包含完整执行信息的字典
  228. {
  229. "input": {...}, # 原始输入
  230. "execution": {...}, # 执行过程信息
  231. "result": {...}, # 最终结果
  232. "metadata": {...} # 元数据
  233. }
  234. """
  235. import time
  236. start_time = time.time()
  237. timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
  238. logger.info("=" * 80)
  239. logger.info(f"Function Knowledge - 开始处理")
  240. logger.info(f"问题: {question}")
  241. logger.info(f"帖子信息: {post_info}")
  242. logger.info(f"人设信息: {persona_info}")
  243. logger.info("=" * 80)
  244. # 组合问题的唯一标识
  245. combined_question = f"{question}||{post_info}||{persona_info}"
  246. # 初始化执行记录
  247. execution_record = {
  248. "input": {
  249. "question": question,
  250. "post_info": post_info,
  251. "persona_info": persona_info,
  252. "timestamp": timestamp
  253. },
  254. "execution": {
  255. "steps": [],
  256. "tool_info": None,
  257. "knowledge_search_info": None
  258. },
  259. "result": {
  260. "type": None, # "tool" 或 "knowledge"
  261. "content": None,
  262. "raw_data": None
  263. },
  264. "metadata": {
  265. "execution_time": None,
  266. "cache_hits": [],
  267. "errors": []
  268. }
  269. }
  270. # 检查最终结果缓存
  271. if self.use_cache:
  272. cached_final = self.cache.get(combined_question, 'function_knowledge', 'final_result.json')
  273. if cached_final:
  274. logger.info(f"✓ 使用缓存的最终结果")
  275. logger.info("=" * 80 + "\n")
  276. # 如果是完整的执行记录,直接返回
  277. if isinstance(cached_final, dict) and "execution" in cached_final:
  278. return cached_final
  279. # 否则构造一个简单的返回
  280. return {
  281. "input": execution_record["input"],
  282. "execution": {"cached": True},
  283. "result": {"type": "cached", "content": cached_final},
  284. "metadata": {"cache_hit": True}
  285. }
  286. try:
  287. # 步骤1: 生成Query
  288. step1_start = time.time()
  289. query, query_detail = self.generate_query(question, post_info, persona_info)
  290. execution_record["execution"]["steps"].append({
  291. "step": 1,
  292. "name": "generate_query",
  293. "duration": time.time() - step1_start,
  294. "output": query,
  295. "detail": query_detail # 包含prompt和response
  296. })
  297. # 步骤2: 选择工具
  298. step2_start = time.time()
  299. tool_name, tool_select_detail = self.select_tool(combined_question, query)
  300. execution_record["execution"]["steps"].append({
  301. "step": 2,
  302. "name": "select_tool",
  303. "duration": time.time() - step2_start,
  304. "output": tool_name,
  305. "detail": tool_select_detail # 包含prompt、response和可用工具列表
  306. })
  307. result_content = None
  308. if tool_name and tool_name != "None":
  309. # 路径A: 使用工具
  310. execution_record["result"]["type"] = "tool"
  311. # 步骤3: 提取参数
  312. step3_start = time.time()
  313. arguments, params_detail = self.extract_tool_params(combined_question, tool_name, query)
  314. execution_record["execution"]["steps"].append({
  315. "step": 3,
  316. "name": "extract_tool_params",
  317. "duration": time.time() - step3_start,
  318. "output": arguments,
  319. "detail": params_detail # 包含prompt、response和工具信息
  320. })
  321. # 步骤4: 调用工具
  322. logger.info(f"[步骤4] 调用工具: {tool_name}")
  323. # 检查工具调用缓存
  324. if self.use_cache:
  325. cached_tool_result = self.cache.get(combined_question, 'function_knowledge', 'tool_result.json')
  326. if cached_tool_result:
  327. logger.info(f"✓ 使用缓存的工具调用结果")
  328. execution_record["metadata"]["cache_hits"].append("tool_result")
  329. tool_result = cached_tool_result
  330. else:
  331. step4_start = time.time()
  332. logger.info(f" → 调用工具,参数: {arguments}")
  333. tool_result = call_tool(tool_name, arguments)
  334. # 缓存工具调用结果
  335. self.cache.set(combined_question, 'function_knowledge', 'tool_result.json', tool_result)
  336. execution_record["execution"]["steps"].append({
  337. "step": 4,
  338. "name": "call_tool",
  339. "duration": time.time() - step4_start,
  340. "output": "success"
  341. })
  342. else:
  343. step4_start = time.time()
  344. logger.info(f" → 调用工具,参数: {arguments}")
  345. tool_result = call_tool(tool_name, arguments)
  346. execution_record["execution"]["steps"].append({
  347. "step": 4,
  348. "name": "call_tool",
  349. "duration": time.time() - step4_start,
  350. "output": "success"
  351. })
  352. # 记录工具调用信息
  353. execution_record["execution"]["tool_info"] = {
  354. "tool_name": tool_name,
  355. "parameters": arguments,
  356. "result": tool_result
  357. }
  358. result_content = f"工具 {tool_name} 执行结果: {json.dumps(tool_result, ensure_ascii=False)}"
  359. execution_record["result"]["content"] = result_content
  360. execution_record["result"]["raw_data"] = tool_result
  361. logger.info(f"✓ 工具调用完成")
  362. else:
  363. # 路径B: 知识搜索
  364. execution_record["result"]["type"] = "knowledge_search"
  365. logger.info("[步骤4] 未找到合适工具,调用 MultiSearch...")
  366. step4_start = time.time()
  367. knowledge = get_multi_search_knowledge(query)
  368. execution_record["execution"]["steps"].append({
  369. "step": 4,
  370. "name": "multi_search_knowledge",
  371. "duration": time.time() - step4_start,
  372. "output": f"knowledge_length: {len(knowledge)}"
  373. })
  374. # 记录知识搜索信息
  375. execution_record["execution"]["knowledge_search_info"] = {
  376. "query": query,
  377. "knowledge_length": len(knowledge),
  378. "source": "multi_search"
  379. }
  380. result_content = knowledge
  381. execution_record["result"]["content"] = knowledge
  382. execution_record["result"]["raw_data"] = {"knowledge": knowledge, "query": query}
  383. # 异步生成新工具
  384. logger.info("[后台任务] 启动新工具生成线程...")
  385. threading.Thread(target=self.generate_and_save_new_tool, args=(knowledge,)).start()
  386. # 计算总执行时间
  387. execution_record["metadata"]["execution_time"] = time.time() - start_time
  388. # 保存完整的执行记录到JSON文件
  389. if self.use_cache:
  390. self.cache.set(combined_question, 'function_knowledge', 'final_result.json', execution_record)
  391. # 同时保存一个格式化的JSON文件供人类阅读
  392. from knowledge_v2.cache_manager import CacheManager
  393. cache = CacheManager()
  394. import hashlib
  395. question_hash = hashlib.md5(combined_question.encode('utf-8')).hexdigest()[:12]
  396. output_file = os.path.join(cache.base_cache_dir, question_hash, 'execution_record.json')
  397. try:
  398. with open(output_file, 'w', encoding='utf-8') as f:
  399. json.dump(execution_record, f, ensure_ascii=False, indent=2)
  400. logger.info(f"✓ 完整执行记录已保存: {output_file}")
  401. except Exception as e:
  402. logger.error(f"保存执行记录失败: {e}")
  403. logger.info("=" * 80)
  404. logger.info(f"✓ Function Knowledge 完成")
  405. logger.info(f" 类型: {execution_record['result']['type']}")
  406. logger.info(f" 结果长度: {len(result_content) if result_content else 0}")
  407. logger.info(f" 执行时间: {execution_record['metadata']['execution_time']:.2f}秒")
  408. logger.info("=" * 80 + "\n")
  409. return execution_record
  410. except Exception as e:
  411. logger.error(f"✗ 执行失败: {e}")
  412. import traceback
  413. error_trace = traceback.format_exc()
  414. execution_record["metadata"]["errors"].append({
  415. "error": str(e),
  416. "traceback": error_trace
  417. })
  418. execution_record["result"]["type"] = "error"
  419. execution_record["result"]["content"] = f"执行失败: {str(e)}"
  420. execution_record["metadata"]["execution_time"] = time.time() - start_time
  421. return execution_record
  422. def get_knowledge(question: str, post_info: str, persona_info: str) -> dict:
  423. """
  424. 便捷调用函数
  425. Returns:
  426. dict: 完整的执行记录,包含输入、执行过程、结果和元数据
  427. """
  428. agent = FunctionKnowledge()
  429. return agent.get_knowledge(question, post_info, persona_info)
  430. if __name__ == "__main__":
  431. # 测试代码
  432. question = "小老虎 穿搭"
  433. post_info = "无"
  434. persona_info = "游戏博主"
  435. try:
  436. agent = FunctionKnowledge()
  437. execution_result = agent.get_knowledge(question, post_info, persona_info)
  438. print("=" * 50)
  439. print("执行结果:")
  440. print("=" * 50)
  441. print(f"类型: {execution_result['result']['type']}")
  442. print(f"内容预览: {execution_result['result']['content'][:200]}...")
  443. print(f"执行时间: {execution_result['metadata']['execution_time']:.2f}秒")
  444. print(f"\n完整JSON已保存到缓存目录")
  445. except Exception as e:
  446. logger.error(f"测试失败: {e}")