|
@@ -17,6 +17,26 @@ from script.search_recommendations.xiaohongshu_search_recommendations import Xia
|
|
|
from script.search.xiaohongshu_search import XiaohongshuSearch
|
|
from script.search.xiaohongshu_search import XiaohongshuSearch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+# ============================================================================
|
|
|
|
|
+# 日志工具类
|
|
|
|
|
+# ============================================================================
|
|
|
|
|
+
|
|
|
|
|
+class TeeLogger:
|
|
|
|
|
+ """同时输出到控制台和日志文件的工具类"""
|
|
|
|
|
+ def __init__(self, stdout, log_file):
|
|
|
|
|
+ self.stdout = stdout
|
|
|
|
|
+ self.log_file = log_file
|
|
|
|
|
+
|
|
|
|
|
+ def write(self, message):
|
|
|
|
|
+ self.stdout.write(message)
|
|
|
|
|
+ self.log_file.write(message)
|
|
|
|
|
+ self.log_file.flush() # 实时写入,避免丢失日志
|
|
|
|
|
+
|
|
|
|
|
+ def flush(self):
|
|
|
|
|
+ self.stdout.flush()
|
|
|
|
|
+ self.log_file.flush()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
# ============================================================================
|
|
# ============================================================================
|
|
|
# 数据模型
|
|
# 数据模型
|
|
|
# ============================================================================
|
|
# ============================================================================
|
|
@@ -1005,14 +1025,19 @@ async def run_round(
|
|
|
|
|
|
|
|
print(f" 评估完成,得到 {len(top_5)} 个组合")
|
|
print(f" 评估完成,得到 {len(top_5)} 个组合")
|
|
|
|
|
|
|
|
- # 将Top 5全部加入q_list_next(去重检查)
|
|
|
|
|
|
|
+ # 将Top 5全部加入q_list_next(去重检查 + 得分过滤)
|
|
|
for comb in top_5:
|
|
for comb in top_5:
|
|
|
|
|
+ # 得分过滤:只有得分大于种子得分的组合词才加入下一轮
|
|
|
|
|
+ if comb['score'] <= seed.score_with_o:
|
|
|
|
|
+ print(f" ⊗ 跳过低分: {comb['query']} (分数{comb['score']:.2f} ≤ 种子{seed.score_with_o:.2f})")
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
# 去重检查
|
|
# 去重检查
|
|
|
if comb['query'] in existing_q_texts:
|
|
if comb['query'] in existing_q_texts:
|
|
|
print(f" ⊗ 跳过重复: {comb['query']}")
|
|
print(f" ⊗ 跳过重复: {comb['query']}")
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
- print(f" ✓ {comb['query']} (分数: {comb['score']:.2f})")
|
|
|
|
|
|
|
+ print(f" ✓ {comb['query']} (分数: {comb['score']:.2f} > 种子: {seed.score_with_o:.2f})")
|
|
|
|
|
|
|
|
new_q = Q(
|
|
new_q = Q(
|
|
|
text=comb['query'],
|
|
text=comb['query'],
|
|
@@ -1040,6 +1065,9 @@ async def run_round(
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
# 保存到all_seed_combinations(用于构建seed_list_next)
|
|
# 保存到all_seed_combinations(用于构建seed_list_next)
|
|
|
|
|
+ # 附加seed_score,用于后续过滤
|
|
|
|
|
+ for comb in top_5:
|
|
|
|
|
+ comb['seed_score'] = seed.score_with_o
|
|
|
all_seed_combinations.extend(top_5)
|
|
all_seed_combinations.extend(top_5)
|
|
|
|
|
|
|
|
# 4.2 对于sug_list_list中,每个sug大于来自的query分数,加到q_list_next(去重检查)
|
|
# 4.2 对于sug_list_list中,每个sug大于来自的query分数,加到q_list_next(去重检查)
|
|
@@ -1066,9 +1094,15 @@ async def run_round(
|
|
|
seed_list_next = []
|
|
seed_list_next = []
|
|
|
existing_seed_texts = set()
|
|
existing_seed_texts = set()
|
|
|
|
|
|
|
|
- # 5.1 加入本轮所有组合词
|
|
|
|
|
- print(f" 5.1 加入本轮所有组合词...")
|
|
|
|
|
|
|
+ # 5.1 加入本轮所有组合词(只加入得分提升的)
|
|
|
|
|
+ print(f" 5.1 加入本轮所有组合词(得分过滤)...")
|
|
|
for comb in all_seed_combinations:
|
|
for comb in all_seed_combinations:
|
|
|
|
|
+ # 得分过滤:只有得分大于种子得分的组合词才作为下一轮种子
|
|
|
|
|
+ seed_score = comb.get('seed_score', 0)
|
|
|
|
|
+ if comb['score'] <= seed_score:
|
|
|
|
|
+ print(f" ⊗ 跳过低分: {comb['query']} (分数{comb['score']:.2f} ≤ 种子{seed_score:.2f})")
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
if comb['query'] not in existing_seed_texts:
|
|
if comb['query'] not in existing_seed_texts:
|
|
|
new_seed = Seed(
|
|
new_seed = Seed(
|
|
|
text=comb['query'],
|
|
text=comb['query'],
|
|
@@ -1078,7 +1112,7 @@ async def run_round(
|
|
|
)
|
|
)
|
|
|
seed_list_next.append(new_seed)
|
|
seed_list_next.append(new_seed)
|
|
|
existing_seed_texts.add(comb['query'])
|
|
existing_seed_texts.add(comb['query'])
|
|
|
- print(f" ✓ {comb['query']} (分数: {comb['score']:.2f})")
|
|
|
|
|
|
|
+ print(f" ✓ {comb['query']} (分数: {comb['score']:.2f} > 种子: {seed_score:.2f})")
|
|
|
|
|
|
|
|
# 5.2 加入高分sug
|
|
# 5.2 加入高分sug
|
|
|
print(f" 5.2 加入高分sug...")
|
|
print(f" 5.2 加入高分sug...")
|
|
@@ -1240,78 +1274,97 @@ async def main(input_dir: str, max_rounds: int = 2, sug_threshold: float = 0.7,
|
|
|
log_url=log_url,
|
|
log_url=log_url,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # 执行迭代
|
|
|
|
|
- all_search_list = await iterative_loop(
|
|
|
|
|
- run_context,
|
|
|
|
|
- max_rounds=max_rounds,
|
|
|
|
|
- sug_threshold=sug_threshold
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # 创建日志目录
|
|
|
|
|
+ os.makedirs(run_context.log_dir, exist_ok=True)
|
|
|
|
|
|
|
|
- # 格式化输出
|
|
|
|
|
- output = f"原始需求:{run_context.c}\n"
|
|
|
|
|
- output += f"原始问题:{run_context.o}\n"
|
|
|
|
|
- output += f"总搜索次数:{len(all_search_list)}\n"
|
|
|
|
|
- output += f"总帖子数:{sum(len(s.post_list) for s in all_search_list)}\n"
|
|
|
|
|
- output += "\n" + "="*60 + "\n"
|
|
|
|
|
-
|
|
|
|
|
- if all_search_list:
|
|
|
|
|
- output += "【搜索结果】\n\n"
|
|
|
|
|
- for idx, search in enumerate(all_search_list, 1):
|
|
|
|
|
- output += f"{idx}. 搜索词: {search.text} (分数: {search.score_with_o:.2f})\n"
|
|
|
|
|
- output += f" 帖子数: {len(search.post_list)}\n"
|
|
|
|
|
- if search.post_list:
|
|
|
|
|
- for post_idx, post in enumerate(search.post_list[:3], 1): # 只显示前3个
|
|
|
|
|
- output += f" {post_idx}) {post.title}\n"
|
|
|
|
|
- output += f" URL: {post.note_url}\n"
|
|
|
|
|
- output += "\n"
|
|
|
|
|
- else:
|
|
|
|
|
- output += "未找到搜索结果\n"
|
|
|
|
|
|
|
+ # 配置日志文件
|
|
|
|
|
+ log_file_path = os.path.join(run_context.log_dir, "run.log")
|
|
|
|
|
+ log_file = open(log_file_path, 'w', encoding='utf-8')
|
|
|
|
|
|
|
|
- run_context.final_output = output
|
|
|
|
|
|
|
+ # 重定向stdout到TeeLogger(同时输出到控制台和文件)
|
|
|
|
|
+ original_stdout = sys.stdout
|
|
|
|
|
+ sys.stdout = TeeLogger(original_stdout, log_file)
|
|
|
|
|
|
|
|
- print(f"\n{'='*60}")
|
|
|
|
|
- print("最终结果")
|
|
|
|
|
- print(f"{'='*60}")
|
|
|
|
|
- print(output)
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ print(f"📝 日志文件: {log_file_path}")
|
|
|
|
|
+ print(f"{'='*60}\n")
|
|
|
|
|
|
|
|
- # 保存日志
|
|
|
|
|
- os.makedirs(run_context.log_dir, exist_ok=True)
|
|
|
|
|
|
|
+ # 执行迭代
|
|
|
|
|
+ all_search_list = await iterative_loop(
|
|
|
|
|
+ run_context,
|
|
|
|
|
+ max_rounds=max_rounds,
|
|
|
|
|
+ sug_threshold=sug_threshold
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- context_file_path = os.path.join(run_context.log_dir, "run_context.json")
|
|
|
|
|
- context_dict = run_context.model_dump()
|
|
|
|
|
- with open(context_file_path, "w", encoding="utf-8") as f:
|
|
|
|
|
- json.dump(context_dict, f, ensure_ascii=False, indent=2)
|
|
|
|
|
- print(f"\nRunContext saved to: {context_file_path}")
|
|
|
|
|
-
|
|
|
|
|
- # 保存详细的搜索结果
|
|
|
|
|
- search_results_path = os.path.join(run_context.log_dir, "search_results.json")
|
|
|
|
|
- search_results_data = [s.model_dump() for s in all_search_list]
|
|
|
|
|
- with open(search_results_path, "w", encoding="utf-8") as f:
|
|
|
|
|
- json.dump(search_results_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
- print(f"Search results saved to: {search_results_path}")
|
|
|
|
|
-
|
|
|
|
|
- # 可视化
|
|
|
|
|
- if visualize:
|
|
|
|
|
- import subprocess
|
|
|
|
|
- output_html = os.path.join(run_context.log_dir, "visualization.html")
|
|
|
|
|
- print(f"\n🎨 生成可视化HTML...")
|
|
|
|
|
-
|
|
|
|
|
- # 获取绝对路径
|
|
|
|
|
- abs_context_file = os.path.abspath(context_file_path)
|
|
|
|
|
- abs_output_html = os.path.abspath(output_html)
|
|
|
|
|
-
|
|
|
|
|
- # 运行可视化脚本
|
|
|
|
|
- result = subprocess.run([
|
|
|
|
|
- "node",
|
|
|
|
|
- "visualization/sug_v6_1_2_8/index.js",
|
|
|
|
|
- abs_context_file,
|
|
|
|
|
- abs_output_html
|
|
|
|
|
- ])
|
|
|
|
|
-
|
|
|
|
|
- if result.returncode == 0:
|
|
|
|
|
- print(f"✅ 可视化已生成: {output_html}")
|
|
|
|
|
|
|
+ # 格式化输出
|
|
|
|
|
+ output = f"原始需求:{run_context.c}\n"
|
|
|
|
|
+ output += f"原始问题:{run_context.o}\n"
|
|
|
|
|
+ output += f"总搜索次数:{len(all_search_list)}\n"
|
|
|
|
|
+ output += f"总帖子数:{sum(len(s.post_list) for s in all_search_list)}\n"
|
|
|
|
|
+ output += "\n" + "="*60 + "\n"
|
|
|
|
|
+
|
|
|
|
|
+ if all_search_list:
|
|
|
|
|
+ output += "【搜索结果】\n\n"
|
|
|
|
|
+ for idx, search in enumerate(all_search_list, 1):
|
|
|
|
|
+ output += f"{idx}. 搜索词: {search.text} (分数: {search.score_with_o:.2f})\n"
|
|
|
|
|
+ output += f" 帖子数: {len(search.post_list)}\n"
|
|
|
|
|
+ if search.post_list:
|
|
|
|
|
+ for post_idx, post in enumerate(search.post_list[:3], 1): # 只显示前3个
|
|
|
|
|
+ output += f" {post_idx}) {post.title}\n"
|
|
|
|
|
+ output += f" URL: {post.note_url}\n"
|
|
|
|
|
+ output += "\n"
|
|
|
else:
|
|
else:
|
|
|
- print(f"❌ 可视化生成失败")
|
|
|
|
|
|
|
+ output += "未找到搜索结果\n"
|
|
|
|
|
+
|
|
|
|
|
+ run_context.final_output = output
|
|
|
|
|
+
|
|
|
|
|
+ print(f"\n{'='*60}")
|
|
|
|
|
+ print("最终结果")
|
|
|
|
|
+ print(f"{'='*60}")
|
|
|
|
|
+ print(output)
|
|
|
|
|
+
|
|
|
|
|
+ # 保存上下文文件
|
|
|
|
|
+ context_file_path = os.path.join(run_context.log_dir, "run_context.json")
|
|
|
|
|
+ context_dict = run_context.model_dump()
|
|
|
|
|
+ with open(context_file_path, "w", encoding="utf-8") as f:
|
|
|
|
|
+ json.dump(context_dict, f, ensure_ascii=False, indent=2)
|
|
|
|
|
+ print(f"\nRunContext saved to: {context_file_path}")
|
|
|
|
|
+
|
|
|
|
|
+ # 保存详细的搜索结果
|
|
|
|
|
+ search_results_path = os.path.join(run_context.log_dir, "search_results.json")
|
|
|
|
|
+ search_results_data = [s.model_dump() for s in all_search_list]
|
|
|
|
|
+ with open(search_results_path, "w", encoding="utf-8") as f:
|
|
|
|
|
+ json.dump(search_results_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
+ print(f"Search results saved to: {search_results_path}")
|
|
|
|
|
+
|
|
|
|
|
+ # 可视化
|
|
|
|
|
+ if visualize:
|
|
|
|
|
+ import subprocess
|
|
|
|
|
+ output_html = os.path.join(run_context.log_dir, "visualization.html")
|
|
|
|
|
+ print(f"\n🎨 生成可视化HTML...")
|
|
|
|
|
+
|
|
|
|
|
+ # 获取绝对路径
|
|
|
|
|
+ abs_context_file = os.path.abspath(context_file_path)
|
|
|
|
|
+ abs_output_html = os.path.abspath(output_html)
|
|
|
|
|
+
|
|
|
|
|
+ # 运行可视化脚本
|
|
|
|
|
+ result = subprocess.run([
|
|
|
|
|
+ "node",
|
|
|
|
|
+ "visualization/sug_v6_1_2_8/index.js",
|
|
|
|
|
+ abs_context_file,
|
|
|
|
|
+ abs_output_html
|
|
|
|
|
+ ])
|
|
|
|
|
+
|
|
|
|
|
+ if result.returncode == 0:
|
|
|
|
|
+ print(f"✅ 可视化已生成: {output_html}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f"❌ 可视化生成失败")
|
|
|
|
|
+
|
|
|
|
|
+ finally:
|
|
|
|
|
+ # 恢复stdout
|
|
|
|
|
+ sys.stdout = original_stdout
|
|
|
|
|
+ log_file.close()
|
|
|
|
|
+ print(f"\n📝 运行日志已保存: {log_file_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|