match_inspiration_features.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 灵感点特征匹配脚本
  5. 从解构任务列表中提取灵感点的特征,与人设灵感特征进行匹配,
  6. 使用 relation_analyzer 模块分析特征之间的语义关系。
  7. """
  8. import json
  9. import asyncio
  10. from pathlib import Path
  11. from typing import Dict, List
  12. import sys
  13. from datetime import datetime
  14. # 添加项目根目录到路径
  15. project_root = Path(__file__).parent.parent.parent
  16. sys.path.insert(0, str(project_root))
  17. from lib.hybrid_similarity import compare_phrases_cartesian
  18. from script.data_processing.path_config import PathConfig
  19. # 进度跟踪
  20. class ProgressTracker:
  21. """进度跟踪器"""
  22. def __init__(self, total: int):
  23. self.total = total
  24. self.completed = 0
  25. self.start_time = datetime.now()
  26. self.last_update_time = datetime.now()
  27. self.last_completed = 0
  28. def update(self, count: int = 1):
  29. """更新进度"""
  30. self.completed += count
  31. current_time = datetime.now()
  32. # 每秒最多更新一次,或者达到总数时更新
  33. if (current_time - self.last_update_time).total_seconds() >= 1.0 or self.completed >= self.total:
  34. self.display()
  35. self.last_update_time = current_time
  36. self.last_completed = self.completed
  37. def display(self):
  38. """显示进度"""
  39. if self.total == 0:
  40. return
  41. percentage = (self.completed / self.total) * 100
  42. elapsed = (datetime.now() - self.start_time).total_seconds()
  43. # 计算速度和预估剩余时间
  44. if elapsed > 0:
  45. speed = self.completed / elapsed
  46. if speed > 0:
  47. remaining = (self.total - self.completed) / speed
  48. eta_str = f", 预计剩余: {int(remaining)}秒"
  49. else:
  50. eta_str = ""
  51. else:
  52. eta_str = ""
  53. bar_length = 40
  54. filled_length = int(bar_length * self.completed / self.total)
  55. bar = '█' * filled_length + '░' * (bar_length - filled_length)
  56. print(f"\r 进度: [{bar}] {self.completed}/{self.total} ({percentage:.1f}%){eta_str}", end='', flush=True)
  57. # 完成时换行
  58. if self.completed >= self.total:
  59. print()
  60. # 全局进度跟踪器
  61. progress_tracker = None
  62. async def process_single_point(
  63. point: Dict,
  64. point_type: str,
  65. persona_features: List[Dict],
  66. category_mapping: Dict = None,
  67. model_name: str = None
  68. ) -> Dict:
  69. """
  70. 处理单个点 - 使用笛卡尔积批量计算(优化版)
  71. Args:
  72. point: 点数据(灵感点/关键点/目的点)
  73. point_type: 点类型("灵感点"/"关键点"/"目的点")
  74. persona_features: 人设特征列表
  75. category_mapping: 特征分类映射字典
  76. model_name: 使用的模型名称
  77. Returns:
  78. 包含 how 步骤列表的点数据
  79. """
  80. global progress_tracker
  81. point_name = point.get("名称", "")
  82. feature_list = point.get("特征列表", [])
  83. # 如果没有特征,直接返回
  84. if not feature_list or not persona_features:
  85. result = point.copy()
  86. result["how步骤列表"] = []
  87. return result
  88. # 提取特征名称和人设名称列表
  89. feature_names = [f.get("特征名称", "") for f in feature_list]
  90. persona_names = [pf["特征名称"] for pf in persona_features]
  91. # 核心优化:使用混合模型笛卡尔积一次计算M×N
  92. try:
  93. similarity_results = await compare_phrases_cartesian(
  94. feature_names, # M个特征
  95. persona_names, # N个人设
  96. max_concurrent=100 # LLM最大并发数
  97. )
  98. # similarity_results[i][j] = {"相似度": float, "说明": str}
  99. except Exception as e:
  100. print(f"\n⚠️ 混合模型调用失败: {e}")
  101. result = point.copy()
  102. result["how步骤列表"] = []
  103. return result
  104. # 构建匹配结果(使用模块返回的完整结果)
  105. feature_match_results = []
  106. for i, feature_item in enumerate(feature_list):
  107. feature_name = feature_item.get("特征名称", "")
  108. feature_weight = feature_item.get("权重", 1.0)
  109. # 该特征与所有人设的匹配结果
  110. match_results = []
  111. for j, persona_feature in enumerate(persona_features):
  112. persona_name = persona_feature["特征名称"]
  113. persona_level = persona_feature["人设特征层级"]
  114. # 直接使用模块返回的完整结果
  115. similarity_result = similarity_results[i][j]
  116. # 判断特征类型和分类
  117. feature_type = "分类" # 默认为分类
  118. categories = []
  119. if category_mapping:
  120. # 先在标签特征中查找
  121. is_tag_feature = False
  122. for ft in ["灵感点", "关键点", "目的点"]:
  123. if ft in category_mapping:
  124. type_mapping = category_mapping[ft]
  125. if persona_name in type_mapping:
  126. feature_type = "标签"
  127. categories = type_mapping[persona_name].get("所属分类", [])
  128. is_tag_feature = True
  129. break
  130. # 如果不是标签特征,检查是否是分类特征
  131. if not is_tag_feature:
  132. all_categories = set()
  133. for ft in ["灵感点", "关键点", "目的点"]:
  134. if ft in category_mapping:
  135. for fname, fdata in category_mapping[ft].items():
  136. cats = fdata.get("所属分类", [])
  137. all_categories.update(cats)
  138. if persona_name in all_categories:
  139. feature_type = "分类"
  140. categories = []
  141. # 去重分类
  142. unique_categories = list(dict.fromkeys(categories))
  143. match_result = {
  144. "人设特征名称": persona_name,
  145. "人设特征层级": persona_level,
  146. "特征类型": feature_type,
  147. "特征分类": unique_categories,
  148. "匹配结果": similarity_result # 直接使用模块返回的结果
  149. }
  150. match_results.append(match_result)
  151. # 更新进度
  152. if progress_tracker:
  153. progress_tracker.update(1)
  154. feature_match_results.append({
  155. "特征名称": feature_name,
  156. "权重": feature_weight,
  157. "匹配结果": match_results
  158. })
  159. # 构建 how 步骤(保持不变)
  160. step_name_mapping = {
  161. "灵感点": "灵感特征分别匹配人设特征",
  162. "关键点": "关键特征分别匹配人设特征",
  163. "目的点": "目的特征分别匹配人设特征"
  164. }
  165. how_step = {
  166. "步骤名称": step_name_mapping.get(point_type, f"{point_type}特征分别匹配人设特征"),
  167. "特征列表": list(feature_match_results)
  168. }
  169. result = point.copy()
  170. result["how步骤列表"] = [how_step]
  171. return result
  172. async def process_single_task(
  173. task: Dict,
  174. task_index: int,
  175. total_tasks: int,
  176. all_persona_features: List[Dict],
  177. category_mapping: Dict = None,
  178. model_name: str = None
  179. ) -> Dict:
  180. """
  181. 处理单个任务
  182. Args:
  183. task: 任务数据
  184. task_index: 任务索引(从1开始)
  185. total_tasks: 总任务数
  186. all_persona_features: 所有人设特征列表(包含三种层级)
  187. category_mapping: 特征分类映射字典
  188. model_name: 使用的模型名称
  189. Returns:
  190. 包含 how 解构结果的任务
  191. """
  192. post_id = task.get("帖子id", "")
  193. print(f"\n[{task_index}/{total_tasks}] 处理帖子: {post_id}")
  194. # 获取 what 解构结果
  195. what_result = task.get("what解构结果", {})
  196. # 构建 how 解构结果
  197. how_result = {}
  198. # 处理灵感点、关键点和目的点
  199. for point_type in ["灵感点", "关键点", "目的点"]:
  200. point_list_key = f"{point_type}列表"
  201. point_list = what_result.get(point_list_key, [])
  202. if point_list:
  203. # 并发处理所有点
  204. tasks = [
  205. process_single_point(
  206. point=point,
  207. point_type=point_type,
  208. persona_features=all_persona_features,
  209. category_mapping=category_mapping,
  210. model_name=model_name
  211. )
  212. for point in point_list
  213. ]
  214. updated_point_list = await asyncio.gather(*tasks)
  215. # 添加到 how 解构结果
  216. how_result[point_list_key] = list(updated_point_list)
  217. # 更新任务
  218. updated_task = task.copy()
  219. updated_task["how解构结果"] = how_result
  220. return updated_task
  221. async def process_task_list(
  222. task_list: List[Dict],
  223. persona_features_dict: Dict,
  224. category_mapping: Dict = None,
  225. model_name: str = None
  226. ) -> List[Dict]:
  227. """
  228. 处理整个解构任务列表(并发执行)
  229. Args:
  230. task_list: 解构任务列表
  231. persona_features_dict: 人设特征字典(包含灵感点、目的点、关键点)
  232. category_mapping: 特征分类映射字典
  233. model_name: 使用的模型名称
  234. Returns:
  235. 包含 how 解构结果的任务列表
  236. """
  237. global progress_tracker
  238. # 合并三种人设特征(灵感点、关键点、目的点)
  239. all_features = []
  240. for feature_type in ["灵感点", "关键点", "目的点"]:
  241. # 获取该类型的标签特征
  242. type_features = persona_features_dict.get(feature_type, [])
  243. # 为每个特征添加层级信息
  244. for feature in type_features:
  245. feature_with_level = feature.copy()
  246. feature_with_level["人设特征层级"] = feature_type
  247. all_features.append(feature_with_level)
  248. print(f"人设{feature_type}标签特征数量: {len(type_features)}")
  249. # 从分类映射中提取该类型的分类特征
  250. if category_mapping and feature_type in category_mapping:
  251. type_categories = set()
  252. for _, feature_data in category_mapping[feature_type].items():
  253. categories = feature_data.get("所属分类", [])
  254. type_categories.update(categories)
  255. # 转换为特征格式并添加层级信息
  256. for cat in sorted(type_categories):
  257. all_features.append({
  258. "特征名称": cat,
  259. "人设特征层级": feature_type
  260. })
  261. print(f"人设{feature_type}分类特征数量: {len(type_categories)}")
  262. print(f"总特征数量(三种类型的标签+分类): {len(all_features)}")
  263. # 计算总匹配任务数(灵感点、关键点和目的点)
  264. total_match_count = 0
  265. for task in task_list:
  266. what_result = task.get("what解构结果", {})
  267. for point_type in ["灵感点", "关键点", "目的点"]:
  268. point_list = what_result.get(f"{point_type}列表", [])
  269. for point in point_list:
  270. feature_count = len(point.get("特征列表", []))
  271. total_match_count += feature_count * len(all_features)
  272. print(f"处理灵感点、关键点和目的点特征")
  273. print(f"总匹配任务数: {total_match_count:,}")
  274. print()
  275. # 初始化全局进度跟踪器
  276. progress_tracker = ProgressTracker(total_match_count)
  277. # 并发处理所有任务
  278. tasks = [
  279. process_single_task(
  280. task=task,
  281. task_index=i,
  282. total_tasks=len(task_list),
  283. all_persona_features=all_features,
  284. category_mapping=category_mapping,
  285. model_name=model_name
  286. )
  287. for i, task in enumerate(task_list, 1)
  288. ]
  289. updated_task_list = await asyncio.gather(*tasks)
  290. return list(updated_task_list)
  291. async def main():
  292. """主函数"""
  293. # 使用路径配置
  294. config = PathConfig()
  295. # 确保输出目录存在
  296. config.ensure_dirs()
  297. # 获取路径
  298. task_list_file = config.task_list_file
  299. persona_features_file = config.feature_source_mapping_file
  300. category_mapping_file = config.feature_category_mapping_file
  301. output_dir = config.how_results_dir
  302. print(f"账号: {config.account_name}")
  303. print(f"任务列表文件: {task_list_file}")
  304. print(f"人设特征文件: {persona_features_file}")
  305. print(f"分类映射文件: {category_mapping_file}")
  306. print(f"输出目录: {output_dir}")
  307. print()
  308. print(f"读取解构任务列表: {task_list_file}")
  309. with open(task_list_file, "r", encoding="utf-8") as f:
  310. task_list_data = json.load(f)
  311. print(f"读取人设特征: {persona_features_file}")
  312. with open(persona_features_file, "r", encoding="utf-8") as f:
  313. persona_features_data = json.load(f)
  314. print(f"读取特征分类映射: {category_mapping_file}")
  315. with open(category_mapping_file, "r", encoding="utf-8") as f:
  316. category_mapping = json.load(f)
  317. # 获取任务列表
  318. task_list = task_list_data.get("解构任务列表", [])
  319. print(f"总任务数: {len(task_list)}")
  320. # 处理任务列表
  321. updated_task_list = await process_task_list(
  322. task_list=task_list,
  323. persona_features_dict=persona_features_data,
  324. category_mapping=category_mapping,
  325. model_name=None # 使用默认模型
  326. )
  327. # 分文件保存结果
  328. print(f"\n保存结果到: {output_dir}")
  329. for task in updated_task_list:
  330. post_id = task.get("帖子id", "unknown")
  331. output_file = output_dir / f"{post_id}_how.json"
  332. print(f" 保存: {output_file.name}")
  333. with open(output_file, "w", encoding="utf-8") as f:
  334. json.dump(task, f, ensure_ascii=False, indent=4)
  335. print("\n完成!")
  336. # 打印统计信息
  337. total_inspiration_points = 0
  338. total_key_points = 0
  339. total_purpose_points = 0
  340. total_inspiration_features = 0
  341. total_key_features = 0
  342. total_purpose_features = 0
  343. for task in updated_task_list:
  344. how_result = task.get("how解构结果", {})
  345. # 统计灵感点
  346. inspiration_list = how_result.get("灵感点列表", [])
  347. total_inspiration_points += len(inspiration_list)
  348. for point in inspiration_list:
  349. total_inspiration_features += len(point.get("特征列表", []))
  350. # 统计关键点
  351. key_list = how_result.get("关键点列表", [])
  352. total_key_points += len(key_list)
  353. for point in key_list:
  354. total_key_features += len(point.get("特征列表", []))
  355. # 统计目的点
  356. purpose_list = how_result.get("目的点列表", [])
  357. total_purpose_points += len(purpose_list)
  358. for point in purpose_list:
  359. total_purpose_features += len(point.get("特征列表", []))
  360. print(f"\n统计:")
  361. print(f" 处理的帖子数: {len(updated_task_list)}")
  362. print(f" 处理的灵感点数: {total_inspiration_points}")
  363. print(f" 处理的灵感点特征数: {total_inspiration_features}")
  364. print(f" 处理的关键点数: {total_key_points}")
  365. print(f" 处理的关键点特征数: {total_key_features}")
  366. print(f" 处理的目的点数: {total_purpose_points}")
  367. print(f" 处理的目的点特征数: {total_purpose_features}")
  368. if __name__ == "__main__":
  369. asyncio.run(main())