|
@@ -1,15 +1,15 @@
|
|
|
#!/usr/bin/env python3
|
|
|
# -*- coding: utf-8 -*-
|
|
|
"""
|
|
|
-使用 FastAPI 重构的 Agent 服务
|
|
|
-提供现代化的 HTTP API 接口
|
|
|
+使用 FastAPI + LangGraph 重构的 Agent 服务
|
|
|
+提供强大的工作流管理和状态控制
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
import sys
|
|
|
import os
|
|
|
import time
|
|
|
-from typing import Any, Dict, List, Optional
|
|
|
+from typing import Any, Dict, List, Optional, TypedDict, Annotated
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
# 保证可以导入本项目模块
|
|
@@ -20,11 +20,32 @@ from fastapi.responses import JSONResponse
|
|
|
from pydantic import BaseModel, Field
|
|
|
import uvicorn
|
|
|
|
|
|
+# LangGraph 相关导入
|
|
|
+try:
|
|
|
+ from langgraph.graph import StateGraph, END
|
|
|
+ HAS_LANGGRAPH = True
|
|
|
+except ImportError:
|
|
|
+ HAS_LANGGRAPH = False
|
|
|
+ print("警告: LangGraph 未安装,将使用传统模式")
|
|
|
+
|
|
|
from utils.logging_config import get_logger
|
|
|
from agent_tools import QueryDataTool, IdentifyTool, StructureTool
|
|
|
|
|
|
# 创建 logger
|
|
|
-logger = get_logger('AgentFastAPI')
|
|
|
+logger = get_logger('Agent')
|
|
|
+
|
|
|
+# 状态定义
|
|
|
+class AgentState(TypedDict):
|
|
|
+ request_id: str
|
|
|
+ items: List[Dict[str, Any]]
|
|
|
+ details: List[Dict[str, Any]]
|
|
|
+ processed: int
|
|
|
+ success: int
|
|
|
+ current_index: int
|
|
|
+ current_item: Optional[Dict[str, Any]]
|
|
|
+ identify_result: Optional[Dict[str, Any]]
|
|
|
+ error: Optional[str]
|
|
|
+ status: str
|
|
|
|
|
|
# 请求模型
|
|
|
class TriggerRequest(BaseModel):
|
|
@@ -56,20 +77,157 @@ async def lifespan(app: FastAPI):
|
|
|
# 创建 FastAPI 应用
|
|
|
app = FastAPI(
|
|
|
title="Knowledge Agent API",
|
|
|
- description="智能内容识别和结构化处理服务",
|
|
|
- version="1.0.0",
|
|
|
+ description="基于 LangGraph 的智能内容识别和结构化处理服务",
|
|
|
+ version="2.0.0",
|
|
|
lifespan=lifespan
|
|
|
)
|
|
|
|
|
|
+# =========================
|
|
|
+# LangGraph 工作流定义
|
|
|
+# =========================
|
|
|
+
|
|
|
+def create_langgraph_workflow():
|
|
|
+ """创建 LangGraph 工作流"""
|
|
|
+ if not HAS_LANGGRAPH:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 工作流节点定义
|
|
|
+
|
|
|
+ def fetch_data(state: AgentState) -> AgentState:
|
|
|
+ """获取待处理数据"""
|
|
|
+ try:
|
|
|
+ request_id = state["request_id"]
|
|
|
+ logger.info(f"开始获取数据: requestId={request_id}")
|
|
|
+
|
|
|
+ items = QueryDataTool.fetch_crawl_data_list(request_id)
|
|
|
+ state["items"] = items
|
|
|
+ state["processed"] = len(items)
|
|
|
+ state["status"] = "data_fetched"
|
|
|
+
|
|
|
+ logger.info(f"数据获取完成: requestId={request_id}, 数量={len(items)}")
|
|
|
+ return state
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"获取数据失败: {e}")
|
|
|
+ state["error"] = str(e)
|
|
|
+ state["status"] = "error"
|
|
|
+ return state
|
|
|
+
|
|
|
+ def process_item(state: AgentState) -> AgentState:
|
|
|
+ """处理单个数据项"""
|
|
|
+ try:
|
|
|
+ items = state["items"]
|
|
|
+ current_index = state.get("current_index", 0)
|
|
|
+
|
|
|
+ if current_index >= len(items):
|
|
|
+ state["status"] = "completed"
|
|
|
+ return state
|
|
|
+
|
|
|
+ item = items[current_index]
|
|
|
+ state["current_item"] = item
|
|
|
+ state["current_index"] = current_index + 1
|
|
|
+
|
|
|
+ # 处理当前项
|
|
|
+ crawl_data = item.get('crawl_data') or {}
|
|
|
+
|
|
|
+ # Step 1: 识别
|
|
|
+ identify_result = identify_tool.run(
|
|
|
+ crawl_data if isinstance(crawl_data, dict) else {}
|
|
|
+ )
|
|
|
+ state["identify_result"] = identify_result
|
|
|
+
|
|
|
+ # Step 2: 结构化并入库
|
|
|
+ affected = StructureTool.store_parsing_result(
|
|
|
+ state["request_id"],
|
|
|
+ item.get('raw') or {},
|
|
|
+ identify_result
|
|
|
+ )
|
|
|
+
|
|
|
+ ok = affected is not None and affected > 0
|
|
|
+ if ok:
|
|
|
+ state["success"] += 1
|
|
|
+
|
|
|
+ # 记录处理详情
|
|
|
+ detail = {
|
|
|
+ "index": current_index + 1,
|
|
|
+ "dbInserted": ok,
|
|
|
+ "identifyError": identify_result.get('error'),
|
|
|
+ "status": "success" if ok else "failed"
|
|
|
+ }
|
|
|
+ state["details"].append(detail)
|
|
|
+
|
|
|
+ state["status"] = "item_processed"
|
|
|
+ logger.info(f"处理进度: {current_index + 1}/{len(items)} - {'成功' if ok else '失败'}")
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"处理第 {current_index + 1} 项时出错: {e}")
|
|
|
+ detail = {
|
|
|
+ "index": current_index + 1,
|
|
|
+ "dbInserted": False,
|
|
|
+ "identifyError": str(e),
|
|
|
+ "status": "error"
|
|
|
+ }
|
|
|
+ state["details"].append(detail)
|
|
|
+ state["status"] = "item_error"
|
|
|
+ return state
|
|
|
+
|
|
|
+ def should_continue(state: AgentState) -> str:
|
|
|
+ """判断是否继续处理"""
|
|
|
+ if state.get("error"):
|
|
|
+ return "end"
|
|
|
+
|
|
|
+ current_index = state.get("current_index", 0)
|
|
|
+ items = state.get("items", [])
|
|
|
+
|
|
|
+ if current_index >= len(items):
|
|
|
+ return "end"
|
|
|
+
|
|
|
+ return "continue"
|
|
|
+
|
|
|
+ # 构建工作流图
|
|
|
+ workflow = StateGraph(AgentState)
|
|
|
+
|
|
|
+ # 添加节点
|
|
|
+ workflow.add_node("fetch_data", fetch_data)
|
|
|
+ workflow.add_node("process_item", process_item)
|
|
|
+
|
|
|
+ # 设置入口点
|
|
|
+ workflow.set_entry_point("fetch_data")
|
|
|
+
|
|
|
+ # 添加边
|
|
|
+ workflow.add_edge("fetch_data", "process_item")
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "process_item",
|
|
|
+ should_continue,
|
|
|
+ {
|
|
|
+ "continue": "process_item",
|
|
|
+ "end": END
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # 编译工作流
|
|
|
+ return workflow.compile()
|
|
|
+
|
|
|
+# 全局工作流实例
|
|
|
+WORKFLOW = create_langgraph_workflow() if HAS_LANGGRAPH else None
|
|
|
+
|
|
|
+# =========================
|
|
|
+# FastAPI 接口定义
|
|
|
+# =========================
|
|
|
+
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
"""根路径,返回服务信息"""
|
|
|
return {
|
|
|
"service": "Knowledge Agent API",
|
|
|
- "version": "1.0.0",
|
|
|
+ "version": "2.0.0",
|
|
|
"status": "running",
|
|
|
+ "langgraph_enabled": HAS_LANGGRAPH,
|
|
|
"endpoints": {
|
|
|
"parse": "/parse",
|
|
|
+ "parse/async": "/parse/async",
|
|
|
"health": "/health",
|
|
|
"docs": "/docs"
|
|
|
}
|
|
@@ -78,7 +236,11 @@ async def root():
|
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
|
"""健康检查接口"""
|
|
|
- return {"status": "healthy", "timestamp": time.time()}
|
|
|
+ return {
|
|
|
+ "status": "healthy",
|
|
|
+ "timestamp": time.time(),
|
|
|
+ "langgraph_enabled": HAS_LANGGRAPH
|
|
|
+ }
|
|
|
|
|
|
@app.post("/parse", response_model=TriggerResponse)
|
|
|
async def parse_processing(request: TriggerRequest, background_tasks: BackgroundTasks):
|
|
@@ -90,64 +252,100 @@ async def parse_processing(request: TriggerRequest, background_tasks: Background
|
|
|
try:
|
|
|
logger.info(f"收到解析请求: requestId={request.requestId}")
|
|
|
|
|
|
- # 获取待处理数据
|
|
|
- items = QueryDataTool.fetch_crawl_data_list(request.requestId)
|
|
|
- if not items:
|
|
|
- return TriggerResponse(
|
|
|
- requestId=request.requestId,
|
|
|
+ if WORKFLOW and HAS_LANGGRAPH:
|
|
|
+ # 使用 LangGraph 工作流
|
|
|
+ logger.info("使用 LangGraph 工作流处理")
|
|
|
+
|
|
|
+ # 初始化状态
|
|
|
+ initial_state = AgentState(
|
|
|
+ request_id=request.requestId,
|
|
|
+ items=[],
|
|
|
+ details=[],
|
|
|
processed=0,
|
|
|
success=0,
|
|
|
- details=[]
|
|
|
+ current_index=0,
|
|
|
+ current_item=None,
|
|
|
+ identify_result=None,
|
|
|
+ error=None,
|
|
|
+ status="started"
|
|
|
)
|
|
|
-
|
|
|
- # 处理数据
|
|
|
- success_count = 0
|
|
|
- details: List[Dict[str, Any]] = []
|
|
|
-
|
|
|
- for idx, item in enumerate(items, start=1):
|
|
|
- try:
|
|
|
- crawl_data = item.get('crawl_data') or {}
|
|
|
-
|
|
|
- # Step 1: 识别
|
|
|
- identify_result = identify_tool.run(
|
|
|
- crawl_data if isinstance(crawl_data, dict) else {}
|
|
|
- )
|
|
|
-
|
|
|
- # Step 2: 结构化并入库
|
|
|
- affected = StructureTool.store_parsing_result(
|
|
|
- request.requestId,
|
|
|
- item.get('raw') or {},
|
|
|
- identify_result
|
|
|
+
|
|
|
+ # 执行工作流
|
|
|
+ final_state = WORKFLOW.invoke(
|
|
|
+ initial_state,
|
|
|
+ config={"configurable": {"thread_id": f"thread_{request.requestId}"}}
|
|
|
+ )
|
|
|
+
|
|
|
+ # 构建响应
|
|
|
+ result = TriggerResponse(
|
|
|
+ requestId=request.requestId,
|
|
|
+ processed=final_state.get("processed", 0),
|
|
|
+ success=final_state.get("success", 0),
|
|
|
+ details=final_state.get("details", [])
|
|
|
+ )
|
|
|
+
|
|
|
+ else:
|
|
|
+ # 回退到传统模式
|
|
|
+ logger.info("使用传统模式处理")
|
|
|
+
|
|
|
+ # 获取待处理数据
|
|
|
+ items = QueryDataTool.fetch_crawl_data_list(request.requestId)
|
|
|
+ if not items:
|
|
|
+ return TriggerResponse(
|
|
|
+ requestId=request.requestId,
|
|
|
+ processed=0,
|
|
|
+ success=0,
|
|
|
+ details=[]
|
|
|
)
|
|
|
-
|
|
|
- ok = affected is not None and affected > 0
|
|
|
- if ok:
|
|
|
- success_count += 1
|
|
|
-
|
|
|
- details.append({
|
|
|
- "index": idx,
|
|
|
- "dbInserted": ok,
|
|
|
- "identifyError": identify_result.get('error'),
|
|
|
- "status": "success" if ok else "failed"
|
|
|
- })
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"处理第 {idx} 项时出错: {e}")
|
|
|
- details.append({
|
|
|
- "index": idx,
|
|
|
- "dbInserted": False,
|
|
|
- "identifyError": str(e),
|
|
|
- "status": "error"
|
|
|
- })
|
|
|
-
|
|
|
- result = TriggerResponse(
|
|
|
- requestId=request.requestId,
|
|
|
- processed=len(items),
|
|
|
- success=success_count,
|
|
|
- details=details
|
|
|
- )
|
|
|
+
|
|
|
+ # 处理数据
|
|
|
+ success_count = 0
|
|
|
+ details: List[Dict[str, Any]] = []
|
|
|
+
|
|
|
+ for idx, item in enumerate(items, start=1):
|
|
|
+ try:
|
|
|
+ crawl_data = item.get('crawl_data') or {}
|
|
|
+
|
|
|
+ # Step 1: 识别
|
|
|
+ identify_result = identify_tool.run(
|
|
|
+ crawl_data if isinstance(crawl_data, dict) else {}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Step 2: 结构化并入库
|
|
|
+ affected = StructureTool.store_parsing_result(
|
|
|
+ request.requestId,
|
|
|
+ item.get('raw') or {},
|
|
|
+ identify_result
|
|
|
+ )
|
|
|
+
|
|
|
+ ok = affected is not None and affected > 0
|
|
|
+ if ok:
|
|
|
+ success_count += 1
|
|
|
+
|
|
|
+ details.append({
|
|
|
+ "index": idx,
|
|
|
+ "dbInserted": ok,
|
|
|
+ "identifyError": identify_result.get('error'),
|
|
|
+ "status": "success" if ok else "failed"
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"处理第 {idx} 项时出错: {e}")
|
|
|
+ details.append({
|
|
|
+ "index": idx,
|
|
|
+ "dbInserted": False,
|
|
|
+ "identifyError": str(e),
|
|
|
+ "status": "error"
|
|
|
+ })
|
|
|
+
|
|
|
+ result = TriggerResponse(
|
|
|
+ requestId=request.requestId,
|
|
|
+ processed=len(items),
|
|
|
+ success=success_count,
|
|
|
+ details=details
|
|
|
+ )
|
|
|
|
|
|
- logger.info(f"处理完成: requestId={request.requestId}, processed={len(items)}, success={success_count}")
|
|
|
+ logger.info(f"处理完成: requestId={request.requestId}, processed={result.processed}, success={result.success}")
|
|
|
return result
|
|
|
|
|
|
except Exception as e:
|
|
@@ -170,7 +368,8 @@ async def parse_processing_async(request: TriggerRequest, background_tasks: Back
|
|
|
return {
|
|
|
"requestId": request.requestId,
|
|
|
"status": "processing",
|
|
|
- "message": "任务已提交到后台处理"
|
|
|
+ "message": "任务已提交到后台处理",
|
|
|
+ "langgraph_enabled": HAS_LANGGRAPH
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
@@ -182,39 +381,58 @@ async def process_request_background(request_id: str):
|
|
|
try:
|
|
|
logger.info(f"开始后台处理: requestId={request_id}")
|
|
|
|
|
|
- # 获取待处理数据
|
|
|
- items = QueryDataTool.fetch_crawl_data_list(request_id)
|
|
|
- if not items:
|
|
|
- logger.info(f"后台处理完成: requestId={request_id}, 无数据需要处理")
|
|
|
- return
|
|
|
-
|
|
|
- # 处理数据
|
|
|
- success_count = 0
|
|
|
- for idx, item in enumerate(items, start=1):
|
|
|
- try:
|
|
|
- crawl_data = item.get('crawl_data') or {}
|
|
|
-
|
|
|
- # Step 1: 识别
|
|
|
- identify_result = identify_tool.run(
|
|
|
- crawl_data if isinstance(crawl_data, dict) else {}
|
|
|
- )
|
|
|
-
|
|
|
- # Step 2: 结构化并入库
|
|
|
- affected = StructureTool.store_parsing_result(
|
|
|
- request_id,
|
|
|
- item.get('raw') or {},
|
|
|
- identify_result
|
|
|
- )
|
|
|
-
|
|
|
- if affected is not None and affected > 0:
|
|
|
- success_count += 1
|
|
|
-
|
|
|
- logger.info(f"后台处理进度: {idx}/{len(items)} - {'成功' if affected else '失败'}")
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"后台处理第 {idx} 项时出错: {e}")
|
|
|
-
|
|
|
- logger.info(f"后台处理完成: requestId={request_id}, processed={len(items)}, success={success_count}")
|
|
|
+ if WORKFLOW and HAS_LANGGRAPH:
|
|
|
+ # 使用 LangGraph 工作流
|
|
|
+ initial_state = AgentState(
|
|
|
+ request_id=request_id,
|
|
|
+ items=[],
|
|
|
+ details=[],
|
|
|
+ processed=0,
|
|
|
+ success=0,
|
|
|
+ current_index=0,
|
|
|
+ current_item=None,
|
|
|
+ identify_result=None,
|
|
|
+ error=None,
|
|
|
+ status="started"
|
|
|
+ )
|
|
|
+
|
|
|
+ final_state = WORKFLOW.invoke(
|
|
|
+ initial_state,
|
|
|
+ config={"configurable": {"thread_id": f"thread_{request_id}"}}
|
|
|
+ )
|
|
|
+ logger.info(f"LangGraph 后台处理完成: requestId={request_id}, processed={final_state.get('processed', 0)}, success={final_state.get('success', 0)}")
|
|
|
+
|
|
|
+ else:
|
|
|
+ # 传统模式
|
|
|
+ items = QueryDataTool.fetch_crawl_data_list(request_id)
|
|
|
+ if not items:
|
|
|
+ logger.info(f"后台处理完成: requestId={request_id}, 无数据需要处理")
|
|
|
+ return
|
|
|
+
|
|
|
+ success_count = 0
|
|
|
+ for idx, item in enumerate(items, start=1):
|
|
|
+ try:
|
|
|
+ crawl_data = item.get('crawl_data') or {}
|
|
|
+
|
|
|
+ identify_result = identify_tool.run(
|
|
|
+ crawl_data if isinstance(crawl_data, dict) else {}
|
|
|
+ )
|
|
|
+
|
|
|
+ affected = StructureTool.store_parsing_result(
|
|
|
+ request_id,
|
|
|
+ item.get('raw') or {},
|
|
|
+ identify_result
|
|
|
+ )
|
|
|
+
|
|
|
+ if affected is not None and affected > 0:
|
|
|
+ success_count += 1
|
|
|
+
|
|
|
+ logger.info(f"后台处理进度: {idx}/{len(items)} - {'成功' if affected else '失败'}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"后台处理第 {idx} 项时出错: {e}")
|
|
|
+
|
|
|
+ logger.info(f"传统模式后台处理完成: requestId={request_id}, processed={len(items)}, success={success_count}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"后台处理失败: requestId={request_id}, error={e}")
|