刘立冬 3 semanas atrás
pai
commit
34b2a0c654
1 arquivos alterados com 98 adições e 27 exclusões
  1. 98 27
      sug_v6_1_2_122.py

+ 98 - 27
sug_v6_1_2_122.py

@@ -76,6 +76,10 @@ class DomainCombination(BaseModel):
     score_with_o: float = 0.0  # 与原始问题的评分
     reason: str = ""  # 评分理由
     from_segments: list[str] = Field(default_factory=list)  # 来源segment的文本列表
+    source_word_details: list[dict] = Field(default_factory=list)  # 词及其得分信息 [{"domain_index":0,"segment_type":"","words":[{"text":"","score":0.0}]}]
+    source_scores: list[float] = Field(default_factory=list)  # 来源词的分数列表(扁平化)
+    max_source_score: float | None = None  # 来源词的最高分
+    is_above_source_scores: bool = False  # 组合得分是否超过所有来源词
 
 
 # ============================================================================
@@ -2367,30 +2371,43 @@ async def run_round_v2(
     for i, comb in enumerate(domain_combinations, 1):
         print(f"    {i}. {comb.text} {comb.type_label} (分数: {comb.score_with_o:.2f})")
 
-    # 步骤6: 构建 q_list_next(组合 + 高分SUG)
-    print(f"\n[步骤6] 生成下轮输入...")
-    q_list_next = []
-
-    # 6.1 添加高分组合
-    high_score_combinations = [comb for comb in domain_combinations if comb.score_with_o > REQUIRED_SCORE_GAIN]
-    for comb in high_score_combinations:
-        # 生成域字符串,如 "D0,D3"
-        domains_str = ','.join([f'D{d}' for d in comb.domains]) if comb.domains else ''
-
-        q = Q(
-            text=comb.text,
-            score_with_o=comb.score_with_o,
-            reason=comb.reason,
-            from_source="domain_comb",
-            type_label=comb.type_label,
-            domain_type=domains_str  # 添加域信息
+    # 为每个组合补充来源词分数信息,并判断是否超过所有来源词得分
+    for comb in domain_combinations:
+        word_details = []
+        flat_scores: list[float] = []
+        for domain_index, words in zip(comb.domains, comb.source_words):
+            segment = segments[domain_index] if 0 <= domain_index < len(segments) else None
+            segment_type = segment.type if segment else ""
+            segment_text = segment.text if segment else ""
+            items = []
+            for word in words:
+                score = 0.0
+                if segment and word in segment.word_scores:
+                    score = segment.word_scores[word]
+                items.append({
+                    "text": word,
+                    "score": score
+                })
+                flat_scores.append(score)
+            word_details.append({
+                "domain_index": domain_index,
+                "segment_type": segment_type,
+                "segment_text": segment_text,
+                "words": items
+            })
+        comb.source_word_details = word_details
+        comb.source_scores = flat_scores
+        comb.max_source_score = max(flat_scores) if flat_scores else None
+        comb.is_above_source_scores = bool(flat_scores) and all(
+            comb.score_with_o > score for score in flat_scores
         )
-        q_list_next.append(q)
 
-    print(f"  添加 {len(high_score_combinations)} 个高分组合")
+    # 步骤6: 构建 q_list_next(组合 + 高分SUG)
+    print(f"\n[步骤6] 生成下轮输入...")
+    q_list_next: list[Q] = []
 
-    # 6.2 添加高分SUG(满足增益条件)
-    high_gain_sugs = []
+    # 6.1 添加高增益SUG(满足增益条件),并按分数排序
+    sug_candidates: list[tuple[Q, Sug]] = []
     for sug in all_sugs:
         if sug.from_q and sug.score_with_o >= sug.from_q.score_with_o + REQUIRED_SCORE_GAIN:
             q = Q(
@@ -2400,10 +2417,32 @@ async def run_round_v2(
                 from_source="sug",
                 type_label=""
             )
-            q_list_next.append(q)
-            high_gain_sugs.append(sug)
+            sug_candidates.append((q, sug))
+
+    sug_candidates.sort(key=lambda item: item[0].score_with_o, reverse=True)
+    q_list_next.extend([item[0] for item in sug_candidates])
+    high_gain_sugs = [item[1] for item in sug_candidates]
+    print(f"  添加 {len(high_gain_sugs)} 个高增益SUG(增益 ≥ {REQUIRED_SCORE_GAIN:.2f})")
+
+    # 6.2 添加高分组合(需超过所有来源词得分),并按分数排序
+    combination_candidates: list[tuple[Q, DomainCombination]] = []
+    for comb in domain_combinations:
+        if comb.is_above_source_scores and comb.score_with_o > 0:
+            domains_str = ','.join([f'D{d}' for d in comb.domains]) if comb.domains else ''
+            q = Q(
+                text=comb.text,
+                score_with_o=comb.score_with_o,
+                reason=comb.reason,
+                from_source="domain_comb",
+                type_label=comb.type_label,
+                domain_type=domains_str  # 添加域信息
+            )
+            combination_candidates.append((q, comb))
 
-    print(f"  添加 {len(high_gain_sugs)} 个高增益SUG(增益 > {REQUIRED_SCORE_GAIN})")
+    combination_candidates.sort(key=lambda item: item[0].score_with_o, reverse=True)
+    q_list_next.extend([item[0] for item in combination_candidates])
+    high_score_combinations = [item[1] for item in combination_candidates]
+    print(f"  添加 {len(high_score_combinations)} 个高分组合(组合得分 > 所有来源词)")
 
     # 保存round数据(包含完整帖子信息)
     search_results_data = []
@@ -2435,18 +2474,50 @@ async def run_round_v2(
                 "reason": comb.reason,
                 "domains": comb.domains,
                 "source_words": comb.source_words,
-                "from_segments": comb.from_segments
+                "from_segments": comb.from_segments,
+                "source_word_details": comb.source_word_details,
+                "source_scores": comb.source_scores,
+                "is_above_source_scores": comb.is_above_source_scores,
+                "max_source_score": comb.max_source_score
             }
             for comb in domain_combinations
         ],
-        "high_score_combinations": [{"text": q.text, "score": q.score_with_o, "type_label": q.type_label, "type": "combination"} for q in q_list_next if q.from_source == "domain_comb"],
+        "high_score_combinations": [
+            {
+                "text": item[0].text,
+                "score": item[0].score_with_o,
+                "type_label": item[0].type_label,
+                "type": "combination",
+                "is_above_source_scores": item[1].is_above_source_scores
+            }
+            for item in combination_candidates
+        ],
         "sug_count": len(all_sugs),
         "sug_details": sug_details,
         "high_score_sug_count": len(high_score_sugs),
         "high_gain_sugs": [{"text": q.text, "score": q.score_with_o, "type": "sug"} for q in q_list_next if q.from_source == "sug"],
         "search_count": len(search_list),
         "search_results": search_results_data,
-        "q_list_next_size": len(q_list_next)
+        "q_list_next_size": len(q_list_next),
+        "q_list_next_sections": {
+            "sugs": [
+                {
+                    "text": item[0].text,
+                    "score": item[0].score_with_o,
+                    "from_source": "sug"
+                }
+                for item in sug_candidates
+            ],
+            "domain_combinations": [
+                {
+                    "text": item[0].text,
+                    "score": item[0].score_with_o,
+                    "from_source": "domain_comb",
+                    "is_above_source_scores": item[1].is_above_source_scores
+                }
+                for item in combination_candidates
+            ]
+        }
     })
     context.rounds.append(round_data)