match_inspiration_features.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 灵感点特征匹配脚本
  5. 从解构任务列表中提取灵感点的特征,与人设灵感特征进行匹配,
  6. 使用 relation_analyzer 模块分析特征之间的语义关系。
  7. """
  8. import json
  9. import asyncio
  10. from pathlib import Path
  11. from typing import Dict, List
  12. import sys
  13. from datetime import datetime
  14. # 添加项目根目录到路径
  15. project_root = Path(__file__).parent.parent.parent
  16. sys.path.insert(0, str(project_root))
  17. from lib.hybrid_similarity import compare_phrases
  18. from script.data_processing.path_config import PathConfig
  19. # 全局并发限制
  20. MAX_CONCURRENT_REQUESTS = 100
  21. semaphore = None
  22. # 进度跟踪
  23. class ProgressTracker:
  24. """进度跟踪器"""
  25. def __init__(self, total: int):
  26. self.total = total
  27. self.completed = 0
  28. self.start_time = datetime.now()
  29. self.last_update_time = datetime.now()
  30. self.last_completed = 0
  31. def update(self, count: int = 1):
  32. """更新进度"""
  33. self.completed += count
  34. current_time = datetime.now()
  35. # 每秒最多更新一次,或者达到总数时更新
  36. if (current_time - self.last_update_time).total_seconds() >= 1.0 or self.completed >= self.total:
  37. self.display()
  38. self.last_update_time = current_time
  39. self.last_completed = self.completed
  40. def display(self):
  41. """显示进度"""
  42. if self.total == 0:
  43. return
  44. percentage = (self.completed / self.total) * 100
  45. elapsed = (datetime.now() - self.start_time).total_seconds()
  46. # 计算速度和预估剩余时间
  47. if elapsed > 0:
  48. speed = self.completed / elapsed
  49. if speed > 0:
  50. remaining = (self.total - self.completed) / speed
  51. eta_str = f", 预计剩余: {int(remaining)}秒"
  52. else:
  53. eta_str = ""
  54. else:
  55. eta_str = ""
  56. bar_length = 40
  57. filled_length = int(bar_length * self.completed / self.total)
  58. bar = '█' * filled_length + '░' * (bar_length - filled_length)
  59. print(f"\r 进度: [{bar}] {self.completed}/{self.total} ({percentage:.1f}%){eta_str}", end='', flush=True)
  60. # 完成时换行
  61. if self.completed >= self.total:
  62. print()
  63. # 全局进度跟踪器
  64. progress_tracker = None
  65. def get_semaphore():
  66. """获取全局信号量"""
  67. global semaphore
  68. if semaphore is None:
  69. semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  70. return semaphore
  71. async def match_single_pair(
  72. feature_name: str,
  73. persona_name: str,
  74. persona_feature_level: str,
  75. category_mapping: Dict = None,
  76. model_name: str = None
  77. ) -> Dict:
  78. """
  79. 匹配单个特征对(带并发限制)
  80. Args:
  81. feature_name: 要匹配的特征名称
  82. persona_name: 人设特征名称
  83. persona_feature_level: 人设特征层级(灵感点/关键点/目的点)
  84. category_mapping: 特征分类映射字典
  85. model_name: 使用的模型名称
  86. Returns:
  87. 单个匹配结果,格式:
  88. {
  89. "人设特征名称": "xxx",
  90. "人设特征层级": "灵感点",
  91. "特征类型": "标签",
  92. "特征分类": ["分类1", "分类2"],
  93. "匹配结果": {
  94. "相似度": 0.75,
  95. "说明": "..."
  96. }
  97. }
  98. """
  99. global progress_tracker
  100. sem = get_semaphore()
  101. async with sem:
  102. # 使用混合相似度模型(异步调用)
  103. similarity_result = await compare_phrases(
  104. phrase_a=feature_name,
  105. phrase_b=persona_name,
  106. weight_embedding=0.5,
  107. weight_semantic=0.5
  108. )
  109. # 更新进度
  110. if progress_tracker:
  111. progress_tracker.update(1)
  112. # 判断该特征是标签还是分类
  113. feature_type = "分类" # 默认为分类
  114. categories = []
  115. if category_mapping:
  116. # 先在标签特征中查找(灵感点、关键点、目的点)
  117. is_tag_feature = False
  118. for ft in ["灵感点", "关键点", "目的点"]:
  119. if ft in category_mapping:
  120. type_mapping = category_mapping[ft]
  121. if persona_name in type_mapping:
  122. # 找到了,说明是标签特征
  123. feature_type = "标签"
  124. categories = type_mapping[persona_name].get("所属分类", [])
  125. is_tag_feature = True
  126. break
  127. # 如果不是标签特征,检查是否是分类特征
  128. if not is_tag_feature:
  129. # 收集所有分类
  130. all_categories = set()
  131. for ft in ["灵感点", "关键点", "目的点"]:
  132. if ft in category_mapping:
  133. for fname, fdata in category_mapping[ft].items():
  134. cats = fdata.get("所属分类", [])
  135. all_categories.update(cats)
  136. # 如果当前特征名在分类列表中,则是分类特征
  137. if persona_name in all_categories:
  138. feature_type = "分类"
  139. categories = [] # 分类特征本身没有所属分类
  140. # 去重分类
  141. unique_categories = list(dict.fromkeys(categories))
  142. return {
  143. "人设特征名称": persona_name,
  144. "人设特征层级": persona_feature_level,
  145. "特征类型": feature_type,
  146. "特征分类": unique_categories,
  147. "匹配结果": similarity_result
  148. }
  149. async def match_feature_with_persona(
  150. feature_name: str,
  151. persona_features: List[Dict],
  152. category_mapping: Dict = None,
  153. model_name: str = None
  154. ) -> List[Dict]:
  155. """
  156. 将一个特征与人设特征列表进行匹配(并发执行)
  157. Args:
  158. feature_name: 要匹配的特征名称
  159. persona_features: 人设特征列表(包含"特征名称"和"人设特征层级")
  160. category_mapping: 特征分类映射字典
  161. model_name: 使用的模型名称
  162. Returns:
  163. 匹配结果列表
  164. """
  165. # 创建所有匹配任务
  166. tasks = [
  167. match_single_pair(
  168. feature_name,
  169. persona_feature["特征名称"],
  170. persona_feature["人设特征层级"],
  171. category_mapping,
  172. model_name
  173. )
  174. for persona_feature in persona_features
  175. ]
  176. # 并发执行所有匹配
  177. match_results = await asyncio.gather(*tasks)
  178. return list(match_results)
  179. async def match_single_feature(
  180. feature_item: Dict,
  181. persona_features: List[Dict],
  182. category_mapping: Dict = None,
  183. model_name: str = None
  184. ) -> Dict:
  185. """
  186. 匹配单个特征与所有人设特征
  187. Args:
  188. feature_item: 特征信息(包含"特征名称"和"权重")
  189. persona_features: 人设特征列表
  190. category_mapping: 特征分类映射字典
  191. model_name: 使用的模型名称
  192. Returns:
  193. 特征匹配结果
  194. """
  195. feature_name = feature_item.get("特征名称", "")
  196. feature_weight = feature_item.get("权重", 1.0)
  197. match_results = await match_feature_with_persona(
  198. feature_name=feature_name,
  199. persona_features=persona_features,
  200. category_mapping=category_mapping,
  201. model_name=model_name
  202. )
  203. return {
  204. "特征名称": feature_name,
  205. "权重": feature_weight,
  206. "匹配结果": match_results
  207. }
  208. async def process_single_point(
  209. point: Dict,
  210. point_type: str,
  211. persona_features: List[Dict],
  212. category_mapping: Dict = None,
  213. model_name: str = None
  214. ) -> Dict:
  215. """
  216. 处理单个点(灵感点/关键点/目的点)的特征匹配(并发执行)
  217. Args:
  218. point: 点数据(灵感点/关键点/目的点)
  219. point_type: 点类型("灵感点"/"关键点"/"目的点")
  220. persona_features: 人设特征列表
  221. category_mapping: 特征分类映射字典
  222. model_name: 使用的模型名称
  223. Returns:
  224. 包含 how 步骤列表的点数据
  225. """
  226. point_name = point.get("名称", "")
  227. feature_list = point.get("特征列表", [])
  228. # 并发匹配所有特征
  229. tasks = [
  230. match_single_feature(feature_item, persona_features, category_mapping, model_name)
  231. for feature_item in feature_list
  232. ]
  233. feature_match_results = await asyncio.gather(*tasks)
  234. # 构建 how 步骤(根据点类型生成步骤名称)
  235. step_name_mapping = {
  236. "灵感点": "灵感特征分别匹配人设特征",
  237. "关键点": "关键特征分别匹配人设特征",
  238. "目的点": "目的特征分别匹配人设特征"
  239. }
  240. how_step = {
  241. "步骤名称": step_name_mapping.get(point_type, f"{point_type}特征分别匹配人设特征"),
  242. "特征列表": list(feature_match_results)
  243. }
  244. # 返回更新后的点
  245. result = point.copy()
  246. result["how步骤列表"] = [how_step]
  247. return result
  248. async def process_single_task(
  249. task: Dict,
  250. task_index: int,
  251. total_tasks: int,
  252. all_persona_features: List[Dict],
  253. category_mapping: Dict = None,
  254. model_name: str = None
  255. ) -> Dict:
  256. """
  257. 处理单个任务
  258. Args:
  259. task: 任务数据
  260. task_index: 任务索引(从1开始)
  261. total_tasks: 总任务数
  262. all_persona_features: 所有人设特征列表(包含三种层级)
  263. category_mapping: 特征分类映射字典
  264. model_name: 使用的模型名称
  265. Returns:
  266. 包含 how 解构结果的任务
  267. """
  268. post_id = task.get("帖子id", "")
  269. print(f"\n[{task_index}/{total_tasks}] 处理帖子: {post_id}")
  270. # 获取 what 解构结果
  271. what_result = task.get("what解构结果", {})
  272. # 构建 how 解构结果
  273. how_result = {}
  274. # 处理灵感点、关键点和目的点
  275. for point_type in ["灵感点", "关键点", "目的点"]:
  276. point_list_key = f"{point_type}列表"
  277. point_list = what_result.get(point_list_key, [])
  278. if point_list:
  279. # 并发处理所有点
  280. tasks = [
  281. process_single_point(
  282. point=point,
  283. point_type=point_type,
  284. persona_features=all_persona_features,
  285. category_mapping=category_mapping,
  286. model_name=model_name
  287. )
  288. for point in point_list
  289. ]
  290. updated_point_list = await asyncio.gather(*tasks)
  291. # 添加到 how 解构结果
  292. how_result[point_list_key] = list(updated_point_list)
  293. # 更新任务
  294. updated_task = task.copy()
  295. updated_task["how解构结果"] = how_result
  296. return updated_task
  297. async def process_task_list(
  298. task_list: List[Dict],
  299. persona_features_dict: Dict,
  300. category_mapping: Dict = None,
  301. model_name: str = None
  302. ) -> List[Dict]:
  303. """
  304. 处理整个解构任务列表(并发执行)
  305. Args:
  306. task_list: 解构任务列表
  307. persona_features_dict: 人设特征字典(包含灵感点、目的点、关键点)
  308. category_mapping: 特征分类映射字典
  309. model_name: 使用的模型名称
  310. Returns:
  311. 包含 how 解构结果的任务列表
  312. """
  313. global progress_tracker
  314. # 合并三种人设特征(灵感点、关键点、目的点)
  315. all_features = []
  316. for feature_type in ["灵感点", "关键点", "目的点"]:
  317. # 获取该类型的标签特征
  318. type_features = persona_features_dict.get(feature_type, [])
  319. # 为每个特征添加层级信息
  320. for feature in type_features:
  321. feature_with_level = feature.copy()
  322. feature_with_level["人设特征层级"] = feature_type
  323. all_features.append(feature_with_level)
  324. print(f"人设{feature_type}标签特征数量: {len(type_features)}")
  325. # 从分类映射中提取该类型的分类特征
  326. if category_mapping and feature_type in category_mapping:
  327. type_categories = set()
  328. for _, feature_data in category_mapping[feature_type].items():
  329. categories = feature_data.get("所属分类", [])
  330. type_categories.update(categories)
  331. # 转换为特征格式并添加层级信息
  332. for cat in sorted(type_categories):
  333. all_features.append({
  334. "特征名称": cat,
  335. "人设特征层级": feature_type
  336. })
  337. print(f"人设{feature_type}分类特征数量: {len(type_categories)}")
  338. print(f"总特征数量(三种类型的标签+分类): {len(all_features)}")
  339. # 计算总匹配任务数(灵感点、关键点和目的点)
  340. total_match_count = 0
  341. for task in task_list:
  342. what_result = task.get("what解构结果", {})
  343. for point_type in ["灵感点", "关键点", "目的点"]:
  344. point_list = what_result.get(f"{point_type}列表", [])
  345. for point in point_list:
  346. feature_count = len(point.get("特征列表", []))
  347. total_match_count += feature_count * len(all_features)
  348. print(f"处理灵感点、关键点和目的点特征")
  349. print(f"总匹配任务数: {total_match_count:,}")
  350. print()
  351. # 初始化全局进度跟踪器
  352. progress_tracker = ProgressTracker(total_match_count)
  353. # 并发处理所有任务
  354. tasks = [
  355. process_single_task(
  356. task=task,
  357. task_index=i,
  358. total_tasks=len(task_list),
  359. all_persona_features=all_features,
  360. category_mapping=category_mapping,
  361. model_name=model_name
  362. )
  363. for i, task in enumerate(task_list, 1)
  364. ]
  365. updated_task_list = await asyncio.gather(*tasks)
  366. return list(updated_task_list)
  367. async def main():
  368. """主函数"""
  369. # 使用路径配置
  370. config = PathConfig()
  371. # 确保输出目录存在
  372. config.ensure_dirs()
  373. # 获取路径
  374. task_list_file = config.task_list_file
  375. persona_features_file = config.feature_source_mapping_file
  376. category_mapping_file = config.feature_category_mapping_file
  377. output_dir = config.how_results_dir
  378. print(f"账号: {config.account_name}")
  379. print(f"任务列表文件: {task_list_file}")
  380. print(f"人设特征文件: {persona_features_file}")
  381. print(f"分类映射文件: {category_mapping_file}")
  382. print(f"输出目录: {output_dir}")
  383. print()
  384. print(f"读取解构任务列表: {task_list_file}")
  385. with open(task_list_file, "r", encoding="utf-8") as f:
  386. task_list_data = json.load(f)
  387. print(f"读取人设特征: {persona_features_file}")
  388. with open(persona_features_file, "r", encoding="utf-8") as f:
  389. persona_features_data = json.load(f)
  390. print(f"读取特征分类映射: {category_mapping_file}")
  391. with open(category_mapping_file, "r", encoding="utf-8") as f:
  392. category_mapping = json.load(f)
  393. # 预先加载模型(混合模型会自动处理)
  394. print("\n预加载混合相似度模型...")
  395. await compare_phrases("测试", "测试", weight_embedding=0.5, weight_semantic=0.5)
  396. print("模型预加载完成!\n")
  397. # 获取任务列表
  398. task_list = task_list_data.get("解构任务列表", [])
  399. print(f"总任务数: {len(task_list)}")
  400. # 处理任务列表
  401. updated_task_list = await process_task_list(
  402. task_list=task_list,
  403. persona_features_dict=persona_features_data,
  404. category_mapping=category_mapping,
  405. model_name=None # 使用默认模型
  406. )
  407. # 分文件保存结果
  408. print(f"\n保存结果到: {output_dir}")
  409. for task in updated_task_list:
  410. post_id = task.get("帖子id", "unknown")
  411. output_file = output_dir / f"{post_id}_how.json"
  412. print(f" 保存: {output_file.name}")
  413. with open(output_file, "w", encoding="utf-8") as f:
  414. json.dump(task, f, ensure_ascii=False, indent=4)
  415. print("\n完成!")
  416. # 打印统计信息
  417. total_inspiration_points = 0
  418. total_key_points = 0
  419. total_purpose_points = 0
  420. total_inspiration_features = 0
  421. total_key_features = 0
  422. total_purpose_features = 0
  423. for task in updated_task_list:
  424. how_result = task.get("how解构结果", {})
  425. # 统计灵感点
  426. inspiration_list = how_result.get("灵感点列表", [])
  427. total_inspiration_points += len(inspiration_list)
  428. for point in inspiration_list:
  429. total_inspiration_features += len(point.get("特征列表", []))
  430. # 统计关键点
  431. key_list = how_result.get("关键点列表", [])
  432. total_key_points += len(key_list)
  433. for point in key_list:
  434. total_key_features += len(point.get("特征列表", []))
  435. # 统计目的点
  436. purpose_list = how_result.get("目的点列表", [])
  437. total_purpose_points += len(purpose_list)
  438. for point in purpose_list:
  439. total_purpose_features += len(point.get("特征列表", []))
  440. print(f"\n统计:")
  441. print(f" 处理的帖子数: {len(updated_task_list)}")
  442. print(f" 处理的灵感点数: {total_inspiration_points}")
  443. print(f" 处理的灵感点特征数: {total_inspiration_features}")
  444. print(f" 处理的关键点数: {total_key_points}")
  445. print(f" 处理的关键点特征数: {total_key_features}")
  446. print(f" 处理的目的点数: {total_purpose_points}")
  447. print(f" 处理的目的点特征数: {total_purpose_features}")
  448. if __name__ == "__main__":
  449. asyncio.run(main())