step1_inspiration_match.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. """
  2. 灵感点与人设匹配分析 - Agent 框架版
  3. 基于 how_decode_v1.py 的 Agent 框架实现
  4. 参考 step1_match_inspiration_to_persona_v11.py 的业务逻辑
  5. """
  6. import asyncio
  7. import json
  8. import os
  9. import sys
  10. from typing import List, Dict
  11. from agents import trace
  12. from agents.tracing.create import custom_span
  13. from lib.my_trace import set_trace_smith as set_trace
  14. from lib.async_utils import process_tasks_with_semaphore
  15. from lib.match_analyzer import match_with_definition
  16. from lib.data_loader import load_persona_data, load_inspiration_list, select_inspiration
  17. # 模型配置
  18. MODEL_NAME = 'x-ai/grok-code-fast-1'
  19. MODEL_NAME = 'anthropic/claude-sonnet-4.5'
  20. MODEL_NAME = 'google/gemini-2.5-flash'
  21. MODEL_NAME = 'openai/gpt-5'
  22. MODEL_NAME = 'deepseek/deepseek-chat-v3-0324'
  23. MODEL_NAME = 'openai/gpt-4.1'
  24. MODEL_NAME = "google/gemini-2.5-pro"
  25. def build_context_str(perspective_name: str, level1_name: str = None) -> str:
  26. """构建上下文字符串
  27. Args:
  28. perspective_name: 视角名称
  29. level1_name: 一级分类名称(仅在匹配二级分类时提供)
  30. Returns:
  31. 上下文字符串
  32. """
  33. if level1_name:
  34. # 匹配二级分类:包含视角和一级分类
  35. return f"""所属视角: {perspective_name}
  36. 一级分类: {level1_name}"""
  37. else:
  38. # 匹配一级分类:只包含视角
  39. return f"""所属视角: {perspective_name}"""
  40. # ========== 核心匹配逻辑 ==========
  41. async def match_single_task(task: dict, _index: int) -> dict:
  42. """执行单个匹配任务(异步版本)
  43. Args:
  44. task: 匹配任务,包含:
  45. - 灵感: 灵感点文本
  46. - 要素名称: 要素名称
  47. - 要素定义: 要素定义
  48. - 要素类型: "一级分类" 或 "二级分类"
  49. - 上下文: 上下文字符串
  50. _index: 任务索引(由 async_utils 传入,此处未使用)
  51. Returns:
  52. 匹配结果
  53. """
  54. inspiration = task["灵感"]
  55. element_name = task["要素名称"]
  56. element_definition = task["要素定义"]
  57. context_str = task["上下文"]
  58. # 调用名称+定义匹配模块(内部已包含错误处理和 custom_span 追踪)
  59. # B = 灵感, A = 名称+定义, A_Context = 上下文
  60. match_result = await match_with_definition(
  61. b_content=inspiration,
  62. element_name=element_name,
  63. element_definition=element_definition,
  64. model_name=MODEL_NAME,
  65. a_context=context_str # 要素的上下文
  66. )
  67. # 构建完整结果(通用字段 + 业务信息统一存储在最后)
  68. full_result = {
  69. "输入信息": {
  70. "B": inspiration, # 待匹配:灵感
  71. "A_名称": element_name, # 要素名称
  72. "A_定义": element_definition, # 要素定义
  73. "B_Context": "", # B的上下文(暂时为空)
  74. "A_Context": context_str # A的上下文:所属视角/一级分类
  75. },
  76. "匹配结果": match_result, # {"名称匹配": {...}, "定义匹配": {...}}
  77. "业务信息": { # 业务语义信息(统一存储在最后)
  78. "灵感": inspiration,
  79. "匹配要素名称": element_name,
  80. "匹配要素定义": element_definition
  81. }
  82. }
  83. return full_result
  84. # ========== 任务构建 ==========
  85. def build_match_tasks(
  86. persona_data: dict,
  87. inspiration: str,
  88. max_tasks: int = None
  89. ) -> List[dict]:
  90. """构建匹配任务列表
  91. Args:
  92. persona_data: 人设数据
  93. inspiration: 灵感点
  94. max_tasks: 最大任务数(None 表示不限制)
  95. Returns:
  96. 任务列表
  97. """
  98. tasks = []
  99. # 从"灵感点列表"中提取任务
  100. for perspective in persona_data.get("灵感点列表", []):
  101. if max_tasks is not None and len(tasks) >= max_tasks:
  102. break
  103. perspective_name = perspective.get("视角名称", "")
  104. for pattern in perspective.get("模式列表", []):
  105. if max_tasks is not None and len(tasks) >= max_tasks:
  106. break
  107. level1_name = pattern.get("分类名称", "")
  108. level1_definition = pattern.get("核心定义", "")
  109. # 添加一级分类任务
  110. context_str = build_context_str(perspective_name)
  111. tasks.append({
  112. "灵感": inspiration,
  113. "要素名称": level1_name,
  114. "要素定义": level1_definition,
  115. "要素类型": "一级分类",
  116. "上下文": context_str
  117. })
  118. # 添加该一级下的所有二级分类任务
  119. for level2 in pattern.get("二级细分", []):
  120. if max_tasks is not None and len(tasks) >= max_tasks:
  121. break
  122. level2_name = level2.get("分类名称", "")
  123. level2_definition = level2.get("分类定义", "")
  124. context_str = build_context_str(perspective_name, level1_name)
  125. tasks.append({
  126. "灵感": inspiration,
  127. "要素名称": level2_name,
  128. "要素定义": level2_definition,
  129. "要素类型": "二级分类",
  130. "上下文": context_str
  131. })
  132. return tasks
  133. # ========== 核心业务逻辑 ==========
  134. async def process_inspiration_match(
  135. persona_data: dict,
  136. inspiration: str,
  137. max_tasks: int = None,
  138. max_concurrent: int = 3,
  139. current_time: str = None,
  140. log_url: str = None
  141. ) -> dict:
  142. """执行灵感与人设匹配分析(核心业务逻辑)
  143. Args:
  144. persona_data: 人设数据字典
  145. inspiration: 灵感点文本
  146. max_tasks: 最大任务数(None 表示不限制)
  147. max_concurrent: 最大并发数
  148. current_time: 当前时间戳
  149. log_url: 日志链接
  150. Returns:
  151. 匹配结果字典,包含元数据和匹配结果列表
  152. """
  153. # 构建匹配任务
  154. test_tasks = build_match_tasks(persona_data, inspiration, max_tasks)
  155. print(f"\n开始匹配分析: {inspiration}")
  156. print(f"任务数: {len(test_tasks)}, 模型: {MODEL_NAME}\n")
  157. # 使用 custom_span 标识整个匹配流程
  158. with custom_span(
  159. name=f"Step1: 灵感与人设匹配 - {inspiration}",
  160. data={
  161. "灵感": inspiration,
  162. "任务总数": len(test_tasks),
  163. "模型": MODEL_NAME,
  164. "并发数": max_concurrent,
  165. "步骤": "字面语义匹配分析"
  166. }
  167. ):
  168. # 异步并发执行匹配(match_single_task 内部已处理所有错误)
  169. results = await process_tasks_with_semaphore(
  170. test_tasks,
  171. match_single_task,
  172. max_concurrent=max_concurrent,
  173. show_progress=True
  174. )
  175. # 按 score 降序排序
  176. results.sort(
  177. key=lambda x: x.get('匹配结果', {}).get('score', 0),
  178. reverse=True
  179. )
  180. # 构建输出结果
  181. output = {
  182. "元数据": {
  183. "current_time": current_time,
  184. "log_url": log_url,
  185. "model": MODEL_NAME
  186. },
  187. "灵感": inspiration,
  188. "匹配结果列表": results
  189. }
  190. return output
  191. # ========== 主函数 ==========
  192. async def main(current_time: str = None, log_url: str = None, force: bool = False):
  193. """主函数:负责参数解析、文件读取、结果保存
  194. Args:
  195. current_time: 当前时间戳(从外部传入)
  196. log_url: 日志链接(从外部传入)
  197. force: 是否强制重新执行(跳过已存在文件检查)
  198. """
  199. # 解析命令行参数
  200. # 第一个参数:人设文件夹路径(默认值)
  201. if len(sys.argv) > 1:
  202. persona_dir = sys.argv[1]
  203. else:
  204. persona_dir = "data/阿里多多酱/out/人设_1110"
  205. # 第二个参数:灵感索引(数字)或灵感名称(字符串),默认为 0
  206. inspiration_arg = sys.argv[2] if len(sys.argv) > 2 else "0"
  207. # 第三个参数:任务数限制,默认为 None(所有任务)
  208. max_tasks = None if len(sys.argv) > 3 and sys.argv[3] == "all" else (
  209. int(sys.argv[3]) if len(sys.argv) > 3 else None
  210. )
  211. # 第四个参数:force(如果从命令行调用且有该参数,则覆盖函数参数)
  212. if len(sys.argv) > 4 and sys.argv[4] == "force":
  213. force = True
  214. # 加载数据(使用辅助函数,失败时自动退出)
  215. persona_data = load_persona_data(persona_dir)
  216. inspiration_list = load_inspiration_list(persona_dir)
  217. test_inspiration = select_inspiration(inspiration_arg, inspiration_list)
  218. # 构建输出文件路径
  219. output_dir = os.path.join(persona_dir, "how", "灵感点", test_inspiration)
  220. model_name_short = MODEL_NAME.replace("google/", "").replace("/", "_")
  221. step_name_cn = "灵感人设匹配"
  222. scope_prefix = f"top{max_tasks}" if max_tasks is not None else "all"
  223. output_filename = f"{scope_prefix}_step1_{step_name_cn}_{model_name_short}.json"
  224. output_file = os.path.join(output_dir, output_filename)
  225. # 检查文件是否已存在
  226. if not force and os.path.exists(output_file):
  227. print(f"\n✓ 输出文件已存在,跳过执行: {output_file}")
  228. print(f"提示: 如需重新执行,请添加 'force' 参数\n")
  229. return
  230. # 执行核心业务逻辑
  231. output = await process_inspiration_match(
  232. persona_data=persona_data,
  233. inspiration=test_inspiration,
  234. max_tasks=max_tasks,
  235. max_concurrent=10,
  236. current_time=current_time,
  237. log_url=log_url
  238. )
  239. # 确保目录存在
  240. os.makedirs(output_dir, exist_ok=True)
  241. # 保存结果
  242. with open(output_file, 'w', encoding='utf-8') as f:
  243. json.dump(output, f, ensure_ascii=False, indent=2)
  244. print(f"\n完成!结果已保存到: {output_file}")
  245. if log_url:
  246. print(f"Trace: {log_url}\n")
  247. if __name__ == "__main__":
  248. # 设置 trace
  249. current_time, log_url = set_trace()
  250. # 使用 trace 上下文包裹整个执行流程
  251. with trace("灵感与人设匹配"):
  252. asyncio.run(main(current_time, log_url))