|
@@ -1151,16 +1151,14 @@ async def process_search_results(
|
|
|
async def iterative_search_loop(
|
|
async def iterative_search_loop(
|
|
|
context: RunContext,
|
|
context: RunContext,
|
|
|
max_iterations: int = 20,
|
|
max_iterations: int = 20,
|
|
|
- max_concurrent_queries: int = 5,
|
|
|
|
|
relevance_threshold: float = 0.6
|
|
relevance_threshold: float = 0.6
|
|
|
) -> list[dict]:
|
|
) -> list[dict]:
|
|
|
"""
|
|
"""
|
|
|
- 主循环:迭代搜索
|
|
|
|
|
|
|
+ 主循环:迭代搜索(按层级处理)
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
context: 运行上下文
|
|
context: 运行上下文
|
|
|
- max_iterations: 最大迭代次数
|
|
|
|
|
- max_concurrent_queries: 最大并发query数量
|
|
|
|
|
|
|
+ max_iterations: 最大迭代次数(层级数)
|
|
|
relevance_threshold: 相关度门槛
|
|
relevance_threshold: 相关度门槛
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
@@ -1236,19 +1234,25 @@ async def iterative_search_loop(
|
|
|
|
|
|
|
|
while query_queue and iteration < max_iterations:
|
|
while query_queue and iteration < max_iterations:
|
|
|
iteration += 1
|
|
iteration += 1
|
|
|
|
|
+
|
|
|
|
|
+ # 获取当前层级(队列中最小的level)
|
|
|
|
|
+ current_level = min(q.level for q in query_queue)
|
|
|
|
|
+
|
|
|
|
|
+ # 提取当前层级的所有query
|
|
|
|
|
+ current_batch = [q for q in query_queue if q.level == current_level]
|
|
|
|
|
+ query_queue = [q for q in query_queue if q.level != current_level]
|
|
|
|
|
+
|
|
|
print(f"\n{'='*60}")
|
|
print(f"\n{'='*60}")
|
|
|
- print(f"迭代 {iteration}: 队列中有 {len(query_queue)} 个query")
|
|
|
|
|
|
|
+ print(f"迭代 {iteration}: 处理第 {current_level} 层,共 {len(current_batch)} 个query")
|
|
|
print(f"{'='*60}")
|
|
print(f"{'='*60}")
|
|
|
|
|
|
|
|
- # 限制并发数量
|
|
|
|
|
- current_batch = query_queue[:max_concurrent_queries]
|
|
|
|
|
- query_queue = query_queue[max_concurrent_queries:]
|
|
|
|
|
-
|
|
|
|
|
# 记录本轮处理的queries
|
|
# 记录本轮处理的queries
|
|
|
add_step(context, f"迭代 {iteration}", "iteration", {
|
|
add_step(context, f"迭代 {iteration}", "iteration", {
|
|
|
"iteration": iteration,
|
|
"iteration": iteration,
|
|
|
- "queue_size": len(query_queue) + len(current_batch),
|
|
|
|
|
- "processing_queries": [q.query for q in current_batch]
|
|
|
|
|
|
|
+ "current_level": current_level,
|
|
|
|
|
+ "current_batch_size": len(current_batch),
|
|
|
|
|
+ "remaining_queue_size": len(query_queue),
|
|
|
|
|
+ "processing_queries": [{"query": q.query, "level": q.level} for q in current_batch]
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
new_queries_from_sug = []
|
|
new_queries_from_sug = []
|
|
@@ -1392,7 +1396,6 @@ async def main(input_dir: str, max_iterations: int = 20, visualize: bool = False
|
|
|
satisfied_notes = await iterative_search_loop(
|
|
satisfied_notes = await iterative_search_loop(
|
|
|
run_context,
|
|
run_context,
|
|
|
max_iterations=max_iterations,
|
|
max_iterations=max_iterations,
|
|
|
- max_concurrent_queries=3,
|
|
|
|
|
relevance_threshold=0.6
|
|
relevance_threshold=0.6
|
|
|
)
|
|
)
|
|
|
|
|
|