agent.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 使用 FastAPI + LangGraph 重构的 Agent 服务
  5. 提供强大的工作流管理和状态控制
  6. """
  7. import json
  8. import sys
  9. import os
  10. import time
  11. from typing import Any, Dict, List, Optional, TypedDict, Annotated
  12. from contextlib import asynccontextmanager
  13. import asyncio
  14. # 保证可以导入本项目模块
  15. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  16. from fastapi import FastAPI, HTTPException, BackgroundTasks
  17. from fastapi.responses import JSONResponse
  18. from pydantic import BaseModel, Field
  19. import uvicorn
  20. from agents.clean_agent.agent import execute_agent_with_api
  21. from agents.expand_agent.agent import execute_expand_agent_with_api, _update_expansion_status
  22. # LangGraph 相关导入
  23. try:
  24. from langgraph.graph import StateGraph, END
  25. HAS_LANGGRAPH = True
  26. except ImportError:
  27. HAS_LANGGRAPH = False
  28. print("警告: LangGraph 未安装,将使用传统模式")
  29. from utils.logging_config import get_logger
  30. from tools.agent_tools import QueryDataTool, IdentifyTool, UpdateDataTool, StructureTool
  31. # 创建 logger
  32. logger = get_logger('Agent')
  33. # 状态定义
  34. class AgentState(TypedDict):
  35. request_id: str
  36. items: List[Dict[str, Any]]
  37. details: List[Dict[str, Any]]
  38. processed: int
  39. success: int
  40. current_index: int
  41. current_item: Optional[Dict[str, Any]]
  42. identify_result: Optional[Dict[str, Any]]
  43. error: Optional[str]
  44. status: str
  45. # 请求模型
  46. class TriggerRequest(BaseModel):
  47. requestId: str = Field(..., description="请求ID")
  48. # 响应模型
  49. class TriggerResponse(BaseModel):
  50. requestId: str
  51. processed: int
  52. success: int
  53. details: List[Dict[str, Any]]
  54. class ExpandRequest(BaseModel):
  55. requestId: str = Field(..., description="扩展查询请求ID")
  56. # 全局变量
  57. identify_tool = None
  58. def update_request_status(request_id: str, status: int):
  59. """
  60. 更新 knowledge_request 表中的 parsing_status
  61. Args:
  62. request_id: 请求ID
  63. status: 状态值 (1: 处理中, 2: 处理完成, 3: 处理失败)
  64. """
  65. try:
  66. from utils.mysql_db import MysqlHelper
  67. sql = "UPDATE knowledge_request SET parsing_status = %s WHERE request_id = %s"
  68. result = MysqlHelper.update_values(sql, (status, request_id))
  69. if result is not None:
  70. logger.info(f"更新请求状态成功: requestId={request_id}, status={status}")
  71. else:
  72. logger.error(f"更新请求状态失败: requestId={request_id}, status={status}")
  73. except Exception as e:
  74. logger.error(f"更新请求状态异常: requestId={request_id}, status={status}, error={e}")
  75. def _update_expansion_status(requestId: str, status: int):
  76. """更新扩展查询状态"""
  77. try:
  78. from utils.mysql_db import MysqlHelper
  79. sql = "UPDATE knowledge_request SET expansion_status = %s WHERE request_id = %s"
  80. MysqlHelper.update_values(sql, (status, requestId))
  81. logger.info(f"更新扩展查询状态成功: requestId={requestId}, status={status}")
  82. except Exception as e:
  83. logger.error(f"更新扩展查询状态失败: requestId={requestId}, status={status}, error={e}")
  84. @asynccontextmanager
  85. async def lifespan(app: FastAPI):
  86. """应用生命周期管理"""
  87. # 启动时初始化
  88. global identify_tool
  89. identify_tool = IdentifyTool()
  90. logger.info("Agent 服务启动完成")
  91. yield
  92. # 关闭时清理
  93. logger.info("Agent 服务正在关闭")
  94. # 创建 FastAPI 应用
  95. app = FastAPI(
  96. title="Knowledge Agent API",
  97. description="基于 LangGraph 的智能内容识别和结构化处理服务",
  98. version="2.0.0",
  99. lifespan=lifespan
  100. )
  101. # 并发控制:跟踪正在处理的 requestId,防止重复并发提交
  102. RUNNING_REQUESTS: set = set()
  103. RUNNING_LOCK = asyncio.Lock()
  104. # =========================
  105. # LangGraph 工作流定义
  106. # =========================
  107. def create_langgraph_workflow():
  108. """创建 LangGraph 工作流"""
  109. if not HAS_LANGGRAPH:
  110. return None
  111. # 工作流节点定义
  112. def fetch_data(state: AgentState) -> AgentState:
  113. """获取待处理数据"""
  114. try:
  115. request_id = state["request_id"]
  116. logger.info(f"开始获取数据: requestId={request_id}")
  117. # 更新状态为处理中
  118. update_request_status(request_id, 1)
  119. items = QueryDataTool.fetch_crawl_data_list(request_id)
  120. state["items"] = items
  121. state["processed"] = len(items)
  122. state["status"] = "data_fetched"
  123. logger.info(f"数据获取完成: requestId={request_id}, 数量={len(items)}")
  124. return state
  125. except Exception as e:
  126. logger.error(f"获取数据失败: {e}")
  127. state["error"] = str(e)
  128. state["status"] = "error"
  129. return state
  130. def process_item(state: AgentState) -> AgentState:
  131. """处理单个数据项"""
  132. try:
  133. items = state["items"]
  134. current_index = state.get("current_index", 0)
  135. if current_index >= len(items):
  136. state["status"] = "completed"
  137. return state
  138. item = items[current_index]
  139. state["current_item"] = item
  140. state["content_id"] = item.get('content_id') or ''
  141. state["task_id"] = item.get('task_id') or ''
  142. state["current_index"] = current_index + 1
  143. # 处理当前项
  144. crawl_data = item.get('crawl_data') or {}
  145. # Step 1: 识别
  146. identify_result = identify_tool.run(
  147. crawl_data if isinstance(crawl_data, dict) else {}
  148. )
  149. state["identify_result"] = identify_result
  150. # Step 2: 结构化并入库
  151. affected = UpdateDataTool.store_indentify_result(
  152. state["request_id"],
  153. {
  154. "content_id": state["content_id"],
  155. "task_id": state["task_id"]
  156. },
  157. identify_result
  158. )
  159. # 使用StructureTool进行内容结构化处理
  160. structure_tool = StructureTool()
  161. structure_result = structure_tool.process_content_structure(identify_result)
  162. # 存储结构化解析结果
  163. parsing_affected = UpdateDataTool.store_parsing_result(
  164. state["request_id"],
  165. {
  166. "content_id": state["content_id"],
  167. "task_id": state["task_id"]
  168. },
  169. structure_result
  170. )
  171. ok = affected is not None and affected > 0 and parsing_affected is not None and parsing_affected > 0
  172. if ok:
  173. state["success"] += 1
  174. # 记录处理详情
  175. detail = {
  176. "index": current_index + 1,
  177. "dbInserted": ok,
  178. "identifyError": identify_result.get('error'),
  179. "status": 2 if ok else 3
  180. }
  181. state["details"].append(detail)
  182. state["status"] = "item_processed"
  183. logger.info(f"处理进度: {current_index + 1}/{len(items)} - {'成功' if ok else '失败'}")
  184. return state
  185. except Exception as e:
  186. logger.error(f"处理第 {current_index + 1} 项时出错: {e}")
  187. detail = {
  188. "index": current_index + 1,
  189. "dbInserted": False,
  190. "identifyError": str(e),
  191. "status": 3
  192. }
  193. state["details"].append(detail)
  194. state["status"] = "item_error"
  195. return state
  196. def should_continue(state: AgentState) -> str:
  197. """判断是否继续处理"""
  198. if state.get("error"):
  199. # 处理失败,更新状态为3
  200. update_request_status(state["request_id"], 3)
  201. return "end"
  202. current_index = state.get("current_index", 0)
  203. items = state.get("items", [])
  204. if current_index >= len(items):
  205. # 所有数据处理完毕,更新状态为2
  206. update_request_status(state["request_id"], 2)
  207. return "end"
  208. return "continue"
  209. # 构建工作流图
  210. workflow = StateGraph(AgentState)
  211. # 添加节点
  212. workflow.add_node("fetch_data", fetch_data)
  213. workflow.add_node("process_item", process_item)
  214. # 设置入口点
  215. workflow.set_entry_point("fetch_data")
  216. # 添加边
  217. workflow.add_edge("fetch_data", "process_item")
  218. workflow.add_conditional_edges(
  219. "process_item",
  220. should_continue,
  221. {
  222. "continue": "process_item",
  223. "end": END
  224. }
  225. )
  226. # 编译工作流
  227. return workflow.compile()
  228. # 全局工作流实例
  229. WORKFLOW = create_langgraph_workflow() if HAS_LANGGRAPH else None
  230. # =========================
  231. # FastAPI 接口定义
  232. # =========================
  233. @app.get("/")
  234. async def root():
  235. """根路径,返回服务信息"""
  236. return {
  237. "service": "Knowledge Agent API",
  238. "version": "2.0.0",
  239. "status": "running",
  240. "langgraph_enabled": HAS_LANGGRAPH,
  241. "endpoints": {
  242. "parse": "/parse",
  243. "parse/async": "/parse/async",
  244. "health": "/health",
  245. "docs": "/docs"
  246. }
  247. }
  248. @app.get("/health")
  249. async def health_check():
  250. """健康检查接口"""
  251. return {
  252. "status": "healthy",
  253. "timestamp": time.time(),
  254. "langgraph_enabled": HAS_LANGGRAPH
  255. }
  256. @app.post("/parse", response_model=TriggerResponse)
  257. async def parse_processing(request: TriggerRequest, background_tasks: BackgroundTasks):
  258. """
  259. 解析内容处理
  260. - **requestId**: 请求ID,用于标识处理任务
  261. """
  262. try:
  263. logger.info(f"收到解析请求: requestId={request.requestId}")
  264. if WORKFLOW and HAS_LANGGRAPH:
  265. # 使用 LangGraph 工作流
  266. logger.info("使用 LangGraph 工作流处理")
  267. # 初始化状态
  268. initial_state = AgentState(
  269. request_id=request.requestId,
  270. items=[],
  271. details=[],
  272. processed=0,
  273. success=0,
  274. current_index=0,
  275. current_item=None,
  276. identify_result=None,
  277. error=None,
  278. status="started"
  279. )
  280. # 执行工作流
  281. final_state = WORKFLOW.invoke(
  282. initial_state,
  283. config={"configurable": {"thread_id": f"thread_{request.requestId}"}}
  284. )
  285. # 构建响应
  286. result = TriggerResponse(
  287. requestId=request.requestId,
  288. processed=final_state.get("processed", 0),
  289. success=final_state.get("success", 0),
  290. details=final_state.get("details", [])
  291. )
  292. else:
  293. # 回退到传统模式
  294. logger.info("使用传统模式处理")
  295. # 更新状态为处理中
  296. update_request_status(request.requestId, 1)
  297. # 获取待处理数据
  298. items = QueryDataTool.fetch_crawl_data_list(request.requestId)
  299. print(f"传统模式---items: {items}")
  300. if not items:
  301. # 无数据需要处理,更新状态为完成
  302. update_request_status(request.requestId, 2)
  303. return TriggerResponse(
  304. requestId=request.requestId,
  305. processed=0,
  306. success=0,
  307. details=[]
  308. )
  309. # 处理数据
  310. success_count = 0
  311. details: List[Dict[str, Any]] = []
  312. for idx, item in enumerate(items, start=1):
  313. try:
  314. crawl_data = item.get('crawl_data') or {}
  315. # Step 1: 识别
  316. identify_result = identify_tool.run(
  317. crawl_data if isinstance(crawl_data, dict) else {}
  318. )
  319. # Step 2: 结构化并入库
  320. affected = UpdateDataTool.store_indentify_result(
  321. request.requestId,
  322. {
  323. content_id: item.get('content_id') or '',
  324. task_id: item.get('task_id') or ''
  325. },
  326. identify_result
  327. )
  328. # 使用StructureTool进行内容结构化处理
  329. structure_tool = StructureTool()
  330. structure_result = structure_tool.process_content_structure(identify_result)
  331. # 存储结构化解析结果
  332. parsing_affected = UpdateDataTool.store_parsing_result(
  333. request.requestId,
  334. {
  335. content_id: item.get('content_id') or '',
  336. task_id: item.get('task_id') or ''
  337. },
  338. structure_result
  339. )
  340. ok = affected is not None and affected > 0
  341. if ok:
  342. success_count += 1
  343. details.append({
  344. "index": idx,
  345. "dbInserted": ok,
  346. "identifyError": identify_result.get('error'),
  347. "status": 2 if ok else 3
  348. })
  349. except Exception as e:
  350. logger.error(f"处理第 {idx} 项时出错: {e}")
  351. details.append({
  352. "index": idx,
  353. "dbInserted": False,
  354. "identifyError": str(e),
  355. "status": 3
  356. })
  357. result = TriggerResponse(
  358. requestId=request.requestId,
  359. processed=len(items),
  360. success=success_count,
  361. details=details
  362. )
  363. # 更新状态为处理完成
  364. update_request_status(request.requestId, 2)
  365. logger.info(f"处理完成: requestId={request.requestId}, processed={result.processed}, success={result.success}")
  366. return result
  367. except Exception as e:
  368. logger.error(f"处理请求失败: {e}")
  369. # 处理失败,更新状态为3
  370. update_request_status(request.requestId, 3)
  371. raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
  372. @app.post("/parse/async", status_code=200)
  373. async def parse_processing_async(request: TriggerRequest, background_tasks: BackgroundTasks):
  374. """
  375. 异步解析内容处理(后台任务)
  376. - **requestId**: 请求ID,用于标识处理任务
  377. 行为:立即返回 200,并在后台继续处理任务。
  378. 若同一个 requestId 已有任务进行中,则立即返回失败(status=3)。
  379. """
  380. try:
  381. logger.info(f"收到异步解析请求: requestId={request.requestId}")
  382. # 并发防抖:同一 requestId 只允许一个在运行
  383. async with RUNNING_LOCK:
  384. if request.requestId in RUNNING_REQUESTS:
  385. return {
  386. "requestId": request.requestId,
  387. "status": 3,
  388. "message": "已有任务进行中,稍后再试",
  389. "langgraph_enabled": HAS_LANGGRAPH
  390. }
  391. RUNNING_REQUESTS.add(request.requestId)
  392. async def _background_wrapper(rid: str):
  393. try:
  394. await process_request_background(rid)
  395. finally:
  396. async with RUNNING_LOCK:
  397. RUNNING_REQUESTS.discard(rid)
  398. # 直接使用 asyncio 创建后台任务(不阻塞当前请求返回)
  399. asyncio.create_task(_background_wrapper(request.requestId))
  400. # 立即返回(不阻塞)
  401. return {
  402. "requestId": request.requestId,
  403. "status": 1,
  404. "message": "任务已进入队列并在后台处理",
  405. "langgraph_enabled": HAS_LANGGRAPH
  406. }
  407. except Exception as e:
  408. logger.error(f"提交异步任务失败: {e}")
  409. raise HTTPException(status_code=500, detail=f"提交任务失败: {str(e)}")
  410. async def process_request_background(request_id: str):
  411. """后台处理请求"""
  412. try:
  413. logger.info(f"开始后台处理: requestId={request_id}")
  414. if WORKFLOW and HAS_LANGGRAPH:
  415. # 使用 LangGraph 工作流
  416. # 更新状态为处理中
  417. update_request_status(request_id, 1)
  418. initial_state = AgentState(
  419. request_id=request_id,
  420. items=[],
  421. details=[],
  422. processed=0,
  423. success=0,
  424. current_index=0,
  425. current_item=None,
  426. identify_result=None,
  427. error=None,
  428. status="started"
  429. )
  430. final_state = WORKFLOW.invoke(
  431. initial_state,
  432. config={"configurable": {"thread_id": f"thread_{request_id}"}}
  433. )
  434. logger.info(f"LangGraph 后台处理完成: requestId={request_id}, processed={final_state.get('processed', 0)}, success={final_state.get('success', 0)}")
  435. else:
  436. # 传统模式
  437. # 更新状态为处理中
  438. update_request_status(request_id, 1)
  439. items = QueryDataTool.fetch_crawl_data_list(request_id)
  440. print(f"传统模式process_request_background---items: {items}")
  441. if not items:
  442. logger.info(f"后台处理完成: requestId={request_id}, 无数据需要处理")
  443. # 无数据需要处理,更新状态为完成
  444. update_request_status(request_id, 2)
  445. return
  446. success_count = 0
  447. for idx, item in enumerate(items, start=1):
  448. try:
  449. crawl_data = item.get('crawl_data') or {}
  450. content_id = item.get('content_id') or ''
  451. identify_result = identify_tool.run(
  452. crawl_data if isinstance(crawl_data, dict) else {}
  453. )
  454. affected = UpdateDataTool.store_indentify_result(
  455. request_id,
  456. {
  457. content_id: item.get('content_id') or '',
  458. task_id: item.get('task_id') or ''
  459. },
  460. identify_result
  461. )
  462. # 使用StructureTool进行内容结构化处理
  463. structure_tool = StructureTool()
  464. structure_result = structure_tool.process_content_structure(identify_result)
  465. # 存储结构化解析结果
  466. parsing_affected = UpdateDataTool.store_parsing_result(
  467. request_id,
  468. {
  469. content_id: item.get('content_id') or '',
  470. task_id: item.get('task_id') or ''
  471. },
  472. structure_result
  473. )
  474. if affected is not None and affected > 0:
  475. success_count += 1
  476. logger.info(f"后台处理进度: {idx}/{len(items)} - {'成功' if affected else '失败'} - 结构化{'成功' if parsing_affected else '失败'}")
  477. except Exception as e:
  478. logger.error(f"后台处理第 {idx} 项时出错: {e}")
  479. logger.info(f"传统模式后台处理完成: requestId={request_id}, processed={len(items)}, success={success_count}")
  480. # 更新状态为处理完成
  481. update_request_status(request_id, 2)
  482. except Exception as e:
  483. logger.error(f"后台处理失败: requestId={request_id}, error={e}")
  484. # 处理失败,更新状态为3
  485. update_request_status(request_id, 3)
  486. @app.post("/extract")
  487. async def extract(input: str):
  488. """
  489. 执行Agent处理用户指令
  490. Args:
  491. input: 包含用户指令的对象
  492. Returns:
  493. dict: 包含执行结果的字典
  494. """
  495. try:
  496. result = execute_agent_with_api(input)
  497. return {"status": "success", "result": result}
  498. except Exception as e:
  499. raise HTTPException(status_code=500, detail=f"执行Agent时出错: {str(e)}")
  500. @app.post("/expand")
  501. async def expand(request: ExpandRequest, background_tasks: BackgroundTasks):
  502. """
  503. 执行扩展查询处理
  504. Args:
  505. request: 包含请求ID的请求体
  506. background_tasks: FastAPI 后台任务
  507. Returns:
  508. dict: 包含执行状态的字典
  509. """
  510. try:
  511. requestId = request.requestId
  512. # 立即更新状态为处理中
  513. _update_expansion_status(requestId, 1)
  514. # 添加后台任务
  515. background_tasks.add_task(execute_expand_agent_with_api, requestId)
  516. # 立即返回状态
  517. return {"status": 1, "requestId": requestId, "message": "扩展查询处理已启动"}
  518. except Exception as e:
  519. logger.error(f"启动扩展查询处理失败: requestId={requestId}, error={e}")
  520. raise HTTPException(status_code=500, detail=f"启动扩展查询处理时出错: {str(e)}")
  521. if __name__ == "__main__":
  522. # 启动服务
  523. uvicorn.run(
  524. "agent:app",
  525. host="0.0.0.0",
  526. port=8080,
  527. reload=True, # 开发模式,自动重载
  528. log_level="info"
  529. )