Ver Fonte

feat: 优化帖子游走算法,实现真正的双向搜索

- 实现双向交替扩展,正向和反向同时搜索
- 同时考虑出边和入边,支持奇偶路径
- 添加中间边类型筛选配置(默认:属于、包含、分类共现)
- 添加中间步分数过滤(默认:0.3)
- 添加路径去重,避免重复路径
- UI增加全选/清空/默认按钮

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
yangxiaohui há 17 horas atrás
pai
commit
59f09e64d6

+ 35 - 0
script/visualization/src/components/GraphView.vue

@@ -75,6 +75,22 @@
         <span class="text-base-content/60 w-20">最后步分数:</span>
         <input type="number" :min="0" :max="1" :step="0.1" v-model.number="store.postWalkConfig.lastStepMinScore" class="input input-xs input-bordered w-16 text-center" />
       </div>
+      <div class="flex items-center gap-2">
+        <span class="text-base-content/60 w-20">中间边类型:</span>
+        <button @click="selectAllMiddleEdgeTypes" class="btn btn-ghost btn-xs text-base-content/50">全选</button>
+        <button @click="clearMiddleEdgeTypes" class="btn btn-ghost btn-xs text-base-content/50">清空</button>
+        <button @click="resetMiddleEdgeTypes" class="btn btn-ghost btn-xs text-base-content/50">默认</button>
+      </div>
+      <div class="flex items-center gap-2 flex-wrap pl-20">
+        <label v-for="t in middleEdgeTypeOptions" :key="t" class="flex items-center gap-1 cursor-pointer">
+          <input type="checkbox" :value="t" v-model="store.postWalkConfig.middleEdgeTypes" class="checkbox checkbox-xs checkbox-primary" />
+          <span>{{ t }}</span>
+        </label>
+      </div>
+      <div class="flex items-center gap-2">
+        <span class="text-base-content/60 w-20">中间步分数:</span>
+        <input type="number" :min="0" :max="1" :step="0.1" v-model.number="store.postWalkConfig.middleMinScore" class="input input-xs input-bordered w-16 text-center" />
+      </div>
       <div class="flex items-center gap-2 text-base-content/50">
         <span>路径: 帖子标签</span>
         <span class="text-primary">--{{ store.postWalkConfig.firstEdgeType }}--></span>
@@ -126,6 +142,25 @@ function copyGraphJson() {
   })
 }
 
+// 中间步骤可选的边类型(排除匹配边)
+const middleEdgeTypeOptions = computed(() => {
+  return store.allEdgeTypes.filter(t => t !== '匹配')
+})
+
+// 中间边类型操作
+function selectAllMiddleEdgeTypes() {
+  store.postWalkConfig.middleEdgeTypes = [...middleEdgeTypeOptions.value]
+}
+
+function clearMiddleEdgeTypes() {
+  store.postWalkConfig.middleEdgeTypes = []
+}
+
+function resetMiddleEdgeTypes() {
+  store.postWalkConfig.middleEdgeTypes = ['属于', '包含', '分类共现']
+  store.postWalkConfig.middleMinScore = 0.3
+}
+
 let simulation = null
 
 // 游走配置操作(直接操作 store)

+ 260 - 106
script/visualization/src/stores/graph.js

@@ -67,9 +67,20 @@ export const useGraphStore = defineStore('graph', () => {
     lastStepMinScore: 0.8,  // 最后一步最小分数
     firstEdgeType: '匹配',  // 第一步边类型
     lastEdgeType: '匹配',  // 最后一步边类型(反向)
-    excludeMiddleEdgeTypes: ['匹配']  // 中间步骤排除的边类型
+    middleEdgeTypes: ['属于', '包含', '分类共现'],  // 中间步骤允许的边类型
+    middleMinScore: 0.3  // 中间步骤最小分数
   })
 
+  // 检查边类型是否允许在中间步骤使用
+  function isMiddleEdgeAllowed(edgeType) {
+    // 匹配边始终不允许
+    if (edgeType === '匹配') return false
+    // 如果配置为空,允许所有非匹配边
+    if (postWalkConfig.middleEdgeTypes.length === 0) return true
+    // 否则只允许配置中的边类型
+    return postWalkConfig.middleEdgeTypes.includes(edgeType)
+  }
+
   // 判断节点是否使用帖子游走(帖子树中的标签节点)
   function shouldPostWalk(nodeId) {
     return postWalkConfig.nodeTypes.some(prefix => nodeId.startsWith(prefix))
@@ -92,7 +103,6 @@ export const useGraphStore = defineStore('graph', () => {
     console.log('=== executePostWalk 双向搜索 ===')
     console.log('起点:', startNodeId)
 
-    const currentPostNodes = getCurrentPostNodeIds()
     const postGraph = currentPostGraph.value
     const personaGraph = graphData.value
 
@@ -104,119 +114,260 @@ export const useGraphStore = defineStore('graph', () => {
       return new Set([startNodeId])
     }
 
-    // ========== 正向:起点 → 匹配边 → 人设节点 ==========
-    const forwardNodes = new Map()  // nodeId -> { prevNode, edge }
-    const forwardFrontier = new Set()
-
-    // 直接遍历边(postGraph.edges),因为 postGraph 可能没有 index 结构
     const postEdges = Object.values(postGraph.edges || {})
     console.log('帖子图谱边数:', postEdges.length)
 
+    // ========== 正向初始化:起点 → 匹配边 → 人设节点 ==========
+    const forwardVisited = new Map()  // nodeId -> { depth, paths: [[edge, ...]] }
+    let forwardFrontier = new Set()
+
     for (const edge of postEdges) {
       if (edge.source === startNodeId && edge.type === postWalkConfig.firstEdgeType) {
-        forwardNodes.set(edge.target, {
-          prevNode: startNodeId,
-          edge: { source: startNodeId, target: edge.target, type: edge.type, score: edge.score || 0 }
-        })
+        const edgeData = { source: startNodeId, target: edge.target, type: edge.type, score: edge.score || 0 }
+        if (!forwardVisited.has(edge.target)) {
+          forwardVisited.set(edge.target, { depth: 1, paths: [] })
+        }
+        forwardVisited.get(edge.target).paths.push([edgeData])
         forwardFrontier.add(edge.target)
       }
     }
     console.log('正向第一步到达节点数:', forwardFrontier.size)
-    console.log('正向到达节点:', Array.from(forwardFrontier))
 
-    // ========== 反向:当前帖子的其他标签 ← 匹配边 ← 人设节点 ==========
-    const backwardNodes = new Map()  // nodeId -> [{ nextNode, edge }]
-    const backwardFrontier = new Set()
+    // ========== 反向初始化:终点 ← 匹配边 ← 人设节点 ==========
+    const backwardVisited = new Map()  // nodeId -> { depth, endings: [{ postNode, edge }] }
+    let backwardFrontier = new Set()
 
-    // 在当前帖子图谱中,找除了起点之外的其他标签节点
     for (const edge of postEdges) {
-      // 匹配边:帖子标签 -> 人设节点
       if (edge.type === postWalkConfig.lastEdgeType && edge.source !== startNodeId) {
         if ((edge.score || 0) >= postWalkConfig.lastStepMinScore) {
-          // edge.target 是人设节点,edge.source 是当前帖子的其他标签
-          if (!backwardNodes.has(edge.target)) {
-            backwardNodes.set(edge.target, [])
+          const edgeData = { source: edge.target, target: edge.source, type: edge.type, score: edge.score || 0 }
+          if (!backwardVisited.has(edge.target)) {
+            backwardVisited.set(edge.target, { depth: 1, endings: [] })
           }
-          backwardNodes.get(edge.target).push({
-            nextNode: edge.source,
-            edge: { source: edge.target, target: edge.source, type: edge.type, score: edge.score || 0 }
-          })
+          backwardVisited.get(edge.target).endings.push({ postNode: edge.source, edge: edgeData })
           backwardFrontier.add(edge.target)
         }
       }
     }
     console.log('反向第一步到达节点数:', backwardFrontier.size)
-    console.log('反向到达节点(部分):', Array.from(backwardFrontier).slice(0, 5))
 
-    // ========== 检查直接相遇(2步路径) ==========
-    const paths = []
-    const meetingNodes = new Set()
+    // ========== 收集所有相遇点和对应路径 ==========
+    const allMeetings = []  // { meetNode, forwardPath, backwardEnding }
 
+    // 检查初始相遇
     for (const nodeId of forwardFrontier) {
-      if (backwardFrontier.has(nodeId)) {
-        meetingNodes.add(nodeId)
+      if (backwardVisited.has(nodeId)) {
+        const fData = forwardVisited.get(nodeId)
+        const bData = backwardVisited.get(nodeId)
+        for (const fPath of fData.paths) {
+          for (const bEnd of bData.endings) {
+            allMeetings.push({ meetNode: nodeId, forwardPath: fPath, backwardEnding: bEnd })
+          }
+        }
       }
     }
-    console.log('直接相遇节点数:', meetingNodes.size)
+    console.log('初始相遇数:', allMeetings.length)
+
+    // ========== 双向交替扩展 ==========
+    const maxSteps = postWalkConfig.maxSteps
+    let forwardDepth = 1
+    let backwardDepth = 1
+
+    for (let step = 0; step < maxSteps; step++) {
+      // 选择扩展较小的一边(优化搜索效率)
+      const expandForward = forwardFrontier.size <= backwardFrontier.size
+
+      if (expandForward && forwardFrontier.size > 0) {
+        // 扩展正向
+        const nextFrontier = new Set()
+        forwardDepth++
+
+        for (const nodeId of forwardFrontier) {
+          const currentPaths = forwardVisited.get(nodeId)?.paths || []
+
+          // 出边
+          const outEdges = personaGraph.index?.outEdges?.[nodeId] || {}
+          for (const [edgeType, targets] of Object.entries(outEdges)) {
+            if (isMiddleEdgeAllowed(edgeType)) {
+              for (const t of targets) {
+                if (t.target !== startNodeId && (t.score || 0) >= postWalkConfig.middleMinScore) {
+                  const newEdge = { source: nodeId, target: t.target, type: edgeType, score: t.score || 0 }
+
+                  if (!forwardVisited.has(t.target)) {
+                    forwardVisited.set(t.target, { depth: forwardDepth, paths: [] })
+                    nextFrontier.add(t.target)
+                  }
+
+                  // 添加所有新路径
+                  const targetData = forwardVisited.get(t.target)
+                  if (targetData.depth === forwardDepth) {
+                    for (const path of currentPaths) {
+                      targetData.paths.push([...path, newEdge])
+                    }
+                  }
+
+                  // 检查相遇
+                  if (backwardVisited.has(t.target)) {
+                    const bData = backwardVisited.get(t.target)
+                    for (const path of currentPaths) {
+                      for (const bEnd of bData.endings) {
+                        allMeetings.push({ meetNode: t.target, forwardPath: [...path, newEdge], backwardEnding: bEnd })
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
 
-    // ========== 如果没有直接相遇,在人设图谱中扩展 ==========
-    const maxMiddleSteps = postWalkConfig.maxSteps  // 中间扩展步数
-    let currentForward = new Set(forwardFrontier)
-    let currentBackward = new Set(backwardFrontier)
+          // 入边(反向遍历)
+          const inEdges = personaGraph.index?.inEdges?.[nodeId] || {}
+          for (const [edgeType, sources] of Object.entries(inEdges)) {
+            if (isMiddleEdgeAllowed(edgeType)) {
+              for (const s of sources) {
+                if (s.source !== startNodeId && (s.score || 0) >= postWalkConfig.middleMinScore) {
+                  const newEdge = { source: nodeId, target: s.source, type: edgeType, score: s.score || 0, reversed: true }
+
+                  if (!forwardVisited.has(s.source)) {
+                    forwardVisited.set(s.source, { depth: forwardDepth, paths: [] })
+                    nextFrontier.add(s.source)
+                  }
+
+                  const targetData = forwardVisited.get(s.source)
+                  if (targetData.depth === forwardDepth) {
+                    for (const path of currentPaths) {
+                      targetData.paths.push([...path, newEdge])
+                    }
+                  }
+
+                  if (backwardVisited.has(s.source)) {
+                    const bData = backwardVisited.get(s.source)
+                    for (const path of currentPaths) {
+                      for (const bEnd of bData.endings) {
+                        allMeetings.push({ meetNode: s.source, forwardPath: [...path, newEdge], backwardEnding: bEnd })
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
 
-    // 记录扩展路径
-    const forwardPaths = new Map()  // nodeId -> [path from start]
-    for (const nodeId of forwardFrontier) {
-      forwardPaths.set(nodeId, [forwardNodes.get(nodeId).edge])
-    }
+        forwardFrontier = nextFrontier
+        console.log(`正向扩展第${forwardDepth}步,新增节点:`, nextFrontier.size, '累计相遇:', allMeetings.length)
+      } else if (backwardFrontier.size > 0) {
+        // 扩展反向
+        const nextFrontier = new Set()
+        backwardDepth++
+
+        for (const nodeId of backwardFrontier) {
+          const currentEndings = backwardVisited.get(nodeId)?.endings || []
+
+          // 入边(反向扩展 = 沿入边方向)
+          const inEdges = personaGraph.index?.inEdges?.[nodeId] || {}
+          for (const [edgeType, sources] of Object.entries(inEdges)) {
+            if (isMiddleEdgeAllowed(edgeType)) {
+              for (const s of sources) {
+                if ((s.score || 0) < postWalkConfig.middleMinScore) continue
+                const newEdge = { source: nodeId, target: s.source, type: edgeType, score: s.score || 0 }
+
+                if (!backwardVisited.has(s.source)) {
+                  backwardVisited.set(s.source, { depth: backwardDepth, endings: [] })
+                  nextFrontier.add(s.source)
+                }
 
-    const backwardPaths = new Map()  // nodeId -> [path to end]
-    for (const nodeId of backwardFrontier) {
-      backwardPaths.set(nodeId, backwardNodes.get(nodeId).map(b => b.edge))
-    }
+                const targetData = backwardVisited.get(s.source)
+                if (targetData.depth === backwardDepth) {
+                  for (const ending of currentEndings) {
+                    targetData.endings.push({
+                      postNode: ending.postNode,
+                      edge: ending.edge,
+                      middleEdges: [...(ending.middleEdges || []), newEdge]
+                    })
+                  }
+                }
 
-    for (let step = 0; step < maxMiddleSteps; step++) {
-      // 扩展正向(在人设图谱中)
-      const nextForward = new Set()
-      const newForwardPaths = new Map()
-
-      for (const nodeId of currentForward) {
-        const outEdges = personaGraph.index?.outEdges?.[nodeId] || {}
-        for (const [edgeType, targets] of Object.entries(outEdges)) {
-          if (!postWalkConfig.excludeMiddleEdgeTypes.includes(edgeType)) {
-            for (const t of targets) {
-              if (!forwardNodes.has(t.target) && t.target !== startNodeId) {
-                const newEdge = { source: nodeId, target: t.target, type: edgeType, score: t.score || 0 }
-                forwardNodes.set(t.target, { prevNode: nodeId, edge: newEdge })
-                nextForward.add(t.target)
-
-                // 记录路径
-                const prevPath = forwardPaths.get(nodeId) || []
-                newForwardPaths.set(t.target, [...prevPath, newEdge])
-
-                // 检查是否与反向相遇
-                if (backwardFrontier.has(t.target)) {
-                  meetingNodes.add(t.target)
+                // 检查相遇
+                if (forwardVisited.has(s.source)) {
+                  const fData = forwardVisited.get(s.source)
+                  for (const fPath of fData.paths) {
+                    for (const ending of currentEndings) {
+                      allMeetings.push({
+                        meetNode: s.source,
+                        forwardPath: fPath,
+                        backwardEnding: {
+                          postNode: ending.postNode,
+                          edge: ending.edge,
+                          middleEdges: [...(ending.middleEdges || []), newEdge]
+                        }
+                      })
+                    }
+                  }
+                }
+              }
+            }
+          }
+
+          // 出边
+          const outEdges = personaGraph.index?.outEdges?.[nodeId] || {}
+          for (const [edgeType, targets] of Object.entries(outEdges)) {
+            if (isMiddleEdgeAllowed(edgeType)) {
+              for (const t of targets) {
+                if ((t.score || 0) < postWalkConfig.middleMinScore) continue
+                const newEdge = { source: nodeId, target: t.target, type: edgeType, score: t.score || 0, reversed: true }
+
+                if (!backwardVisited.has(t.target)) {
+                  backwardVisited.set(t.target, { depth: backwardDepth, endings: [] })
+                  nextFrontier.add(t.target)
+                }
+
+                const targetData = backwardVisited.get(t.target)
+                if (targetData.depth === backwardDepth) {
+                  for (const ending of currentEndings) {
+                    targetData.endings.push({
+                      postNode: ending.postNode,
+                      edge: ending.edge,
+                      middleEdges: [...(ending.middleEdges || []), newEdge]
+                    })
+                  }
+                }
+
+                if (forwardVisited.has(t.target)) {
+                  const fData = forwardVisited.get(t.target)
+                  for (const fPath of fData.paths) {
+                    for (const ending of currentEndings) {
+                      allMeetings.push({
+                        meetNode: t.target,
+                        forwardPath: fPath,
+                        backwardEnding: {
+                          postNode: ending.postNode,
+                          edge: ending.edge,
+                          middleEdges: [...(ending.middleEdges || []), newEdge]
+                        }
+                      })
+                    }
+                  }
                 }
               }
             }
           }
         }
-      }
 
-      for (const [k, v] of newForwardPaths) {
-        forwardPaths.set(k, v)
+        backwardFrontier = nextFrontier
+        console.log(`反向扩展第${backwardDepth}步,新增节点:`, nextFrontier.size, '累计相遇:', allMeetings.length)
+      } else {
+        break
       }
-      currentForward = nextForward
-      console.log(`正向扩展第${step + 1}步,新增节点:`, nextForward.size, '相遇节点:', meetingNodes.size)
 
-      // 如果没有新节点可扩展,提前退出
-      if (nextForward.size === 0) break
+      if (forwardFrontier.size === 0 && backwardFrontier.size === 0) break
     }
 
-    console.log('最终相遇节点数:', meetingNodes.size)
+    console.log('最终相遇数:', allMeetings.length)
 
-    // ========== 构建完整路径 ==========
+    // ========== 构建完整路径(去重) ==========
+    const paths = []
+    const pathSignatures = new Set()  // 用于去重
     const allNodes = new Map()
     const allEdges = new Map()
 
@@ -226,15 +377,38 @@ export const useGraphStore = defineStore('graph', () => {
       allNodes.set(startNodeId, { id: startNodeId, ...startNodeData })
     }
 
-    for (const meetNode of meetingNodes) {
-      // 正向路径
-      const fPath = forwardPaths.get(meetNode) || []
-      for (const edge of fPath) {
+    for (const meeting of allMeetings) {
+      const { meetNode, forwardPath, backwardEnding } = meeting
+
+      // 构建完整边列表
+      const fullEdges = [...forwardPath]
+      // 反向中间边(如果有)
+      if (backwardEnding.middleEdges) {
+        fullEdges.push(...backwardEnding.middleEdges)
+      }
+      // 最后的匹配边
+      fullEdges.push(backwardEnding.edge)
+
+      // 构建节点列表
+      const nodeList = [startNodeId]
+      for (const edge of fullEdges) {
+        nodeList.push(edge.target)
+      }
+
+      // 路径签名:节点序列(用于去重)
+      const signature = nodeList.join('|')
+      if (pathSignatures.has(signature)) continue
+      pathSignatures.add(signature)
+
+      // 添加到 paths
+      paths.push({ nodes: nodeList, edges: fullEdges })
+
+      // 收集所有节点和边
+      for (const edge of fullEdges) {
         const edgeKey = `${edge.source}->${edge.target}`
         if (!allEdges.has(edgeKey)) {
           allEdges.set(edgeKey, edge)
         }
-        // 添加节点
         for (const nid of [edge.source, edge.target]) {
           if (!allNodes.has(nid)) {
             const nodeData = postGraph.nodes?.[nid] || personaGraph.nodes?.[nid]
@@ -244,44 +418,24 @@ export const useGraphStore = defineStore('graph', () => {
           }
         }
       }
-
-      // 反向路径(到终点)
-      const bEdges = backwardNodes.get(meetNode) || []
-      for (const b of bEdges) {
-        const edgeKey = `${b.edge.source}->${b.edge.target}`
-        if (!allEdges.has(edgeKey)) {
-          allEdges.set(edgeKey, b.edge)
-        }
-        // 添加终点节点(当前帖子的其他标签)
-        if (!allNodes.has(b.nextNode)) {
-          const nodeData = postGraph.nodes?.[b.nextNode]
-          if (nodeData) {
-            allNodes.set(b.nextNode, { id: b.nextNode, ...nodeData })
-          }
-        }
-      }
-
-      // 记录路径
-      for (const b of bEdges) {
-        paths.push({
-          nodes: [...fPath.map(e => e.source), meetNode, b.nextNode],
-          edges: [...fPath, b.edge]
-        })
-      }
     }
 
     console.log('找到路径数:', paths.length)
     console.log('涉及节点数:', allNodes.size)
     console.log('涉及边数:', allEdges.size)
 
-    // 打印完整路径
-    for (let i = 0; i < paths.length; i++) {
+    // 打印完整路径(限制数量避免刷屏)
+    const printLimit = Math.min(paths.length, 10)
+    for (let i = 0; i < printLimit; i++) {
       const p = paths[i]
       const pathStr = p.nodes.join(' -> ')
       const scoresStr = p.edges.map(e => `${e.type}(${e.score?.toFixed(2) || 0})`).join(' -> ')
       console.log(`路径${i + 1}: ${pathStr}`)
       console.log(`  边: ${scoresStr}`)
     }
+    if (paths.length > printLimit) {
+      console.log(`... 还有 ${paths.length - printLimit} 条路径`)
+    }
 
     postWalkedPaths.value = paths
     postWalkedNodes.value = Array.from(allNodes.values())