Parcourir la source

可视化 +阈值0.05

刘立冬 il y a 3 semaines
Parent
commit
044969287f
2 fichiers modifiés avec 46 ajouts et 15 suppressions
  1. 15 13
      sug_v6_1_2_119.py
  2. 31 2
      visualization/sug_v6_1_2_8/index.js

+ 15 - 13
sug_v6_1_2_119.py

@@ -13,6 +13,8 @@ from pydantic import BaseModel, Field
 from lib.utils import read_file_as_string
 from lib.client import get_model
 MODEL_NAME = "google/gemini-2.5-flash"
+# 得分提升阈值:sug或组合词必须比来源query提升至少此幅度才能进入下一轮
+REQUIRED_SCORE_GAIN = 0.05
 from script.search_recommendations.xiaohongshu_search_recommendations import XiaohongshuSearchRecommendations
 from script.search.xiaohongshu_search import XiaohongshuSearch
 
@@ -141,7 +143,6 @@ word_segmentation_instructions = """
 2. 拆分成独立的概念
 3. 保留专业术语的完整性
 4. 去除虚词(的、吗、呢等)
-如果是双标行为,单独分词 不拆分,如果有如何两个字 不要
 
 ## 输出要求
 返回分词列表和分词理由。
@@ -1306,9 +1307,9 @@ async def run_round(
 
         # 将Top 5全部加入q_list_next(去重检查 + 得分过滤)
         for comb in top_5:
-            # 得分过滤:只有得分大于种子得分的组合词才加入下一轮
-            if comb['score'] <= seed.score_with_o:
-                print(f"        ⊗ 跳过低分: {comb['query']} (分数{comb['score']:.2f} ≤ 种子{seed.score_with_o:.2f})")
+            # 得分过滤:组合词必须比种子提升至少REQUIRED_SCORE_GAIN加入下一轮
+            if comb['score'] < seed.score_with_o + REQUIRED_SCORE_GAIN:
+                print(f"        ⊗ 跳过低分: {comb['query']} (分数{comb['score']:.2f} < 种子{seed.score_with_o:.2f} + {REQUIRED_SCORE_GAIN:.2f})")
                 continue
 
             # 去重检查
@@ -1357,7 +1358,8 @@ async def run_round(
             print(f"    ⊗ 跳过来自被剪枝query的sug: {sug.text} (来源: {sug.from_q.text})")
             continue
 
-        if sug.from_q and sug.score_with_o > sug.from_q.score_with_o:
+        # sug必须比来源query提升至少REQUIRED_SCORE_GAIN才能加入下一轮
+        if sug.from_q and sug.score_with_o >= sug.from_q.score_with_o + REQUIRED_SCORE_GAIN:
             # 去重检查
             if sug.text in existing_q_texts:
                 print(f"    ⊗ 跳过重复: {sug.text}")
@@ -1371,7 +1373,7 @@ async def run_round(
             )
             q_list_next.append(new_q)
             existing_q_texts.add(sug.text)  # 记录到去重集合
-            print(f"    ✓ {sug.text} (分数: {sug.score_with_o:.2f} > {sug.from_q.score_with_o:.2f})")
+            print(f"    ✓ {sug.text} (分数: {sug.score_with_o:.2f} >= 来源query: {sug.from_q.score_with_o:.2f} + {REQUIRED_SCORE_GAIN:.2f})")
 
     # 5. 构建seed_list_next(关键修改:不保留上一轮的seed)
     print(f"\n[步骤5] 构建seed_list_next(不保留上轮seed)...")
@@ -1381,10 +1383,10 @@ async def run_round(
     # 5.1 加入本轮所有组合词(只加入得分提升的)
     print(f"  5.1 加入本轮所有组合词(得分过滤)...")
     for comb in all_seed_combinations:
-        # 得分过滤:只有得分大于种子得分的组合词才作为下一轮种子
+        # 得分过滤:组合词必须比种子提升至少REQUIRED_SCORE_GAIN才作为下一轮种子
         seed_score = comb.get('seed_score', 0)
-        if comb['score'] <= seed_score:
-            print(f"    ⊗ 跳过低分: {comb['query']} (分数{comb['score']:.2f} ≤ 种子{seed_score:.2f})")
+        if comb['score'] < seed_score + REQUIRED_SCORE_GAIN:
+            print(f"    ⊗ 跳过低分: {comb['query']} (分数{comb['score']:.2f} < 种子{seed_score:.2f} + {REQUIRED_SCORE_GAIN:.2f})")
             continue
 
         if comb['query'] not in existing_seed_texts:
@@ -1396,7 +1398,7 @@ async def run_round(
             )
             seed_list_next.append(new_seed)
             existing_seed_texts.add(comb['query'])
-            print(f"    ✓ {comb['query']} (分数: {comb['score']:.2f} > 种子: {seed_score:.2f})")
+            print(f"    ✓ {comb['query']} (分数: {comb['score']:.2f} >= 种子: {seed_score:.2f} + {REQUIRED_SCORE_GAIN:.2f})")
 
     # 5.2 加入高分sug
     print(f"  5.2 加入高分sug...")
@@ -1405,8 +1407,8 @@ async def run_round(
         if sug.from_q and sug.from_q.text in pruned_query_texts:
             continue
 
-        # sug分数 > 对应query分数
-        if sug.from_q and sug.score_with_o > sug.from_q.score_with_o and sug.text not in existing_seed_texts:
+        # sug必须比来源query提升至少REQUIRED_SCORE_GAIN才作为下一轮种子
+        if sug.from_q and sug.score_with_o >= sug.from_q.score_with_o + REQUIRED_SCORE_GAIN and sug.text not in existing_seed_texts:
             new_seed = Seed(
                 text=sug.text,
                 added_words=[],
@@ -1415,7 +1417,7 @@ async def run_round(
             )
             seed_list_next.append(new_seed)
             existing_seed_texts.add(sug.text)
-            print(f"    ✓ {sug.text} (分数: {sug.score_with_o:.2f} > 来源query: {sug.from_q.score_with_o:.2f})")
+            print(f"    ✓ {sug.text} (分数: {sug.score_with_o:.2f} >= 来源query: {sug.from_q.score_with_o:.2f} + {REQUIRED_SCORE_GAIN:.2f})")
 
     # 序列化搜索结果数据(包含帖子详情)
     search_results_data = []

+ 31 - 2
visualization/sug_v6_1_2_8/index.js

@@ -782,6 +782,18 @@ function TreeNode({ node, level, children, isCollapsed, onToggle, isSelected, on
   const strategyColor = getStrategyColor(strategy);
   const nodeActualType = node.data.nodeType || node.type; // 获取实际节点类型
 
+  // 计算字体颜色:根据分数提升幅度判断
+  let fontColor = '#374151'; // 默认颜色
+  if (node.type === 'note') {
+    fontColor = node.data.matchLevel === 'unsatisfied' ? '#ef4444' : '#374151';
+  } else if (node.data.seed_score !== undefined) {
+    const parentScore = parseFloat(node.data.seed_score);
+    const gain = score - parentScore;
+    fontColor = gain >= 0.05 ? '#16a34a' : '#ef4444';
+  } else if (node.data.isSelected === false) {
+    fontColor = '#ef4444';
+  }
+
   return (
     <div style={{ marginLeft: level * 12 + 'px' }}>
       <div
@@ -861,7 +873,7 @@ function TreeNode({ node, level, children, isCollapsed, onToggle, isSelected, on
               maxWidth: '180px',
               flex: 1,
               minWidth: 0,
-              color: node.data.scoreColor || ((node.type === 'note' ? node.data.matchLevel === 'unsatisfied' : node.data.isSelected === false) ? '#ef4444' : '#374151'),
+              color: node.data.scoreColor || fontColor,
             }}
             title={node.data.title || node.id}
             >
@@ -1807,6 +1819,23 @@ function FlowContent() {
                           const nodeIsSelected = node.type === 'note' ? node.data.matchLevel !== 'unsatisfied' : node.data.isSelected !== false;
                           const nodeActualType = node.data.nodeType || node.type; // 获取实际节点类型
 
+                          // 计算路径节点字体颜色:根据分数提升幅度判断
+                          let pathFontColor = '#374151'; // 默认颜色
+                          if (node.type === 'note') {
+                            pathFontColor = node.data.matchLevel === 'unsatisfied' ? '#ef4444' : '#374151';
+                          } else if (node.data.seed_score !== undefined) {
+                            const parentScore = parseFloat(node.data.seed_score);
+                            const gain = nodeScore - parentScore;
+                            pathFontColor = gain >= 0.05 ? '#16a34a' : '#ef4444';
+                          } else if (index > 0) {
+                            const prevNode = path[index - 1];
+                            const prevScore = prevNode.data.score ? parseFloat(prevNode.data.score) : 0;
+                            const gain = nodeScore - prevScore;
+                            pathFontColor = gain >= 0.05 ? '#16a34a' : '#ef4444';
+                          } else if (node.data.isSelected === false) {
+                            pathFontColor = '#ef4444';
+                          }
+
                           return (
                           <React.Fragment key={node.id + '-' + index}>
                             <span
@@ -1892,7 +1921,7 @@ function FlowContent() {
                                 <span style={{
                                   flex: 1,
                                   fontSize: '12px',
-                                  color: nodeIsSelected ? '#374151' : '#ef4444',
+                                  color: pathFontColor,
                                 }}>
                                   {truncateMiddle(node.data.title || node.id, 18)}
                                 </span>