Browse Source

feat: 增强前端配置灵活性和WebSocket可靠性

- 在API客户端和WebSocket钩子中优先读取window.CONFIG配置,支持动态配置API基础URL
- 将ErrorFallback组件提取为独立模块,提高代码复用性
- 为WebSocket连接添加心跳机制和事件轮询,确保跨进程写入时的实时推送
- 在Trace存储中添加message_added事件的WebSocket广播支持
- 修复TopBar组件中表单API的类型定义,移除any类型
- 重构流程图数据钩子,使用统一的request函数替代直接fetch调用
max_liu 3 days ago
parent
commit
ab35e4eb0b

+ 16 - 1
agent/trace/store.py

@@ -22,6 +22,7 @@ Sub-Trace 是完全独立的 Trace,有自己的目录:
 
 import json
 import os
+import logging
 from pathlib import Path
 from typing import Dict, List, Optional, Any
 from datetime import datetime
@@ -29,6 +30,8 @@ from datetime import datetime
 from .models import Trace, Message
 from .goal_models import GoalTree, Goal, GoalStats
 
+logger = logging.getLogger(__name__)
+
 
 class FileSystemTraceStore:
     """文件系统 Trace 存储"""
@@ -370,10 +373,22 @@ class FileSystemTraceStore:
 
         # 4. 追加 message_added 事件
         affected_goals = await self._get_affected_goals(trace_id, message)
-        await self.append_event(trace_id, "message_added", {
+        event_id = await self.append_event(trace_id, "message_added", {
             "message": message.to_dict(),
             "affected_goals": affected_goals
         })
+        if event_id:
+            try:
+                from . import websocket as trace_ws
+
+                await trace_ws.broadcast_message_added(
+                    trace_id=trace_id,
+                    event_id=event_id,
+                    message_dict=message.to_dict(),
+                    affected_goals=affected_goals,
+                )
+            except Exception:
+                logger.exception("Failed to broadcast message_added (trace_id=%s, event_id=%s)", trace_id, event_id)
 
         return message.message_id
 

+ 49 - 3
agent/trace/websocket.py

@@ -6,6 +6,7 @@ Trace WebSocket 推送
 
 from typing import Dict, Set, Any
 from datetime import datetime
+import asyncio
 from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
 
 from .protocols import TraceStore
@@ -100,6 +101,7 @@ async def watch_trace(
         })
 
         # 补发历史事件(since_event_id=0 表示补发所有历史)
+        last_sent_event_id = since_event_id
         if since_event_id >= 0:
             missed_events = await store.get_events(trace_id, since_event_id)
             # 限制补发数量(最多 100 条)
@@ -111,16 +113,33 @@ async def watch_trace(
             else:
                 for evt in missed_events:
                     await websocket.send_json(evt)
+                    if isinstance(evt, dict) and isinstance(evt.get("event_id"), int):
+                        last_sent_event_id = max(last_sent_event_id, evt["event_id"])
 
-        # 保持连接(等待客户端断开或接收消息)
+        # 保持连接:同时支持心跳 + 轮询 events.jsonl(跨进程写入时也能实时推送
         while True:
             try:
-                # 接收客户端消息(心跳检测)
-                data = await websocket.receive_text()
+                # 允许在没有客户端消息时继续轮询事件流
+                data = await asyncio.wait_for(websocket.receive_text(), timeout=0.5)
                 if data == "ping":
                     await websocket.send_json({"event": "pong"})
             except WebSocketDisconnect:
                 break
+            except asyncio.TimeoutError:
+                pass
+
+            new_events = await store.get_events(trace_id, last_sent_event_id)
+            if len(new_events) > 100:
+                await websocket.send_json({
+                    "event": "error",
+                    "message": f"Too many missed events ({len(new_events)}), please reload via REST API"
+                })
+                continue
+
+            for evt in new_events:
+                await websocket.send_json(evt)
+                if isinstance(evt, dict) and isinstance(evt.get("event_id"), int):
+                    last_sent_event_id = max(last_sent_event_id, evt["event_id"])
 
     finally:
         # 清理连接
@@ -196,6 +215,33 @@ async def broadcast_goal_updated(
     await _broadcast_to_trace(trace_id, message)
 
 
+async def broadcast_message_added(
+    trace_id: str,
+    event_id: int,
+    message_dict: Dict[str, Any],
+    affected_goals: list[Dict[str, Any]] = None,
+):
+    """
+    广播 Message 添加事件(不在此处写入 events.jsonl)
+
+    说明:
+    - message_added 的 events.jsonl 写入由 TraceStore.append_event 负责
+    - 这里仅负责把“已经持久化”的事件推送给当前活跃连接
+    """
+    if trace_id not in _active_connections:
+        return
+
+    message = {
+        "event": "message_added",
+        "event_id": event_id,
+        "ts": datetime.now().isoformat(),
+        "message": message_dict,
+        "affected_goals": affected_goals or [],
+    }
+
+    await _broadcast_to_trace(trace_id, message)
+
+
 async def broadcast_sub_trace_started(
     trace_id: str,
     sub_trace_id: str,

+ 5 - 1
frontend/react-template/src/api/client.ts

@@ -4,8 +4,12 @@ import { Toast } from "@douyinfe/semi-ui";
 // Determine base URL from environment variables, or fallback to default
 const DEFAULT_BASE_URL = "http://localhost:8000";
 
-// Handle various environment variable formats (Vite uses import.meta.env.VITE_*)
 const getBaseUrl = () => {
+  const winConfig =
+    typeof window !== "undefined"
+      ? (window as unknown as { CONFIG?: { API_BASE_URL?: string } }).CONFIG?.API_BASE_URL
+      : undefined;
+  if (typeof winConfig === "string" && winConfig) return winConfig;
   if (typeof import.meta !== "undefined" && import.meta.env && import.meta.env.VITE_API_BASE_URL) {
     return import.meta.env.VITE_API_BASE_URL;
   }

+ 27 - 0
frontend/react-template/src/components/ErrorFallback/ErrorFallback.tsx

@@ -0,0 +1,27 @@
+import type { FallbackProps } from "react-error-boundary";
+
+export const ErrorFallback = ({ error, resetErrorBoundary }: FallbackProps) => {
+  return (
+    <div style={{ padding: "20px", textAlign: "center", marginTop: "50px" }}>
+      <h2>Something went wrong:</h2>
+      <pre style={{ color: "red", backgroundColor: "#fce4e4", padding: "10px", borderRadius: "4px" }}>
+        {error instanceof Error ? error.message : String(error)}
+      </pre>
+      <button
+        onClick={resetErrorBoundary}
+        style={{
+          marginTop: "10px",
+          padding: "8px 16px",
+          backgroundColor: "#3b82f6",
+          color: "white",
+          border: "none",
+          borderRadius: "4px",
+          cursor: "pointer",
+        }}
+      >
+        Try again
+      </button>
+    </div>
+  );
+};
+

+ 95 - 78
frontend/react-template/src/components/FlowChart/hooks/useFlowChartData.ts

@@ -1,5 +1,6 @@
 import { useCallback, useEffect, useMemo, useRef, useState } from "react";
 import { useWebSocket } from "../../../hooks/useWebSocket";
+import { request } from "../../../api/client";
 import type { Goal } from "../../../types/goal";
 import type { Message } from "../../../types/message";
 
@@ -108,90 +109,83 @@ export const useFlowChartData = (traceId: string | null, refreshTrigger?: number
     setReloading(true);
     let nextSinceEventId: number | null = null;
     try {
-      const [traceRes, messagesRes] = await Promise.all([
-        fetch(`http://localhost:8000/api/traces/${traceId}`),
-        fetch(`http://localhost:8000/api/traces/${traceId}/messages?mode=all`),
+      const [traceJson, messagesJson] = await Promise.all([
+        request<unknown>(`/api/traces/${traceId}`),
+        request<unknown>(`/api/traces/${traceId}/messages?mode=all`),
       ]);
 
-      if (traceRes.ok) {
-        const json = (await traceRes.json()) as unknown;
-        const root = isRecord(json) ? json : {};
-        const trace = isRecord(root.trace) ? root.trace : undefined;
-        const goalTree = isRecord(root.goal_tree) ? root.goal_tree : undefined;
-        const goalList = goalTree && Array.isArray(goalTree.goals) ? (goalTree.goals as Goal[]) : [];
-
-        const lastEventId = trace && typeof trace.last_event_id === "number" ? trace.last_event_id : undefined;
-        if (typeof lastEventId === "number") {
-          currentEventIdRef.current = Math.max(currentEventIdRef.current, lastEventId);
-          setSinceEventId(lastEventId);
-          nextSinceEventId = lastEventId;
-        }
+      const traceRoot = isRecord(traceJson) ? traceJson : {};
+      const trace = isRecord(traceRoot.trace) ? traceRoot.trace : undefined;
+      const goalTree = isRecord(traceRoot.goal_tree) ? traceRoot.goal_tree : undefined;
+      const goalList = goalTree && Array.isArray(goalTree.goals) ? (goalTree.goals as Goal[]) : [];
 
-        if (goalList.length > 0) {
-          setGoals((prev) => {
-            const mergedFlat = goalList.map((ng) => {
-              const existing = prev.find((p) => p.id === ng.id);
-              if (!existing) return ng;
-              const merged: Goal = { ...existing, ...ng };
-              if (existing.sub_trace_ids && !merged.sub_trace_ids) {
-                merged.sub_trace_ids = existing.sub_trace_ids;
-              }
-              if (existing.agent_call_mode && !merged.agent_call_mode) {
-                merged.agent_call_mode = existing.agent_call_mode;
-              }
-              if (existing.knowledge && !merged.knowledge) {
-                merged.knowledge = existing.knowledge;
-              }
-              return merged;
-            });
-            return buildSubGoals(mergedFlat);
-          });
-        }
+      const lastEventId = trace && typeof trace.last_event_id === "number" ? trace.last_event_id : undefined;
+      if (typeof lastEventId === "number") {
+        currentEventIdRef.current = Math.max(currentEventIdRef.current, lastEventId);
+        setSinceEventId(lastEventId);
+        nextSinceEventId = lastEventId;
       }
 
-      if (messagesRes.ok) {
-        const json = (await messagesRes.json()) as unknown;
-        const root = isRecord(json) ? json : {};
-        const list = Array.isArray(root.messages) ? (root.messages as Message[]) : [];
-        console.log("%c [ list ]-149", "font-size:13px; background:pink; color:#bf2c9f;", list);
-
-        const filtered = list.filter((message) => (message as { status?: string }).status !== "abandoned");
-        const nextMessages = [...filtered].sort(messageComparator);
-
-        const { availableData: finalMessages, invalidBranches: invalidBranchesTemp } = processRetryLogic(nextMessages);
-
-        // Update max sequence
-        const maxSeq = finalMessages.reduce((max, msg) => {
-          const seq = typeof msg.sequence === "number" ? msg.sequence : -1;
-          return Math.max(max, seq);
-        }, 0);
-        maxSequenceRef.current = maxSeq;
-
-        setMessages(finalMessages);
-        setInvalidBranches(invalidBranchesTemp);
-        const grouped: Record<string, Message[]> = {};
-        finalMessages.forEach((message) => {
-          const groupKey = typeof message.goal_id === "string" && message.goal_id ? message.goal_id : "START";
-          if (!grouped[groupKey]) grouped[groupKey] = [];
-          grouped[groupKey].push(message);
-        });
-        Object.keys(grouped).forEach((key) => {
-          grouped[key].sort(messageComparator);
+      if (goalList.length > 0) {
+        setGoals((prev) => {
+          const mergedFlat = goalList.map((ng) => {
+            const existing = prev.find((p) => p.id === ng.id);
+            if (!existing) return ng;
+            const merged: Goal = { ...existing, ...ng };
+            if (existing.sub_trace_ids && !merged.sub_trace_ids) {
+              merged.sub_trace_ids = existing.sub_trace_ids;
+            }
+            if (existing.agent_call_mode && !merged.agent_call_mode) {
+              merged.agent_call_mode = existing.agent_call_mode;
+            }
+            if (existing.knowledge && !merged.knowledge) {
+              merged.knowledge = existing.knowledge;
+            }
+            return merged;
+          });
+          return buildSubGoals(mergedFlat);
         });
-        setMsgGroups(grouped);
+      }
 
-        if (grouped.START && grouped.START.length > 0) {
-          setGoals((prev) => {
-            if (prev.some((g) => g.id === "START")) return prev;
-            const startGoal: Goal = {
-              id: "START",
-              description: "START",
-              status: "completed",
-              created_at: "",
-            };
-            return [startGoal, ...prev];
-          });
-        }
+      const messagesRoot = isRecord(messagesJson) ? messagesJson : {};
+      const list = Array.isArray(messagesRoot.messages) ? (messagesRoot.messages as Message[]) : [];
+      console.log("%c [ list ]-149", "font-size:13px; background:pink; color:#bf2c9f;", list);
+
+      const filtered = list.filter((message) => (message as { status?: string }).status !== "abandoned");
+      const nextMessages = [...filtered].sort(messageComparator);
+
+      const { availableData: finalMessages, invalidBranches: invalidBranchesTemp } = processRetryLogic(nextMessages);
+
+      const maxSeq = finalMessages.reduce((max, msg) => {
+        const seq = typeof msg.sequence === "number" ? msg.sequence : -1;
+        return Math.max(max, seq);
+      }, 0);
+      maxSequenceRef.current = maxSeq;
+
+      setMessages(finalMessages);
+      setInvalidBranches(invalidBranchesTemp);
+      const grouped: Record<string, Message[]> = {};
+      finalMessages.forEach((message) => {
+        const groupKey = typeof message.goal_id === "string" && message.goal_id ? message.goal_id : "START";
+        if (!grouped[groupKey]) grouped[groupKey] = [];
+        grouped[groupKey].push(message);
+      });
+      Object.keys(grouped).forEach((key) => {
+        grouped[key].sort(messageComparator);
+      });
+      setMsgGroups(grouped);
+
+      if (grouped.START && grouped.START.length > 0) {
+        setGoals((prev) => {
+          if (prev.some((g) => g.id === "START")) return prev;
+          const startGoal: Goal = {
+            id: "START",
+            description: "START",
+            status: "completed",
+            created_at: "",
+          };
+          return [startGoal, ...prev];
+        });
       }
 
       // REST 请求完成后,允许建立 WebSocket 连接
@@ -199,6 +193,7 @@ export const useFlowChartData = (traceId: string | null, refreshTrigger?: number
     } finally {
       restReloadingRef.current = false;
       setReloading(false);
+      setReadyToConnect(true);
     }
     return nextSinceEventId;
   }, [messageComparator, traceId]);
@@ -263,7 +258,24 @@ export const useFlowChartData = (traceId: string | null, refreshTrigger?: number
           (typeof raw.current_event_id === "number" ? raw.current_event_id : undefined);
         if (typeof currentEventId === "number") {
           currentEventIdRef.current = Math.max(currentEventIdRef.current, currentEventId);
+          setSinceEventId(currentEventId);
         }
+
+        const goalTree = isRecord(data.goal_tree)
+          ? data.goal_tree
+          : isRecord(raw.goal_tree)
+            ? raw.goal_tree
+            : undefined;
+        if (goalTree && Array.isArray(goalTree.goals)) {
+          setGoals((prev) => {
+            if (prev.length > 0) return prev;
+            return buildSubGoals(goalTree.goals as Goal[]);
+          });
+        }
+        return;
+      }
+
+      if (event === "pong") {
         return;
       }
 
@@ -332,7 +344,12 @@ export const useFlowChartData = (traceId: string | null, refreshTrigger?: number
           (typeof data.goal_id === "string" ? data.goal_id : undefined) ||
           (isRecord(data.goal) && typeof data.goal.id === "string" ? data.goal.id : undefined) ||
           (typeof raw.goal_id === "string" ? raw.goal_id : undefined);
-        const updates = isRecord(data.updates) ? data.updates : isRecord(raw.updates) ? raw.updates : {};
+        const updates =
+          (isRecord(data.updates) ? data.updates : undefined) ||
+          (isRecord(raw.updates) ? raw.updates : undefined) ||
+          (isRecord(data.patch) ? data.patch : undefined) ||
+          (isRecord(raw.patch) ? raw.patch : undefined) ||
+          {};
         if (!goalId) return;
         setGoals((prev: Goal[]) =>
           prev.map((g: Goal) => {

+ 30 - 10
frontend/react-template/src/components/TopBar/TopBar.tsx

@@ -29,8 +29,13 @@ export const TopBar: FC<TopBarProps> = ({
   const [isReflectModalVisible, setIsReflectModalVisible] = useState(false);
   const [isExperienceModalVisible, setIsExperienceModalVisible] = useState(false);
   const [experienceContent, setExperienceContent] = useState("");
-  const [exampleProjects, setExampleProjects] = useState<Array<{ name: string; path: string; has_prompt: boolean }>>([]);
-  const formApiRef = useRef<any>(null);
+  const [exampleProjects, setExampleProjects] = useState<Array<{ name: string; path: string; has_prompt: boolean }>>(
+    [],
+  );
+  const formApiRef = useRef<{
+    getValues: () => { system_prompt?: string; user_prompt?: string };
+    setValue: (field: "system_prompt" | "user_prompt", value: string) => void;
+  } | null>(null);
   const insertFormApiRef = useRef<{ getValues: () => { insert_prompt: string } } | null>(null);
   const reflectFormApiRef = useRef<{ getValues: () => { reflect_focus: string } } | null>(null);
 
@@ -74,7 +79,7 @@ export const TopBar: FC<TopBarProps> = ({
     // 加载 example 项目列表
     try {
       const data = await traceApi.fetchExamples();
-      setExampleProjects(data.projects.filter(p => p.has_prompt));
+      setExampleProjects(data.projects.filter((p) => p.has_prompt));
     } catch (error) {
       console.error("Failed to load examples:", error);
     }
@@ -109,7 +114,13 @@ export const TopBar: FC<TopBarProps> = ({
         messages.push({ role: "user", content: values.user_prompt });
       }
 
-      await traceApi.createTrace({ messages });
+      const created = await traceApi.createTrace({ messages });
+      const nextTitle =
+        (typeof values.user_prompt === "string" && values.user_prompt.trim()
+          ? values.user_prompt.trim().split("\n")[0]
+          : "新任务") || "新任务";
+
+      onTraceSelect(created.trace_id, nextTitle);
       await loadTraces();
       onTraceCreated?.();
       setIsModalVisible(false);
@@ -292,8 +303,11 @@ export const TopBar: FC<TopBarProps> = ({
         centered
         style={{ width: 600 }}
       >
-        {/* eslint-disable-next-line @typescript-eslint/no-explicit-any */}
-        <Form getFormApi={(api: any) => (formApiRef.current = api)}>
+        <Form
+          getFormApi={(api: unknown) => {
+            formApiRef.current = api as unknown as NonNullable<typeof formApiRef.current>;
+          }}
+        >
           <Form.Select
             field="example_project"
             label="选择示例项目(可选)"
@@ -333,8 +347,11 @@ export const TopBar: FC<TopBarProps> = ({
         centered
         style={{ width: 600 }}
       >
-        {/* eslint-disable-next-line @typescript-eslint/no-explicit-any */}
-        <Form getFormApi={(api: any) => (insertFormApiRef.current = api)}>
+        <Form
+          getFormApi={(api: unknown) => {
+            insertFormApiRef.current = api as unknown as NonNullable<typeof insertFormApiRef.current>;
+          }}
+        >
           <Form.TextArea
             field="insert_prompt"
             label=" "
@@ -351,8 +368,11 @@ export const TopBar: FC<TopBarProps> = ({
         centered
         style={{ width: 600 }}
       >
-        {/* eslint-disable-next-line @typescript-eslint/no-explicit-any */}
-        <Form getFormApi={(api: any) => (reflectFormApiRef.current = api)}>
+        <Form
+          getFormApi={(api: unknown) => {
+            reflectFormApiRef.current = api as unknown as NonNullable<typeof reflectFormApiRef.current>;
+          }}
+        >
           <Form.TextArea
             field="reflect_focus"
             label=" "

+ 24 - 1
frontend/react-template/src/hooks/useWebSocket.ts

@@ -15,11 +15,26 @@ export const useWebSocket = (traceId: string | null, options: UseWebSocketOption
   useEffect(() => {
     if (!traceId) return;
 
-    const url = `ws://localhost:8000/api/traces/${traceId}/watch?since_event_id=${sinceEventId}`;
+    const httpBase =
+      (typeof window !== "undefined"
+        ? (window as unknown as { CONFIG?: { API_BASE_URL?: string } }).CONFIG?.API_BASE_URL
+        : undefined) ||
+      (typeof import.meta !== "undefined" && import.meta.env && import.meta.env.VITE_API_BASE_URL
+        ? import.meta.env.VITE_API_BASE_URL
+        : "http://localhost:8000");
+
+    const wsBase = httpBase.replace(/^http(s?):\/\//, "ws$1://").replace(/\/+$/, "");
+    const url = `${wsBase}/api/traces/${traceId}/watch?since_event_id=${sinceEventId}`;
     const ws = new WebSocket(url);
+    let pingTimer: number | null = null;
 
     ws.onopen = () => {
       setConnected(true);
+      pingTimer = window.setInterval(() => {
+        if (ws.readyState === WebSocket.OPEN) {
+          ws.send("ping");
+        }
+      }, 15000);
     };
 
     ws.onmessage = (event) => {
@@ -37,12 +52,20 @@ export const useWebSocket = (traceId: string | null, options: UseWebSocketOption
 
     ws.onclose = () => {
       setConnected(false);
+      if (pingTimer) {
+        window.clearInterval(pingTimer);
+        pingTimer = null;
+      }
       onClose?.();
     };
 
     wsRef.current = ws;
 
     return () => {
+      if (pingTimer) {
+        window.clearInterval(pingTimer);
+        pingTimer = null;
+      }
       ws.close();
     };
   }, [traceId, onMessage, onError, onClose, sinceEventId]);

+ 1 - 26
frontend/react-template/src/main.tsx

@@ -1,37 +1,12 @@
 import { createRoot } from "react-dom/client";
 import { ErrorBoundary } from "react-error-boundary";
-import type { FallbackProps } from "react-error-boundary";
 import App from "./App";
+import { ErrorFallback } from "./components/ErrorFallback/ErrorFallback";
 import "./styles/global.css";
 import "./styles/variables.css";
 
 const container = document.getElementById("root");
 
-const ErrorFallback = ({ error, resetErrorBoundary }: FallbackProps) => {
-  return (
-    <div style={{ padding: "20px", textAlign: "center", marginTop: "50px" }}>
-      <h2>Something went wrong:</h2>
-      <pre style={{ color: "red", backgroundColor: "#fce4e4", padding: "10px", borderRadius: "4px" }}>
-        {error instanceof Error ? error.message : String(error)}
-      </pre>
-      <button
-        onClick={resetErrorBoundary}
-        style={{
-          marginTop: "10px",
-          padding: "8px 16px",
-          backgroundColor: "#3b82f6",
-          color: "white",
-          border: "none",
-          borderRadius: "4px",
-          cursor: "pointer",
-        }}
-      >
-        Try again
-      </button>
-    </div>
-  );
-};
-
 if (container) {
   createRoot(container).render(
     <ErrorBoundary FallbackComponent={ErrorFallback}>