|
|
@@ -12,7 +12,7 @@ import asyncio
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, List
|
|
|
import sys
|
|
|
-from datetime import datetime
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
# 添加项目根目录到路径
|
|
|
project_root = Path(__file__).parent.parent.parent
|
|
|
@@ -21,58 +21,8 @@ sys.path.insert(0, str(project_root))
|
|
|
from lib.hybrid_similarity import compare_phrases_cartesian
|
|
|
from script.data_processing.path_config import PathConfig
|
|
|
|
|
|
-# 进度跟踪
|
|
|
-class ProgressTracker:
|
|
|
- """进度跟踪器"""
|
|
|
- def __init__(self, total: int):
|
|
|
- self.total = total
|
|
|
- self.completed = 0
|
|
|
- self.start_time = datetime.now()
|
|
|
- self.last_update_time = datetime.now()
|
|
|
- self.last_completed = 0
|
|
|
-
|
|
|
- def update(self, count: int = 1):
|
|
|
- """更新进度"""
|
|
|
- self.completed += count
|
|
|
- current_time = datetime.now()
|
|
|
-
|
|
|
- # 每秒最多更新一次,或者达到总数时更新
|
|
|
- if (current_time - self.last_update_time).total_seconds() >= 1.0 or self.completed >= self.total:
|
|
|
- self.display()
|
|
|
- self.last_update_time = current_time
|
|
|
- self.last_completed = self.completed
|
|
|
-
|
|
|
- def display(self):
|
|
|
- """显示进度"""
|
|
|
- if self.total == 0:
|
|
|
- return
|
|
|
-
|
|
|
- percentage = (self.completed / self.total) * 100
|
|
|
- elapsed = (datetime.now() - self.start_time).total_seconds()
|
|
|
-
|
|
|
- # 计算速度和预估剩余时间
|
|
|
- if elapsed > 0:
|
|
|
- speed = self.completed / elapsed
|
|
|
- if speed > 0:
|
|
|
- remaining = (self.total - self.completed) / speed
|
|
|
- eta_str = f", 预计剩余: {int(remaining)}秒"
|
|
|
- else:
|
|
|
- eta_str = ""
|
|
|
- else:
|
|
|
- eta_str = ""
|
|
|
-
|
|
|
- bar_length = 40
|
|
|
- filled_length = int(bar_length * self.completed / self.total)
|
|
|
- bar = '█' * filled_length + '░' * (bar_length - filled_length)
|
|
|
-
|
|
|
- print(f"\r 进度: [{bar}] {self.completed}/{self.total} ({percentage:.1f}%){eta_str}", end='', flush=True)
|
|
|
-
|
|
|
- # 完成时换行
|
|
|
- if self.completed >= self.total:
|
|
|
- print()
|
|
|
-
|
|
|
-# 全局进度跟踪器
|
|
|
-progress_tracker = None
|
|
|
+# 全局进度条
|
|
|
+progress_bar = None
|
|
|
|
|
|
|
|
|
async def process_single_point(
|
|
|
@@ -95,7 +45,7 @@ async def process_single_point(
|
|
|
Returns:
|
|
|
包含 how 步骤列表的点数据
|
|
|
"""
|
|
|
- global progress_tracker
|
|
|
+ global progress_bar
|
|
|
|
|
|
point_name = point.get("名称", "")
|
|
|
feature_list = point.get("特征列表", [])
|
|
|
@@ -110,19 +60,22 @@ async def process_single_point(
|
|
|
feature_names = [f.get("特征名称", "") for f in feature_list]
|
|
|
persona_names = [pf["特征名称"] for pf in persona_features]
|
|
|
|
|
|
+ # 定义进度回调函数
|
|
|
+ def on_llm_progress(count: int):
|
|
|
+ """LLM完成一个任务时的回调"""
|
|
|
+ if progress_bar:
|
|
|
+ progress_bar.update(count)
|
|
|
+
|
|
|
# 核心优化:使用混合模型笛卡尔积一次计算M×N
|
|
|
- try:
|
|
|
- similarity_results = await compare_phrases_cartesian(
|
|
|
- feature_names, # M个特征
|
|
|
- persona_names, # N个人设
|
|
|
- max_concurrent=100 # LLM最大并发数
|
|
|
- )
|
|
|
- # similarity_results[i][j] = {"相似度": float, "说明": str}
|
|
|
- except Exception as e:
|
|
|
- print(f"\n⚠️ 混合模型调用失败: {e}")
|
|
|
- result = point.copy()
|
|
|
- result["how步骤列表"] = []
|
|
|
- return result
|
|
|
+ # max_concurrent 控制的是底层 LLM 的全局并发数
|
|
|
+ similarity_results = await compare_phrases_cartesian(
|
|
|
+ feature_names, # M个特征
|
|
|
+ persona_names, # N个人设
|
|
|
+ max_concurrent=100, # LLM最大并发数(全局共享)
|
|
|
+ progress_callback=on_llm_progress # 传递进度回调
|
|
|
+ )
|
|
|
+ # similarity_results[i][j] = {"相似度": float, "说明": str}
|
|
|
+
|
|
|
|
|
|
# 构建匹配结果(使用模块返回的完整结果)
|
|
|
feature_match_results = []
|
|
|
@@ -161,7 +114,7 @@ async def process_single_point(
|
|
|
all_categories = set()
|
|
|
for ft in ["灵感点", "关键点", "目的点"]:
|
|
|
if ft in category_mapping:
|
|
|
- for fname, fdata in category_mapping[ft].items():
|
|
|
+ for _fname, fdata in category_mapping[ft].items():
|
|
|
cats = fdata.get("所属分类", [])
|
|
|
all_categories.update(cats)
|
|
|
|
|
|
@@ -181,10 +134,6 @@ async def process_single_point(
|
|
|
}
|
|
|
match_results.append(match_result)
|
|
|
|
|
|
- # 更新进度
|
|
|
- if progress_tracker:
|
|
|
- progress_tracker.update(1)
|
|
|
-
|
|
|
feature_match_results.append({
|
|
|
"特征名称": feature_name,
|
|
|
"权重": feature_weight,
|
|
|
@@ -231,36 +180,56 @@ async def process_single_task(
|
|
|
Returns:
|
|
|
包含 how 解构结果的任务
|
|
|
"""
|
|
|
+ global progress_bar
|
|
|
+
|
|
|
post_id = task.get("帖子id", "")
|
|
|
- print(f"\n[{task_index}/{total_tasks}] 处理帖子: {post_id}")
|
|
|
|
|
|
# 获取 what 解构结果
|
|
|
what_result = task.get("what解构结果", {})
|
|
|
|
|
|
+ # 计算当前帖子的总匹配任务数
|
|
|
+ current_task_match_count = 0
|
|
|
+ for point_type in ["灵感点", "关键点", "目的点"]:
|
|
|
+ point_list = what_result.get(f"{point_type}列表", [])
|
|
|
+ for point in point_list:
|
|
|
+ feature_count = len(point.get("特征列表", []))
|
|
|
+ current_task_match_count += feature_count * len(all_persona_features)
|
|
|
+
|
|
|
+ # 创建当前帖子的进度条
|
|
|
+ progress_bar = tqdm(
|
|
|
+ total=current_task_match_count,
|
|
|
+ desc=f"[{task_index}/{total_tasks}] {post_id}",
|
|
|
+ unit="匹配",
|
|
|
+ ncols=100
|
|
|
+ )
|
|
|
+
|
|
|
# 构建 how 解构结果
|
|
|
how_result = {}
|
|
|
|
|
|
- # 处理灵感点、关键点和目的点
|
|
|
+ # 串行处理灵感点、关键点和目的点
|
|
|
for point_type in ["灵感点", "关键点", "目的点"]:
|
|
|
point_list_key = f"{point_type}列表"
|
|
|
point_list = what_result.get(point_list_key, [])
|
|
|
|
|
|
if point_list:
|
|
|
- # 并发处理所有点
|
|
|
- tasks = [
|
|
|
- process_single_point(
|
|
|
+ updated_point_list = []
|
|
|
+ # 串行处理每个点
|
|
|
+ for point in point_list:
|
|
|
+ result = await process_single_point(
|
|
|
point=point,
|
|
|
point_type=point_type,
|
|
|
persona_features=all_persona_features,
|
|
|
category_mapping=category_mapping,
|
|
|
model_name=model_name
|
|
|
)
|
|
|
- for point in point_list
|
|
|
- ]
|
|
|
- updated_point_list = await asyncio.gather(*tasks)
|
|
|
+ updated_point_list.append(result)
|
|
|
|
|
|
# 添加到 how 解构结果
|
|
|
- how_result[point_list_key] = list(updated_point_list)
|
|
|
+ how_result[point_list_key] = updated_point_list
|
|
|
+
|
|
|
+ # 关闭当前帖子的进度条
|
|
|
+ if progress_bar:
|
|
|
+ progress_bar.close()
|
|
|
|
|
|
# 更新任务
|
|
|
updated_task = task.copy()
|
|
|
@@ -273,22 +242,22 @@ async def process_task_list(
|
|
|
task_list: List[Dict],
|
|
|
persona_features_dict: Dict,
|
|
|
category_mapping: Dict = None,
|
|
|
- model_name: str = None
|
|
|
+ model_name: str = None,
|
|
|
+ output_dir: Path = None
|
|
|
) -> List[Dict]:
|
|
|
"""
|
|
|
- 处理整个解构任务列表(并发执行)
|
|
|
+ 处理整个解构任务列表(串行执行,每个帖子处理完立即保存)
|
|
|
|
|
|
Args:
|
|
|
task_list: 解构任务列表
|
|
|
persona_features_dict: 人设特征字典(包含灵感点、目的点、关键点)
|
|
|
category_mapping: 特征分类映射字典
|
|
|
model_name: 使用的模型名称
|
|
|
+ output_dir: 输出目录(如果提供,每个帖子处理完立即保存)
|
|
|
|
|
|
Returns:
|
|
|
包含 how 解构结果的任务列表
|
|
|
"""
|
|
|
- global progress_tracker
|
|
|
-
|
|
|
# 合并三种人设特征(灵感点、关键点、目的点)
|
|
|
all_features = []
|
|
|
|
|
|
@@ -336,12 +305,10 @@ async def process_task_list(
|
|
|
print(f"总匹配任务数: {total_match_count:,}")
|
|
|
print()
|
|
|
|
|
|
- # 初始化全局进度跟踪器
|
|
|
- progress_tracker = ProgressTracker(total_match_count)
|
|
|
-
|
|
|
- # 并发处理所有任务
|
|
|
- tasks = [
|
|
|
- process_single_task(
|
|
|
+ # 串行处理所有任务(一个接一个,每个处理完立即保存)
|
|
|
+ updated_task_list = []
|
|
|
+ for i, task in enumerate(task_list, 1):
|
|
|
+ updated_task = await process_single_task(
|
|
|
task=task,
|
|
|
task_index=i,
|
|
|
total_tasks=len(task_list),
|
|
|
@@ -349,11 +316,19 @@ async def process_task_list(
|
|
|
category_mapping=category_mapping,
|
|
|
model_name=model_name
|
|
|
)
|
|
|
- for i, task in enumerate(task_list, 1)
|
|
|
- ]
|
|
|
- updated_task_list = await asyncio.gather(*tasks)
|
|
|
+ updated_task_list.append(updated_task)
|
|
|
+
|
|
|
+ # 立即保存当前帖子的结果
|
|
|
+ if output_dir:
|
|
|
+ post_id = updated_task.get("帖子id", "unknown")
|
|
|
+ output_file = output_dir / f"{post_id}_how.json"
|
|
|
|
|
|
- return list(updated_task_list)
|
|
|
+ with open(output_file, "w", encoding="utf-8") as f:
|
|
|
+ json.dump(updated_task, f, ensure_ascii=False, indent=4)
|
|
|
+
|
|
|
+ print(f" ✓ 已保存: {output_file.name}")
|
|
|
+
|
|
|
+ return updated_task_list
|
|
|
|
|
|
|
|
|
async def main():
|
|
|
@@ -393,24 +368,15 @@ async def main():
|
|
|
task_list = task_list_data.get("解构任务列表", [])
|
|
|
print(f"总任务数: {len(task_list)}")
|
|
|
|
|
|
- # 处理任务列表
|
|
|
+ # 处理任务列表(每个帖子处理完立即保存)
|
|
|
updated_task_list = await process_task_list(
|
|
|
task_list=task_list,
|
|
|
persona_features_dict=persona_features_data,
|
|
|
category_mapping=category_mapping,
|
|
|
- model_name=None # 使用默认模型
|
|
|
+ model_name=None, # 使用默认模型
|
|
|
+ output_dir=output_dir # 传递输出目录,启用即时保存
|
|
|
)
|
|
|
|
|
|
- # 分文件保存结果
|
|
|
- print(f"\n保存结果到: {output_dir}")
|
|
|
- for task in updated_task_list:
|
|
|
- post_id = task.get("帖子id", "unknown")
|
|
|
- output_file = output_dir / f"{post_id}_how.json"
|
|
|
-
|
|
|
- print(f" 保存: {output_file.name}")
|
|
|
- with open(output_file, "w", encoding="utf-8") as f:
|
|
|
- json.dump(task, f, ensure_ascii=False, indent=4)
|
|
|
-
|
|
|
print("\n完成!")
|
|
|
|
|
|
# 打印统计信息
|