match_inspiration_features_v4.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 灵感点特征匹配脚本 v4(统一匹配版本)
  5. 使用单个prompt同时完成标签匹配和分类匹配,不分步骤执行。
  6. 一次LLM调用完成所有层级的评估。
  7. """
  8. import json
  9. import asyncio
  10. from pathlib import Path
  11. from typing import Dict, List, Optional
  12. import sys
  13. # 添加项目根目录到路径
  14. project_root = Path(__file__).parent.parent.parent
  15. sys.path.insert(0, str(project_root))
  16. from agents import trace
  17. from agents.tracing.create import custom_span
  18. from lib.my_trace import set_trace
  19. from lib.unified_match_analyzer import unified_match
  20. # 全局并发限制
  21. MAX_CONCURRENT_REQUESTS = 20
  22. semaphore = None
  23. def get_semaphore():
  24. """获取全局信号量"""
  25. global semaphore
  26. if semaphore is None:
  27. semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  28. return semaphore
  29. def load_feature_categories(categories_file: Path) -> Dict:
  30. """加载特征分类映射"""
  31. with open(categories_file, "r", encoding="utf-8") as f:
  32. return json.load(f)
  33. def enrich_persona_combinations_with_categories(
  34. persona_combinations: List[Dict],
  35. feature_categories: Dict,
  36. point_type: str
  37. ) -> List[Dict]:
  38. """为人设特征组合添加分类信息"""
  39. enriched_combinations = []
  40. type_categories = feature_categories.get(point_type, {})
  41. for combo in persona_combinations:
  42. feature_list = combo.get("特征组合", [])
  43. # 为每个特征添加分类信息
  44. enriched_features = []
  45. for feature_name in feature_list:
  46. categories = type_categories.get(feature_name, {}).get("所属分类", [])
  47. enriched_features.append({
  48. "特征名称": feature_name,
  49. "所属分类": categories
  50. })
  51. enriched_combo = {
  52. "特征组合": enriched_features,
  53. "原始特征组合": feature_list,
  54. "特征来源": combo.get("特征来源", [])
  55. }
  56. enriched_combinations.append(enriched_combo)
  57. return enriched_combinations
  58. async def match_tag_list_with_combination(
  59. current_tag_list: List[str],
  60. persona_combination: Dict,
  61. model_name: Optional[str] = None
  62. ) -> Dict:
  63. """
  64. 使用统一匹配将当前点的标签列表与一个人设历史组合进行匹配
  65. 一次LLM调用完成标签匹配和分类匹配的评估
  66. Returns:
  67. {
  68. "人设标签组合": [...],
  69. "当前标签匹配结果": [
  70. {"当前标签": "立冬", "最终得分": 0.7, "匹配层级": "...", ...},
  71. {"当前标签": "教资查分", "最终得分": 0.6, ...},
  72. ...
  73. ],
  74. "人设标签来源": [...]
  75. }
  76. """
  77. sem = get_semaphore()
  78. async with sem:
  79. # 调用统一匹配模块(返回每个当前标签的匹配结果)
  80. tag_match_results = await unified_match(
  81. current_tags=current_tag_list,
  82. persona_combination=persona_combination["特征组合"],
  83. model_name=model_name
  84. )
  85. # 构建返回结果
  86. result = {
  87. "人设标签组合": persona_combination["原始特征组合"],
  88. "当前标签匹配结果": tag_match_results, # 每个当前标签的匹配结果
  89. "人设标签来源": persona_combination["特征来源"]
  90. }
  91. return result
  92. async def match_inspiration_point_with_combinations(
  93. current_feature_list: List[str],
  94. persona_combinations: List[Dict],
  95. model_name: Optional[str] = None
  96. ) -> List[Dict]:
  97. """将当前点的特征列表与所有人设特征组合进行匹配"""
  98. print(f" 批量匹配: 当前{len(current_feature_list)}个标签 {current_feature_list} vs {len(persona_combinations)}个人设组合")
  99. # 并发匹配所有组合
  100. tasks = [
  101. match_tag_list_with_combination(
  102. current_tag_list=current_feature_list,
  103. persona_combination=combo,
  104. model_name=model_name
  105. )
  106. for combo in persona_combinations
  107. ]
  108. match_results = await asyncio.gather(*tasks)
  109. # 过滤和修复无效结果
  110. valid_results = []
  111. for result in match_results:
  112. # 确保result是dict
  113. if not isinstance(result, dict):
  114. print(f"警告: 跳过无效结果 (不是字典): {type(result)}")
  115. continue
  116. # 确保有当前标签匹配结果字段
  117. tag_results = result.get("当前标签匹配结果")
  118. if tag_results is None:
  119. print(f"警告: 结果缺少当前标签匹配结果字段")
  120. continue
  121. # 确保当前标签匹配结果是list
  122. if not isinstance(tag_results, list):
  123. print(f"警告: 当前标签匹配结果不是列表: {type(tag_results)}")
  124. continue
  125. # 计算该人设组合的加权平均得分
  126. weighted_scores = []
  127. for tag_result in tag_results:
  128. if isinstance(tag_result, dict):
  129. match_result = tag_result.get("匹配结果", {})
  130. match_type = match_result.get("匹配类型")
  131. similarity = match_result.get("语义相似度", 0)
  132. # 根据匹配类型设置权重
  133. if match_type == "标签匹配":
  134. weight = 1.0
  135. elif match_type == "分类匹配":
  136. weight = 0.5
  137. else: # 无匹配
  138. weight = 1.0 # 无匹配也使用1.0权重,因为相似度已经是0
  139. weighted_score = similarity * weight
  140. weighted_scores.append(weighted_score)
  141. avg_score = sum(weighted_scores) / len(weighted_scores) if weighted_scores else 0
  142. result["组合平均得分"] = avg_score
  143. # 添加精简结果字段
  144. result["精简结果"] = {
  145. "人设标签组合": result.get("人设标签组合", []),
  146. "组合平均得分": avg_score,
  147. "各标签得分": [
  148. {
  149. "标签": tag_res.get("当前标签"),
  150. "原始相似度": tag_res.get("匹配结果", {}).get("语义相似度", 0),
  151. "匹配类型": tag_res.get("匹配结果", {}).get("匹配类型"),
  152. "权重": 1.0 if tag_res.get("匹配结果", {}).get("匹配类型") == "标签匹配" else 0.5 if tag_res.get("匹配结果", {}).get("匹配类型") == "分类匹配" else 1.0,
  153. "加权得分": tag_res.get("匹配结果", {}).get("语义相似度", 0) * (1.0 if tag_res.get("匹配结果", {}).get("匹配类型") == "标签匹配" else 0.5 if tag_res.get("匹配结果", {}).get("匹配类型") == "分类匹配" else 1.0),
  154. "匹配到": tag_res.get("匹配结果", {}).get("匹配到")
  155. }
  156. for tag_res in tag_results if isinstance(tag_res, dict)
  157. ]
  158. }
  159. valid_results.append(result)
  160. # 按组合平均得分降序排序
  161. valid_results.sort(
  162. key=lambda x: x.get("组合平均得分", 0),
  163. reverse=True
  164. )
  165. return valid_results
  166. async def process_single_inspiration_point(
  167. inspiration_point: Dict,
  168. persona_combinations: List[Dict],
  169. model_name: Optional[str] = None
  170. ) -> Dict:
  171. """处理单个灵感点的特征组合匹配"""
  172. point_name = inspiration_point.get("名称", "")
  173. feature_list = inspiration_point.get("特征列表", [])
  174. print(f" 处理灵感点: {point_name}")
  175. print(f" 特征列表: {feature_list}")
  176. with custom_span(
  177. name=f"处理灵感点: {point_name}",
  178. data={
  179. "灵感点": point_name,
  180. "特征列表": feature_list,
  181. "人设组合数量": len(persona_combinations)
  182. }
  183. ):
  184. # 将特征列表与所有人设组合进行匹配
  185. match_results = await match_inspiration_point_with_combinations(
  186. current_feature_list=feature_list,
  187. persona_combinations=persona_combinations,
  188. model_name=model_name
  189. )
  190. # 构建完整版 how 步骤
  191. how_step = {
  192. "步骤名称": "灵感特征列表统一匹配人设特征组合 (v4)",
  193. "当前特征列表": feature_list,
  194. "匹配结果": match_results
  195. }
  196. # 构建精简版 how 步骤(只包含精简结果)
  197. how_step_simplified = {
  198. "步骤名称": "灵感特征列表统一匹配人设特征组合 (v4) - 精简版",
  199. "当前特征列表": feature_list,
  200. "匹配结果": [
  201. match.get("精简结果", {})
  202. for match in match_results
  203. ]
  204. }
  205. # 返回更新后的灵感点
  206. result = inspiration_point.copy()
  207. result["how步骤列表"] = [how_step]
  208. result["how步骤列表_精简版"] = [how_step_simplified]
  209. return result
  210. async def process_single_task(
  211. task: Dict,
  212. task_index: int,
  213. total_tasks: int,
  214. persona_combinations: List[Dict],
  215. model_name: Optional[str] = None
  216. ) -> Dict:
  217. """处理单个任务"""
  218. post_id = task.get("帖子id", "")
  219. print(f"\n处理任务 [{task_index}/{total_tasks}]: {post_id}")
  220. what_result = task.get("what解构结果", {})
  221. inspiration_list = what_result.get("灵感点列表", [])
  222. print(f" 灵感点数量: {len(inspiration_list)}")
  223. # 并发处理所有灵感点
  224. tasks = [
  225. process_single_inspiration_point(
  226. inspiration_point=inspiration_point,
  227. persona_combinations=persona_combinations,
  228. model_name=model_name
  229. )
  230. for inspiration_point in inspiration_list
  231. ]
  232. updated_inspiration_list = await asyncio.gather(*tasks)
  233. # 构建 how 解构结果
  234. how_result = {
  235. "灵感点列表": list(updated_inspiration_list)
  236. }
  237. # 更新任务
  238. updated_task = task.copy()
  239. updated_task["how解构结果"] = how_result
  240. return updated_task
  241. async def process_task_list(
  242. task_list: List[Dict],
  243. persona_combinations: List[Dict],
  244. model_name: Optional[str] = None,
  245. current_time: Optional[str] = None,
  246. log_url: Optional[str] = None
  247. ) -> List[Dict]:
  248. """处理整个解构任务列表(并发执行)"""
  249. print(f"人设灵感特征组合数量: {len(persona_combinations)}")
  250. with custom_span(
  251. name="统一匹配 v4 - 所有任务",
  252. data={
  253. "任务总数": len(task_list),
  254. "人设组合数量": len(persona_combinations),
  255. "current_time": current_time,
  256. "log_url": log_url
  257. }
  258. ):
  259. # 并发处理所有任务
  260. tasks = [
  261. process_single_task(
  262. task=task,
  263. task_index=i,
  264. total_tasks=len(task_list),
  265. persona_combinations=persona_combinations,
  266. model_name=model_name
  267. )
  268. for i, task in enumerate(task_list, 1)
  269. ]
  270. updated_task_list = await asyncio.gather(*tasks)
  271. return list(updated_task_list)
  272. async def main(current_time: Optional[str] = None, log_url: Optional[str] = None):
  273. """主函数"""
  274. # 输入输出路径
  275. script_dir = Path(__file__).parent
  276. project_root = script_dir.parent.parent
  277. data_dir = project_root / "data" / "data_1118"
  278. task_list_file = data_dir / "当前帖子_解构任务列表.json"
  279. persona_combinations_file = data_dir / "特征组合_帖子来源.json"
  280. feature_categories_file = data_dir / "特征名称_分类映射.json"
  281. output_dir = data_dir / "当前帖子_how解构结果_v4"
  282. # 创建输出目录
  283. output_dir.mkdir(parents=True, exist_ok=True)
  284. # 获取模型名称
  285. from lib.client import MODEL_NAME
  286. model_name_short = MODEL_NAME.replace("google/", "").replace("/", "_")
  287. print(f"读取解构任务列表: {task_list_file}")
  288. with open(task_list_file, "r", encoding="utf-8") as f:
  289. task_list_data = json.load(f)
  290. print(f"读取人设特征组合: {persona_combinations_file}")
  291. with open(persona_combinations_file, "r", encoding="utf-8") as f:
  292. persona_combinations_data = json.load(f)
  293. print(f"读取特征分类映射: {feature_categories_file}")
  294. feature_categories = load_feature_categories(feature_categories_file)
  295. # 获取任务列表 - 处理所有帖子
  296. task_list = task_list_data.get("解构任务列表", [])
  297. print(f"\n总任务数: {len(task_list)}")
  298. print(f"使用模型: {MODEL_NAME}\n")
  299. # 为人设特征组合添加分类信息(只处理灵感点)- 使用所有组合
  300. persona_inspiration_combinations_raw = persona_combinations_data.get("灵感点", [])
  301. persona_inspiration_combinations = enrich_persona_combinations_with_categories(
  302. persona_combinations=persona_inspiration_combinations_raw,
  303. feature_categories=feature_categories,
  304. point_type="灵感点"
  305. )
  306. print(f"灵感点特征组合数量: {len(persona_inspiration_combinations)}")
  307. print(f"示例组合 (前2个):")
  308. for i, combo in enumerate(persona_inspiration_combinations[:2], 1):
  309. print(f" {i}. 原始组合: {combo['原始特征组合']}")
  310. print(f" 带分类: {combo['特征组合'][:2]}...") # 只显示前2个特征
  311. print()
  312. # 处理任务列表
  313. updated_task_list = await process_task_list(
  314. task_list=task_list,
  315. persona_combinations=persona_inspiration_combinations,
  316. model_name=None,
  317. current_time=current_time,
  318. log_url=log_url
  319. )
  320. # 分文件保存结果
  321. print(f"\n保存结果到: {output_dir}")
  322. for task in updated_task_list:
  323. post_id = task.get("帖子id", "unknown")
  324. output_file = output_dir / f"{post_id}_how_v4_{model_name_short}.json"
  325. # 在每个任务中添加元数据
  326. task["元数据"] = {
  327. "current_time": current_time,
  328. "log_url": log_url,
  329. "version": "v4_unified_match",
  330. "model": MODEL_NAME,
  331. "说明": "v4版本: 使用单个prompt统一完成标签匹配和分类匹配"
  332. }
  333. print(f" 保存: {output_file.name}")
  334. with open(output_file, "w", encoding="utf-8") as f:
  335. json.dump(task, f, ensure_ascii=False, indent=4)
  336. print("\n完成!")
  337. # 打印统计信息
  338. total_inspiration_points = sum(
  339. len(task["how解构结果"]["灵感点列表"])
  340. for task in updated_task_list
  341. )
  342. total_matches = sum(
  343. len(point["how步骤列表"][0]["匹配结果"])
  344. for task in updated_task_list
  345. for point in task["how解构结果"]["灵感点列表"]
  346. )
  347. print(f"\n统计:")
  348. print(f" 处理的帖子数: {len(updated_task_list)}")
  349. print(f" 处理的灵感点数: {total_inspiration_points}")
  350. print(f" 生成的匹配结果数: {total_matches}")
  351. if log_url:
  352. print(f"\nTrace: {log_url}\n")
  353. if __name__ == "__main__":
  354. # 设置 trace
  355. current_time, log_url = set_trace()
  356. # 使用 trace 上下文包裹整个执行流程
  357. with trace("灵感特征统一匹配 v4"):
  358. asyncio.run(main(current_time, log_url))