match_inspiration_features.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 灵感点特征匹配脚本
  5. 从解构任务列表中提取灵感点的特征,与人设灵感特征进行匹配,
  6. 使用 relation_analyzer 模块分析特征之间的语义关系。
  7. """
  8. import json
  9. import asyncio
  10. from pathlib import Path
  11. from typing import Dict, List
  12. import sys
  13. from tqdm import tqdm
  14. # 添加项目根目录到路径
  15. project_root = Path(__file__).parent.parent.parent
  16. sys.path.insert(0, str(project_root))
  17. from lib.hybrid_similarity import compare_phrases_cartesian
  18. from script.data_processing.path_config import PathConfig
  19. # 全局进度条
  20. progress_bar = None
  21. async def process_single_point(
  22. point: Dict,
  23. point_type: str,
  24. persona_features: List[Dict],
  25. category_mapping: Dict = None,
  26. model_name: str = None
  27. ) -> Dict:
  28. """
  29. 处理单个点 - 使用笛卡尔积批量计算(优化版)
  30. Args:
  31. point: 点数据(灵感点/关键点/目的点)
  32. point_type: 点类型("灵感点"/"关键点"/"目的点")
  33. persona_features: 人设特征列表
  34. category_mapping: 特征分类映射字典
  35. model_name: 使用的模型名称
  36. Returns:
  37. 包含 how 步骤列表的点数据
  38. """
  39. global progress_bar
  40. point_name = point.get("名称", "")
  41. feature_list = point.get("特征列表", [])
  42. # 如果没有特征,直接返回
  43. if not feature_list or not persona_features:
  44. result = point.copy()
  45. result["how步骤列表"] = []
  46. return result
  47. # 提取特征名称和人设名称列表
  48. feature_names = [f.get("特征名称", "") for f in feature_list]
  49. persona_names = [pf["特征名称"] for pf in persona_features]
  50. # 定义进度回调函数
  51. def on_llm_progress(count: int):
  52. """LLM完成一个任务时的回调"""
  53. if progress_bar:
  54. progress_bar.update(count)
  55. # 核心优化:使用混合模型笛卡尔积一次计算M×N
  56. # max_concurrent 控制的是底层 LLM 的全局并发数
  57. similarity_results = await compare_phrases_cartesian(
  58. feature_names, # M个特征
  59. persona_names, # N个人设
  60. max_concurrent=100, # LLM最大并发数(全局共享)
  61. progress_callback=on_llm_progress # 传递进度回调
  62. )
  63. # similarity_results[i][j] = {"相似度": float, "说明": str}
  64. # 构建匹配结果(使用模块返回的完整结果)
  65. feature_match_results = []
  66. for i, feature_item in enumerate(feature_list):
  67. feature_name = feature_item.get("特征名称", "")
  68. feature_weight = feature_item.get("权重", 1.0)
  69. # 该特征与所有人设的匹配结果
  70. match_results = []
  71. for j, persona_feature in enumerate(persona_features):
  72. persona_name = persona_feature["特征名称"]
  73. persona_level = persona_feature["人设特征层级"]
  74. # 直接使用模块返回的完整结果
  75. similarity_result = similarity_results[i][j]
  76. # 判断特征类型和分类
  77. feature_type = "分类" # 默认为分类
  78. categories = []
  79. if category_mapping:
  80. # 先在标签特征中查找
  81. is_tag_feature = False
  82. for ft in ["灵感点", "关键点", "目的点"]:
  83. if ft in category_mapping:
  84. type_mapping = category_mapping[ft]
  85. if persona_name in type_mapping:
  86. feature_type = "标签"
  87. categories = type_mapping[persona_name].get("所属分类", [])
  88. is_tag_feature = True
  89. break
  90. # 如果不是标签特征,检查是否是分类特征
  91. if not is_tag_feature:
  92. all_categories = set()
  93. for ft in ["灵感点", "关键点", "目的点"]:
  94. if ft in category_mapping:
  95. for _fname, fdata in category_mapping[ft].items():
  96. cats = fdata.get("所属分类", [])
  97. all_categories.update(cats)
  98. if persona_name in all_categories:
  99. feature_type = "分类"
  100. categories = []
  101. # 去重分类
  102. unique_categories = list(dict.fromkeys(categories))
  103. match_result = {
  104. "人设特征名称": persona_name,
  105. "人设特征层级": persona_level,
  106. "特征类型": feature_type,
  107. "特征分类": unique_categories,
  108. "匹配结果": similarity_result # 直接使用模块返回的结果
  109. }
  110. match_results.append(match_result)
  111. feature_match_results.append({
  112. "特征名称": feature_name,
  113. "权重": feature_weight,
  114. "匹配结果": match_results
  115. })
  116. # 构建 how 步骤(保持不变)
  117. step_name_mapping = {
  118. "灵感点": "灵感特征分别匹配人设特征",
  119. "关键点": "关键特征分别匹配人设特征",
  120. "目的点": "目的特征分别匹配人设特征"
  121. }
  122. how_step = {
  123. "步骤名称": step_name_mapping.get(point_type, f"{point_type}特征分别匹配人设特征"),
  124. "特征列表": list(feature_match_results)
  125. }
  126. result = point.copy()
  127. result["how步骤列表"] = [how_step]
  128. return result
  129. async def process_single_task(
  130. task: Dict,
  131. task_index: int,
  132. total_tasks: int,
  133. all_persona_features: List[Dict],
  134. category_mapping: Dict = None,
  135. model_name: str = None
  136. ) -> Dict:
  137. """
  138. 处理单个任务
  139. Args:
  140. task: 任务数据
  141. task_index: 任务索引(从1开始)
  142. total_tasks: 总任务数
  143. all_persona_features: 所有人设特征列表(包含三种层级)
  144. category_mapping: 特征分类映射字典
  145. model_name: 使用的模型名称
  146. Returns:
  147. 包含 how 解构结果的任务
  148. """
  149. global progress_bar
  150. post_id = task.get("帖子id", "")
  151. # 获取 what 解构结果
  152. what_result = task.get("what解构结果", {})
  153. # 计算当前帖子的总匹配任务数
  154. current_task_match_count = 0
  155. for point_type in ["灵感点", "关键点", "目的点"]:
  156. point_list = what_result.get(f"{point_type}列表", [])
  157. for point in point_list:
  158. feature_count = len(point.get("特征列表", []))
  159. current_task_match_count += feature_count * len(all_persona_features)
  160. # 创建当前帖子的进度条
  161. progress_bar = tqdm(
  162. total=current_task_match_count,
  163. desc=f"[{task_index}/{total_tasks}] {post_id}",
  164. unit="匹配",
  165. ncols=100
  166. )
  167. # 构建 how 解构结果
  168. how_result = {}
  169. # 串行处理灵感点、关键点和目的点
  170. for point_type in ["灵感点", "关键点", "目的点"]:
  171. point_list_key = f"{point_type}列表"
  172. point_list = what_result.get(point_list_key, [])
  173. if point_list:
  174. updated_point_list = []
  175. # 串行处理每个点
  176. for point in point_list:
  177. result = await process_single_point(
  178. point=point,
  179. point_type=point_type,
  180. persona_features=all_persona_features,
  181. category_mapping=category_mapping,
  182. model_name=model_name
  183. )
  184. updated_point_list.append(result)
  185. # 添加到 how 解构结果
  186. how_result[point_list_key] = updated_point_list
  187. # 关闭当前帖子的进度条
  188. if progress_bar:
  189. progress_bar.close()
  190. # 更新任务
  191. updated_task = task.copy()
  192. updated_task["how解构结果"] = how_result
  193. return updated_task
  194. async def process_task_list(
  195. task_list: List[Dict],
  196. persona_features_dict: Dict,
  197. category_mapping: Dict = None,
  198. model_name: str = None,
  199. output_dir: Path = None
  200. ) -> List[Dict]:
  201. """
  202. 处理整个解构任务列表(串行执行,每个帖子处理完立即保存)
  203. Args:
  204. task_list: 解构任务列表
  205. persona_features_dict: 人设特征字典(包含灵感点、目的点、关键点)
  206. category_mapping: 特征分类映射字典
  207. model_name: 使用的模型名称
  208. output_dir: 输出目录(如果提供,每个帖子处理完立即保存)
  209. Returns:
  210. 包含 how 解构结果的任务列表
  211. """
  212. # 合并三种人设特征(灵感点、关键点、目的点)
  213. all_features = []
  214. for feature_type in ["灵感点", "关键点", "目的点"]:
  215. # 获取该类型的标签特征
  216. type_features = persona_features_dict.get(feature_type, [])
  217. # 为每个特征添加层级信息
  218. for feature in type_features:
  219. feature_with_level = feature.copy()
  220. feature_with_level["人设特征层级"] = feature_type
  221. all_features.append(feature_with_level)
  222. print(f"人设{feature_type}标签特征数量: {len(type_features)}")
  223. # 从分类映射中提取该类型的分类特征
  224. if category_mapping and feature_type in category_mapping:
  225. type_categories = set()
  226. for _, feature_data in category_mapping[feature_type].items():
  227. categories = feature_data.get("所属分类", [])
  228. type_categories.update(categories)
  229. # 转换为特征格式并添加层级信息
  230. for cat in sorted(type_categories):
  231. all_features.append({
  232. "特征名称": cat,
  233. "人设特征层级": feature_type
  234. })
  235. print(f"人设{feature_type}分类特征数量: {len(type_categories)}")
  236. print(f"总特征数量(三种类型的标签+分类): {len(all_features)}")
  237. # 计算总匹配任务数(灵感点、关键点和目的点)
  238. total_match_count = 0
  239. for task in task_list:
  240. what_result = task.get("what解构结果", {})
  241. for point_type in ["灵感点", "关键点", "目的点"]:
  242. point_list = what_result.get(f"{point_type}列表", [])
  243. for point in point_list:
  244. feature_count = len(point.get("特征列表", []))
  245. total_match_count += feature_count * len(all_features)
  246. print(f"处理灵感点、关键点和目的点特征")
  247. print(f"总匹配任务数: {total_match_count:,}")
  248. print()
  249. # 串行处理所有任务(一个接一个,每个处理完立即保存)
  250. updated_task_list = []
  251. for i, task in enumerate(task_list, 1):
  252. updated_task = await process_single_task(
  253. task=task,
  254. task_index=i,
  255. total_tasks=len(task_list),
  256. all_persona_features=all_features,
  257. category_mapping=category_mapping,
  258. model_name=model_name
  259. )
  260. updated_task_list.append(updated_task)
  261. # 立即保存当前帖子的结果
  262. if output_dir:
  263. post_id = updated_task.get("帖子id", "unknown")
  264. output_file = output_dir / f"{post_id}_how.json"
  265. with open(output_file, "w", encoding="utf-8") as f:
  266. json.dump(updated_task, f, ensure_ascii=False, indent=4)
  267. print(f" ✓ 已保存: {output_file.name}")
  268. return updated_task_list
  269. async def main():
  270. """主函数"""
  271. # 使用路径配置
  272. config = PathConfig()
  273. # 确保输出目录存在
  274. config.ensure_dirs()
  275. # 获取路径
  276. task_list_file = config.task_list_file
  277. persona_features_file = config.feature_source_mapping_file
  278. category_mapping_file = config.feature_category_mapping_file
  279. output_dir = config.how_results_dir
  280. print(f"账号: {config.account_name}")
  281. print(f"任务列表文件: {task_list_file}")
  282. print(f"人设特征文件: {persona_features_file}")
  283. print(f"分类映射文件: {category_mapping_file}")
  284. print(f"输出目录: {output_dir}")
  285. print()
  286. print(f"读取解构任务列表: {task_list_file}")
  287. with open(task_list_file, "r", encoding="utf-8") as f:
  288. task_list_data = json.load(f)
  289. print(f"读取人设特征: {persona_features_file}")
  290. with open(persona_features_file, "r", encoding="utf-8") as f:
  291. persona_features_data = json.load(f)
  292. print(f"读取特征分类映射: {category_mapping_file}")
  293. with open(category_mapping_file, "r", encoding="utf-8") as f:
  294. category_mapping = json.load(f)
  295. # 获取任务列表
  296. task_list = task_list_data.get("解构任务列表", [])
  297. print(f"总任务数: {len(task_list)}")
  298. # 处理任务列表(每个帖子处理完立即保存)
  299. updated_task_list = await process_task_list(
  300. task_list=task_list,
  301. persona_features_dict=persona_features_data,
  302. category_mapping=category_mapping,
  303. model_name=None, # 使用默认模型
  304. output_dir=output_dir # 传递输出目录,启用即时保存
  305. )
  306. print("\n完成!")
  307. # 打印统计信息
  308. total_inspiration_points = 0
  309. total_key_points = 0
  310. total_purpose_points = 0
  311. total_inspiration_features = 0
  312. total_key_features = 0
  313. total_purpose_features = 0
  314. for task in updated_task_list:
  315. how_result = task.get("how解构结果", {})
  316. # 统计灵感点
  317. inspiration_list = how_result.get("灵感点列表", [])
  318. total_inspiration_points += len(inspiration_list)
  319. for point in inspiration_list:
  320. total_inspiration_features += len(point.get("特征列表", []))
  321. # 统计关键点
  322. key_list = how_result.get("关键点列表", [])
  323. total_key_points += len(key_list)
  324. for point in key_list:
  325. total_key_features += len(point.get("特征列表", []))
  326. # 统计目的点
  327. purpose_list = how_result.get("目的点列表", [])
  328. total_purpose_points += len(purpose_list)
  329. for point in purpose_list:
  330. total_purpose_features += len(point.get("特征列表", []))
  331. print(f"\n统计:")
  332. print(f" 处理的帖子数: {len(updated_task_list)}")
  333. print(f" 处理的灵感点数: {total_inspiration_points}")
  334. print(f" 处理的灵感点特征数: {total_inspiration_features}")
  335. print(f" 处理的关键点数: {total_key_points}")
  336. print(f" 处理的关键点特征数: {total_key_features}")
  337. print(f" 处理的目的点数: {total_purpose_points}")
  338. print(f" 处理的目的点特征数: {total_purpose_features}")
  339. if __name__ == "__main__":
  340. asyncio.run(main())