123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- import json
- import sys
- import os
- import argparse
- import signal
- from http.server import BaseHTTPRequestHandler, HTTPServer
- from urllib.parse import urlparse
- from typing import Any, Dict, List, Optional, Tuple
- # 保证可以导入本项目模块
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
- from utils.logging_config import get_logger
- from agent_tools import QueryDataTool, IdentifyTool, StructureTool
- from agent_process import start_daemon, stop_daemon, status_daemon
- # 可选引入 LangGraph(如未安装,将在运行时优雅回退到顺序执行)
- HAS_LANGGRAPH = False
- try:
- from langgraph.graph import StateGraph, END
- HAS_LANGGRAPH = True
- except Exception:
- HAS_LANGGRAPH = False
- logger = get_logger('Agent')
- PID_FILE = os.path.join(os.path.dirname(__file__), 'agent_scheduler.pid')
- class ReactAgent:
- def __init__(self) -> None:
- self.identify_tool = IdentifyTool()
- def handle_request(self, request_id: str) -> Dict[str, Any]:
- items = QueryDataTool.fetch_crawl_data_list(request_id)
- if not items:
- return {"requestId": request_id, "processed": 0, "success": 0, "details": []}
- success_count = 0
- details: List[Dict[str, Any]] = []
- for idx, item in enumerate(items, start=1):
- crawl_data = item.get('crawl_data') or {}
- # Step 1: 识别
- identify_result = self.identify_tool.run(crawl_data if isinstance(crawl_data, dict) else {})
- # Step 2: 结构化并入库
- affected = StructureTool.store_parsing_result(request_id, item.get('raw') or {}, identify_result)
- ok = affected is not None and affected > 0
- if ok:
- success_count += 1
- details.append({
- "index": idx,
- "dbInserted": ok,
- "identifyError": identify_result.get('error'),
- })
- return {
- "requestId": request_id,
- "processed": len(items),
- "success": success_count,
- "details": details,
- }
- AGENT = ReactAgent()
- # =========================
- # LangGraph 风格实现(可选)
- # =========================
- def build_langgraph_app():
- if not HAS_LANGGRAPH:
- return None
- # 状态:以 dict 形式承载
- # 输入: {"request_id": str}
- # 输出附加: items, details, processed, success
- def node_fetch(state: Dict[str, Any]) -> Dict[str, Any]:
- request_id = str(state.get("request_id", ""))
- items = QueryDataTool.fetch_crawl_data_list(request_id)
- return {
- **state,
- "items": items,
- "details": [],
- "processed": 0,
- "success": 0,
- }
- identify_tool = IdentifyTool()
- def node_process(state: Dict[str, Any]) -> Dict[str, Any]:
- request_id = str(state.get("request_id", ""))
- items: List[Dict[str, Any]] = state.get("items", []) or []
- details: List[Dict[str, Any]] = []
- success_count = 0
- for idx, item in enumerate(items, start=1):
- crawl_data = item.get('crawl_data') or {}
- identify_result = identify_tool.run(crawl_data if isinstance(crawl_data, dict) else {})
- affected = StructureTool.store_parsing_result(request_id, item.get('raw') or {}, identify_result)
- ok = affected is not None and affected > 0
- if ok:
- success_count += 1
- details.append({
- "index": idx,
- "dbInserted": ok,
- "identifyError": identify_result.get('error'),
- })
- return {
- **state,
- "details": details,
- "processed": len(items),
- "success": success_count,
- }
- graph = StateGraph(dict)
- graph.add_node("fetch", node_fetch)
- graph.add_node("process", node_process)
- graph.set_entry_point("fetch")
- graph.add_edge("fetch", "process")
- graph.add_edge("process", END)
- return graph.compile()
- APP = build_langgraph_app()
- class AgentHttpHandler(BaseHTTPRequestHandler):
- def _set_headers(self, status_code: int = 200):
- self.send_response(status_code)
- self.send_header('Content-Type', 'application/json; charset=utf-8')
- self.end_headers()
- def do_POST(self):
- parsed = urlparse(self.path)
- if parsed.path != '/trigger':
- self._set_headers(404)
- self.wfile.write(json.dumps({"error": "not found"}).encode('utf-8'))
- return
- length = int(self.headers.get('Content-Length', '0') or '0')
- body = self.rfile.read(length) if length > 0 else b''
- try:
- payload = json.loads(body.decode('utf-8')) if body else {}
- except Exception:
- self._set_headers(400)
- self.wfile.write(json.dumps({"error": "invalid json"}).encode('utf-8'))
- return
- request_id = (payload or {}).get('requestId')
- if not request_id:
- self._set_headers(400)
- self.wfile.write(json.dumps({"error": "requestId is required"}).encode('utf-8'))
- return
- try:
- logger.info(f"收到触发请求: requestId={request_id}")
- if APP is not None:
- result = APP.invoke({"request_id": str(request_id)})
- # 标准化返回
- result = {
- "requestId": str(request_id),
- "processed": result.get("processed", 0),
- "success": result.get("success", 0),
- "details": result.get("details", []),
- }
- else:
- # 回退到顺序执行
- result = AGENT.handle_request(str(request_id))
- self._set_headers(200)
- self.wfile.write(json.dumps(result, ensure_ascii=False).encode('utf-8'))
- except Exception as e:
- logger.error(f"处理失败: {e}")
- self._set_headers(500)
- self.wfile.write(json.dumps({"error": str(e)}).encode('utf-8'))
- def log_message(self, format: str, *args) -> None:
- # 重定向默认日志到我们统一的 logger
- logger.info("HTTP " + (format % args))
- def run(host: str = '0.0.0.0', port: int = 8080):
- server_address = (host, port)
- httpd = HTTPServer(server_address, AgentHttpHandler)
- def _graceful_shutdown(signum, frame):
- try:
- logger.info(f"收到信号 {signum},正在停止HTTP服务...")
- # shutdown 会在其他线程优雅停止; 这里我们直接关闭,避免阻塞
- httpd.shutdown()
- except Exception:
- pass
- for sig in (signal.SIGINT, signal.SIGTERM):
- signal.signal(sig, _graceful_shutdown)
- logger.info(f"Agent HTTP 服务已启动: http://{host}:{port}/trigger")
- try:
- httpd.serve_forever()
- finally:
- try:
- httpd.server_close()
- except Exception:
- pass
- logger.info("Agent HTTP 服务已停止")
- def _write_pid_file(pid: int) -> None:
- with open(PID_FILE, 'w') as f:
- f.write(str(pid))
- def _read_pid_file() -> Optional[int]:
- if not os.path.exists(PID_FILE):
- return None
- try:
- with open(PID_FILE, 'r') as f:
- content = f.read().strip()
- return int(content) if content else None
- except Exception:
- return None
- def _is_process_running(pid: int) -> bool:
- try:
- os.kill(pid, 0)
- return True
- except Exception:
- return False
- def start_daemon(host: str, port: int) -> Dict[str, Any]:
- old_pid = _read_pid_file()
- if old_pid and _is_process_running(old_pid):
- return {"status": "already_running", "pid": old_pid}
- python_exec = sys.executable
- script_path = os.path.abspath(__file__)
- args = [python_exec, script_path, "--serve", "--host", host, "--port", str(port)]
- with open(os.devnull, 'wb') as devnull:
- proc = subprocess.Popen(
- args,
- stdout=devnull,
- stderr=devnull,
- stdin=devnull,
- close_fds=True,
- preexec_fn=os.setsid if hasattr(os, 'setsid') else None,
- )
- _write_pid_file(proc.pid)
- # 简单等待,确认进程未立即退出
- time.sleep(0.5)
- running = _is_process_running(proc.pid)
- return {"status": "started" if running else "failed", "pid": proc.pid}
- def stop_daemon(timeout: float = 5.0) -> Dict[str, Any]:
- pid = _read_pid_file()
- if not pid:
- return {"status": "not_running"}
- if not _is_process_running(pid):
- try:
- os.remove(PID_FILE)
- except Exception:
- pass
- return {"status": "not_running"}
- try:
- os.kill(pid, signal.SIGTERM)
- except Exception as e:
- return {"status": "error", "error": str(e)}
- start_time = time.time()
- while time.time() - start_time < timeout:
- if not _is_process_running(pid):
- break
- time.sleep(0.2)
- if _is_process_running(pid):
- try:
- os.kill(pid, signal.SIGKILL)
- except Exception as e:
- return {"status": "error", "error": str(e)}
- try:
- os.remove(PID_FILE)
- except Exception:
- pass
- return {"status": "stopped"}
- def status_daemon() -> Dict[str, Any]:
- pid = _read_pid_file()
- if pid and _is_process_running(pid):
- return {"status": "running", "pid": pid}
- return {"status": "not_running"}
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Agent 服务管理')
- parser.add_argument('--serve', action='store_true', help='以前台模式启动HTTP服务')
- parser.add_argument('--host', default='0.0.0.0', help='监听地址')
- parser.add_argument('--port', type=int, default=8080, help='监听端口')
- parser.add_argument('command', nargs='?', choices=['start', 'stop', 'status'], help='守护进程管理命令')
- args = parser.parse_args()
- if args.serve:
- run(args.host, args.port)
- sys.exit(0)
- if args.command == 'start':
- res = start_daemon(args.host, args.port)
- print(json.dumps(res, ensure_ascii=False))
- sys.exit(0 if res.get('status') == 'started' else 1)
- elif args.command == 'stop':
- res = stop_daemon()
- print(json.dumps(res, ensure_ascii=False))
- sys.exit(0 if res.get('status') in ('stopped', 'not_running') else 1)
- elif args.command == 'status':
- res = status_daemon()
- print(json.dumps(res, ensure_ascii=False))
- sys.exit(0 if res.get('status') in ('running', 'not_running') else 1)
- else:
- # 默认行为:以前台启动(兼容旧用法)
- run(args.host, args.port)
|