agent.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import json
  2. import sys
  3. import os
  4. import argparse
  5. import signal
  6. from http.server import BaseHTTPRequestHandler, HTTPServer
  7. from urllib.parse import urlparse
  8. from typing import Any, Dict, List, Optional, Tuple
  9. # 保证可以导入本项目模块
  10. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  11. from utils.logging_config import get_logger
  12. from agent_tools import QueryDataTool, IdentifyTool, StructureTool
  13. from agent_process import start_daemon, stop_daemon, status_daemon
  14. # 可选引入 LangGraph(如未安装,将在运行时优雅回退到顺序执行)
  15. HAS_LANGGRAPH = False
  16. try:
  17. from langgraph.graph import StateGraph, END
  18. HAS_LANGGRAPH = True
  19. except Exception:
  20. HAS_LANGGRAPH = False
  21. logger = get_logger('Agent')
  22. PID_FILE = os.path.join(os.path.dirname(__file__), 'agent_scheduler.pid')
  23. class ReactAgent:
  24. def __init__(self) -> None:
  25. self.identify_tool = IdentifyTool()
  26. def handle_request(self, request_id: str) -> Dict[str, Any]:
  27. items = QueryDataTool.fetch_crawl_data_list(request_id)
  28. if not items:
  29. return {"requestId": request_id, "processed": 0, "success": 0, "details": []}
  30. success_count = 0
  31. details: List[Dict[str, Any]] = []
  32. for idx, item in enumerate(items, start=1):
  33. crawl_data = item.get('crawl_data') or {}
  34. # Step 1: 识别
  35. identify_result = self.identify_tool.run(crawl_data if isinstance(crawl_data, dict) else {})
  36. # Step 2: 结构化并入库
  37. affected = StructureTool.store_parsing_result(request_id, item.get('raw') or {}, identify_result)
  38. ok = affected is not None and affected > 0
  39. if ok:
  40. success_count += 1
  41. details.append({
  42. "index": idx,
  43. "dbInserted": ok,
  44. "identifyError": identify_result.get('error'),
  45. })
  46. return {
  47. "requestId": request_id,
  48. "processed": len(items),
  49. "success": success_count,
  50. "details": details,
  51. }
  52. AGENT = ReactAgent()
  53. # =========================
  54. # LangGraph 风格实现(可选)
  55. # =========================
  56. def build_langgraph_app():
  57. if not HAS_LANGGRAPH:
  58. return None
  59. # 状态:以 dict 形式承载
  60. # 输入: {"request_id": str}
  61. # 输出附加: items, details, processed, success
  62. def node_fetch(state: Dict[str, Any]) -> Dict[str, Any]:
  63. request_id = str(state.get("request_id", ""))
  64. items = QueryDataTool.fetch_crawl_data_list(request_id)
  65. return {
  66. **state,
  67. "items": items,
  68. "details": [],
  69. "processed": 0,
  70. "success": 0,
  71. }
  72. identify_tool = IdentifyTool()
  73. def node_process(state: Dict[str, Any]) -> Dict[str, Any]:
  74. request_id = str(state.get("request_id", ""))
  75. items: List[Dict[str, Any]] = state.get("items", []) or []
  76. details: List[Dict[str, Any]] = []
  77. success_count = 0
  78. for idx, item in enumerate(items, start=1):
  79. crawl_data = item.get('crawl_data') or {}
  80. identify_result = identify_tool.run(crawl_data if isinstance(crawl_data, dict) else {})
  81. affected = StructureTool.store_parsing_result(request_id, item.get('raw') or {}, identify_result)
  82. ok = affected is not None and affected > 0
  83. if ok:
  84. success_count += 1
  85. details.append({
  86. "index": idx,
  87. "dbInserted": ok,
  88. "identifyError": identify_result.get('error'),
  89. })
  90. return {
  91. **state,
  92. "details": details,
  93. "processed": len(items),
  94. "success": success_count,
  95. }
  96. graph = StateGraph(dict)
  97. graph.add_node("fetch", node_fetch)
  98. graph.add_node("process", node_process)
  99. graph.set_entry_point("fetch")
  100. graph.add_edge("fetch", "process")
  101. graph.add_edge("process", END)
  102. return graph.compile()
  103. APP = build_langgraph_app()
  104. class AgentHttpHandler(BaseHTTPRequestHandler):
  105. def _set_headers(self, status_code: int = 200):
  106. self.send_response(status_code)
  107. self.send_header('Content-Type', 'application/json; charset=utf-8')
  108. self.end_headers()
  109. def do_POST(self):
  110. parsed = urlparse(self.path)
  111. if parsed.path != '/trigger':
  112. self._set_headers(404)
  113. self.wfile.write(json.dumps({"error": "not found"}).encode('utf-8'))
  114. return
  115. length = int(self.headers.get('Content-Length', '0') or '0')
  116. body = self.rfile.read(length) if length > 0 else b''
  117. try:
  118. payload = json.loads(body.decode('utf-8')) if body else {}
  119. except Exception:
  120. self._set_headers(400)
  121. self.wfile.write(json.dumps({"error": "invalid json"}).encode('utf-8'))
  122. return
  123. request_id = (payload or {}).get('requestId')
  124. if not request_id:
  125. self._set_headers(400)
  126. self.wfile.write(json.dumps({"error": "requestId is required"}).encode('utf-8'))
  127. return
  128. try:
  129. logger.info(f"收到触发请求: requestId={request_id}")
  130. if APP is not None:
  131. result = APP.invoke({"request_id": str(request_id)})
  132. # 标准化返回
  133. result = {
  134. "requestId": str(request_id),
  135. "processed": result.get("processed", 0),
  136. "success": result.get("success", 0),
  137. "details": result.get("details", []),
  138. }
  139. else:
  140. # 回退到顺序执行
  141. result = AGENT.handle_request(str(request_id))
  142. self._set_headers(200)
  143. self.wfile.write(json.dumps(result, ensure_ascii=False).encode('utf-8'))
  144. except Exception as e:
  145. logger.error(f"处理失败: {e}")
  146. self._set_headers(500)
  147. self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
  148. def log_message(self, format: str, *args) -> None:
  149. # 重定向默认日志到我们统一的 logger
  150. logger.info("HTTP " + (format % args))
  151. def run(host: str = '0.0.0.0', port: int = 8080):
  152. server_address = (host, port)
  153. httpd = HTTPServer(server_address, AgentHttpHandler)
  154. def _graceful_shutdown(signum, frame):
  155. try:
  156. logger.info(f"收到信号 {signum},正在停止HTTP服务...")
  157. # shutdown 会在其他线程优雅停止; 这里我们直接关闭,避免阻塞
  158. httpd.shutdown()
  159. except Exception:
  160. pass
  161. for sig in (signal.SIGINT, signal.SIGTERM):
  162. signal.signal(sig, _graceful_shutdown)
  163. logger.info(f"Agent HTTP 服务已启动: http://{host}:{port}/trigger")
  164. try:
  165. httpd.serve_forever()
  166. finally:
  167. try:
  168. httpd.server_close()
  169. except Exception:
  170. pass
  171. logger.info("Agent HTTP 服务已停止")
  172. def _write_pid_file(pid: int) -> None:
  173. with open(PID_FILE, 'w') as f:
  174. f.write(str(pid))
  175. def _read_pid_file() -> Optional[int]:
  176. if not os.path.exists(PID_FILE):
  177. return None
  178. try:
  179. with open(PID_FILE, 'r') as f:
  180. content = f.read().strip()
  181. return int(content) if content else None
  182. except Exception:
  183. return None
  184. def _is_process_running(pid: int) -> bool:
  185. try:
  186. os.kill(pid, 0)
  187. return True
  188. except Exception:
  189. return False
  190. def start_daemon(host: str, port: int) -> Dict[str, Any]:
  191. old_pid = _read_pid_file()
  192. if old_pid and _is_process_running(old_pid):
  193. return {"status": "already_running", "pid": old_pid}
  194. python_exec = sys.executable
  195. script_path = os.path.abspath(__file__)
  196. args = [python_exec, script_path, "--serve", "--host", host, "--port", str(port)]
  197. with open(os.devnull, 'wb') as devnull:
  198. proc = subprocess.Popen(
  199. args,
  200. stdout=devnull,
  201. stderr=devnull,
  202. stdin=devnull,
  203. close_fds=True,
  204. preexec_fn=os.setsid if hasattr(os, 'setsid') else None,
  205. )
  206. _write_pid_file(proc.pid)
  207. # 简单等待,确认进程未立即退出
  208. time.sleep(0.5)
  209. running = _is_process_running(proc.pid)
  210. return {"status": "started" if running else "failed", "pid": proc.pid}
  211. def stop_daemon(timeout: float = 5.0) -> Dict[str, Any]:
  212. pid = _read_pid_file()
  213. if not pid:
  214. return {"status": "not_running"}
  215. if not _is_process_running(pid):
  216. try:
  217. os.remove(PID_FILE)
  218. except Exception:
  219. pass
  220. return {"status": "not_running"}
  221. try:
  222. os.kill(pid, signal.SIGTERM)
  223. except Exception as e:
  224. return {"status": "error", "error": str(e)}
  225. start_time = time.time()
  226. while time.time() - start_time < timeout:
  227. if not _is_process_running(pid):
  228. break
  229. time.sleep(0.2)
  230. if _is_process_running(pid):
  231. try:
  232. os.kill(pid, signal.SIGKILL)
  233. except Exception as e:
  234. return {"status": "error", "error": str(e)}
  235. try:
  236. os.remove(PID_FILE)
  237. except Exception:
  238. pass
  239. return {"status": "stopped"}
  240. def status_daemon() -> Dict[str, Any]:
  241. pid = _read_pid_file()
  242. if pid and _is_process_running(pid):
  243. return {"status": "running", "pid": pid}
  244. return {"status": "not_running"}
  245. if __name__ == '__main__':
  246. parser = argparse.ArgumentParser(description='Agent 服务管理')
  247. parser.add_argument('--serve', action='store_true', help='以前台模式启动HTTP服务')
  248. parser.add_argument('--host', default='0.0.0.0', help='监听地址')
  249. parser.add_argument('--port', type=int, default=8080, help='监听端口')
  250. parser.add_argument('command', nargs='?', choices=['start', 'stop', 'status'], help='守护进程管理命令')
  251. args = parser.parse_args()
  252. if args.serve:
  253. run(args.host, args.port)
  254. sys.exit(0)
  255. if args.command == 'start':
  256. res = start_daemon(args.host, args.port)
  257. print(json.dumps(res, ensure_ascii=False))
  258. sys.exit(0 if res.get('status') == 'started' else 1)
  259. elif args.command == 'stop':
  260. res = stop_daemon()
  261. print(json.dumps(res, ensure_ascii=False))
  262. sys.exit(0 if res.get('status') in ('stopped', 'not_running') else 1)
  263. elif args.command == 'status':
  264. res = status_daemon()
  265. print(json.dumps(res, ensure_ascii=False))
  266. sys.exit(0 if res.get('status') in ('running', 'not_running') else 1)
  267. else:
  268. # 默认行为:以前台启动(兼容旧用法)
  269. run(args.host, args.port)