TanJingyu 5 ساعت پیش
والد
کامیت
4771565f11
4فایلهای تغییر یافته به همراه134 افزوده شده و 259 حذف شده
  1. 50 73
      knowledge_v2/function_knowledge.py
  2. 25 92
      knowledge_v2/llm_search_knowledge.py
  3. 2 25
      knowledge_v2/multi_search_knowledge.py
  4. 57 69
      knowledge_v2/tools_library.py

+ 50 - 73
knowledge_v2/function_knowledge.py

@@ -22,7 +22,7 @@ root_dir = os.path.dirname(current_dir)
 sys.path.insert(0, root_dir)
 
 from utils.gemini_client import generate_text
-from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info
+from knowledge_v2.tools_library import call_tool, save_tool_info, get_all_tool_infos, get_tool_info, get_tool_params
 from knowledge_v2.multi_search_knowledge import get_knowledge as get_multi_search_knowledge
 from knowledge_v2.cache_manager import CacheManager
 
@@ -55,7 +55,7 @@ class FunctionKnowledge:
         logger.info("=" * 80)
         
     def _save_execution_detail(self, cache_key: str):
-        """保存执行详情到缓存(支持合并旧记录)"""
+        """保存执行详情到缓存"""
         if not self.use_cache or not self.cache:
             return
         
@@ -70,37 +70,8 @@ class FunctionKnowledge:
             os.makedirs(detail_dir, exist_ok=True)
             
             detail_file = os.path.join(detail_dir, 'execution_detail.json')
-            
-            # 准备最终要保存的数据,默认为当前内存中的数据
-            final_detail = self.execution_detail.copy()
-            
-            # 尝试读取旧文件进行合并
-            if os.path.exists(detail_file):
-                try:
-                    with open(detail_file, 'r', encoding='utf-8') as f:
-                        old_detail = json.load(f)
-                    
-                    # 智能合并逻辑:保留更有价值的历史信息
-                    for key, new_val in self.execution_detail.items():
-                        # 跳过非字典字段或旧文件中不存在的字段
-                        if not isinstance(new_val, dict) or key not in old_detail:
-                            continue
-                            
-                        old_val = old_detail[key]
-                        if not isinstance(old_val, dict):
-                            continue
-                            
-                        # 核心逻辑:如果新记录是缓存命中(cached=True),而旧记录包含prompt(说明是当初生成的)
-                        # 则保留旧记录,防止被简略信息覆盖
-                        if new_val.get("cached", False) is True and "prompt" in old_val:
-                            # logger.debug(f"  保留 {key} 的历史详细记录")
-                            final_detail[key] = old_val
-                            
-                except Exception as e:
-                    logger.warning(f"  ⚠ 读取旧详情失败,将使用新记录: {e}")
-
             with open(detail_file, 'w', encoding='utf-8') as f:
-                json.dump(final_detail, f, ensure_ascii=False, indent=2)
+                json.dump(self.execution_detail, f, ensure_ascii=False, indent=2)
             
             logger.info(f"✓ 执行详情已保存: {detail_file}")
             
@@ -138,7 +109,11 @@ class FunctionKnowledge:
         
         try:
             prompt_template = self._load_prompt("function_generate_query_prompt.md")
-            prompt = prompt_template.replace("{question}", question)
+            prompt = prompt_template.format(
+                question=question,
+                post_info=post_info,
+                persona_info=persona_info
+            )
             
             logger.info("→ 调用Gemini生成Query...")
             query = generate_text(prompt=prompt)
@@ -172,20 +147,8 @@ class FunctionKnowledge:
         """
         logger.info(f"[步骤2] 选择工具...")
         
-        # 尝试从缓存读取
-        if self.use_cache:
-            cached_tool = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.txt')
-            if cached_tool:
-                logger.info(f"✓ 使用缓存的工具: {cached_tool}")
-                # 记录缓存命中
-                self.execution_detail["select_tool"].update({
-                    "cached": True,
-                    "tool_name": cached_tool
-                })
-                return cached_tool
-        
         try:
-            all_tool_infos = self._load_prompt("all_tools_infos.md")
+            all_tool_infos = get_all_tool_infos()
             if not all_tool_infos:
                 logger.info("  工具库为空,无可用工具")
                 return "None"
@@ -194,36 +157,48 @@ class FunctionKnowledge:
             logger.info(f"  当前可用工具数: {tool_count}")
                 
             prompt_template = self._load_prompt("function_knowledge_select_tools_prompt.md")
-            prompt = prompt_template.replace("{all_tool_infos}", all_tool_infos)
+            prompt = prompt_template.format(
+                query=query,
+                tool_infos=all_tool_infos
+            )
+
+            # 尝试从缓存读取
+            if self.use_cache:
+                cached_tool = self.cache.get(combined_question, 'function_knowledge', 'selected_tool.txt')
+                if cached_tool:
+                    logger.info(f"✓ 使用缓存的工具: {cached_tool}")
+                    # 记录缓存命中
+                    self.execution_detail["select_tool"].update({
+                        "cached": True,
+                        "response": json.loads(cached_tool),
+                        "prompt": prompt,
+                    })
+                    return json.loads(cached_tool)
             
             logger.info("→ 调用Gemini选择工具...")
             result = generate_text(prompt=prompt)
-            result_json = json.loads(result)
-            tool_name = result_json.get('工具名', '')
-            tool_mcp_name = result_json.get('工具调用ID', '')
-            tool_instructions = result_json.get('使用方法', '')
-            
-            logger.info(f"✓ 选择结果: {tool_name}")
+            result_json = result.loads(result)
+
+            logger.info(f"✓ 选择结果: {result_json.get('工具名', 'None')}")
             
             # 写入缓存
             if self.use_cache:
-                self.cache.set(combined_question, 'function_knowledge', 'selected_tool.txt', tool_name)
+                self.cache.set(combined_question, 'function_knowledge', 'selected_tool.txt', result)
             
             # 记录详情
             self.execution_detail["select_tool"] = {
                 "cached": False,
                 "prompt": prompt,
-                "response": tool_name,
-                "tool_name": tool_name,
+                "response": result_json,
                 "available_tools_count": tool_count
             }
             
-            return tool_name
+            return result_json
         except Exception as e:
             logger.error(f"✗ 选择工具失败: {e}")
             return "None"
 
-    def extract_tool_params(self, combined_question: str, tool_name: str, query: str) -> dict:
+    def extract_tool_params(self, combined_question: str, query: str, tool_id: str, tool_instructions: str) -> dict:
         """
         根据工具信息和查询提取调用参数
         
@@ -251,18 +226,16 @@ class FunctionKnowledge:
         
         try:
             # 获取工具信息
-            tool_info = get_tool_info(tool_name)
-            if not tool_info:
-                logger.warning(f"  ⚠ 未找到工具 {tool_name} 的信息,使用默认参数")
+            tool_params = get_tool_params(tool_id)
+            if not tool_params:
+                logger.warning(f"  ⚠ 未找到工具 {tool_id} 的信息,使用默认参数")
                 return {"keyword": query}
             
-            logger.info(f"  工具 {tool_name} 信息长度: {len(tool_info)}")
-            
             # 加载prompt
             prompt_template = self._load_prompt("function_knowledge_extract_tool_params_prompt.md")
             prompt = prompt_template.format(
                 query=query,
-                tool_info=tool_info
+                # tool_info=tool_info
             )
             
             # 调用LLM提取参数
@@ -367,15 +340,17 @@ class FunctionKnowledge:
             query = self.generate_query(question, post_info, persona_info)
             
             # 步骤2: 选择工具
-            tool_name = self.select_tool(combined_question, query)
-            
-            if tool_name and tool_name != "None":
+            tool_info = self.select_tool(combined_question, query)
+            # tool_name = tool_info.get("工具名")
+            tool_id = tool_info.get("工具调用ID")
+            tool_instructions = tool_info.get("使用方法")
+            if tool_id and tool_instructions:
                 # 路径A: 使用工具
                 # 步骤3: 提取参数
-                arguments = self.extract_tool_params(combined_question, tool_name, query)
-                
+                arguments = self.extract_tool_params(combined_question, query, tool_id, tool_instructions)
+
                 # 步骤4: 调用工具
-                logger.info(f"[步骤4] 调用工具: {tool_name}")
+                logger.info(f"[步骤4] 调用工具: {tool_id}")
                 
                 # 检查工具调用缓存
                 if self.use_cache:
@@ -385,13 +360,15 @@ class FunctionKnowledge:
                         tool_result = cached_tool_result
                     else:
                         logger.info(f"  → 调用工具,参数: {arguments}")
-                        tool_result = call_tool(tool_name, arguments)
+                        tool_result = call_tool(tool_id, arguments)
                         # 缓存工具调用结果
                         self.cache.set(combined_question, 'function_knowledge', 'tool_result.json', tool_result)
                 else:
                     logger.info(f"  → 调用工具,参数: {arguments}")
-                    tool_result = call_tool(tool_name, arguments)
-                
+                    tool_result = call_tool(tool_id, arguments)
+                    # 缓存工具调用结果
+                    self.cache.set(combined_question, 'function_knowledge', 'tool_result.json', tool_result)
+
                 logger.info(f"✓ 工具调用完成")
                 
             else:

+ 25 - 92
knowledge_v2/llm_search_knowledge.py

@@ -111,10 +111,11 @@ class LLMSearchKnowledge:
             if cached_queries:
                 logger.info(f"✓ 使用缓存的queries: {cached_queries}")
                 # 记录缓存命中
-                self.execution_detail["generate_queries"].update({
+                self.execution_detail["generate_queries"] = {
                     "cached": True,
                     "queries_count": len(cached_queries)
-                })
+                }
+                self.execution_detail["cache_hits"].append("generated_queries")
                 return cached_queries
         
         try:
@@ -152,13 +153,13 @@ class LLMSearchKnowledge:
                     logger.info(f"  {i}. {q}")
                 
                 # 记录执行详情
-                self.execution_detail["generate_queries"].update({
+                self.execution_detail["generate_queries"] = {
                     "cached": False,
                     "prompt": prompt,
                     "response": response_text,
                     "queries_count": len(queries),
                     "queries": queries
-                })
+                }
                 
                 # 写入缓存
                 if self.use_cache:
@@ -285,21 +286,12 @@ class LLMSearchKnowledge:
             Exception: 合并失败时抛出异常
         """
         logger.info(f"[步骤3] 合并知识 - 共 {len(knowledge_texts)} 个文本")
-
-        if len(knowledge_texts) == 1:
-            return knowledge_texts[0]
-
+        
         # 尝试从缓存读取
         if self.use_cache:
             cached_merged = self.cache.get(question, 'llm_search', 'merged_knowledge.txt')
             if cached_merged:
                 logger.info(f"✓ 使用缓存的合并知识 (长度: {len(cached_merged)})")
-                # 记录缓存命中
-                self.execution_detail["merge_detail"].update({
-                    "cached": True,
-                    "knowledge_count": len(knowledge_texts),
-                    "result_length": len(cached_merged)
-                })
                 return cached_merged
         
         try:
@@ -335,15 +327,6 @@ class LLMSearchKnowledge:
             
             logger.info(f"✓ 成功合并知识文本 (长度: {len(merged_text)})")
             
-            # 记录合并详情
-            self.execution_detail["merge_detail"].update({
-                "cached": False,
-                "prompt": prompt,
-                "response": merged_text,
-                "knowledge_count": len(knowledge_texts),
-                "result_length": len(merged_text)
-            })
-            
             # 写入缓存
             if self.use_cache:
                 self.cache.set(question, 'llm_search', 'merged_knowledge.txt', merged_text.strip())
@@ -353,17 +336,17 @@ class LLMSearchKnowledge:
         except Exception as e:
             logger.error(f"✗ 合并知识文本失败: {e}")
             raise
-
+    
     def _save_execution_detail(self, cache_key: str):
         """
-        保存执行详情到缓存(支持合并旧记录)
-
+        保存执行详情到缓存
+        
         Args:
             cache_key: 缓存键
         """
         if not self.use_cache or not self.cache:
             return
-
+        
         try:
             import hashlib
             question_hash = hashlib.md5(cache_key.encode('utf-8')).hexdigest()[:12]
@@ -373,70 +356,17 @@ class LLMSearchKnowledge:
                 'llm_search'
             )
             os.makedirs(detail_dir, exist_ok=True)
-
+            
             detail_file = os.path.join(detail_dir, 'execution_detail.json')
-
-            # 准备最终要保存的数据
-            final_detail = self.execution_detail.copy()
-
-            # 尝试读取旧文件进行合并
-            if os.path.exists(detail_file):
-                try:
-                    with open(detail_file, 'r', encoding='utf-8') as f:
-                        old_detail = json.load(f)
-
-                    # 1. 合并 generate_queries
-                    new_gen = self.execution_detail.get("generate_queries")
-                    old_gen = old_detail.get("generate_queries")
-                    if (new_gen and isinstance(new_gen, dict) and
-                            new_gen.get("cached") is True and
-                            old_gen and isinstance(old_gen, dict) and
-                            "prompt" in old_gen):
-                        final_detail["generate_queries"] = old_gen
-
-                    # 2. 合并 merge_detail
-                    new_merge = self.execution_detail.get("merge_detail")
-                    old_merge = old_detail.get("merge_detail")
-                    if (new_merge and isinstance(new_merge, dict) and
-                            new_merge.get("cached") is True and
-                            old_merge and isinstance(old_merge, dict) and
-                            "prompt" in old_merge):
-                        final_detail["merge_detail"] = old_merge
-
-                    # 3. 合并 search_results (列表)
-                    new_results = self.execution_detail.get("search_results", [])
-                    old_results = old_detail.get("search_results", [])
-
-                    if new_results and old_results:
-                        merged_results = []
-                        # 建立旧结果的索引:(query, index) -> item
-                        old_map = {(item.get("query"), item.get("query_index")): item
-                                   for item in old_results if isinstance(item, dict)}
-
-                        for item in new_results:
-                            if item.get("cached") is True:
-                                key = (item.get("query"), item.get("query_index"))
-                                if key in old_map:
-                                    # 如果旧项包含更多信息(例如非cached状态),则使用旧项
-                                    old_item = old_map[key]
-                                    if old_item.get("cached") is False:
-                                        merged_results.append(old_item)
-                                        continue
-                            merged_results.append(item)
-                        final_detail["search_results"] = merged_results
-
-                except Exception as e:
-                    logger.warning(f"  ⚠ 读取旧详情失败: {e}")
-
             with open(detail_file, 'w', encoding='utf-8') as f:
-                json.dump(final_detail, f, ensure_ascii=False, indent=2)
-
+                json.dump(self.execution_detail, f, ensure_ascii=False, indent=2)
+            
             logger.info(f"✓ 执行详情已保存: {detail_file}")
-
+            
         except Exception as e:
             logger.error(f"✗ 保存执行详情失败: {e}")
     
-    def get_knowledge(self, question: str, cache_key: str = None, need_generate_query: bool = True) -> str:
+    def get_knowledge(self, question: str, cache_key: str = None) -> str:
         """
         主方法:根据问题获取知识文本
         
@@ -462,10 +392,7 @@ class LLMSearchKnowledge:
             logger.info(f"{'='*60}")
             
             # 步骤1: 生成多个query
-            if need_generate_query:
-                queries = self.generate_queries(actual_cache_key)
-            else:
-                queries = [question]
+            queries = self.generate_queries(actual_cache_key)
             
             # 步骤2: 对每个query搜索知识
             knowledge_texts = self.search_knowledge_batch(actual_cache_key, queries)
@@ -491,7 +418,7 @@ class LLMSearchKnowledge:
             raise
 
 
-def get_knowledge(question: str, cache_key: str = None, need_generate_query: bool = True) -> str:
+def get_knowledge(question: str, cache_key: str = None) -> str:
     """
     便捷函数:根据问题获取知识文本
     
@@ -503,7 +430,7 @@ def get_knowledge(question: str, cache_key: str = None, need_generate_query: boo
         str: 最终的知识文本
     """
     agent = LLMSearchKnowledge()
-    return agent.get_knowledge(question, cache_key=cache_key, need_generate_query=need_generate_query)
+    return agent.get_knowledge(question, cache_key=cache_key)
 
 
 if __name__ == "__main__":
@@ -511,7 +438,13 @@ if __name__ == "__main__":
     test_question = "关于猫咪和墨镜的服装造型元素"
     
     try:
-        result = get_knowledge(question=test_question, need_generate_query=False)
+        result = get_knowledge(test_question)
+        print("=" * 50)
+        print("最终知识文本:")
+        print("=" * 50)
+        print(result)
+    except Exception as e:
+        logger.error(f"测试失败: {e}")
         print("=" * 50)
         print("最终知识文本:")
         print("=" * 50)

+ 2 - 25
knowledge_v2/multi_search_knowledge.py

@@ -156,7 +156,7 @@ class MultiSearchKnowledge:
             raise
     
     def _save_execution_detail(self, cache_key: str):
-        """保存执行详情到缓存(支持合并旧记录)"""
+        """保存执行详情到缓存"""
         if not self.use_cache or not self.cache:
             return
         
@@ -171,31 +171,8 @@ class MultiSearchKnowledge:
             os.makedirs(detail_dir, exist_ok=True)
             
             detail_file = os.path.join(detail_dir, 'execution_detail.json')
-            
-            # 准备最终要保存的数据
-            final_detail = self.execution_detail.copy()
-            
-            # 尝试读取旧文件进行合并
-            if os.path.exists(detail_file):
-                try:
-                    with open(detail_file, 'r', encoding='utf-8') as f:
-                        old_detail = json.load(f)
-                    
-                    # 合并 merge_detail
-                    new_merge = self.execution_detail.get("merge_detail")
-                    old_merge = old_detail.get("merge_detail")
-                    
-                    if (new_merge and isinstance(new_merge, dict) and 
-                        new_merge.get("cached") is True and 
-                        old_merge and isinstance(old_merge, dict) and 
-                        "prompt" in old_merge):
-                        final_detail["merge_detail"] = old_merge
-                        
-                except Exception as e:
-                    logger.warning(f"  ⚠ 读取旧详情失败: {e}")
-
             with open(detail_file, 'w', encoding='utf-8') as f:
-                json.dump(final_detail, f, ensure_ascii=False, indent=2)
+                json.dump(self.execution_detail, f, ensure_ascii=False, indent=2)
             
             logger.info(f"✓ 执行详情已保存: {detail_file}")
             

+ 57 - 69
knowledge_v2/tools_library.py

@@ -1,81 +1,16 @@
 '''
 工具库模块,提供工具库中的工具调用、保存新的待接入工具的信息
-分为个函数:
+分为个函数:
 
 1. 调用工具库中的工具
-curl --location 'http://47.84.182.56:8001/tools/call/wechat_search_article' \
---header 'Content-Type: application/json' \
---data '{
-    "keyword": "英雄联盟"
-}'
-其中的data和wechat_search_article是函数的入参,把工具名替换到wechat_search_article,把参数字典替换到data
-
-
 2. 保存新的待接入工具的信息
-入参为工具信息,是一个工具的文档字符串
-目前默认将这个文档字符串保存到一个文件中,文件名默认是工具的名称
-'''
-
-import requests
-import os
-import json
-
-TOOL_SERVER_URL = "http://47.84.182.56:8001/tools/call"
-
-def call_tool(tool_name: str, arguments: dict):
-    """
-    调用工具库中的工具
-    :param tool_name: 工具名称
-    :param arguments: 工具参数字典
-    :return: 工具调用结果
-    """
-    url = f"{TOOL_SERVER_URL}/{tool_name}"
-    headers = {
-        'Content-Type': 'application/json'
-    }
-    try:
-        response = requests.post(url, headers=headers, json=arguments)
-        response.raise_for_status()
-        return response.json()
-    except requests.RequestException as e:
-        # 在实际生产中可能需要更复杂的错误处理或日志记录
-        return {"error": f"Failed to call tool {tool_name}: {str(e)}"}
-
-def save_tool_info(tool_name: str, tool_doc: str):
-    """
-    保存新的待接入工具的信息
-    :param tool_name: 工具名称
-    :param tool_doc: 工具文档字符串
-    :return: 保存的文件路径
-    """
-    # 获取当前文件所在目录
-    current_dir = os.path.dirname(os.path.abspath(__file__))
-    # 创建 tool_infos 目录(如果不存在)
-    save_dir = os.path.join(current_dir, 'tool_infos')
-    if not os.path.exists(save_dir):
-        os.makedirs(save_dir)
-    
-'''
-工具库模块,提供工具库中的工具调用、保存新的待接入工具的信息
-分为两个函数:
-
-1. 调用工具库中的工具
-curl --location 'http://47.84.182.56:8001/tools/call/wechat_search_article' \
---header 'Content-Type: application/json' \
---data '{
-    "keyword": "英雄联盟"
-}'
-其中的data和wechat_search_article是函数的入参,把工具名替换到wechat_search_article,把参数字典替换到data
-
-
-2. 保存新的待接入工具的信息
-入参为工具信息,是一个工具的文档字符串
-目前默认将这个文档字符串保存到一个文件中,文件名默认是工具的名称
+3. 获取工具调用参数信息
 '''
 
 import requests
 import os
 import json
+import re
 
 TOOL_SERVER_URL = "http://47.84.182.56:8001/tools/call"
 
@@ -165,4 +100,57 @@ def get_tool_info(tool_name: str) -> str:
         with open(file_path, 'r', encoding='utf-8') as f:
             return f.read().strip()
     except Exception as e:
-        return f"Error reading tool info: {str(e)}"
+        return f"Error reading tool info: {str(e)}"
+
+def get_tool_params(tools_id: str) -> str:
+    """
+    根据tools_id获取工具调用参数信息
+    :param tools_id: 工具调用ID
+    :return: 工具参数信息的JSON字符串,如果未找到返回空字符串
+    """
+    current_dir = os.path.dirname(os.path.abspath(__file__))
+    prompt_dir = os.path.join(current_dir, 'prompt')
+    params_file = os.path.join(prompt_dir, 'all_tools_params.md')
+    
+    if not os.path.exists(params_file):
+        return ""
+        
+    try:
+        with open(params_file, 'r', encoding='utf-8') as f:
+            content = f.read()
+            
+        # 查找工具调用ID的位置
+        # 格式:工具调用ID:<tools_id>
+        id_pattern = f"工具调用ID:{re.escape(tools_id)}"
+        match = re.search(id_pattern, content)
+        
+        if not match:
+            return ""
+            
+        # 从匹配位置开始往后找JSON块
+        start_index = match.end()
+        json_start = content.find('{', start_index)
+        
+        if json_start == -1:
+            return ""
+            
+        # 简单的JSON提取逻辑:找匹配的大括号
+        brace_count = 0
+        json_end = -1
+        
+        for i in range(json_start, len(content)):
+            if content[i] == '{':
+                brace_count += 1
+            elif content[i] == '}':
+                brace_count -= 1
+                if brace_count == 0:
+                    json_end = i + 1
+                    break
+        
+        if json_end != -1:
+            return content[json_start:json_end]
+            
+        return ""
+        
+    except Exception as e:
+        return f"Error reading tool params: {str(e)}"