templateData.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. templateData.py - 生成 Trace 可视化的 Mock 数据
  5. """
  6. import os
  7. import asyncio
  8. import json
  9. from datetime import datetime
  10. from typing import Dict, List, Any, Optional, Tuple
  11. import httpx
  12. import websockets
  13. from templateHtml import generate_trace_visualization_html
  14. goalList: List[Dict[str, Any]] = []
  15. msgList: List[Dict[str, Any]] = []
  16. msgGroups: Dict[str, List[Dict[str, Any]]] = {}
  17. def generate_trace_list(
  18. base_url: str = "http://43.106.118.91:8000",
  19. status: Optional[str] = None,
  20. mode: Optional[str] = None,
  21. limit: int = 20,
  22. ) -> Dict[str, Any]:
  23. params: Dict[str, Any] = {"limit": limit}
  24. if status:
  25. params["status"] = status
  26. if mode:
  27. params["mode"] = mode
  28. url = f"{base_url.rstrip('/')}/api/traces"
  29. response = httpx.get(url, params=params, timeout=10.0)
  30. response.raise_for_status()
  31. return response.json()
  32. def generate_goal_list(
  33. trace_id: str = "trace_001", base_url: str = "http://43.106.118.91:8000"
  34. ) -> Dict[str, Any]:
  35. url = f"{base_url.rstrip('/')}/api/traces/{trace_id}"
  36. response = httpx.get(url, timeout=10.0)
  37. response.raise_for_status()
  38. return response.json()
  39. def generate_subgoal_list(
  40. sub_trace_id: str, base_url: str = "http://43.106.118.91:8000"
  41. ) -> Dict[str, Any]:
  42. url = f"{base_url.rstrip('/')}/api/traces/{sub_trace_id}"
  43. response = httpx.get(url, timeout=10.0)
  44. response.raise_for_status()
  45. return response.json()
  46. def generate_messages_list(
  47. trace_id: str, goal_id: Optional[str] = None, base_url: str = "http://43.106.118.91:8000"
  48. ) -> Dict[str, Any]:
  49. url = f"{base_url.rstrip('/')}/api/traces/{trace_id}/messages"
  50. params = {}
  51. if goal_id:
  52. params["goal_id"] = goal_id
  53. response = httpx.get(url, params=params, timeout=10.0)
  54. response.raise_for_status()
  55. return response.json()
  56. def generate_mock_branch_detail(trace_id: str = "trace_001", branch_id: str = "branch_001") -> Dict[str, Any]:
  57. """生成分支详情的 Mock 数据"""
  58. return {
  59. "id": branch_id,
  60. "explore_start_id": "goal_003",
  61. "description": "JWT 认证方案",
  62. "status": "completed",
  63. "summary": "JWT 方案实现完成,性能测试通过",
  64. "goal_tree": {
  65. "mission": "实现 JWT 认证",
  66. "current_id": "branch_goal_003",
  67. "goals": [
  68. {
  69. "id": "branch_goal_001",
  70. "parent_id": None,
  71. "branch_id": branch_id,
  72. "type": "normal",
  73. "description": "研究 JWT 原理",
  74. "reason": "需要理解 JWT 的工作机制",
  75. "status": "completed",
  76. "summary": "已完成 JWT 原理学习",
  77. "self_stats": {
  78. "message_count": 2,
  79. "total_tokens": 400,
  80. "total_cost": 0.005,
  81. "preview": "research × 2"
  82. },
  83. "cumulative_stats": {
  84. "message_count": 5,
  85. "total_tokens": 1100,
  86. "total_cost": 0.015,
  87. "preview": "research × 2 → implement × 3"
  88. }
  89. },
  90. {
  91. "id": "branch_goal_002",
  92. "parent_id": "branch_goal_001",
  93. "branch_id": branch_id,
  94. "type": "normal",
  95. "description": "实现 JWT 生成和验证",
  96. "reason": "需要实现核心功能",
  97. "status": "completed",
  98. "summary": "已完成 JWT 的生成和验证逻辑",
  99. "self_stats": {
  100. "message_count": 2,
  101. "total_tokens": 500,
  102. "total_cost": 0.007,
  103. "preview": "implement × 2"
  104. },
  105. "cumulative_stats": {
  106. "message_count": 3,
  107. "total_tokens": 700,
  108. "total_cost": 0.01,
  109. "preview": "implement × 2 → test"
  110. }
  111. },
  112. {
  113. "id": "branch_goal_003",
  114. "parent_id": "branch_goal_002",
  115. "branch_id": branch_id,
  116. "type": "normal",
  117. "description": "测试 JWT 性能",
  118. "reason": "需要验证性能是否满足要求",
  119. "status": "completed",
  120. "summary": "性能测试通过,QPS 达到 5000+",
  121. "self_stats": {
  122. "message_count": 1,
  123. "total_tokens": 200,
  124. "total_cost": 0.003,
  125. "preview": "test"
  126. },
  127. "cumulative_stats": {
  128. "message_count": 1,
  129. "total_tokens": 200,
  130. "total_cost": 0.003,
  131. "preview": "test"
  132. }
  133. }
  134. ]
  135. },
  136. "cumulative_stats": {
  137. "message_count": 5,
  138. "total_tokens": 1100,
  139. "total_cost": 0.015,
  140. "preview": "research × 2 → implement × 2 → test"
  141. }
  142. }
  143. async def _fetch_ws_connected_event(trace_id: str, since_event_id: int = 0, ws_url: Optional[str] = None) -> Dict[str, Any]:
  144. url = ws_url or f"ws://43.106.118.91:8000/api/traces/{trace_id}/watch?since_event_id={since_event_id}"
  145. async with websockets.connect(url) as ws:
  146. while True:
  147. raw_message = await ws.recv()
  148. data = json.loads(raw_message)
  149. if data.get("event") == "connected":
  150. return data
  151. def _get_goals_container(trace_detail: Dict[str, Any]) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  152. goal_tree = trace_detail.get("goal_tree")
  153. if isinstance(goal_tree, dict):
  154. goals = goal_tree.get("goals")
  155. if isinstance(goals, list):
  156. return goal_tree, goals
  157. goals = trace_detail.get("goals")
  158. if isinstance(goals, list):
  159. return trace_detail, goals
  160. trace_detail["goal_tree"] = {"goals": []}
  161. return trace_detail["goal_tree"], trace_detail["goal_tree"]["goals"]
  162. def _message_sort_key(message: Dict[str, Any]) -> int:
  163. message_id = message.get("message_id")
  164. if not isinstance(message_id, str):
  165. return 0
  166. if "-" not in message_id:
  167. return 0
  168. suffix = message_id.rsplit("-", 1)[-1]
  169. return int(suffix) if suffix.isdigit() else 0
  170. def _update_message_groups(message: Dict[str, Any]):
  171. group_key = message.get("goal_id") or "START"
  172. group_list = msgGroups.setdefault(group_key, [])
  173. group_list.append(message)
  174. group_list.sort(key=_message_sort_key)
  175. def _apply_event(data: Dict[str, Any]):
  176. event = data.get("event")
  177. if event == "connected":
  178. goal_tree = data.get("goal_tree") or (data.get("trace") or {}).get("goal_tree") or {}
  179. goals = goal_tree.get("goals") if isinstance(goal_tree, dict) else []
  180. if isinstance(goals, list):
  181. goalList.clear()
  182. goalList.extend(goals)
  183. if event == "goal_added":
  184. goal = data.get("goal")
  185. if isinstance(goal, dict):
  186. for idx, existing in enumerate(goalList):
  187. if existing.get("id") == goal.get("id"):
  188. goalList[idx] = {**existing, **goal}
  189. break
  190. else:
  191. goalList.append(goal)
  192. elif event == "goal_updated":
  193. goal_id = data.get("goal_id")
  194. updates = data.get("updates") or {}
  195. for g in goalList:
  196. if g.get("id") == goal_id:
  197. if "status" in updates:
  198. g["status"] = updates.get("status")
  199. if "summary" in updates:
  200. g["summary"] = updates.get("summary")
  201. break
  202. elif event == "message_added":
  203. message = data.get("message")
  204. if isinstance(message, dict):
  205. msgList.append(message)
  206. _update_message_groups(message)
  207. def _append_event_jsonl(event_data: Dict[str, Any], mock_dir: str):
  208. event_path = os.path.join(mock_dir, "event.jsonl")
  209. with open(event_path, "a", encoding="utf-8") as f:
  210. f.write(json.dumps(event_data, ensure_ascii=False) + "\n")
  211. async def _watch_ws_events(trace_id: str, since_event_id: int = 0, ws_url: Optional[str] = None):
  212. url = ws_url or f"ws://43.106.118.91:8000/api/traces/{trace_id}/watch?since_event_id={since_event_id}"
  213. mock_dir = os.path.join(os.path.dirname(__file__), "ws_data")
  214. os.makedirs(mock_dir, exist_ok=True)
  215. while True:
  216. try:
  217. print(f"开始监听 WebSocket: {url}")
  218. async with websockets.connect(url) as ws:
  219. async for raw_message in ws:
  220. data = json.loads(raw_message)
  221. _apply_event(data)
  222. _append_event_jsonl(data, mock_dir)
  223. generate_trace_visualization_html(goalList, msgGroups)
  224. event = data.get("event")
  225. if event:
  226. print(f"收到事件: {event}")
  227. except Exception as e:
  228. import traceback
  229. traceback.print_exc()
  230. print(f"WebSocket 连接断开: {e},1 秒后重连")
  231. await asyncio.sleep(1)
  232. def save_ws_data_to_file(trace_list_data: Dict[str, Any], goal_list: List[Dict[str, Any]]):
  233. mock_dir = os.path.join(os.path.dirname(__file__), "api_data")
  234. os.makedirs(mock_dir, exist_ok=True)
  235. with open(os.path.join(mock_dir, "trace_list.json"), "w", encoding="utf-8") as f:
  236. json.dump(trace_list_data, f, ensure_ascii=False, indent=2)
  237. with open(os.path.join(mock_dir, "goal_list.json"), "w", encoding="utf-8") as f:
  238. json.dump(goal_list, f, ensure_ascii=False, indent=2)
  239. print(f"Trace 数据已保存到: {mock_dir}")
  240. if __name__ == "__main__":
  241. import argparse
  242. parser = argparse.ArgumentParser()
  243. parser.add_argument("--trace-id", dest="trace_id")
  244. parser.add_argument("--since-event-id", dest="since_event_id", type=int, default=0)
  245. parser.add_argument("--ws-url", dest="ws_url")
  246. parser.add_argument("--watch", action="store_true")
  247. args = parser.parse_args()
  248. if args.trace_id:
  249. if args.watch:
  250. print(f"使用 trace_id 监听: {args.trace_id}")
  251. asyncio.run(_watch_ws_events(args.trace_id, args.since_event_id, args.ws_url))
  252. else:
  253. print(f"❌暂无 trace_id")
  254. # save_ws_data_to_file(args.trace_id, args.since_event_id, args.ws_url)
  255. else:
  256. trace_list_data = generate_trace_list()
  257. # print(f"🐒trace_list_data: {trace_list_data}")
  258. traces = trace_list_data.get("traces") or []
  259. # trace_id = traces[0].get("trace_id") if traces else None
  260. trace_id = "eb3aa9f6-37d4-4888-96ba-a9b9c5a4766b"
  261. if trace_id:
  262. if args.watch:
  263. print(f"✅使用 trace_id 监听: {trace_id}")
  264. asyncio.run(_watch_ws_events(trace_id, args.since_event_id, args.ws_url))
  265. else:
  266. goal_list = generate_goal_list(trace_id)
  267. print(f"✅使用 trace_id 生成 goal_list: {goal_list}")
  268. save_ws_data_to_file(trace_list_data, goal_list)
  269. # save_ws_data_to_file(trace_id, args.since_event_id, args.ws_url)
  270. else:
  271. raise Exception("trace_list.json 中没有 trace_id")