Просмотр исходного кода

feat: 优化流程图数据加载并支持带焦点的反思

- 在 useFlowChartData 中增加 reloading 状态,过滤掉已放弃的消息,并支持传入 mode=all 参数获取所有消息
- 修改 traceApi 接口参数以支持 traceId 和可选的 focus 字段
- 为 useTrace 钩子添加 reload 方法,允许手动刷新 trace 数据
- 在 MainContent 中缓存 goals 和 msgGroups,避免切换 trace 时内容闪烁
- 在 TopBar 中为反思功能添加模态框,支持输入可选的重点内容
max_liu 1 неделя назад
Родитель
Сommit
064e6a41e2

+ 5 - 4
frontend/react-template/src/api/traceApi.ts

@@ -28,13 +28,13 @@ export const traceApi = {
     });
   },
   runTrace(
-    messageId: string,
+    traceId: string,
     data?: {
       messages?: Array<{ role: "system" | "user" | "assistant" | "tool"; content: unknown }>;
-      after_sequence?: number;
+      after_message_id?: string | null;
     },
   ) {
-    return request<void>(`/api/traces/${messageId}/run`, {
+    return request<void>(`/api/traces/${traceId}/run`, {
       method: "POST",
       data,
     });
@@ -44,9 +44,10 @@ export const traceApi = {
       method: "POST",
     });
   },
-  reflectTrace(traceId: string) {
+  reflectTrace(traceId: string, data?: { focus?: string }) {
     return request<void>(`/api/traces/${traceId}/reflect`, {
       method: "POST",
+      data,
     });
   },
   getExperiences() {

+ 7 - 3
frontend/react-template/src/components/FlowChart/hooks/useFlowChartData.ts

@@ -48,6 +48,7 @@ export const useFlowChartData = (traceId: string | null, initialGoals: Goal[], r
   const [sinceEventId, setSinceEventId] = useState(0);
   const currentEventIdRef = useRef(0);
   const restReloadingRef = useRef(false);
+  const [reloading, setReloading] = useState(false);
 
   const messageSortKey = useCallback((message: Message): number => {
     const mid =
@@ -104,11 +105,12 @@ export const useFlowChartData = (traceId: string | null, initialGoals: Goal[], r
     if (!traceId) return;
     if (restReloadingRef.current) return;
     restReloadingRef.current = true;
+    setReloading(true);
     let nextSinceEventId: number | null = null;
     try {
       const [traceRes, messagesRes] = await Promise.all([
         fetch(`http://localhost:8000/api/traces/${traceId}`),
-        fetch(`http://localhost:8000/api/traces/${traceId}/messages`),
+        fetch(`http://localhost:8000/api/traces/${traceId}/messages?mode=all`),
       ]);
 
       if (traceRes.ok) {
@@ -148,7 +150,8 @@ export const useFlowChartData = (traceId: string | null, initialGoals: Goal[], r
         const json = (await messagesRes.json()) as unknown;
         const root = isRecord(json) ? json : {};
         const list = Array.isArray(root.messages) ? (root.messages as Message[]) : [];
-        const nextMessages = [...list].sort((a, b) => messageSortKey(a) - messageSortKey(b));
+        const filtered = list.filter((message) => (message as { status?: string }).status !== "abandoned");
+        const nextMessages = [...filtered].sort((a, b) => messageSortKey(a) - messageSortKey(b));
         setMessages(nextMessages);
         const grouped: Record<string, Message[]> = {};
         nextMessages.forEach((message) => {
@@ -176,6 +179,7 @@ export const useFlowChartData = (traceId: string | null, initialGoals: Goal[], r
       }
     } finally {
       restReloadingRef.current = false;
+      setReloading(false);
     }
     return nextSinceEventId;
   }, [messageSortKey, traceId]);
@@ -325,5 +329,5 @@ export const useFlowChartData = (traceId: string | null, initialGoals: Goal[], r
   );
   const { connected } = useWebSocket(traceId, wsOptions);
 
-  return { goals, messages, msgGroups, connected };
+  return { goals, messages, msgGroups, connected, reloading };
 };

+ 32 - 6
frontend/react-template/src/components/MainContent/MainContent.tsx

@@ -29,9 +29,13 @@ export const MainContent: FC<MainContentProps> = ({
   const flowChartRef = useRef<FlowChartRef>(null);
   const [isAllExpanded, setIsAllExpanded] = useState(true);
   const [traceList, setTraceList] = useState<TraceListItem[]>([]);
-  const { trace, loading } = useTrace(traceId);
+  const [cachedGoals, setCachedGoals] = useState<Goal[]>([]);
+  const [cachedMsgGroups, setCachedMsgGroups] = useState<Record<string, Message[]>>({});
+  const { trace, loading, reload } = useTrace(traceId);
   const initialGoals = useMemo(() => trace?.goal_tree?.goals ?? [], [trace]);
-  const { goals, connected, msgGroups } = useFlowChartData(traceId, initialGoals, messageRefreshTrigger);
+  const { goals, connected, msgGroups, reloading } = useFlowChartData(traceId, initialGoals, messageRefreshTrigger);
+  const displayGoals = goals.length > 0 ? goals : cachedGoals;
+  const displayMsgGroups = Object.keys(msgGroups).length > 0 ? msgGroups : cachedMsgGroups;
 
   useEffect(() => {
     const fetchTraces = async () => {
@@ -45,6 +49,28 @@ export const MainContent: FC<MainContentProps> = ({
     fetchTraces();
   }, [refreshTrigger]);
 
+  useEffect(() => {
+    if (!messageRefreshTrigger) return;
+    void reload();
+  }, [messageRefreshTrigger, reload]);
+
+  useEffect(() => {
+    if (goals.length > 0) {
+      setCachedGoals(goals);
+    }
+  }, [goals]);
+
+  useEffect(() => {
+    if (Object.keys(msgGroups).length > 0) {
+      setCachedMsgGroups(msgGroups);
+    }
+  }, [msgGroups]);
+
+  useEffect(() => {
+    setCachedGoals([]);
+    setCachedMsgGroups({});
+  }, [traceId]);
+
   if (!traceId && !loading) {
     return (
       <div className={styles.main}>
@@ -123,15 +149,15 @@ export const MainContent: FC<MainContentProps> = ({
         </div>
       </div>
       <div className={styles.content}>
-        {loading ? (
+        {loading || reloading ? (
           <div className={styles.empty}>加载中...</div>
-        ) : goals.length === 0 ? (
+        ) : displayGoals.length === 0 ? (
           <div className={styles.empty}>暂无节点</div>
         ) : (
           <FlowChart
             ref={flowChartRef}
-            goals={goals}
-            msgGroups={msgGroups}
+            goals={displayGoals}
+            msgGroups={displayMsgGroups}
             onNodeClick={onNodeClick}
           />
         )}

+ 38 - 4
frontend/react-template/src/components/TopBar/TopBar.tsx

@@ -24,10 +24,12 @@ export const TopBar: FC<TopBarProps> = ({
   const [title, setTitle] = useState("流程图可视化系统");
   const [isModalVisible, setIsModalVisible] = useState(false);
   const [isInsertModalVisible, setIsInsertModalVisible] = useState(false);
+  const [isReflectModalVisible, setIsReflectModalVisible] = useState(false);
   const [isExperienceModalVisible, setIsExperienceModalVisible] = useState(false);
   const [experienceContent, setExperienceContent] = useState("");
   const formApiRef = useRef<{ getValues: () => { system_prompt: string; user_prompt: string } } | null>(null);
   const insertFormApiRef = useRef<{ getValues: () => { insert_prompt: string } } | null>(null);
+  const reflectFormApiRef = useRef<{ getValues: () => { reflect_focus: string } } | null>(null);
 
   const isMessageNode = (node: Goal | Message): node is Message =>
     "message_id" in node || "role" in node || "content" in node || "goal_id" in node || "tokens" in node;
@@ -112,6 +114,10 @@ export const TopBar: FC<TopBarProps> = ({
       Toast.warning("请选择插入节点");
       return;
     }
+    if (!selectedTraceId) {
+      Toast.warning("请先选择一个 Trace");
+      return;
+    }
 
     if (!isMessageNode(node)) {
       Toast.warning("插入位置错误");
@@ -132,12 +138,11 @@ export const TopBar: FC<TopBarProps> = ({
     }
 
     try {
-      const sequence = (node as Message & { sequence?: number }).sequence;
       const payload = {
         messages: [{ role: "user" as const, content: insertPrompt }],
-        after_sequence: typeof sequence === "number" ? sequence : undefined,
+        after_message_id: messageId,
       };
-      await traceApi.runTrace(messageId, payload);
+      await traceApi.runTrace(selectedTraceId, payload);
       Toast.success("插入成功");
       setIsInsertModalVisible(false);
       onMessageInserted?.();
@@ -172,9 +177,20 @@ export const TopBar: FC<TopBarProps> = ({
       Toast.warning("请先选择一个 Trace");
       return;
     }
+    setIsReflectModalVisible(true);
+  };
+
+  const handleReflectConfirm = async () => {
+    if (!selectedTraceId) {
+      Toast.warning("请先选择一个 Trace");
+      return;
+    }
+    const values = reflectFormApiRef.current?.getValues();
+    const focus = values?.reflect_focus?.trim();
     try {
-      await traceApi.reflectTrace(selectedTraceId);
+      await traceApi.reflectTrace(selectedTraceId, focus ? { focus } : undefined);
       Toast.success("已触发反思");
+      setIsReflectModalVisible(false);
     } catch (error) {
       console.error("Failed to reflect trace:", error);
       Toast.error("反思请求失败");
@@ -271,6 +287,24 @@ export const TopBar: FC<TopBarProps> = ({
           />
         </Form>
       </Modal>
+      <Modal
+        title="反思"
+        visible={isReflectModalVisible}
+        onOk={handleReflectConfirm}
+        onCancel={() => setIsReflectModalVisible(false)}
+        centered
+        style={{ width: 600 }}
+      >
+        {/* eslint-disable-next-line @typescript-eslint/no-explicit-any */}
+        <Form getFormApi={(api: any) => (reflectFormApiRef.current = api)}>
+          <Form.TextArea
+            field="reflect_focus"
+            label="反思重点"
+            placeholder="请输入反思重点(可选)"
+            autosize={{ minRows: 3, maxRows: 6 }}
+          />
+        </Form>
+      </Modal>
       <Modal
         title="经验列表"
         visible={isExperienceModalVisible}

+ 13 - 10
frontend/react-template/src/hooks/useTrace.ts

@@ -1,4 +1,4 @@
-import { useState, useEffect } from "react";
+import { useState, useEffect, useCallback } from "react";
 import { traceApi } from "../api/traceApi";
 import type { TraceDetailResponse } from "../types/trace";
 
@@ -7,24 +7,27 @@ export const useTrace = (traceId: string | null) => {
   const [loading, setLoading] = useState(false);
   const [error, setError] = useState<Error | null>(null);
 
-  useEffect(() => {
-    if (!traceId) return;
-
-    const loadTrace = async () => {
+  const reload = useCallback(
+    async (idOverride?: string | null) => {
+      const id = typeof idOverride === "string" ? idOverride : traceId;
+      if (!id) return;
       setLoading(true);
       setError(null);
       try {
-        const data = await traceApi.fetchTraceDetail(traceId);
+        const data = await traceApi.fetchTraceDetail(id);
         setTrace(data);
       } catch (err) {
         setError(err as Error);
       } finally {
         setLoading(false);
       }
-    };
+    },
+    [traceId],
+  );
 
-    loadTrace();
-  }, [traceId]);
+  useEffect(() => {
+    void reload();
+  }, [reload]);
 
-  return { trace, loading, error };
+  return { trace, loading, error, reload };
 };