Browse Source

expand任务修改

jihuaqiang 5 days ago
parent
commit
0fad9f8f1f
3 changed files with 40 additions and 39 deletions
  1. 37 36
      agents/expand_agent/agent.py
  2. 1 1
      prompt/expansion.md
  3. 2 2
      tools/agent_tools.py

+ 37 - 36
agents/expand_agent/agent.py

@@ -19,15 +19,15 @@ except ImportError:
 
 
 def _fetch_parsing_data_by_request(request_id: str) -> List[str]:
-    """从 knowledge_parsing_content 表中根据 request_id 获取 parsing_data 字段"""
-    sql = "SELECT parsing_data FROM knowledge_parsing_content WHERE request_id = %s ORDER BY id DESC LIMIT 10"
+    """从 knowledge_extraction_content 表中根据 request_id 获取 data 字段"""
+    sql = "SELECT data FROM knowledge_extraction_content WHERE request_id = %s"
     rows = MysqlHelper.get_values(sql, (request_id,)) or []
     
     results = []
     for row in rows:
-        parsing_data = row[0]  # 获取 parsing_data 字段
-        if parsing_data:
-            results.append(parsing_data)
+        data = row[0]  # 获取 data 字段
+        if data:
+            results.append(data)
     
     print(f"Final results: {len(results)} items")
     return results
@@ -144,49 +144,50 @@ def _update_expansion_status(requestId: str, status: int):
 
 def execute_expand_agent_with_api(requestId: str, query: str = "") -> Dict[str, Any]:
     """对外暴露的API:根据requestId查询数据,生成扩展查询"""
-    # 获取数据
+    # 获取数据(可能多条)
     data_samples = _fetch_parsing_data_by_request(requestId)
-    
-    # 构建prompt
-    prompt = _build_prompt(data_samples[0], query)
 
-    # 生成扩展查询
-    expanded = _run_llm(prompt)
-    if not expanded:
-        expanded = _heuristic_expand(query)
-    
-    # 将扩展查询结果插入到 knowledge_expand_content 表
     try:
-        # 先检查是否已存在记录
-        check_sql = "SELECT id FROM knowledge_expand_content WHERE request_id = %s LIMIT 1"
-        existing_record = MysqlHelper.get_values(check_sql, (requestId,))
-        
-        # 将 expanded 列表转换为 JSON 字符串
-        expand_querys_json = json.dumps(expanded, ensure_ascii=False)
-        print(f"expand_querys_json: {expand_querys_json}")
-        
-        if existing_record:
-            # 记录已存在,执行更新
-            update_sql = """
-            UPDATE knowledge_expand_content 
-            SET expand_querys = %s, query = %s, create_time = NOW()
-            WHERE request_id = %s
+        total = 0
+        success = 0
+        if not data_samples:
+            # 即使没有数据,也基于 query 生成一次兜底扩展
+            prompt = _build_prompt("", query)
+            expanded = _run_llm(prompt)
+            if not expanded:
+                expanded = _heuristic_expand(query)
+            expand_querys_json = json.dumps(expanded, ensure_ascii=False)
+            insert_sql = """
+            INSERT INTO knowledge_expand_content 
+            (request_id, create_time, expand_querys, query) 
+            VALUES (%s, NOW(), %s, %s)
             """
-            affected_rows = MysqlHelper.update_values(update_sql, (expand_querys_json, query, requestId))
-            logger.info(f"扩展查询结果已更新: requestId={requestId}, affected_rows={affected_rows}")
+            MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query))
+            total = 1
+            success = 1 if expanded else 0
         else:
-            # 记录不存在,执行插入
+            # 针对每条 parsing_data 分别生成与入库
             insert_sql = """
             INSERT INTO knowledge_expand_content 
             (request_id, create_time, expand_querys, query) 
             VALUES (%s, NOW(), %s, %s)
             """
-            insert_result = MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query))
-            logger.info(f"扩展查询结果已插入: requestId={requestId}, insert_id={insert_result}")
-        
+            for sample in data_samples:
+                total += 1
+                prompt = _build_prompt(sample, query)
+                expanded = _run_llm(prompt)
+                if not expanded:
+                    expanded = _heuristic_expand(query)
+                try:
+                    expand_querys_json = json.dumps(expanded, ensure_ascii=False)
+                    MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query))
+                    success += 1
+                except Exception as ie:
+                    logger.error(f"单条扩展结果入库失败: requestId={requestId}, error={ie}")
+
         # 更新状态为处理完成
         _update_expansion_status(requestId, 2)
-            
+        logger.info(f"扩展完成: requestId={requestId}, total={total}, success={success}")
     except Exception as e:
         logger.error(f"保存扩展查询结果到数据库时出错: requestId={requestId}, error={e}")
         _update_expansion_status(requestId, 3)

+ 1 - 1
prompt/expansion.md

@@ -47,7 +47,7 @@
     *   是否超出原始内容范围?
     *   是否包含敏感词?
     *   如果评估结果不符合标准,则舍弃该Query词。
-6.  **生成输出:** 将通过评估的所有Query词以JSON数组的格式输出。
+6.  **生成输出:** 将通过评估的所有Query词以JSON数组的格式输出3-5条构思词
 
 # 输出格式
 ```json

+ 2 - 2
tools/agent_tools.py

@@ -120,7 +120,7 @@ class QueryDataTool:
                     "update_timestamp": 1755239186502
                 }
             }]
-            return default_data
+            return []
 
         results: List[Dict[str, Any]] = []
         for row in rows:
@@ -129,7 +129,7 @@ class QueryDataTool:
                 continue
             try:
                 parsed = json.loads(data_cell) if isinstance(data_cell, (str, bytes)) else data_cell
-                logger.info(f"parsed: {parsed}")
+                # logger.info(f"parsed: {parsed}")
                 
                 # 处理元组类型(数据库查询结果)
                 if isinstance(parsed, tuple) and len(parsed) > 4: