agent.py 15 KB


  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 使用 FastAPI + LangGraph 重构的 Agent 服务
  5. 提供强大的工作流管理和状态控制
  6. """
  7. import json
  8. import sys
  9. import os
  10. import time
  11. from typing import Any, Dict, List, Optional, TypedDict, Annotated
  12. from contextlib import asynccontextmanager
  13. # 保证可以导入本项目模块
  14. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  15. # 禁用 LangSmith 追踪,避免网络连接错误
  16. os.environ["LANGCHAIN_TRACING_V2"] = "false"
  17. os.environ["LANGCHAIN_ENDPOINT"] = ""
  18. os.environ["LANGCHAIN_API_KEY"] = ""
  19. from fastapi import FastAPI, HTTPException, BackgroundTasks
  20. from fastapi.responses import JSONResponse
  21. from pydantic import BaseModel, Field
  22. import uvicorn
  23. # LangGraph 相关导入
  24. try:
  25. from langgraph.graph import StateGraph, END
  26. HAS_LANGGRAPH = True
  27. except ImportError:
  28. HAS_LANGGRAPH = False
  29. print("警告: LangGraph 未安装,将使用传统模式")
  30. from utils.logging_config import get_logger
  31. from agent_tools import QueryDataTool, IdentifyTool, StructureTool
  32. # 创建 logger
  33. logger = get_logger('Agent')
  34. # 状态定义
  35. class AgentState(TypedDict):
  36. request_id: str
  37. items: List[Dict[str, Any]]
  38. details: List[Dict[str, Any]]
  39. processed: int
  40. success: int
  41. current_index: int
  42. current_item: Optional[Dict[str, Any]]
  43. identify_result: Optional[Dict[str, Any]]
  44. error: Optional[str]
  45. status: str
  46. # 请求模型
  47. class TriggerRequest(BaseModel):
  48. requestId: str = Field(..., description="请求ID")
  49. # 响应模型
  50. class TriggerResponse(BaseModel):
  51. requestId: str
  52. processed: int
  53. success: int
  54. details: List[Dict[str, Any]]
  55. # 全局变量
  56. identify_tool = None
  57. @asynccontextmanager
  58. async def lifespan(app: FastAPI):
  59. """应用生命周期管理"""
  60. # 启动时初始化
  61. global identify_tool
  62. identify_tool = IdentifyTool()
  63. logger.info("Agent 服务启动完成")
  64. yield
  65. # 关闭时清理
  66. logger.info("Agent 服务正在关闭")
  67. # 创建 FastAPI 应用
  68. app = FastAPI(
  69. title="Knowledge Agent API",
  70. description="基于 LangGraph 的智能内容识别和结构化处理服务",
  71. version="2.0.0",
  72. lifespan=lifespan
  73. )
  74. # =========================
  75. # LangGraph 工作流定义
  76. # =========================
  77. def create_langgraph_workflow():
  78. """创建 LangGraph 工作流"""
  79. if not HAS_LANGGRAPH:
  80. return None
  81. # 工作流节点定义
  82. def fetch_data(state: AgentState) -> AgentState:
  83. """获取待处理数据"""
  84. try:
  85. request_id = state["request_id"]
  86. logger.info(f"开始获取数据: requestId={request_id}")
  87. items = QueryDataTool.fetch_crawl_data_list(request_id)
  88. state["items"] = items
  89. state["processed"] = len(items)
  90. state["status"] = "data_fetched"
  91. logger.info(f"数据获取完成: requestId={request_id}, 数量={len(items)}")
  92. return state
  93. except Exception as e:
  94. logger.error(f"获取数据失败: {e}")
  95. state["error"] = str(e)
  96. state["status"] = "error"
  97. return state
  98. def process_item(state: AgentState) -> AgentState:
  99. """处理单个数据项"""
  100. try:
  101. items = state["items"]
  102. current_index = state.get("current_index", 0)
  103. if current_index >= len(items):
  104. state["status"] = "completed"
  105. return state
  106. item = items[current_index]
  107. state["current_item"] = item
  108. state["content_id"] = item.get('content_id') or ''
  109. state["task_id"] = item.get('task_id') or ''
  110. state["current_index"] = current_index + 1
  111. # 处理当前项
  112. crawl_data = item.get('crawl_data') or {}
  113. # Step 1: 识别
  114. identify_result = identify_tool.run(
  115. crawl_data if isinstance(crawl_data, dict) else {}
  116. )
  117. state["identify_result"] = identify_result
  118. # Step 2: 结构化并入库
  119. affected = StructureTool.store_parsing_result(
  120. state["request_id"],
  121. {
  122. "content_id": state["content_id"],
  123. "task_id": state["task_id"]
  124. },
  125. identify_result
  126. )
  127. ok = affected is not None and affected > 0
  128. if ok:
  129. state["success"] += 1
  130. # 记录处理详情
  131. detail = {
  132. "index": current_index + 1,
  133. "dbInserted": ok,
  134. "identifyError": identify_result.get('error'),
  135. "status": "success" if ok else "failed"
  136. }
  137. state["details"].append(detail)
  138. state["status"] = "item_processed"
  139. logger.info(f"处理进度: {current_index + 1}/{len(items)} - {'成功' if ok else '失败'}")
  140. return state
  141. except Exception as e:
  142. logger.error(f"处理第 {current_index + 1} 项时出错: {e}")
  143. detail = {
  144. "index": current_index + 1,
  145. "dbInserted": False,
  146. "identifyError": str(e),
  147. "status": "error"
  148. }
  149. state["details"].append(detail)
  150. state["status"] = "item_error"
  151. return state
  152. def should_continue(state: AgentState) -> str:
  153. """判断是否继续处理"""
  154. if state.get("error"):
  155. return "end"
  156. current_index = state.get("current_index", 0)
  157. items = state.get("items", [])
  158. if current_index >= len(items):
  159. return "end"
  160. return "continue"
  161. # 构建工作流图
  162. workflow = StateGraph(AgentState)
  163. # 添加节点
  164. workflow.add_node("fetch_data", fetch_data)
  165. workflow.add_node("process_item", process_item)
  166. # 设置入口点
  167. workflow.set_entry_point("fetch_data")
  168. # 添加边
  169. workflow.add_edge("fetch_data", "process_item")
  170. workflow.add_conditional_edges(
  171. "process_item",
  172. should_continue,
  173. {
  174. "continue": "process_item",
  175. "end": END
  176. }
  177. )
  178. # 编译工作流,禁用 LangSmith 追踪
  179. return workflow.compile()
  180. # 全局工作流实例
  181. WORKFLOW = create_langgraph_workflow() if HAS_LANGGRAPH else None
  182. # =========================
  183. # FastAPI 接口定义
  184. # =========================
  185. @app.get("/")
  186. async def root():
  187. """根路径,返回服务信息"""
  188. return {
  189. "service": "Knowledge Agent API",
  190. "version": "2.0.0",
  191. "status": "running",
  192. "langgraph_enabled": HAS_LANGGRAPH,
  193. "endpoints": {
  194. "parse": "/parse",
  195. "parse/async": "/parse/async",
  196. "health": "/health",
  197. "docs": "/docs"
  198. }
  199. }
  200. @app.get("/health")
  201. async def health_check():
  202. """健康检查接口"""
  203. return {
  204. "status": "healthy",
  205. "timestamp": time.time(),
  206. "langgraph_enabled": HAS_LANGGRAPH
  207. }
  208. @app.post("/parse", response_model=TriggerResponse)
  209. async def parse_processing(request: TriggerRequest, background_tasks: BackgroundTasks):
  210. """
  211. 解析内容处理
  212. - **requestId**: 请求ID,用于标识处理任务
  213. """
  214. try:
  215. logger.info(f"收到解析请求: requestId={request.requestId}")
  216. if WORKFLOW and HAS_LANGGRAPH:
  217. # 使用 LangGraph 工作流
  218. logger.info("使用 LangGraph 工作流处理")
  219. # 初始化状态
  220. initial_state = AgentState(
  221. request_id=request.requestId,
  222. items=[],
  223. details=[],
  224. processed=0,
  225. success=0,
  226. current_index=0,
  227. current_item=None,
  228. identify_result=None,
  229. error=None,
  230. status="started"
  231. )
  232. # 执行工作流
  233. final_state = WORKFLOW.invoke(
  234. initial_state,
  235. config={"configurable": {"thread_id": f"thread_{request.requestId}"}}
  236. )
  237. # 构建响应
  238. result = TriggerResponse(
  239. requestId=request.requestId,
  240. processed=final_state.get("processed", 0),
  241. success=final_state.get("success", 0),
  242. details=final_state.get("details", [])
  243. )
  244. else:
  245. # 回退到传统模式
  246. logger.info("使用传统模式处理")
  247. # 获取待处理数据
  248. items = QueryDataTool.fetch_crawl_data_list(request.requestId)
  249. if not items:
  250. return TriggerResponse(
  251. requestId=request.requestId,
  252. processed=0,
  253. success=0,
  254. details=[]
  255. )
  256. # 处理数据
  257. success_count = 0
  258. details: List[Dict[str, Any]] = []
  259. for idx, item in enumerate(items, start=1):
  260. try:
  261. crawl_data = item.get('crawl_data') or {}
  262. # Step 1: 识别
  263. identify_result = identify_tool.run(
  264. crawl_data if isinstance(crawl_data, dict) else {}
  265. )
  266. # Step 2: 结构化并入库
  267. affected = StructureTool.store_parsing_result(
  268. request.requestId,
  269. {
  270. content_id: item.get('content_id') or '',
  271. task_id: item.get('task_id') or ''
  272. },
  273. identify_result
  274. )
  275. ok = affected is not None and affected > 0
  276. if ok:
  277. success_count += 1
  278. details.append({
  279. "index": idx,
  280. "dbInserted": ok,
  281. "identifyError": identify_result.get('error'),
  282. "status": "success" if ok else "failed"
  283. })
  284. except Exception as e:
  285. logger.error(f"处理第 {idx} 项时出错: {e}")
  286. details.append({
  287. "index": idx,
  288. "dbInserted": False,
  289. "identifyError": str(e),
  290. "status": "error"
  291. })
  292. result = TriggerResponse(
  293. requestId=request.requestId,
  294. processed=len(items),
  295. success=success_count,
  296. details=details
  297. )
  298. logger.info(f"处理完成: requestId={request.requestId}, processed={result.processed}, success={result.success}")
  299. return result
  300. except Exception as e:
  301. logger.error(f"处理请求失败: {e}")
  302. raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
  303. @app.post("/parse/async")
  304. async def parse_processing_async(request: TriggerRequest, background_tasks: BackgroundTasks):
  305. """
  306. 异步解析内容处理(后台任务)
  307. - **requestId**: 请求ID,用于标识处理任务
  308. """
  309. try:
  310. logger.info(f"收到异步解析请求: requestId={request.requestId}")
  311. # 添加后台任务
  312. background_tasks.add_task(process_request_background, request.requestId)
  313. return {
  314. "requestId": request.requestId,
  315. "status": "processing",
  316. "message": "任务已提交到后台处理",
  317. "langgraph_enabled": HAS_LANGGRAPH
  318. }
  319. except Exception as e:
  320. logger.error(f"提交异步任务失败: {e}")
  321. raise HTTPException(status_code=500, detail=f"提交任务失败: {str(e)}")
  322. async def process_request_background(request_id: str):
  323. """后台处理请求"""
  324. try:
  325. logger.info(f"开始后台处理: requestId={request_id}")
  326. if WORKFLOW and HAS_LANGGRAPH:
  327. # 使用 LangGraph 工作流
  328. initial_state = AgentState(
  329. request_id=request_id,
  330. items=[],
  331. details=[],
  332. processed=0,
  333. success=0,
  334. current_index=0,
  335. current_item=None,
  336. identify_result=None,
  337. error=None,
  338. status="started"
  339. )
  340. final_state = WORKFLOW.invoke(
  341. initial_state,
  342. config={"configurable": {"thread_id": f"thread_{request_id}"}}
  343. )
  344. logger.info(f"LangGraph 后台处理完成: requestId={request_id}, processed={final_state.get('processed', 0)}, success={final_state.get('success', 0)}")
  345. else:
  346. # 传统模式
  347. items = QueryDataTool.fetch_crawl_data_list(request_id)
  348. if not items:
  349. logger.info(f"后台处理完成: requestId={request_id}, 无数据需要处理")
  350. return
  351. success_count = 0
  352. for idx, item in enumerate(items, start=1):
  353. try:
  354. crawl_data = item.get('crawl_data') or {}
  355. content_id = item.get('content_id') or ''
  356. identify_result = identify_tool.run(
  357. crawl_data if isinstance(crawl_data, dict) else {}
  358. )
  359. affected = StructureTool.store_parsing_result(
  360. request_id,
  361. {
  362. content_id: item.get('content_id') or '',
  363. task_id: item.get('task_id') or ''
  364. },
  365. identify_result
  366. )
  367. if affected is not None and affected > 0:
  368. success_count += 1
  369. logger.info(f"后台处理进度: {idx}/{len(items)} - {'成功' if affected else '失败'}")
  370. except Exception as e:
  371. logger.error(f"后台处理第 {idx} 项时出错: {e}")
  372. logger.info(f"传统模式后台处理完成: requestId={request_id}, processed={len(items)}, success={success_count}")
  373. except Exception as e:
  374. logger.error(f"后台处理失败: requestId={request_id}, error={e}")
  375. if __name__ == "__main__":
  376. # 启动服务
  377. uvicorn.run(
  378. "agent:app",
  379. host="0.0.0.0",
  380. port=8080,
  381. reload=True, # 开发模式,自动重载
  382. log_level="info"
  383. )