|
|
@@ -76,6 +76,10 @@ class DomainCombination(BaseModel):
|
|
|
score_with_o: float = 0.0 # 与原始问题的评分
|
|
|
reason: str = "" # 评分理由
|
|
|
from_segments: list[str] = Field(default_factory=list) # 来源segment的文本列表
|
|
|
+ source_word_details: list[dict] = Field(default_factory=list) # 词及其得分信息 [{"domain_index":0,"segment_type":"","words":[{"text":"","score":0.0}]}]
|
|
|
+ source_scores: list[float] = Field(default_factory=list) # 来源词的分数列表(扁平化)
|
|
|
+ max_source_score: float | None = None # 来源词的最高分
|
|
|
+ is_above_source_scores: bool = False # 组合得分是否超过所有来源词
|
|
|
|
|
|
|
|
|
# ============================================================================
|
|
|
@@ -2367,30 +2371,43 @@ async def run_round_v2(
|
|
|
for i, comb in enumerate(domain_combinations, 1):
|
|
|
print(f" {i}. {comb.text} {comb.type_label} (分数: {comb.score_with_o:.2f})")
|
|
|
|
|
|
- # 步骤6: 构建 q_list_next(组合 + 高分SUG)
|
|
|
- print(f"\n[步骤6] 生成下轮输入...")
|
|
|
- q_list_next = []
|
|
|
-
|
|
|
- # 6.1 添加高分组合
|
|
|
- high_score_combinations = [comb for comb in domain_combinations if comb.score_with_o > REQUIRED_SCORE_GAIN]
|
|
|
- for comb in high_score_combinations:
|
|
|
- # 生成域字符串,如 "D0,D3"
|
|
|
- domains_str = ','.join([f'D{d}' for d in comb.domains]) if comb.domains else ''
|
|
|
-
|
|
|
- q = Q(
|
|
|
- text=comb.text,
|
|
|
- score_with_o=comb.score_with_o,
|
|
|
- reason=comb.reason,
|
|
|
- from_source="domain_comb",
|
|
|
- type_label=comb.type_label,
|
|
|
- domain_type=domains_str # 添加域信息
|
|
|
+ # 为每个组合补充来源词分数信息,并判断是否超过所有来源词得分
|
|
|
+ for comb in domain_combinations:
|
|
|
+ word_details = []
|
|
|
+ flat_scores: list[float] = []
|
|
|
+ for domain_index, words in zip(comb.domains, comb.source_words):
|
|
|
+ segment = segments[domain_index] if 0 <= domain_index < len(segments) else None
|
|
|
+ segment_type = segment.type if segment else ""
|
|
|
+ segment_text = segment.text if segment else ""
|
|
|
+ items = []
|
|
|
+ for word in words:
|
|
|
+ score = 0.0
|
|
|
+ if segment and word in segment.word_scores:
|
|
|
+ score = segment.word_scores[word]
|
|
|
+ items.append({
|
|
|
+ "text": word,
|
|
|
+ "score": score
|
|
|
+ })
|
|
|
+ flat_scores.append(score)
|
|
|
+ word_details.append({
|
|
|
+ "domain_index": domain_index,
|
|
|
+ "segment_type": segment_type,
|
|
|
+ "segment_text": segment_text,
|
|
|
+ "words": items
|
|
|
+ })
|
|
|
+ comb.source_word_details = word_details
|
|
|
+ comb.source_scores = flat_scores
|
|
|
+ comb.max_source_score = max(flat_scores) if flat_scores else None
|
|
|
+ comb.is_above_source_scores = bool(flat_scores) and all(
|
|
|
+ comb.score_with_o > score for score in flat_scores
|
|
|
)
|
|
|
- q_list_next.append(q)
|
|
|
|
|
|
- print(f" 添加 {len(high_score_combinations)} 个高分组合")
|
|
|
+ # 步骤6: 构建 q_list_next(组合 + 高分SUG)
|
|
|
+ print(f"\n[步骤6] 生成下轮输入...")
|
|
|
+ q_list_next: list[Q] = []
|
|
|
|
|
|
- # 6.2 添加高分SUG(满足增益条件)
|
|
|
- high_gain_sugs = []
|
|
|
+ # 6.1 添加高增益SUG(满足增益条件),并按分数排序
|
|
|
+ sug_candidates: list[tuple[Q, Sug]] = []
|
|
|
for sug in all_sugs:
|
|
|
if sug.from_q and sug.score_with_o >= sug.from_q.score_with_o + REQUIRED_SCORE_GAIN:
|
|
|
q = Q(
|
|
|
@@ -2400,10 +2417,32 @@ async def run_round_v2(
|
|
|
from_source="sug",
|
|
|
type_label=""
|
|
|
)
|
|
|
- q_list_next.append(q)
|
|
|
- high_gain_sugs.append(sug)
|
|
|
+ sug_candidates.append((q, sug))
|
|
|
+
|
|
|
+ sug_candidates.sort(key=lambda item: item[0].score_with_o, reverse=True)
|
|
|
+ q_list_next.extend([item[0] for item in sug_candidates])
|
|
|
+ high_gain_sugs = [item[1] for item in sug_candidates]
|
|
|
+ print(f" 添加 {len(high_gain_sugs)} 个高增益SUG(增益 ≥ {REQUIRED_SCORE_GAIN:.2f})")
|
|
|
+
|
|
|
+ # 6.2 添加高分组合(需超过所有来源词得分),并按分数排序
|
|
|
+ combination_candidates: list[tuple[Q, DomainCombination]] = []
|
|
|
+ for comb in domain_combinations:
|
|
|
+ if comb.is_above_source_scores and comb.score_with_o > 0:
|
|
|
+ domains_str = ','.join([f'D{d}' for d in comb.domains]) if comb.domains else ''
|
|
|
+ q = Q(
|
|
|
+ text=comb.text,
|
|
|
+ score_with_o=comb.score_with_o,
|
|
|
+ reason=comb.reason,
|
|
|
+ from_source="domain_comb",
|
|
|
+ type_label=comb.type_label,
|
|
|
+ domain_type=domains_str # 添加域信息
|
|
|
+ )
|
|
|
+ combination_candidates.append((q, comb))
|
|
|
|
|
|
- print(f" 添加 {len(high_gain_sugs)} 个高增益SUG(增益 > {REQUIRED_SCORE_GAIN})")
|
|
|
+ combination_candidates.sort(key=lambda item: item[0].score_with_o, reverse=True)
|
|
|
+ q_list_next.extend([item[0] for item in combination_candidates])
|
|
|
+ high_score_combinations = [item[1] for item in combination_candidates]
|
|
|
+ print(f" 添加 {len(high_score_combinations)} 个高分组合(组合得分 > 所有来源词)")
|
|
|
|
|
|
# 保存round数据(包含完整帖子信息)
|
|
|
search_results_data = []
|
|
|
@@ -2435,18 +2474,50 @@ async def run_round_v2(
|
|
|
"reason": comb.reason,
|
|
|
"domains": comb.domains,
|
|
|
"source_words": comb.source_words,
|
|
|
- "from_segments": comb.from_segments
|
|
|
+ "from_segments": comb.from_segments,
|
|
|
+ "source_word_details": comb.source_word_details,
|
|
|
+ "source_scores": comb.source_scores,
|
|
|
+ "is_above_source_scores": comb.is_above_source_scores,
|
|
|
+ "max_source_score": comb.max_source_score
|
|
|
}
|
|
|
for comb in domain_combinations
|
|
|
],
|
|
|
- "high_score_combinations": [{"text": q.text, "score": q.score_with_o, "type_label": q.type_label, "type": "combination"} for q in q_list_next if q.from_source == "domain_comb"],
|
|
|
+ "high_score_combinations": [
|
|
|
+ {
|
|
|
+ "text": item[0].text,
|
|
|
+ "score": item[0].score_with_o,
|
|
|
+ "type_label": item[0].type_label,
|
|
|
+ "type": "combination",
|
|
|
+ "is_above_source_scores": item[1].is_above_source_scores
|
|
|
+ }
|
|
|
+ for item in combination_candidates
|
|
|
+ ],
|
|
|
"sug_count": len(all_sugs),
|
|
|
"sug_details": sug_details,
|
|
|
"high_score_sug_count": len(high_score_sugs),
|
|
|
"high_gain_sugs": [{"text": q.text, "score": q.score_with_o, "type": "sug"} for q in q_list_next if q.from_source == "sug"],
|
|
|
"search_count": len(search_list),
|
|
|
"search_results": search_results_data,
|
|
|
- "q_list_next_size": len(q_list_next)
|
|
|
+ "q_list_next_size": len(q_list_next),
|
|
|
+ "q_list_next_sections": {
|
|
|
+ "sugs": [
|
|
|
+ {
|
|
|
+ "text": item[0].text,
|
|
|
+ "score": item[0].score_with_o,
|
|
|
+ "from_source": "sug"
|
|
|
+ }
|
|
|
+ for item in sug_candidates
|
|
|
+ ],
|
|
|
+ "domain_combinations": [
|
|
|
+ {
|
|
|
+ "text": item[0].text,
|
|
|
+ "score": item[0].score_with_o,
|
|
|
+ "from_source": "domain_comb",
|
|
|
+ "is_above_source_scores": item[1].is_above_source_scores
|
|
|
+ }
|
|
|
+ for item in combination_candidates
|
|
|
+ ]
|
|
|
+ }
|
|
|
})
|
|
|
context.rounds.append(round_data)
|
|
|
|