jihuaqiang 1 miesiąc temu
rodzic
commit
f2bb1da78d
2 zmienionych plików z 88 dodań i 96 usunięć
  1. 85 94
      agent.py
  2. 3 2
      tools/agent_tools.py

+ 85 - 94
agent.py

@@ -45,9 +45,6 @@ class AgentState(TypedDict):
     details: List[Dict[str, Any]]
     processed: int
     success: int
-    current_index: int
-    current_item: Optional[Dict[str, Any]]
-    identify_result: Optional[Dict[str, Any]]
     error: Optional[str]
     status: str
 
@@ -164,82 +161,88 @@ def create_langgraph_workflow():
             state["status"] = "error"
             return state
     
-    def process_item(state: AgentState) -> AgentState:
-        """处理单个数据项"""
+    def process_items_batch(state: AgentState) -> AgentState:
+        """批量处理所有数据项"""
         try:
             items = state["items"]
-            current_index = state.get("current_index", 0)
-            
-            if current_index >= len(items):
+            if not items:
                 state["status"] = "completed"
                 return state
             
-            item = items[current_index]
-            state["current_item"] = item
-            state["content_id"] = item.get('content_id') or ''
-            state["task_id"] = item.get('task_id') or ''
-            state["current_index"] = current_index + 1
-            
-            # 处理当前项
-            crawl_data = item.get('crawl_data') or {}
-            
-            # Step 1: 识别
-            identify_result = identify_tool.run(
-                crawl_data if isinstance(crawl_data, dict) else {}
-            )
-            state["identify_result"] = identify_result
-            
-            # Step 2: 结构化并入库
-            affected = UpdateDataTool.store_indentify_result(
-                state["request_id"], 
-                {
-                    "content_id": state["content_id"],
-                    "task_id": state["task_id"]
-                }, 
-                identify_result
-            )
-            # 使用StructureTool进行内容结构化处理
-            structure_tool = StructureTool()
-            structure_result = structure_tool.process_content_structure(identify_result)
-            
-            # 存储结构化解析结果
-            parsing_affected = UpdateDataTool.store_parsing_result(
-                state["request_id"],
-                {
-                    "content_id": state["content_id"],
-                    "task_id": state["task_id"]
-                },
-                structure_result
-            )
-            
-            ok = affected is not None and affected > 0 and parsing_affected is not None and parsing_affected > 0
-            if ok:
-                state["success"] += 1
+            success_count = 0
+            details = []
             
-            # 记录处理详情
-            detail = {
-                "index": current_index + 1,
-                "dbInserted": ok,
-                "identifyError": identify_result.get('error'),
-                "status": 2 if ok else 3
-            }
-            state["details"].append(detail)
+            for idx, item in enumerate(items, start=1):
+                try:
+                    crawl_data = item.get('crawl_data') or {}
+                    content_id = item.get('content_id') or ''
+                    task_id = item.get('task_id') or ''
+                    
+                    # Step 1: 识别
+                    identify_result = identify_tool.run(
+                        crawl_data if isinstance(crawl_data, dict) else {}
+                    )
+                    
+                    # Step 2: 结构化并入库
+                    affected = UpdateDataTool.store_indentify_result(
+                        state["request_id"], 
+                        {
+                            "content_id": content_id,
+                            "task_id": task_id
+                        }, 
+                        identify_result
+                    )
+                    
+                    # 使用StructureTool进行内容结构化处理
+                    structure_tool = StructureTool()
+                    structure_result = structure_tool.process_content_structure(identify_result)
+                    
+                    # 存储结构化解析结果
+                    parsing_affected = UpdateDataTool.store_parsing_result(
+                        state["request_id"],
+                        {
+                            "id": affected,
+                            "content_id": content_id,
+                            "task_id": task_id
+                        },
+                        structure_result
+                    )
+                    
+                    ok = affected is not None and affected > 0 and parsing_affected is not None and parsing_affected > 0
+                    if ok:
+                        success_count += 1
+                    
+                    # 记录处理详情
+                    detail = {
+                        "index": idx,
+                        "dbInserted": ok,
+                        "identifyError": identify_result.get('error'),
+                        "status": 2 if ok else 3
+                    }
+                    details.append(detail)
+                    
+                    logger.info(f"处理进度: {idx}/{len(items)} - {'成功' if ok else '失败'}")
+                    
+                except Exception as e:
+                    logger.error(f"处理第 {idx} 项时出错: {e}")
+                    detail = {
+                        "index": idx,
+                        "dbInserted": False,
+                        "identifyError": str(e),
+                        "status": 3
+                    }
+                    details.append(detail)
             
-            state["status"] = "item_processed"
-            logger.info(f"处理进度: {current_index + 1}/{len(items)} - {'成功' if ok else '失败'}")
+            state["success"] = success_count
+            state["details"] = details
+            state["status"] = "completed"
             
             return state
             
         except Exception as e:
-            logger.error(f"处理第 {current_index + 1} 项时出错: {e}")
-            detail = {
-                "index": current_index + 1,
-                "dbInserted": False,
-                "identifyError": str(e),
-                "status": 3
-            }
-            state["details"].append(detail)
-            state["status"] = "item_error"
+            logger.error(f"批量处理失败: {e}")
+            state["error"] = str(e)
+            state["status"] = "error"
             return state
     
     def should_continue(state: AgentState) -> str:
@@ -249,35 +252,23 @@ def create_langgraph_workflow():
             update_request_status(state["request_id"], 3)
             return "end"
         
-        current_index = state.get("current_index", 0)
-        items = state.get("items", [])
-        if current_index >= len(items):
-            # 所有数据处理完毕,更新状态为2
-            update_request_status(state["request_id"], 2)
-            return "end"
-        
-        return "continue"
+        # 所有数据处理完毕,更新状态为2
+        update_request_status(state["request_id"], 2)
+        return "end"
     
     # 构建工作流图
     workflow = StateGraph(AgentState)
     
     # 添加节点
     workflow.add_node("fetch_data", fetch_data)
-    workflow.add_node("process_item", process_item)
+    workflow.add_node("process_items_batch", process_items_batch)
     
     # 设置入口点
     workflow.set_entry_point("fetch_data")
     
     # 添加边
-    workflow.add_edge("fetch_data", "process_item")
-    workflow.add_conditional_edges(
-        "process_item",
-        should_continue,
-        {
-            "continue": "process_item",
-            "end": END
-        }
-    )
+    workflow.add_edge("fetch_data", "process_items_batch")
+    workflow.add_edge("process_items_batch", END)
     
     # 编译工作流
     return workflow.compile()
@@ -335,9 +326,6 @@ async def parse_processing(request: TriggerRequest, background_tasks: Background
                 details=[],
                 processed=0,
                 success=0,
-                current_index=0,
-                current_item=None,
-                identify_result=None,
                 error=None,
                 status="started"
             )
@@ -345,7 +333,10 @@ async def parse_processing(request: TriggerRequest, background_tasks: Background
             # 执行工作流
             final_state = WORKFLOW.invoke(
                 initial_state,
-                config={"configurable": {"thread_id": f"thread_{request.requestId}"}}
+                config={
+                    "configurable": {"thread_id": f"thread_{request.requestId}"},
+                    "recursion_limit": 100  # 增加递归限制
+                }
             )
             
             # 构建响应
@@ -514,16 +505,16 @@ async def process_request_background(request_id: str):
                 details=[],
                 processed=0,
                 success=0,
-                current_index=0,
-                current_item=None,
-                identify_result=None,
                 error=None,
                 status="started"
             )
             
             final_state = WORKFLOW.invoke(
                 initial_state,
-                config={"configurable": {"thread_id": f"thread_{request_id}"}}
+                config={
+                    "configurable": {"thread_id": f"thread_{request_id}"},
+                    "recursion_limit": 100  # 增加递归限制
+                }
             )
             logger.info(f"LangGraph 后台处理完成: requestId={request_id}, processed={final_state.get('processed', 0)}, success={final_state.get('success', 0)}")
             

+ 3 - 2
tools/agent_tools.py

@@ -311,7 +311,7 @@ class UpdateDataTool:
             sql = (
                 "UPDATE knowledge_parsing_content "
                 "SET parsing_data = %s, status = %s "
-                "WHERE content_id = %s"
+                "WHERE content_id = %s AND id = %s"
             )
             
             # 状态:5 表示结构化处理完成
@@ -319,7 +319,8 @@ class UpdateDataTool:
             params = (
                 parsing_payload,
                 status,
-                content_id
+                content_id,
+                crawl_raw.get('id') or ''
             )
             
             result = MysqlHelper.update_values(sql, params)