刘立冬 3 주 전
부모
커밋
649888705e
3개의 변경된 파일139개의 추가작업 그리고 75개의 파일을 삭제
  1. 3 0
      lib/my_trace.py
  2. 124 71
      sug_v6_1_2_115.py
  3. 12 4
      visualization/sug_v6_1_2_8/convert_v8_to_graph_v3.js

+ 3 - 0
lib/my_trace.py

@@ -33,6 +33,9 @@ def set_trace():
     # 确保根logger级别生效
     logging.getLogger().setLevel(logging.WARNING)
 
+    # 禁用 openai.agents 的 WARNING 日志,避免大量 "OPENAI_API_KEY is not set" 警告
+    logging.getLogger('openai.agents').setLevel(logging.ERROR)
+
     # 临时绕过 logfire
     try:
         return set_trace_logfire()

+ 124 - 71
sug_v6_1_2_115.py

@@ -17,6 +17,26 @@ from script.search_recommendations.xiaohongshu_search_recommendations import Xia
 from script.search.xiaohongshu_search import XiaohongshuSearch
 
 
+# ============================================================================
+# 日志工具类
+# ============================================================================
+
+class TeeLogger:
+    """同时输出到控制台和日志文件的工具类"""
+    def __init__(self, stdout, log_file):
+        self.stdout = stdout
+        self.log_file = log_file
+
+    def write(self, message):
+        self.stdout.write(message)
+        self.log_file.write(message)
+        self.log_file.flush()  # 实时写入,避免丢失日志
+
+    def flush(self):
+        self.stdout.flush()
+        self.log_file.flush()
+
+
 # ============================================================================
 # 数据模型
 # ============================================================================
@@ -1005,14 +1025,19 @@ async def run_round(
 
         print(f"      评估完成,得到 {len(top_5)} 个组合")
 
-        # 将Top 5全部加入q_list_next(去重检查)
+        # 将Top 5全部加入q_list_next(去重检查 + 得分过滤
         for comb in top_5:
+            # 得分过滤:只有得分大于种子得分的组合词才加入下一轮
+            if comb['score'] <= seed.score_with_o:
+                print(f"        ⊗ 跳过低分: {comb['query']} (分数{comb['score']:.2f} ≤ 种子{seed.score_with_o:.2f})")
+                continue
+
             # 去重检查
             if comb['query'] in existing_q_texts:
                 print(f"        ⊗ 跳过重复: {comb['query']}")
                 continue
 
-            print(f"        ✓ {comb['query']} (分数: {comb['score']:.2f})")
+            print(f"        ✓ {comb['query']} (分数: {comb['score']:.2f} > 种子: {seed.score_with_o:.2f})")
 
             new_q = Q(
                 text=comb['query'],
@@ -1040,6 +1065,9 @@ async def run_round(
         ]
 
         # 保存到all_seed_combinations(用于构建seed_list_next)
+        # 附加seed_score,用于后续过滤
+        for comb in top_5:
+            comb['seed_score'] = seed.score_with_o
         all_seed_combinations.extend(top_5)
 
     # 4.2 对于sug_list_list中,每个sug大于来自的query分数,加到q_list_next(去重检查)
@@ -1066,9 +1094,15 @@ async def run_round(
     seed_list_next = []
     existing_seed_texts = set()
 
-    # 5.1 加入本轮所有组合词
-    print(f"  5.1 加入本轮所有组合词...")
+    # 5.1 加入本轮所有组合词(只加入得分提升的)
+    print(f"  5.1 加入本轮所有组合词(得分过滤)...")
     for comb in all_seed_combinations:
+        # 得分过滤:只有得分大于种子得分的组合词才作为下一轮种子
+        seed_score = comb.get('seed_score', 0)
+        if comb['score'] <= seed_score:
+            print(f"    ⊗ 跳过低分: {comb['query']} (分数{comb['score']:.2f} ≤ 种子{seed_score:.2f})")
+            continue
+
         if comb['query'] not in existing_seed_texts:
             new_seed = Seed(
                 text=comb['query'],
@@ -1078,7 +1112,7 @@ async def run_round(
             )
             seed_list_next.append(new_seed)
             existing_seed_texts.add(comb['query'])
-            print(f"    ✓ {comb['query']} (分数: {comb['score']:.2f})")
+            print(f"    ✓ {comb['query']} (分数: {comb['score']:.2f} > 种子: {seed_score:.2f})")
 
     # 5.2 加入高分sug
     print(f"  5.2 加入高分sug...")
@@ -1240,78 +1274,97 @@ async def main(input_dir: str, max_rounds: int = 2, sug_threshold: float = 0.7,
         log_url=log_url,
     )
 
-    # 执行迭代
-    all_search_list = await iterative_loop(
-        run_context,
-        max_rounds=max_rounds,
-        sug_threshold=sug_threshold
-    )
+    # 创建日志目录
+    os.makedirs(run_context.log_dir, exist_ok=True)
 
-    # 格式化输出
-    output = f"原始需求:{run_context.c}\n"
-    output += f"原始问题:{run_context.o}\n"
-    output += f"总搜索次数:{len(all_search_list)}\n"
-    output += f"总帖子数:{sum(len(s.post_list) for s in all_search_list)}\n"
-    output += "\n" + "="*60 + "\n"
-
-    if all_search_list:
-        output += "【搜索结果】\n\n"
-        for idx, search in enumerate(all_search_list, 1):
-            output += f"{idx}. 搜索词: {search.text} (分数: {search.score_with_o:.2f})\n"
-            output += f"   帖子数: {len(search.post_list)}\n"
-            if search.post_list:
-                for post_idx, post in enumerate(search.post_list[:3], 1):  # 只显示前3个
-                    output += f"   {post_idx}) {post.title}\n"
-                    output += f"      URL: {post.note_url}\n"
-            output += "\n"
-    else:
-        output += "未找到搜索结果\n"
+    # 配置日志文件
+    log_file_path = os.path.join(run_context.log_dir, "run.log")
+    log_file = open(log_file_path, 'w', encoding='utf-8')
 
-    run_context.final_output = output
+    # 重定向stdout到TeeLogger(同时输出到控制台和文件)
+    original_stdout = sys.stdout
+    sys.stdout = TeeLogger(original_stdout, log_file)
 
-    print(f"\n{'='*60}")
-    print("最终结果")
-    print(f"{'='*60}")
-    print(output)
+    try:
+        print(f"📝 日志文件: {log_file_path}")
+        print(f"{'='*60}\n")
 
-    # 保存日志
-    os.makedirs(run_context.log_dir, exist_ok=True)
+        # 执行迭代
+        all_search_list = await iterative_loop(
+            run_context,
+            max_rounds=max_rounds,
+            sug_threshold=sug_threshold
+        )
 
-    context_file_path = os.path.join(run_context.log_dir, "run_context.json")
-    context_dict = run_context.model_dump()
-    with open(context_file_path, "w", encoding="utf-8") as f:
-        json.dump(context_dict, f, ensure_ascii=False, indent=2)
-    print(f"\nRunContext saved to: {context_file_path}")
-
-    # 保存详细的搜索结果
-    search_results_path = os.path.join(run_context.log_dir, "search_results.json")
-    search_results_data = [s.model_dump() for s in all_search_list]
-    with open(search_results_path, "w", encoding="utf-8") as f:
-        json.dump(search_results_data, f, ensure_ascii=False, indent=2)
-    print(f"Search results saved to: {search_results_path}")
-
-    # 可视化
-    if visualize:
-        import subprocess
-        output_html = os.path.join(run_context.log_dir, "visualization.html")
-        print(f"\n🎨 生成可视化HTML...")
-
-        # 获取绝对路径
-        abs_context_file = os.path.abspath(context_file_path)
-        abs_output_html = os.path.abspath(output_html)
-
-        # 运行可视化脚本
-        result = subprocess.run([
-            "node",
-            "visualization/sug_v6_1_2_8/index.js",
-            abs_context_file,
-            abs_output_html
-        ])
-
-        if result.returncode == 0:
-            print(f"✅ 可视化已生成: {output_html}")
+        # 格式化输出
+        output = f"原始需求:{run_context.c}\n"
+        output += f"原始问题:{run_context.o}\n"
+        output += f"总搜索次数:{len(all_search_list)}\n"
+        output += f"总帖子数:{sum(len(s.post_list) for s in all_search_list)}\n"
+        output += "\n" + "="*60 + "\n"
+
+        if all_search_list:
+            output += "【搜索结果】\n\n"
+            for idx, search in enumerate(all_search_list, 1):
+                output += f"{idx}. 搜索词: {search.text} (分数: {search.score_with_o:.2f})\n"
+                output += f"   帖子数: {len(search.post_list)}\n"
+                if search.post_list:
+                    for post_idx, post in enumerate(search.post_list[:3], 1):  # 只显示前3个
+                        output += f"   {post_idx}) {post.title}\n"
+                        output += f"      URL: {post.note_url}\n"
+                output += "\n"
         else:
-            print(f"❌ 可视化生成失败")
+            output += "未找到搜索结果\n"
+
+        run_context.final_output = output
+
+        print(f"\n{'='*60}")
+        print("最终结果")
+        print(f"{'='*60}")
+        print(output)
+
+        # 保存上下文文件
+        context_file_path = os.path.join(run_context.log_dir, "run_context.json")
+        context_dict = run_context.model_dump()
+        with open(context_file_path, "w", encoding="utf-8") as f:
+            json.dump(context_dict, f, ensure_ascii=False, indent=2)
+        print(f"\nRunContext saved to: {context_file_path}")
+
+        # 保存详细的搜索结果
+        search_results_path = os.path.join(run_context.log_dir, "search_results.json")
+        search_results_data = [s.model_dump() for s in all_search_list]
+        with open(search_results_path, "w", encoding="utf-8") as f:
+            json.dump(search_results_data, f, ensure_ascii=False, indent=2)
+        print(f"Search results saved to: {search_results_path}")
+
+        # 可视化
+        if visualize:
+            import subprocess
+            output_html = os.path.join(run_context.log_dir, "visualization.html")
+            print(f"\n🎨 生成可视化HTML...")
+
+            # 获取绝对路径
+            abs_context_file = os.path.abspath(context_file_path)
+            abs_output_html = os.path.abspath(output_html)
+
+            # 运行可视化脚本
+            result = subprocess.run([
+                "node",
+                "visualization/sug_v6_1_2_8/index.js",
+                abs_context_file,
+                abs_output_html
+            ])
+
+            if result.returncode == 0:
+                print(f"✅ 可视化已生成: {output_html}")
+            else:
+                print(f"❌ 可视化生成失败")
+
+    finally:
+        # 恢复stdout
+        sys.stdout = original_stdout
+        log_file.close()
+        print(f"\n📝 运行日志已保存: {log_file_path}")
 
 
 if __name__ == "__main__":

+ 12 - 4
visualization/sug_v6_1_2_8/convert_v8_to_graph_v3.js

@@ -339,10 +339,18 @@ function convertV8ToGraphV2(runContext, searchResults) {
         Object.keys(round.add_word_details).forEach((seedText, seedIndex) => {
           const seedId = `seed_${seedText}_r${roundNum}_${seedIndex}`;
 
-          // 查找seed的来源信息 - 从Round 0的seed_list查找基础种子的from_type
-          const round0 = rounds.find(r => r.round_num === 0 || r.type === 'initialization');
-          const seedInfo = round0?.seed_list?.find(s => s.text === seedText) || {};
-          const fromType = seedInfo.from_type || 'unknown';
+          // 查找seed的来源信息和分数 - 动态从正确的轮次查找
+          let seedInfo = {};
+          if (roundNum === 1) {
+            // Round 1:种子来自 Round 0 的 seed_list
+            const round0 = rounds.find(r => r.round_num === 0 || r.type === 'initialization');
+            seedInfo = round0?.seed_list?.find(s => s.text === seedText) || {};
+          } else {
+            // Round 2+:种子来自前一轮的 seed_list_next
+            const prevRound = rounds.find(r => r.round_num === roundNum - 1);
+            seedInfo = prevRound?.seed_list_next?.find(s => s.text === seedText) || {};
+          }
+          const fromType = seedInfo.from_type || seedInfo.from || 'unknown';
 
           // 根据来源设置strategy
           let strategy;