|
@@ -1,335 +1,230 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+使用 FastAPI 重构的 Agent 服务
|
|
|
+提供现代化的 HTTP API 接口
|
|
|
+"""
|
|
|
+
|
|
|
import json
|
|
|
import sys
|
|
|
import os
|
|
|
-import argparse
|
|
|
-import signal
|
|
|
-from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
|
-from urllib.parse import urlparse
|
|
|
-from typing import Any, Dict, List, Optional, Tuple
|
|
|
+import time
|
|
|
+from typing import Any, Dict, List, Optional
|
|
|
+from contextlib import asynccontextmanager
|
|
|
|
|
|
# 保证可以导入本项目模块
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
|
+from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
|
+from fastapi.responses import JSONResponse
|
|
|
+from pydantic import BaseModel, Field
|
|
|
+import uvicorn
|
|
|
+
|
|
|
from utils.logging_config import get_logger
|
|
|
from agent_tools import QueryDataTool, IdentifyTool, StructureTool
|
|
|
-from agent_process import start_daemon, stop_daemon, status_daemon
|
|
|
-
|
|
|
-# 可选引入 LangGraph(如未安装,将在运行时优雅回退到顺序执行)
|
|
|
-HAS_LANGGRAPH = False
|
|
|
-try:
|
|
|
- from langgraph.graph import StateGraph, END
|
|
|
- HAS_LANGGRAPH = True
|
|
|
-except Exception:
|
|
|
- HAS_LANGGRAPH = False
|
|
|
|
|
|
+# 创建 logger
|
|
|
+logger = get_logger('AgentFastAPI')
|
|
|
|
|
|
-logger = get_logger('Agent')
|
|
|
+# 请求模型
|
|
|
+class TriggerRequest(BaseModel):
|
|
|
+ requestId: str = Field(..., description="请求ID")
|
|
|
|
|
|
-PID_FILE = os.path.join(os.path.dirname(__file__), 'agent_scheduler.pid')
|
|
|
+# 响应模型
|
|
|
+class TriggerResponse(BaseModel):
|
|
|
+ requestId: str
|
|
|
+ processed: int
|
|
|
+ success: int
|
|
|
+ details: List[Dict[str, Any]]
|
|
|
|
|
|
+# 全局变量
|
|
|
+identify_tool = None
|
|
|
|
|
|
-class ReactAgent:
|
|
|
- def __init__(self) -> None:
|
|
|
- self.identify_tool = IdentifyTool()
|
|
|
-
|
|
|
- def handle_request(self, request_id: str) -> Dict[str, Any]:
|
|
|
- items = QueryDataTool.fetch_crawl_data_list(request_id)
|
|
|
+@asynccontextmanager
|
|
|
+async def lifespan(app: FastAPI):
|
|
|
+ """应用生命周期管理"""
|
|
|
+ # 启动时初始化
|
|
|
+ global identify_tool
|
|
|
+ identify_tool = IdentifyTool()
|
|
|
+ logger.info("Agent 服务启动完成")
|
|
|
+
|
|
|
+ yield
|
|
|
+
|
|
|
+ # 关闭时清理
|
|
|
+ logger.info("Agent 服务正在关闭")
|
|
|
+
|
|
|
+# 创建 FastAPI 应用
|
|
|
+app = FastAPI(
|
|
|
+ title="Knowledge Agent API",
|
|
|
+ description="智能内容识别和结构化处理服务",
|
|
|
+ version="1.0.0",
|
|
|
+ lifespan=lifespan
|
|
|
+)
|
|
|
+
|
|
|
+@app.get("/")
|
|
|
+async def root():
|
|
|
+ """根路径,返回服务信息"""
|
|
|
+ return {
|
|
|
+ "service": "Knowledge Agent API",
|
|
|
+ "version": "1.0.0",
|
|
|
+ "status": "running",
|
|
|
+ "endpoints": {
|
|
|
+ "trigger": "/trigger",
|
|
|
+ "health": "/health",
|
|
|
+ "docs": "/docs"
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+@app.get("/health")
|
|
|
+async def health_check():
|
|
|
+ """健康检查接口"""
|
|
|
+ return {"status": "healthy", "timestamp": time.time()}
|
|
|
+
|
|
|
+@app.post("/trigger", response_model=TriggerResponse)
|
|
|
+async def trigger_processing(request: TriggerRequest, background_tasks: BackgroundTasks):
|
|
|
+ """
|
|
|
+ 触发内容处理
|
|
|
+
|
|
|
+ - **requestId**: 请求ID,用于标识处理任务
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ logger.info(f"收到触发请求: requestId={request.requestId}")
|
|
|
+
|
|
|
+ # 获取待处理数据
|
|
|
+ items = QueryDataTool.fetch_crawl_data_list(request.requestId)
|
|
|
if not items:
|
|
|
- return {"requestId": request_id, "processed": 0, "success": 0, "details": []}
|
|
|
-
|
|
|
+ return TriggerResponse(
|
|
|
+ requestId=request.requestId,
|
|
|
+ processed=0,
|
|
|
+ success=0,
|
|
|
+ details=[]
|
|
|
+ )
|
|
|
+
|
|
|
+ # 处理数据
|
|
|
success_count = 0
|
|
|
details: List[Dict[str, Any]] = []
|
|
|
+
|
|
|
for idx, item in enumerate(items, start=1):
|
|
|
- crawl_data = item.get('crawl_data') or {}
|
|
|
-
|
|
|
- # Step 1: 识别
|
|
|
- identify_result = self.identify_tool.run(crawl_data if isinstance(crawl_data, dict) else {})
|
|
|
-
|
|
|
- # Step 2: 结构化并入库
|
|
|
- affected = StructureTool.store_parsing_result(request_id, item.get('raw') or {}, identify_result)
|
|
|
- ok = affected is not None and affected > 0
|
|
|
- if ok:
|
|
|
- success_count += 1
|
|
|
-
|
|
|
- details.append({
|
|
|
- "index": idx,
|
|
|
- "dbInserted": ok,
|
|
|
- "identifyError": identify_result.get('error'),
|
|
|
- })
|
|
|
-
|
|
|
+ try:
|
|
|
+ crawl_data = item.get('crawl_data') or {}
|
|
|
+
|
|
|
+ # Step 1: 识别
|
|
|
+ identify_result = identify_tool.run(
|
|
|
+ crawl_data if isinstance(crawl_data, dict) else {}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Step 2: 结构化并入库
|
|
|
+ affected = StructureTool.store_parsing_result(
|
|
|
+ request.requestId,
|
|
|
+ item.get('raw') or {},
|
|
|
+ identify_result
|
|
|
+ )
|
|
|
+
|
|
|
+ ok = affected is not None and affected > 0
|
|
|
+ if ok:
|
|
|
+ success_count += 1
|
|
|
+
|
|
|
+ details.append({
|
|
|
+ "index": idx,
|
|
|
+ "dbInserted": ok,
|
|
|
+ "identifyError": identify_result.get('error'),
|
|
|
+ "status": "success" if ok else "failed"
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"处理第 {idx} 项时出错: {e}")
|
|
|
+ details.append({
|
|
|
+ "index": idx,
|
|
|
+ "dbInserted": False,
|
|
|
+ "identifyError": str(e),
|
|
|
+ "status": "error"
|
|
|
+ })
|
|
|
+
|
|
|
+ result = TriggerResponse(
|
|
|
+ requestId=request.requestId,
|
|
|
+ processed=len(items),
|
|
|
+ success=success_count,
|
|
|
+ details=details
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info(f"处理完成: requestId={request.requestId}, processed={len(items)}, success={success_count}")
|
|
|
+ return result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"处理请求失败: {e}")
|
|
|
+ raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
|
|
+
|
|
|
+@app.post("/trigger/async")
|
|
|
+async def trigger_processing_async(request: TriggerRequest, background_tasks: BackgroundTasks):
|
|
|
+ """
|
|
|
+ 异步触发内容处理(后台任务)
|
|
|
+
|
|
|
+ - **requestId**: 请求ID,用于标识处理任务
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ logger.info(f"收到异步触发请求: requestId={request.requestId}")
|
|
|
+
|
|
|
+ # 添加后台任务
|
|
|
+ background_tasks.add_task(process_request_background, request.requestId)
|
|
|
+
|
|
|
return {
|
|
|
- "requestId": request_id,
|
|
|
- "processed": len(items),
|
|
|
- "success": success_count,
|
|
|
- "details": details,
|
|
|
+ "requestId": request.requestId,
|
|
|
+ "status": "processing",
|
|
|
+ "message": "任务已提交到后台处理"
|
|
|
}
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"提交异步任务失败: {e}")
|
|
|
+ raise HTTPException(status_code=500, detail=f"提交任务失败: {str(e)}")
|
|
|
|
|
|
-
|
|
|
-AGENT = ReactAgent()
|
|
|
-
|
|
|
-
|
|
|
-# =========================
|
|
|
-# LangGraph 风格实现(可选)
|
|
|
-# =========================
|
|
|
-def build_langgraph_app():
|
|
|
- if not HAS_LANGGRAPH:
|
|
|
- return None
|
|
|
-
|
|
|
- # 状态:以 dict 形式承载
|
|
|
- # 输入: {"request_id": str}
|
|
|
- # 输出附加: items, details, processed, success
|
|
|
-
|
|
|
- def node_fetch(state: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
- request_id = str(state.get("request_id", ""))
|
|
|
+async def process_request_background(request_id: str):
|
|
|
+ """后台处理请求"""
|
|
|
+ try:
|
|
|
+ logger.info(f"开始后台处理: requestId={request_id}")
|
|
|
+
|
|
|
+ # 获取待处理数据
|
|
|
items = QueryDataTool.fetch_crawl_data_list(request_id)
|
|
|
- return {
|
|
|
- **state,
|
|
|
- "items": items,
|
|
|
- "details": [],
|
|
|
- "processed": 0,
|
|
|
- "success": 0,
|
|
|
- }
|
|
|
-
|
|
|
- identify_tool = IdentifyTool()
|
|
|
+ if not items:
|
|
|
+ logger.info(f"后台处理完成: requestId={request_id}, 无数据需要处理")
|
|
|
+ return
|
|
|
|
|
|
- def node_process(state: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
- request_id = str(state.get("request_id", ""))
|
|
|
- items: List[Dict[str, Any]] = state.get("items", []) or []
|
|
|
- details: List[Dict[str, Any]] = []
|
|
|
+ # 处理数据
|
|
|
success_count = 0
|
|
|
-
|
|
|
for idx, item in enumerate(items, start=1):
|
|
|
- crawl_data = item.get('crawl_data') or {}
|
|
|
- identify_result = identify_tool.run(crawl_data if isinstance(crawl_data, dict) else {})
|
|
|
- affected = StructureTool.store_parsing_result(request_id, item.get('raw') or {}, identify_result)
|
|
|
- ok = affected is not None and affected > 0
|
|
|
- if ok:
|
|
|
- success_count += 1
|
|
|
- details.append({
|
|
|
- "index": idx,
|
|
|
- "dbInserted": ok,
|
|
|
- "identifyError": identify_result.get('error'),
|
|
|
- })
|
|
|
-
|
|
|
- return {
|
|
|
- **state,
|
|
|
- "details": details,
|
|
|
- "processed": len(items),
|
|
|
- "success": success_count,
|
|
|
- }
|
|
|
-
|
|
|
- graph = StateGraph(dict)
|
|
|
- graph.add_node("fetch", node_fetch)
|
|
|
- graph.add_node("process", node_process)
|
|
|
-
|
|
|
- graph.set_entry_point("fetch")
|
|
|
- graph.add_edge("fetch", "process")
|
|
|
- graph.add_edge("process", END)
|
|
|
-
|
|
|
- return graph.compile()
|
|
|
-
|
|
|
-
|
|
|
-APP = build_langgraph_app()
|
|
|
-
|
|
|
-
|
|
|
-class AgentHttpHandler(BaseHTTPRequestHandler):
|
|
|
- def _set_headers(self, status_code: int = 200):
|
|
|
- self.send_response(status_code)
|
|
|
- self.send_header('Content-Type', 'application/json; charset=utf-8')
|
|
|
- self.end_headers()
|
|
|
-
|
|
|
- def do_POST(self):
|
|
|
- parsed = urlparse(self.path)
|
|
|
- if parsed.path != '/trigger':
|
|
|
- self._set_headers(404)
|
|
|
- self.wfile.write(json.dumps({"error": "not found"}).encode('utf-8'))
|
|
|
- return
|
|
|
-
|
|
|
- length = int(self.headers.get('Content-Length', '0') or '0')
|
|
|
- body = self.rfile.read(length) if length > 0 else b''
|
|
|
- try:
|
|
|
- payload = json.loads(body.decode('utf-8')) if body else {}
|
|
|
- except Exception:
|
|
|
- self._set_headers(400)
|
|
|
- self.wfile.write(json.dumps({"error": "invalid json"}).encode('utf-8'))
|
|
|
- return
|
|
|
-
|
|
|
- request_id = (payload or {}).get('requestId')
|
|
|
- if not request_id:
|
|
|
- self._set_headers(400)
|
|
|
- self.wfile.write(json.dumps({"error": "requestId is required"}).encode('utf-8'))
|
|
|
- return
|
|
|
-
|
|
|
- try:
|
|
|
- logger.info(f"收到触发请求: requestId={request_id}")
|
|
|
- if APP is not None:
|
|
|
- result = APP.invoke({"request_id": str(request_id)})
|
|
|
- # 标准化返回
|
|
|
- result = {
|
|
|
- "requestId": str(request_id),
|
|
|
- "processed": result.get("processed", 0),
|
|
|
- "success": result.get("success", 0),
|
|
|
- "details": result.get("details", []),
|
|
|
- }
|
|
|
- else:
|
|
|
- # 回退到顺序执行
|
|
|
- result = AGENT.handle_request(str(request_id))
|
|
|
- self._set_headers(200)
|
|
|
- self.wfile.write(json.dumps(result, ensure_ascii=False).encode('utf-8'))
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"处理失败: {e}")
|
|
|
- self._set_headers(500)
|
|
|
- self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
|
|
|
-
|
|
|
- def log_message(self, format: str, *args) -> None:
|
|
|
- # 重定向默认日志到我们统一的 logger
|
|
|
- logger.info("HTTP " + (format % args))
|
|
|
-
|
|
|
-
|
|
|
-def run(host: str = '0.0.0.0', port: int = 8080):
|
|
|
- server_address = (host, port)
|
|
|
- httpd = HTTPServer(server_address, AgentHttpHandler)
|
|
|
-
|
|
|
- def _graceful_shutdown(signum, frame):
|
|
|
- try:
|
|
|
- logger.info(f"收到信号 {signum},正在停止HTTP服务...")
|
|
|
- # shutdown 会在其他线程优雅停止; 这里我们直接关闭,避免阻塞
|
|
|
- httpd.shutdown()
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
- for sig in (signal.SIGINT, signal.SIGTERM):
|
|
|
- signal.signal(sig, _graceful_shutdown)
|
|
|
-
|
|
|
- logger.info(f"Agent HTTP 服务已启动: http://{host}:{port}/trigger")
|
|
|
- try:
|
|
|
- httpd.serve_forever()
|
|
|
- finally:
|
|
|
- try:
|
|
|
- httpd.server_close()
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
- logger.info("Agent HTTP 服务已停止")
|
|
|
-
|
|
|
-
|
|
|
-def _write_pid_file(pid: int) -> None:
|
|
|
- with open(PID_FILE, 'w') as f:
|
|
|
- f.write(str(pid))
|
|
|
-
|
|
|
-
|
|
|
-def _read_pid_file() -> Optional[int]:
|
|
|
- if not os.path.exists(PID_FILE):
|
|
|
- return None
|
|
|
- try:
|
|
|
- with open(PID_FILE, 'r') as f:
|
|
|
- content = f.read().strip()
|
|
|
- return int(content) if content else None
|
|
|
- except Exception:
|
|
|
- return None
|
|
|
-
|
|
|
-
|
|
|
-def _is_process_running(pid: int) -> bool:
|
|
|
- try:
|
|
|
- os.kill(pid, 0)
|
|
|
- return True
|
|
|
- except Exception:
|
|
|
- return False
|
|
|
-
|
|
|
-
|
|
|
-def start_daemon(host: str, port: int) -> Dict[str, Any]:
|
|
|
- old_pid = _read_pid_file()
|
|
|
- if old_pid and _is_process_running(old_pid):
|
|
|
- return {"status": "already_running", "pid": old_pid}
|
|
|
-
|
|
|
- python_exec = sys.executable
|
|
|
- script_path = os.path.abspath(__file__)
|
|
|
- args = [python_exec, script_path, "--serve", "--host", host, "--port", str(port)]
|
|
|
-
|
|
|
- with open(os.devnull, 'wb') as devnull:
|
|
|
- proc = subprocess.Popen(
|
|
|
- args,
|
|
|
- stdout=devnull,
|
|
|
- stderr=devnull,
|
|
|
- stdin=devnull,
|
|
|
- close_fds=True,
|
|
|
- preexec_fn=os.setsid if hasattr(os, 'setsid') else None,
|
|
|
- )
|
|
|
-
|
|
|
- _write_pid_file(proc.pid)
|
|
|
- # 简单等待,确认进程未立即退出
|
|
|
- time.sleep(0.5)
|
|
|
- running = _is_process_running(proc.pid)
|
|
|
- return {"status": "started" if running else "failed", "pid": proc.pid}
|
|
|
-
|
|
|
-
|
|
|
-def stop_daemon(timeout: float = 5.0) -> Dict[str, Any]:
|
|
|
- pid = _read_pid_file()
|
|
|
- if not pid:
|
|
|
- return {"status": "not_running"}
|
|
|
-
|
|
|
- if not _is_process_running(pid):
|
|
|
- try:
|
|
|
- os.remove(PID_FILE)
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
- return {"status": "not_running"}
|
|
|
-
|
|
|
- try:
|
|
|
- os.kill(pid, signal.SIGTERM)
|
|
|
+ try:
|
|
|
+ crawl_data = item.get('crawl_data') or {}
|
|
|
+
|
|
|
+ # Step 1: 识别
|
|
|
+ identify_result = identify_tool.run(
|
|
|
+ crawl_data if isinstance(crawl_data, dict) else {}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Step 2: 结构化并入库
|
|
|
+ affected = StructureTool.store_parsing_result(
|
|
|
+ request_id,
|
|
|
+ item.get('raw') or {},
|
|
|
+ identify_result
|
|
|
+ )
|
|
|
+
|
|
|
+ if affected is not None and affected > 0:
|
|
|
+ success_count += 1
|
|
|
+
|
|
|
+ logger.info(f"后台处理进度: {idx}/{len(items)} - {'成功' if affected else '失败'}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"后台处理第 {idx} 项时出错: {e}")
|
|
|
+
|
|
|
+ logger.info(f"后台处理完成: requestId={request_id}, processed={len(items)}, success={success_count}")
|
|
|
+
|
|
|
except Exception as e:
|
|
|
- return {"status": "error", "error": str(e)}
|
|
|
-
|
|
|
- start_time = time.time()
|
|
|
- while time.time() - start_time < timeout:
|
|
|
- if not _is_process_running(pid):
|
|
|
- break
|
|
|
- time.sleep(0.2)
|
|
|
-
|
|
|
- if _is_process_running(pid):
|
|
|
- try:
|
|
|
- os.kill(pid, signal.SIGKILL)
|
|
|
- except Exception as e:
|
|
|
- return {"status": "error", "error": str(e)}
|
|
|
-
|
|
|
- try:
|
|
|
- os.remove(PID_FILE)
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
-
|
|
|
- return {"status": "stopped"}
|
|
|
-
|
|
|
-
|
|
|
-def status_daemon() -> Dict[str, Any]:
|
|
|
- pid = _read_pid_file()
|
|
|
- if pid and _is_process_running(pid):
|
|
|
- return {"status": "running", "pid": pid}
|
|
|
- return {"status": "not_running"}
|
|
|
-
|
|
|
-
|
|
|
-if __name__ == '__main__':
|
|
|
- parser = argparse.ArgumentParser(description='Agent 服务管理')
|
|
|
- parser.add_argument('--serve', action='store_true', help='以前台模式启动HTTP服务')
|
|
|
- parser.add_argument('--host', default='0.0.0.0', help='监听地址')
|
|
|
- parser.add_argument('--port', type=int, default=8080, help='监听端口')
|
|
|
- parser.add_argument('command', nargs='?', choices=['start', 'stop', 'status'], help='守护进程管理命令')
|
|
|
- args = parser.parse_args()
|
|
|
-
|
|
|
- if args.serve:
|
|
|
- run(args.host, args.port)
|
|
|
- sys.exit(0)
|
|
|
-
|
|
|
- if args.command == 'start':
|
|
|
- res = start_daemon(args.host, args.port)
|
|
|
- print(json.dumps(res, ensure_ascii=False))
|
|
|
- sys.exit(0 if res.get('status') == 'started' else 1)
|
|
|
- elif args.command == 'stop':
|
|
|
- res = stop_daemon()
|
|
|
- print(json.dumps(res, ensure_ascii=False))
|
|
|
- sys.exit(0 if res.get('status') in ('stopped', 'not_running') else 1)
|
|
|
- elif args.command == 'status':
|
|
|
- res = status_daemon()
|
|
|
- print(json.dumps(res, ensure_ascii=False))
|
|
|
- sys.exit(0 if res.get('status') in ('running', 'not_running') else 1)
|
|
|
- else:
|
|
|
- # 默认行为:以前台启动(兼容旧用法)
|
|
|
- run(args.host, args.port)
|
|
|
-
|
|
|
+ logger.error(f"后台处理失败: requestId={request_id}, error={e}")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 启动服务
|
|
|
+ uvicorn.run(
|
|
|
+ "agent:app",
|
|
|
+ host="0.0.0.0",
|
|
|
+ port=8080,
|
|
|
+ reload=True, # 开发模式,自动重载
|
|
|
+ log_level="info"
|
|
|
+ )
|