|
|
@@ -320,7 +320,7 @@ def add_step(context: RunContext, step_name: str, step_type: str, data: dict):
|
|
|
return step
|
|
|
|
|
|
|
|
|
-def add_query_to_graph(context: RunContext, query_state: QueryState, iteration: int, evaluation_reason: str = "", is_selected: bool = True):
|
|
|
+def add_query_to_graph(context: RunContext, query_state: QueryState, iteration: int, evaluation_reason: str = "", is_selected: bool = True, parent_level: int | None = None):
|
|
|
"""添加Query节点到演化图
|
|
|
|
|
|
Args:
|
|
|
@@ -329,8 +329,10 @@ def add_query_to_graph(context: RunContext, query_state: QueryState, iteration:
|
|
|
iteration: 迭代次数
|
|
|
evaluation_reason: 评估原因(可选)
|
|
|
is_selected: 是否被选中进入处理队列(默认True)
|
|
|
+ parent_level: 父节点的层级(用于构造parent_id)
|
|
|
"""
|
|
|
- query_id = query_state.query # 直接使用query作为ID
|
|
|
+ # 使用 "query_level" 格式作为节点ID
|
|
|
+ query_id = f"{query_state.query}_{query_state.level}"
|
|
|
|
|
|
# 初始化图结构
|
|
|
if "nodes" not in context.query_graph:
|
|
|
@@ -354,8 +356,9 @@ def add_query_to_graph(context: RunContext, query_state: QueryState, iteration:
|
|
|
}
|
|
|
|
|
|
# 添加边(父子关系)
|
|
|
- if query_state.parent_query:
|
|
|
- parent_id = query_state.parent_query
|
|
|
+ if query_state.parent_query and parent_level is not None:
|
|
|
+ # 构造父节点ID: parent_query_parent_level
|
|
|
+ parent_id = f"{query_state.parent_query}_{parent_level}"
|
|
|
if parent_id in context.query_graph["nodes"]:
|
|
|
context.query_graph["edges"].append({
|
|
|
"from": parent_id,
|
|
|
@@ -371,8 +374,15 @@ def add_query_to_graph(context: RunContext, query_state: QueryState, iteration:
|
|
|
context.query_graph["iterations"][iteration].append(query_id)
|
|
|
|
|
|
|
|
|
-def add_note_to_graph(context: RunContext, query: str, note: dict):
|
|
|
- """添加Note节点到演化图,并连接到对应的Query"""
|
|
|
+def add_note_to_graph(context: RunContext, query: str, query_level: int, note: dict):
|
|
|
+ """添加Note节点到演化图,并连接到对应的Query
|
|
|
+
|
|
|
+ Args:
|
|
|
+ context: 运行上下文
|
|
|
+ query: query文本
|
|
|
+ query_level: query所在层级
|
|
|
+ note: 帖子数据
|
|
|
+ """
|
|
|
note_id = note["note_id"]
|
|
|
|
|
|
# 初始化图结构
|
|
|
@@ -396,10 +406,11 @@ def add_note_to_graph(context: RunContext, query: str, note: dict):
|
|
|
"found_by_query": query
|
|
|
}
|
|
|
|
|
|
- # 添加边:Query → Note
|
|
|
- if query in context.query_graph["nodes"]:
|
|
|
+ # 添加边:Query → Note,使用 query_level 格式的ID
|
|
|
+ query_id = f"{query}_{query_level}"
|
|
|
+ if query_id in context.query_graph["nodes"]:
|
|
|
context.query_graph["edges"].append({
|
|
|
- "from": query,
|
|
|
+ "from": query_id,
|
|
|
"to": note_id,
|
|
|
"edge_type": "query_to_note",
|
|
|
"match_level": note["evaluation"]["match_level"],
|
|
|
@@ -539,7 +550,7 @@ async def process_suggestions(
|
|
|
new_queries = []
|
|
|
suggestion_evaluations = []
|
|
|
|
|
|
- for sug in suggestions[:5]: # 限制处理数量
|
|
|
+ for sug in suggestions: # 处理所有建议
|
|
|
# 评估sug与原始需求的相关度(注意:这里是与原始需求original_need对比,而非当前query)
|
|
|
# 这样可以确保生成的suggestion始终围绕用户的核心需求
|
|
|
sug_eval = await evaluate_query_relevance(sug, original_need, query_state.relevance_score, context)
|
|
|
@@ -558,7 +569,7 @@ async def process_suggestions(
|
|
|
level=query_state.level + 1,
|
|
|
relevance_score=sug_eval.relevance_score,
|
|
|
parent_query=query,
|
|
|
- strategy="direct_sug"
|
|
|
+ strategy="调用sug"
|
|
|
)
|
|
|
|
|
|
# 判断是否比当前query更好(只有提升的才加入待处理队列)
|
|
|
@@ -570,7 +581,8 @@ async def process_suggestions(
|
|
|
sug_state,
|
|
|
iteration,
|
|
|
evaluation_reason=sug_eval.reason,
|
|
|
- is_selected=is_selected
|
|
|
+ is_selected=is_selected,
|
|
|
+ parent_level=query_state.level # 父节点的层级
|
|
|
)
|
|
|
|
|
|
if is_selected:
|
|
|
@@ -621,7 +633,7 @@ async def process_suggestions(
|
|
|
level=query_state.level + 1,
|
|
|
relevance_score=rewrite_eval.relevance_score,
|
|
|
parent_query=query,
|
|
|
- strategy="rewrite_abstract"
|
|
|
+ strategy="抽象改写"
|
|
|
)
|
|
|
|
|
|
# 添加到演化图(无论是否提升)
|
|
|
@@ -630,7 +642,8 @@ async def process_suggestions(
|
|
|
new_state,
|
|
|
iteration,
|
|
|
evaluation_reason=rewrite_eval.reason,
|
|
|
- is_selected=rewrite_eval.is_improved
|
|
|
+ is_selected=rewrite_eval.is_improved,
|
|
|
+ parent_level=query_state.level # 父节点的层级
|
|
|
)
|
|
|
|
|
|
if rewrite_eval.is_improved:
|
|
|
@@ -681,7 +694,7 @@ async def process_suggestions(
|
|
|
level=query_state.level + 1,
|
|
|
relevance_score=rewrite_syn_eval.relevance_score,
|
|
|
parent_query=query,
|
|
|
- strategy="rewrite_synonym"
|
|
|
+ strategy="同义改写"
|
|
|
)
|
|
|
|
|
|
# 添加到演化图(无论是否提升)
|
|
|
@@ -690,7 +703,8 @@ async def process_suggestions(
|
|
|
new_state,
|
|
|
iteration,
|
|
|
evaluation_reason=rewrite_syn_eval.reason,
|
|
|
- is_selected=rewrite_syn_eval.is_improved
|
|
|
+ is_selected=rewrite_syn_eval.is_improved,
|
|
|
+ parent_level=query_state.level # 父节点的层级
|
|
|
)
|
|
|
|
|
|
if rewrite_syn_eval.is_improved:
|
|
|
@@ -741,7 +755,7 @@ async def process_suggestions(
|
|
|
level=query_state.level + 1,
|
|
|
relevance_score=insertion_eval.relevance_score,
|
|
|
parent_query=query,
|
|
|
- strategy="add_word"
|
|
|
+ strategy="加词"
|
|
|
)
|
|
|
|
|
|
# 添加到演化图(无论是否提升)
|
|
|
@@ -750,7 +764,8 @@ async def process_suggestions(
|
|
|
new_state,
|
|
|
iteration,
|
|
|
evaluation_reason=insertion_eval.reason,
|
|
|
- is_selected=insertion_eval.is_improved
|
|
|
+ is_selected=insertion_eval.is_improved,
|
|
|
+ parent_level=query_state.level # 父节点的层级
|
|
|
)
|
|
|
|
|
|
if insertion_eval.is_improved:
|
|
|
@@ -766,7 +781,7 @@ async def process_suggestions(
|
|
|
"query_relevance": query_state.relevance_score,
|
|
|
"suggestions_count": len(suggestions),
|
|
|
"suggestions_evaluated": len(suggestion_evaluations),
|
|
|
- "suggestion_evaluations": suggestion_evaluations[:10], # 只保存前10个
|
|
|
+ "suggestion_evaluations": suggestion_evaluations, # 保存所有评估
|
|
|
"agent_calls": agent_calls, # 所有Agent调用的详细记录
|
|
|
"new_queries_generated": len(new_queries),
|
|
|
"new_queries": [{"query": nq.query, "score": nq.relevance_score, "strategy": nq.strategy} for nq in new_queries],
|
|
|
@@ -826,7 +841,7 @@ async def process_search_results(
|
|
|
satisfied_notes = []
|
|
|
partial_notes = []
|
|
|
|
|
|
- for note in notes[:10]: # 限制评估数量
|
|
|
+ for note in notes: # 评估所有帖子
|
|
|
note_data = process_note_data(note)
|
|
|
title = note_data["title"] or ""
|
|
|
desc = note_data["desc"] or ""
|
|
|
@@ -858,7 +873,7 @@ async def process_search_results(
|
|
|
"input": {
|
|
|
"note_id": note_data.get("note_id"),
|
|
|
"title": title,
|
|
|
- "desc": desc[:200] if len(desc) > 200 else desc # 限制长度
|
|
|
+ "desc": desc # 完整描述
|
|
|
},
|
|
|
"output": {
|
|
|
"match_level": evaluation.match_level,
|
|
|
@@ -877,7 +892,7 @@ async def process_search_results(
|
|
|
}
|
|
|
|
|
|
# 将所有评估过的帖子添加到演化图(包括satisfied、partial、unsatisfied)
|
|
|
- add_note_to_graph(context, query, note_data)
|
|
|
+ add_note_to_graph(context, query, query_state.level, note_data)
|
|
|
|
|
|
if evaluation.match_level == "satisfied":
|
|
|
satisfied_notes.append(note_data)
|
|
|
@@ -928,7 +943,7 @@ async def process_search_results(
|
|
|
</当前Query>
|
|
|
|
|
|
<缺失的方面>
|
|
|
-{', '.join(set(all_missing[:5]))}
|
|
|
+{', '.join(set(all_missing))}
|
|
|
</缺失的方面>
|
|
|
|
|
|
请改造query使其包含这些缺失的内容。
|
|
|
@@ -942,7 +957,7 @@ async def process_search_results(
|
|
|
"action": "基于缺失方面改造Query",
|
|
|
"input": {
|
|
|
"query": query,
|
|
|
- "missing_aspects": list(set(all_missing[:5]))
|
|
|
+ "missing_aspects": list(set(all_missing))
|
|
|
},
|
|
|
"output": {
|
|
|
"improved_query": improvement.improved_query,
|
|
|
@@ -961,7 +976,7 @@ async def process_search_results(
|
|
|
level=query_state.level + 1,
|
|
|
relevance_score=improved_eval.relevance_score,
|
|
|
parent_query=query,
|
|
|
- strategy="improve_from_partial"
|
|
|
+ strategy="基于部分匹配改进"
|
|
|
)
|
|
|
|
|
|
# 添加到演化图(无论是否提升)
|
|
|
@@ -970,7 +985,8 @@ async def process_search_results(
|
|
|
new_state,
|
|
|
iteration,
|
|
|
evaluation_reason=improved_eval.reason,
|
|
|
- is_selected=improved_eval.is_improved
|
|
|
+ is_selected=improved_eval.is_improved,
|
|
|
+ parent_level=query_state.level # 父节点的层级
|
|
|
)
|
|
|
|
|
|
if improved_eval.is_improved:
|
|
|
@@ -1025,7 +1041,7 @@ async def process_search_results(
|
|
|
level=query_state.level + 1,
|
|
|
relevance_score=rewrite_eval.relevance_score,
|
|
|
parent_query=query,
|
|
|
- strategy="result_rewrite_abstract"
|
|
|
+ strategy="结果分支-抽象改写"
|
|
|
)
|
|
|
|
|
|
# 添加到演化图(无论是否提升)
|
|
|
@@ -1034,7 +1050,8 @@ async def process_search_results(
|
|
|
new_state,
|
|
|
iteration,
|
|
|
evaluation_reason=rewrite_eval.reason,
|
|
|
- is_selected=rewrite_eval.is_improved
|
|
|
+ is_selected=rewrite_eval.is_improved,
|
|
|
+ parent_level=query_state.level # 父节点的层级
|
|
|
)
|
|
|
|
|
|
if rewrite_eval.is_improved:
|
|
|
@@ -1085,7 +1102,7 @@ async def process_search_results(
|
|
|
level=query_state.level + 1,
|
|
|
relevance_score=rewrite_syn_eval.relevance_score,
|
|
|
parent_query=query,
|
|
|
- strategy="result_rewrite_synonym"
|
|
|
+ strategy="结果分支-同义改写"
|
|
|
)
|
|
|
|
|
|
# 添加到演化图(无论是否提升)
|
|
|
@@ -1094,7 +1111,8 @@ async def process_search_results(
|
|
|
new_state,
|
|
|
iteration,
|
|
|
evaluation_reason=rewrite_syn_eval.reason,
|
|
|
- is_selected=rewrite_syn_eval.is_improved
|
|
|
+ is_selected=rewrite_syn_eval.is_improved,
|
|
|
+ parent_level=query_state.level # 父节点的层级
|
|
|
)
|
|
|
|
|
|
if rewrite_syn_eval.is_improved:
|
|
|
@@ -1120,7 +1138,7 @@ async def process_search_results(
|
|
|
"score": note["evaluation"]["relevance_score"],
|
|
|
"match_level": note["evaluation"]["match_level"]
|
|
|
}
|
|
|
- for note in satisfied_notes[:10] # 只保存前10个
|
|
|
+ for note in satisfied_notes # 保存所有满足的帖子
|
|
|
],
|
|
|
"agent_calls": agent_calls, # 所有Agent调用的详细记录
|
|
|
"new_queries_generated": len(new_queries),
|
|
|
@@ -1158,7 +1176,7 @@ async def iterative_search_loop(
|
|
|
query=context.q,
|
|
|
level=0,
|
|
|
relevance_score=1.0, # 原始问题本身相关度为1.0
|
|
|
- strategy="root"
|
|
|
+ strategy="根节点"
|
|
|
)
|
|
|
add_query_to_graph(context, root_query_state, 0, evaluation_reason="原始问题,作为搜索的根节点", is_selected=True)
|
|
|
print(f"[根节点] 原始问题: {context.q}")
|
|
|
@@ -1183,30 +1201,30 @@ async def iterative_search_loop(
|
|
|
})
|
|
|
print(f" {word}: {eval_result.relevance_score:.2f}")
|
|
|
|
|
|
- # 按相关度排序,选择top 3
|
|
|
+ # 按相关度排序,使用所有分词
|
|
|
word_scores.sort(key=lambda x: x['score'], reverse=True)
|
|
|
- selected_words = word_scores[:3]
|
|
|
+ selected_words = word_scores # 使用所有分词
|
|
|
|
|
|
- # 将所有分词添加到演化图(包括未被选中的)
|
|
|
+ # 将所有分词添加到演化图(全部被选中)
|
|
|
for item in word_scores:
|
|
|
- is_selected = item in selected_words
|
|
|
+ is_selected = True # 所有分词都被选中
|
|
|
query_state = QueryState(
|
|
|
query=item['word'],
|
|
|
level=1,
|
|
|
relevance_score=item['score'],
|
|
|
- strategy="initial",
|
|
|
+ strategy="初始分词",
|
|
|
parent_query=context.q # 父节点是原始问题
|
|
|
)
|
|
|
|
|
|
# 添加到演化图(会自动创建从parent_query到该query的边)
|
|
|
- add_query_to_graph(context, query_state, 0, evaluation_reason=item['eval'].reason, is_selected=is_selected)
|
|
|
+ add_query_to_graph(context, query_state, 0, evaluation_reason=item['eval'].reason, is_selected=is_selected, parent_level=0) # 父节点是根节点(level 0)
|
|
|
|
|
|
# 只有被选中的才加入队列
|
|
|
if is_selected:
|
|
|
query_queue.append(query_state)
|
|
|
|
|
|
- print(f"\n初始query队列(按相关度选择): {[(q.query, f'{q.relevance_score:.2f}') for q in query_queue]}")
|
|
|
- print(f" (共评估了 {len(word_scores)} 个分词,选择了前 {len(query_queue)} 个)")
|
|
|
+ print(f"\n初始query队列(按相关度排序): {[(q.query, f'{q.relevance_score:.2f}') for q in query_queue]}")
|
|
|
+ print(f" (共评估了 {len(word_scores)} 个分词,全部加入队列)")
|
|
|
|
|
|
# 3. API实例
|
|
|
xiaohongshu_api = XiaohongshuSearchRecommendations()
|
|
|
@@ -1276,9 +1294,8 @@ async def iterative_search_loop(
|
|
|
# 更新队列
|
|
|
all_new_queries = new_queries_from_sug + new_queries_from_result
|
|
|
|
|
|
- # 将新生成的queries添加到演化图
|
|
|
- for new_q in all_new_queries:
|
|
|
- add_query_to_graph(context, new_q, iteration)
|
|
|
+ # 注意:不需要在这里再次添加到演化图,因为在 process_suggestions 和 process_search_results 中已经添加过了
|
|
|
+ # 如果在这里再次调用 add_query_to_graph,会覆盖之前设置的 evaluation_reason 等字段
|
|
|
|
|
|
query_queue.extend(all_new_queries)
|
|
|
|
|
|
@@ -1390,7 +1407,7 @@ async def main(input_dir: str, max_iterations: int = 20, visualize: bool = False
|
|
|
|
|
|
if satisfied_notes:
|
|
|
output += "【满足需求的帖子】\n\n"
|
|
|
- for idx, note in enumerate(satisfied_notes[:10], 1):
|
|
|
+ for idx, note in enumerate(satisfied_notes, 1):
|
|
|
output += f"{idx}. {note['title']}\n"
|
|
|
output += f" 相关度: {note['evaluation']['relevance_score']:.2f}\n"
|
|
|
output += f" URL: {note['note_url']}\n\n"
|