| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- templateData.py - 生成 Trace 可视化的 Mock 数据
- """
- import os
- import asyncio
- import json
- from datetime import datetime
- from typing import Dict, List, Any, Optional, Tuple
- import httpx
- import websockets
- from templateHtml import generate_trace_visualization_html
- goalList: List[Dict[str, Any]] = []
- msgList: List[Dict[str, Any]] = []
- msgGroups: Dict[str, List[Dict[str, Any]]] = {}
- def generate_trace_list(
- base_url: str = "http://43.106.118.91:8000",
- status: Optional[str] = None,
- mode: Optional[str] = None,
- limit: int = 20,
- ) -> Dict[str, Any]:
- params: Dict[str, Any] = {"limit": limit}
- if status:
- params["status"] = status
- if mode:
- params["mode"] = mode
- url = f"{base_url.rstrip('/')}/api/traces"
- response = httpx.get(url, params=params, timeout=10.0)
- response.raise_for_status()
- return response.json()
- def generate_goal_list(
- trace_id: str = "trace_001", base_url: str = "http://43.106.118.91:8000"
- ) -> Dict[str, Any]:
- url = f"{base_url.rstrip('/')}/api/traces/{trace_id}"
- response = httpx.get(url, timeout=10.0)
- response.raise_for_status()
- return response.json()
- def generate_subgoal_list(
- sub_trace_id: str, base_url: str = "http://43.106.118.91:8000"
- ) -> Dict[str, Any]:
- url = f"{base_url.rstrip('/')}/api/traces/{sub_trace_id}"
- response = httpx.get(url, timeout=10.0)
- response.raise_for_status()
- return response.json()
- def generate_messages_list(
- trace_id: str, goal_id: Optional[str] = None, base_url: str = "http://43.106.118.91:8000"
- ) -> Dict[str, Any]:
- url = f"{base_url.rstrip('/')}/api/traces/{trace_id}/messages"
- params = {}
- if goal_id:
- params["goal_id"] = goal_id
- response = httpx.get(url, params=params, timeout=10.0)
- response.raise_for_status()
- return response.json()
- def generate_mock_branch_detail(trace_id: str = "trace_001", branch_id: str = "branch_001") -> Dict[str, Any]:
- """生成分支详情的 Mock 数据"""
- return {
- "id": branch_id,
- "explore_start_id": "goal_003",
- "description": "JWT 认证方案",
- "status": "completed",
- "summary": "JWT 方案实现完成,性能测试通过",
- "goal_tree": {
- "mission": "实现 JWT 认证",
- "current_id": "branch_goal_003",
- "goals": [
- {
- "id": "branch_goal_001",
- "parent_id": None,
- "branch_id": branch_id,
- "type": "normal",
- "description": "研究 JWT 原理",
- "reason": "需要理解 JWT 的工作机制",
- "status": "completed",
- "summary": "已完成 JWT 原理学习",
- "self_stats": {
- "message_count": 2,
- "total_tokens": 400,
- "total_cost": 0.005,
- "preview": "research × 2"
- },
- "cumulative_stats": {
- "message_count": 5,
- "total_tokens": 1100,
- "total_cost": 0.015,
- "preview": "research × 2 → implement × 3"
- }
- },
- {
- "id": "branch_goal_002",
- "parent_id": "branch_goal_001",
- "branch_id": branch_id,
- "type": "normal",
- "description": "实现 JWT 生成和验证",
- "reason": "需要实现核心功能",
- "status": "completed",
- "summary": "已完成 JWT 的生成和验证逻辑",
- "self_stats": {
- "message_count": 2,
- "total_tokens": 500,
- "total_cost": 0.007,
- "preview": "implement × 2"
- },
- "cumulative_stats": {
- "message_count": 3,
- "total_tokens": 700,
- "total_cost": 0.01,
- "preview": "implement × 2 → test"
- }
- },
- {
- "id": "branch_goal_003",
- "parent_id": "branch_goal_002",
- "branch_id": branch_id,
- "type": "normal",
- "description": "测试 JWT 性能",
- "reason": "需要验证性能是否满足要求",
- "status": "completed",
- "summary": "性能测试通过,QPS 达到 5000+",
- "self_stats": {
- "message_count": 1,
- "total_tokens": 200,
- "total_cost": 0.003,
- "preview": "test"
- },
- "cumulative_stats": {
- "message_count": 1,
- "total_tokens": 200,
- "total_cost": 0.003,
- "preview": "test"
- }
- }
- ]
- },
- "cumulative_stats": {
- "message_count": 5,
- "total_tokens": 1100,
- "total_cost": 0.015,
- "preview": "research × 2 → implement × 2 → test"
- }
- }
- async def _fetch_ws_connected_event(trace_id: str, since_event_id: int = 0, ws_url: Optional[str] = None) -> Dict[str, Any]:
- url = ws_url or f"ws://43.106.118.91:8000/api/traces/{trace_id}/watch?since_event_id={since_event_id}"
- async with websockets.connect(url) as ws:
- while True:
- raw_message = await ws.recv()
- data = json.loads(raw_message)
- if data.get("event") == "connected":
- return data
- def _get_goals_container(trace_detail: Dict[str, Any]) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
- goal_tree = trace_detail.get("goal_tree")
- if isinstance(goal_tree, dict):
- goals = goal_tree.get("goals")
- if isinstance(goals, list):
- return goal_tree, goals
- goals = trace_detail.get("goals")
- if isinstance(goals, list):
- return trace_detail, goals
- trace_detail["goal_tree"] = {"goals": []}
- return trace_detail["goal_tree"], trace_detail["goal_tree"]["goals"]
- def _message_sort_key(message: Dict[str, Any]) -> int:
- message_id = message.get("message_id")
- if not isinstance(message_id, str):
- return 0
- if "-" not in message_id:
- return 0
- suffix = message_id.rsplit("-", 1)[-1]
- return int(suffix) if suffix.isdigit() else 0
- def _update_message_groups(message: Dict[str, Any]):
- group_key = message.get("goal_id") or "START"
- group_list = msgGroups.setdefault(group_key, [])
- group_list.append(message)
- group_list.sort(key=_message_sort_key)
- def _apply_event(data: Dict[str, Any]):
- event = data.get("event")
- if event == "connected":
- goal_tree = data.get("goal_tree") or (data.get("trace") or {}).get("goal_tree") or {}
- goals = goal_tree.get("goals") if isinstance(goal_tree, dict) else []
- if isinstance(goals, list):
- goalList.clear()
- goalList.extend(goals)
- if event == "goal_added":
- goal = data.get("goal")
- if isinstance(goal, dict):
- for idx, existing in enumerate(goalList):
- if existing.get("id") == goal.get("id"):
- goalList[idx] = {**existing, **goal}
- break
- else:
- goalList.append(goal)
- elif event == "goal_updated":
- goal_id = data.get("goal_id")
- updates = data.get("updates") or {}
- for g in goalList:
- if g.get("id") == goal_id:
- if "status" in updates:
- g["status"] = updates.get("status")
- if "summary" in updates:
- g["summary"] = updates.get("summary")
- break
- elif event == "message_added":
- message = data.get("message")
- if isinstance(message, dict):
- msgList.append(message)
- _update_message_groups(message)
- def _append_event_jsonl(event_data: Dict[str, Any], mock_dir: str):
- event_path = os.path.join(mock_dir, "event.jsonl")
- with open(event_path, "a", encoding="utf-8") as f:
- f.write(json.dumps(event_data, ensure_ascii=False) + "\n")
- async def _watch_ws_events(trace_id: str, since_event_id: int = 0, ws_url: Optional[str] = None):
- url = ws_url or f"ws://43.106.118.91:8000/api/traces/{trace_id}/watch?since_event_id={since_event_id}"
- mock_dir = os.path.join(os.path.dirname(__file__), "ws_data")
- os.makedirs(mock_dir, exist_ok=True)
- while True:
- try:
- print(f"开始监听 WebSocket: {url}")
- async with websockets.connect(url) as ws:
- async for raw_message in ws:
- data = json.loads(raw_message)
- _apply_event(data)
- _append_event_jsonl(data, mock_dir)
- generate_trace_visualization_html(goalList, msgGroups)
- event = data.get("event")
- if event:
- print(f"收到事件: {event}")
- except Exception as e:
- import traceback
- traceback.print_exc()
- print(f"WebSocket 连接断开: {e},1 秒后重连")
- await asyncio.sleep(1)
- def save_ws_data_to_file(trace_list_data: Dict[str, Any], goal_list: List[Dict[str, Any]]):
- mock_dir = os.path.join(os.path.dirname(__file__), "api_data")
- os.makedirs(mock_dir, exist_ok=True)
- with open(os.path.join(mock_dir, "trace_list.json"), "w", encoding="utf-8") as f:
- json.dump(trace_list_data, f, ensure_ascii=False, indent=2)
- with open(os.path.join(mock_dir, "goal_list.json"), "w", encoding="utf-8") as f:
- json.dump(goal_list, f, ensure_ascii=False, indent=2)
- print(f"Trace 数据已保存到: {mock_dir}")
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--trace-id", dest="trace_id")
- parser.add_argument("--since-event-id", dest="since_event_id", type=int, default=0)
- parser.add_argument("--ws-url", dest="ws_url")
- parser.add_argument("--watch", action="store_true")
- args = parser.parse_args()
- if args.trace_id:
- if args.watch:
- print(f"使用 trace_id 监听: {args.trace_id}")
- asyncio.run(_watch_ws_events(args.trace_id, args.since_event_id, args.ws_url))
- else:
- print(f"❌暂无 trace_id")
- # save_ws_data_to_file(args.trace_id, args.since_event_id, args.ws_url)
- else:
- trace_list_data = generate_trace_list()
- # print(f"🐒trace_list_data: {trace_list_data}")
- traces = trace_list_data.get("traces") or []
- # trace_id = traces[0].get("trace_id") if traces else None
- trace_id = "eb3aa9f6-37d4-4888-96ba-a9b9c5a4766b"
- if trace_id:
- if args.watch:
- print(f"✅使用 trace_id 监听: {trace_id}")
- asyncio.run(_watch_ws_events(trace_id, args.since_event_id, args.ws_url))
- else:
- goal_list = generate_goal_list(trace_id)
- print(f"✅使用 trace_id 生成 goal_list: {goal_list}")
- save_ws_data_to_file(trace_list_data, goal_list)
- # save_ws_data_to_file(trace_id, args.since_event_id, args.ws_url)
- else:
- raise Exception("trace_list.json 中没有 trace_id")
|