Ver código fonte

组合词和query词也增加检索

刘立冬 2 semanas atrás
pai
commit
7ddc8edaf1
2 arquivos alterados com 366 adições e 68 exclusões
  1. 281 0
      extract_topn_multimodal.py
  2. 85 68
      knowledge_search_traverse.py

+ 281 - 0
extract_topn_multimodal.py

@@ -0,0 +1,281 @@
+"""
+从 run_context_v3.json 中提取 topN 帖子并进行多模态解析
+
+功能:
+1. 读取 run_context_v3.json
+2. 提取所有帖子,按 final_score 排序,取 topN
+3. 使用 multimodal_extractor 进行图片内容解析
+4. 保存结果到独立的 JSON 文件
+
+参数化配置:
+- top_n: 提取前N个帖子(默认10)
+- max_concurrent: 最大并发数(默认5)
+"""
+
+import argparse
+import asyncio
+import json
+import os
+import sys
+from pathlib import Path
+from typing import Optional
+
+# 导入必要的模块
+from knowledge_search_traverse import Post
+from multimodal_extractor import extract_all_posts
+
+
+def load_run_context(json_path: str) -> dict:
+    """加载 run_context_v3.json 文件"""
+    with open(json_path, 'r', encoding='utf-8') as f:
+        return json.load(f)
+
+
+def extract_all_posts_from_context(context_data: dict) -> list[dict]:
+    """从 context 数据中提取所有帖子(按note_id去重,保留得分最高的)"""
+    # 使用字典进行去重,key为note_id
+    posts_dict = {}
+
+    # 遍历所有轮次
+    for round_data in context_data.get('rounds', []):
+        # 遍历搜索结果
+        for search_result in round_data.get('search_results', []):
+            # 遍历帖子列表
+            for post in search_result.get('post_list', []):
+                note_id = post.get('note_id')
+                if not note_id:
+                    continue
+
+                # 如果是新帖子,直接添加
+                if note_id not in posts_dict:
+                    posts_dict[note_id] = post
+                else:
+                    # 如果已存在,比较final_score,保留得分更高的
+                    existing_score = posts_dict[note_id].get('final_score')
+                    current_score = post.get('final_score')
+
+                    # 如果当前帖子的分数更高,或者现有帖子没有分数,则替换
+                    if existing_score is None or (current_score is not None and current_score > existing_score):
+                        posts_dict[note_id] = post
+
+    # 返回去重后的帖子列表
+    return list(posts_dict.values())
+
+
+def filter_and_sort_topn(posts: list[dict], top_n: int = 10) -> list[dict]:
+    """过滤并排序,获取 final_score topN 的帖子"""
+    # 过滤掉 final_score 为 null 的帖子
+    valid_posts = [p for p in posts if p.get('final_score') is not None]
+
+    # 按 final_score 降序排序
+    sorted_posts = sorted(valid_posts, key=lambda x: x.get('final_score', 0), reverse=True)
+
+    # 取前N个
+    topn = sorted_posts[:top_n]
+
+    return topn
+
+
+def convert_to_post_objects(post_dicts: list[dict]) -> list[Post]:
+    """将字典数据转换为 Post 对象"""
+    post_objects = []
+
+    for post_dict in post_dicts:
+        # 创建 Post 对象,设置默认 type="normal"
+        post = Post(
+            note_id=post_dict.get('note_id', ''),
+            note_url=post_dict.get('note_url', ''),
+            title=post_dict.get('title', ''),
+            body_text=post_dict.get('body_text', ''),
+            type='normal',  # 默认值,因为原数据缺少此字段
+            images=post_dict.get('images', []),
+            video=post_dict.get('video', ''),
+            interact_info=post_dict.get('interact_info', {}),
+        )
+        post_objects.append(post)
+
+    return post_objects
+
+
+def save_extraction_results(results: dict, output_path: str, topn_posts: list[dict]):
+    """保存多模态解析结果到 JSON 文件"""
+    # 构建输出数据
+    output_data = {
+        'total_extracted': len(results),
+        'extraction_results': []
+    }
+
+    # 遍历每个解析结果
+    for note_id, extraction in results.items():
+        # 找到对应的原始帖子数据
+        original_post = None
+        for post in topn_posts:
+            if post.get('note_id') == note_id:
+                original_post = post
+                break
+
+        # 构建结果条目
+        result_entry = {
+            'note_id': extraction.note_id,
+            'note_url': extraction.note_url,
+            'title': extraction.title,
+            'body_text': extraction.body_text,
+            'type': extraction.type,
+            'extraction_time': extraction.extraction_time,
+            'final_score': original_post.get('final_score') if original_post else None,
+            'images': [
+                {
+                    'image_index': img.image_index,
+                    'original_url': img.original_url,
+                    'description': img.description,
+                    'extract_text': img.extract_text
+                }
+                for img in extraction.images
+            ]
+        }
+
+        output_data['extraction_results'].append(result_entry)
+
+    # 保存到文件
+    with open(output_path, 'w', encoding='utf-8') as f:
+        json.dump(output_data, f, ensure_ascii=False, indent=2)
+
+    print(f"\n✅ 结果已保存到: {output_path}")
+
+
+async def main(context_file_path: str, output_file_path: str, top_n: int = 10,
+               max_concurrent: int = 5):
+    """主函数
+
+    Args:
+        context_file_path: run_context_v3.json 文件路径
+        output_file_path: 输出文件路径
+        top_n: 提取前N个帖子(默认10)
+        max_concurrent: 最大并发数(默认5)
+    """
+    print("=" * 80)
+    print(f"多模态解析 - Top{top_n} 帖子")
+    print("=" * 80)
+
+    # 1. 加载数据
+    print(f"\n📂 加载文件: {context_file_path}")
+    context_data = load_run_context(context_file_path)
+
+    # 2. 提取所有帖子
+    print(f"\n🔍 提取所有帖子...")
+    all_posts = extract_all_posts_from_context(context_data)
+    print(f"   去重后共找到 {len(all_posts)} 个唯一帖子")
+
+    # 3. 过滤并排序获取 topN
+    print(f"\n📊 筛选 top{top_n} 帖子...")
+    topn_posts = filter_and_sort_topn(all_posts, top_n)
+
+    if len(topn_posts) == 0:
+        print("   ⚠️  没有找到有效的帖子")
+        return
+
+    print(f"   Top{top_n} 帖子得分范围: {topn_posts[-1].get('final_score')} ~ {topn_posts[0].get('final_score')}")
+
+    # 打印 topN 列表
+    print(f"\n   Top{top_n} 帖子列表:")
+    for i, post in enumerate(topn_posts, 1):
+        print(f"   {i}. [{post.get('final_score')}] {post.get('title')[:40]}... ({post.get('note_id')})")
+
+    # 4. 转换为 Post 对象
+    print(f"\n🔄 转换为 Post 对象...")
+    post_objects = convert_to_post_objects(topn_posts)
+    print(f"   成功转换 {len(post_objects)} 个 Post 对象")
+
+    # 5. 进行多模态解析
+    print(f"\n🖼️  开始多模态图片内容解析...")
+    print(f"   (并发限制: {max_concurrent})")
+    extraction_results = await extract_all_posts(
+        post_objects,
+        max_concurrent=max_concurrent
+    )
+
+    # 6. 保存结果
+    print(f"\n💾 保存解析结果...")
+    save_extraction_results(extraction_results, output_file_path, topn_posts)
+
+    print("\n" + "=" * 80)
+    print("✅ 处理完成!")
+    print("=" * 80)
+
+
+if __name__ == "__main__":
+    # 创建命令行参数解析器
+    parser = argparse.ArgumentParser(
+        description='从 run_context_v3.json 中提取 topN 帖子并进行多模态解析',
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog='''
+示例用法:
+  # 使用默认参数 (top10, 并发5)
+  python3 extract_topn_multimodal.py
+
+  # 提取前20个帖子
+  python3 extract_topn_multimodal.py --top-n 20
+
+  # 自定义并发数
+  python3 extract_topn_multimodal.py --top-n 15 --max-concurrent 10
+
+  # 指定输入输出文件
+  python3 extract_topn_multimodal.py -i input.json -o output.json --top-n 30
+        '''
+    )
+
+    # 默认路径配置
+    DEFAULT_CONTEXT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/run_context_v3.json"
+    DEFAULT_OUTPUT_FILE = "input/test_case/output/knowledge_search_traverse/20251114/005215_b1/multimodal_extraction_topn.json"
+
+    # 添加参数
+    parser.add_argument(
+        '-i', '--input',
+        dest='context_file',
+        default=DEFAULT_CONTEXT_FILE,
+        help=f'输入的 run_context_v3.json 文件路径 (默认: {DEFAULT_CONTEXT_FILE})'
+    )
+    parser.add_argument(
+        '-o', '--output',
+        dest='output_file',
+        default=DEFAULT_OUTPUT_FILE,
+        help=f'输出的 JSON 文件路径 (默认: {DEFAULT_OUTPUT_FILE})'
+    )
+    parser.add_argument(
+        '-n', '--top-n',
+        dest='top_n',
+        type=int,
+        default=10,
+        help='提取前N个帖子 (默认: 10)'
+    )
+    parser.add_argument(
+        '-c', '--max-concurrent',
+        dest='max_concurrent',
+        type=int,
+        default=5,
+        help='最大并发数 (默认: 5)'
+    )
+
+    # 解析参数
+    args = parser.parse_args()
+
+    # 检查文件是否存在
+    if not os.path.exists(args.context_file):
+        print(f"❌ 错误: 文件不存在 - {args.context_file}")
+        sys.exit(1)
+
+    # 打印参数配置
+    print(f"\n📋 参数配置:")
+    print(f"   输入文件: {args.context_file}")
+    print(f"   输出文件: {args.output_file}")
+    print(f"   提取数量: Top{args.top_n}")
+    print(f"   最大并发: {args.max_concurrent}")
+    print()
+
+    # 运行主函数
+    asyncio.run(main(
+        args.context_file,
+        args.output_file,
+        args.top_n,
+        args.max_concurrent
+    ))

+ 85 - 68
knowledge_search_traverse.py

@@ -3642,91 +3642,85 @@ async def run_round_v2(
                     "type": "sug"
                 })
 
-    # 步骤3: 搜索高分SUG
-    print(f"\n[步骤3] 搜索高分SUG(阈值 > {sug_threshold})...")
-    high_score_sugs = [sug for sug in all_sugs if sug.score_with_o > sug_threshold]
-    print(f"  找到 {len(high_score_sugs)} 个高分SUG")
-
-    search_list = []
-    # extraction_results = {}  # 内容提取流程已断开
-
-    if len(high_score_sugs) > 0:
-        async def search_for_sug(sug: Sug) -> Search:
-            """返回Search结果"""
-            print(f"    搜索: {sug.text}")
-            # post_extractions = {}  # 内容提取流程已断开
-
-            try:
-                search_result = xiaohongshu_search.search(keyword=sug.text)
-                # xiaohongshu_search.search() 已经返回解析后的数据
-                notes = search_result.get("data", {}).get("data", [])
-                post_list = []
-                for note in notes[:10]:
+    # 定义通用搜索函数(供步骤2.5、3、5.5共用)
+    async def search_keyword(text: str, score: float, source_type: str) -> Search:
+        """通用搜索函数"""
+        print(f"    搜索: {text} (来源: {source_type})")
+        try:
+            search_result = xiaohongshu_search.search(keyword=text)
+            notes = search_result.get("data", {}).get("data", [])
+            post_list = []
+
+            for note in notes[:10]:
+                try:
+                    post = process_note_data(note)
+                    post_list.append(post)
+                except Exception as e:
+                    print(f"      ⚠️  解析帖子失败 {note.get('id', 'unknown')}: {str(e)[:50]}")
+
+            # 补充详情信息(仅视频类型需要补充视频URL)
+            video_posts = [p for p in post_list if p.type == "video"]
+            if video_posts:
+                print(f"      补充详情({len(video_posts)}个视频)...")
+                for post in video_posts:
                     try:
-                        post = process_note_data(note)
-
-                        # # 🆕 多模态提取(搜索后立即处理) - 内容提取流程已断开
-                        # if post.type == "normal" and len(post.images) > 0:
-                        #     extraction = await extract_post_images(post)
-                        #     if extraction:
-                        #         post_extractions[post.note_id] = extraction
-
-                        post_list.append(post)
+                        detail_response = xiaohongshu_detail.get_detail(post.note_id)
+                        enrich_post_with_detail(post, detail_response)
                     except Exception as e:
-                        print(f"      ⚠️  解析帖子失败 {note.get('id', 'unknown')}: {str(e)[:50]}")
+                        print(f"        ⚠️  详情补充失败 {post.note_id}: {str(e)[:50]}")
 
-                # 补充详情信息(仅视频类型需要补充视频URL)
-                video_posts = [p for p in post_list if p.type == "video"]
-                if video_posts:
-                    print(f"      补充详情({len(video_posts)}个视频)...")
-                    for post in video_posts:
-                        try:
-                            detail_response = xiaohongshu_detail.get_detail(post.note_id)
-                            enrich_post_with_detail(post, detail_response)
-                        except Exception as e:
-                            print(f"        ⚠️  详情补充失败 {post.note_id}: {str(e)[:50]}")
+            print(f"      → 找到 {len(post_list)} 个帖子")
+            return Search(text=text, score_with_o=score, post_list=post_list)
+        except Exception as e:
+            print(f"      ✗ 搜索失败: {e}")
+            return Search(text=text, score_with_o=score, post_list=[])
 
-                print(f"      → 找到 {len(post_list)} 个帖子")
+    # 初始化search_list
+    search_list = []
 
-                return Search(
-                    text=sug.text,
-                    score_with_o=sug.score_with_o,
-                    from_q=sug.from_q,
-                    post_list=post_list
-                )
-                # , post_extractions  # 内容提取流程已断开
+    # 步骤2.5: 搜索高分query_input
+    print(f"\n[步骤2.5] 搜索高分输入query(阈值 > {sug_threshold})...")
+    high_score_queries = [q for q in query_input if q.score_with_o > sug_threshold]
+    print(f"  找到 {len(high_score_queries)} 个高分输入query")
 
-            except Exception as e:
-                print(f"      ✗ 搜索失败: {e}")
-                return Search(
-                    text=sug.text,
-                    score_with_o=sug.score_with_o,
-                    from_q=sug.from_q,
-                    post_list=[]
-                )
-                # , {}  # 内容提取流程已断开
+    if high_score_queries:
+        query_search_tasks = [search_keyword(q.text, q.score_with_o, "query_input")
+                              for q in high_score_queries]
+        query_searches = await asyncio.gather(*query_search_tasks)
+        search_list.extend(query_searches)
 
-        search_tasks = [search_for_sug(sug) for sug in high_score_sugs]
-        results = await asyncio.gather(*search_tasks)
+        # 评估搜索结果中的帖子
+        if enable_evaluation:
+            print(f"\n[评估] 评估query_input搜索结果中的帖子...")
+            for search in query_searches:
+                if search.post_list:
+                    print(f"  评估来自 '{search.text}' 的 {len(search.post_list)} 个帖子")
+                    for post in search.post_list:
+                        knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level = await evaluate_post_v3(post, o, semaphore=None)
+                        if knowledge_eval:
+                            apply_evaluation_v3_to_post(post, knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level)
 
-        # 收集搜索结果
-        for search in results:
-            search_list.append(search)
-            # extraction_results.update(extractions)  # 内容提取流程已断开
+    # 步骤3: 搜索高分SUG
+    print(f"\n[步骤3] 搜索高分SUG(阈值 > {sug_threshold})...")
+    high_score_sugs = [sug for sug in all_sugs if sug.score_with_o > sug_threshold]
+    print(f"  找到 {len(high_score_sugs)} 个高分SUG")
+
+    if high_score_sugs:
+        sug_search_tasks = [search_keyword(sug.text, sug.score_with_o, "sug")
+                            for sug in high_score_sugs]
+        sug_searches = await asyncio.gather(*sug_search_tasks)
+        search_list.extend(sug_searches)
 
         # 评估搜索结果中的帖子
         if enable_evaluation:
-            print(f"\n[评估] 评估搜索结果中的帖子...")
-            for search in search_list:
+            print(f"\n[评估] 评估SUG搜索结果中的帖子...")
+            for search in sug_searches:
                 if search.post_list:
                     print(f"  评估来自 '{search.text}' 的 {len(search.post_list)} 个帖子")
-                    # 对每个帖子进行评估 (V3)
                     for post in search.post_list:
                         knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level = await evaluate_post_v3(post, o, semaphore=None)
                         if knowledge_eval:
                             apply_evaluation_v3_to_post(post, knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level)
-        else:
-            print(f"\n[评估] 实时评估已关闭 (使用 --enable-evaluation 启用)")
 
     # 步骤4: 生成N域组合
     print(f"\n[步骤4] 生成{round_num}域组合...")
@@ -3824,6 +3818,29 @@ async def run_round_v2(
             comb.score_with_o > score for score in flat_scores
         )
 
+    # 步骤5.5: 搜索高分组合词
+    print(f"\n[步骤5.5] 搜索高分组合词(阈值 > {sug_threshold})...")
+    high_score_combinations = [comb for comb in domain_combinations
+                               if comb.score_with_o > sug_threshold]
+    print(f"  找到 {len(high_score_combinations)} 个高分组合词")
+
+    if high_score_combinations:
+        comb_search_tasks = [search_keyword(comb.text, comb.score_with_o, "combination")
+                             for comb in high_score_combinations]
+        comb_searches = await asyncio.gather(*comb_search_tasks)
+        search_list.extend(comb_searches)
+
+        # 评估搜索结果中的帖子
+        if enable_evaluation:
+            print(f"\n[评估] 评估组合词搜索结果中的帖子...")
+            for search in comb_searches:
+                if search.post_list:
+                    print(f"  评估来自 '{search.text}' 的 {len(search.post_list)} 个帖子")
+                    for post in search.post_list:
+                        knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level = await evaluate_post_v3(post, o, semaphore=None)
+                        if knowledge_eval:
+                            apply_evaluation_v3_to_post(post, knowledge_eval, content_eval, purpose_eval, category_eval, final_score, match_level)
+
     # 步骤6: 构建 q_list_next(组合 + 高分SUG)
     print(f"\n[步骤6] 生成下轮输入...")
     q_list_next: list[Q] = []