analyze_node_origin_v2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 特征来源分析脚本 V2
  5. 基于过滤后的 how 解构结果,分析目标特征可能由哪些其他特征推导而来。
  6. 输入:intermediate/filtered_results/ 中的过滤结果
  7. 输出:特征来源分析结果
  8. """
  9. import asyncio
  10. import json
  11. from pathlib import Path
  12. from typing import Dict, List, Optional
  13. import sys
  14. # 添加项目根目录到路径
  15. project_root = Path(__file__).parent.parent.parent
  16. sys.path.insert(0, str(project_root))
  17. from agents import Agent, Runner, ModelSettings, trace
  18. from agents.tracing.create import custom_span
  19. from lib.client import get_model
  20. from lib.my_trace import set_trace_smith as set_trace
  21. from script.data_processing.path_config import PathConfig
  22. # 模型配置
  23. MODEL_NAME = "google/gemini-3-pro-preview"
  24. # MODEL_NAME = 'anthropic/claude-sonnet-4.5'
  25. agent = Agent(
  26. name="Feature Origin Analyzer",
  27. model=get_model(MODEL_NAME),
  28. model_settings=ModelSettings(
  29. temperature=0.0,
  30. max_tokens=65536,
  31. ),
  32. tools=[],
  33. )
  34. # ===== 数据提取函数 =====
  35. def extract_post_info(how_result: Dict) -> Dict:
  36. """
  37. 从 how 解构结果中提取帖子信息(灵感点、目的点、关键点列表)
  38. Args:
  39. how_result: how解构结果
  40. Returns:
  41. 包含三类点列表的字典,每个点含名称、描述、特征列表
  42. """
  43. result = {}
  44. for point_type in ["灵感点", "目的点", "关键点"]:
  45. point_list_key = f"{point_type}列表"
  46. point_list = how_result.get(point_list_key, [])
  47. extracted_points = []
  48. for point in point_list:
  49. # 提取特征名称列表
  50. feature_names = []
  51. for feature in point.get("特征列表", []):
  52. feature_name = feature.get("特征名称", "")
  53. if feature_name:
  54. feature_names.append(feature_name)
  55. extracted_points.append({
  56. "名称": point.get("名称", ""),
  57. "描述": point.get("描述", ""),
  58. "特征列表": feature_names
  59. })
  60. if extracted_points:
  61. result[point_list_key] = extracted_points
  62. return result
  63. def get_all_features(post_info: Dict) -> List[Dict]:
  64. """
  65. 从帖子信息中提取所有特征(点+特征列表中的特征)
  66. Args:
  67. post_info: 帖子信息
  68. Returns:
  69. 所有特征列表,包含名称和类型
  70. """
  71. features = []
  72. for point_type in ["灵感点", "目的点", "关键点"]:
  73. point_list_key = f"{point_type}列表"
  74. for point in post_info.get(point_list_key, []):
  75. # 添加点本身作为特征
  76. features.append({
  77. "特征名称": point["名称"],
  78. "特征类型": point_type,
  79. "描述": point.get("描述", "")
  80. })
  81. return features
  82. # ===== Prompt 构建 =====
  83. def build_prompt(target_feature: str, post_info: Dict) -> str:
  84. """
  85. 构建分析 prompt
  86. Args:
  87. target_feature: 目标关键特征名称
  88. post_info: 帖子信息
  89. Returns:
  90. prompt 文本
  91. """
  92. # 将帖子信息转为 JSON 格式
  93. post_info_json = json.dumps(post_info, ensure_ascii=False, indent=4)
  94. return f'''你是一个内容创作逆向工程分析专家。你的任务是分析给定的特征是如何从其他特征中推理得出的。
  95. 请按照以下要求进行分析:
  96. ## 目标关键特征
  97. {target_feature}
  98. ## 帖子信息
  99. {post_info_json}
  100. ## 分析任务
  101. 将所有来源特征分为两类:
  102. ### 1. 单独推理
  103. - 定义: 该特征单独存在时,可以独立推导出目标关键特征,无需其他特征辅助
  104. ### 2. 组合推理
  105. - 定义: 2个或更多特征必须同时存在才能有效推导出目标关键特征
  106. ## 输出格式
  107. 使用JSON格式输出,结构如下:
  108. {{
  109. "目标关键特征": "...",
  110. "推理类型分类": {{
  111. "单独推理": [
  112. {{
  113. "排名": 1,
  114. "特征名称": "...",
  115. "特征类型": "灵感点/目的点/关键点",
  116. "可能性": 0.xx,
  117. "推理说明": "..."
  118. }}
  119. ],
  120. "组合推理": [
  121. {{
  122. "组合编号": 1,
  123. "组合成员": ["...", "..."],
  124. "成员类型": ["...", "..."],
  125. "可能性": 0.xx,
  126. "单独可能性": {{
  127. "成员1": 0.xx,
  128. "成员2": 0.xx
  129. }},
  130. "协同效应分析": {{
  131. "单独平均值": 0.xx,
  132. "协同增益": 0.xx,
  133. "增益说明": "..."
  134. }},
  135. "推理说明": "..."
  136. }}
  137. ]
  138. }}
  139. }}
  140. ## 注意事项
  141. 1. 可能性数值需要合理评估,范围在0-1之间
  142. 2. 单独推理按可能性从高到低排序
  143. 3. 组合推理必须包含2个或以上成员
  144. 4. 协同增益 = 组合可能性 - 单独平均值
  145. 5. 推理说明要清晰说明推导逻辑,避免空洞表述
  146. 6. 每个特征只能属于一种推理类型,不能既是单独推理又是组合推理的成员
  147. 7. 优先识别组合推理,剩余的特征作为单独推理
  148. 8. 一般先有实质,再有形式,如,先有角色,再有服化道;除非形式是关键特征
  149. '''.strip()
  150. # ===== 主分析函数 =====
  151. async def analyze_feature_origin(
  152. post_data: Dict,
  153. target_feature: str = None
  154. ) -> Dict:
  155. """
  156. 分析单个帖子中目标特征的来源
  157. Args:
  158. post_data: 帖子数据(包含 how解构结果)
  159. target_feature: 目标特征名称,如果为 None 则使用关键点的第一个
  160. Returns:
  161. 分析结果
  162. """
  163. post_id = post_data.get("帖子id", "")
  164. how_result = post_data.get("how解构结果", {})
  165. # 提取帖子信息
  166. post_info = extract_post_info(how_result)
  167. if not post_info:
  168. return {
  169. "帖子id": post_id,
  170. "模型": MODEL_NAME,
  171. "输入": {"帖子信息": {}},
  172. "输出": None,
  173. "错误": "没有可分析的点"
  174. }
  175. # 确定目标特征
  176. if target_feature is None:
  177. # 默认使用关键点的第一个
  178. key_points = post_info.get("关键点列表", [])
  179. if key_points:
  180. target_feature = key_points[0]["名称"]
  181. else:
  182. return {
  183. "帖子id": post_id,
  184. "模型": MODEL_NAME,
  185. "输入": {"帖子信息": post_info},
  186. "输出": None,
  187. "错误": "没有找到关键点"
  188. }
  189. # 构建 prompt
  190. prompt = build_prompt(target_feature, post_info)
  191. # 使用 custom_span 标识分析流程
  192. with custom_span(
  193. name=f"分析特征来源 - {target_feature}",
  194. data={
  195. "帖子id": post_id,
  196. "目标特征": target_feature,
  197. "模型": MODEL_NAME
  198. }
  199. ):
  200. # 调用 agent
  201. result = await Runner.run(agent, input=prompt)
  202. output = result.final_output
  203. # 解析 JSON
  204. try:
  205. if "```json" in output:
  206. json_start = output.find("```json") + 7
  207. json_end = output.find("```", json_start)
  208. json_str = output[json_start:json_end].strip()
  209. elif "{" in output and "}" in output:
  210. json_start = output.find("{")
  211. json_end = output.rfind("}") + 1
  212. json_str = output[json_start:json_end]
  213. else:
  214. json_str = output
  215. analysis_result = json.loads(json_str)
  216. return {
  217. "帖子id": post_id,
  218. "目标特征": target_feature,
  219. "模型": MODEL_NAME,
  220. "输入": {
  221. "帖子信息": post_info,
  222. "prompt": prompt
  223. },
  224. "输出": analysis_result
  225. }
  226. except Exception as e:
  227. return {
  228. "帖子id": post_id,
  229. "目标特征": target_feature,
  230. "模型": MODEL_NAME,
  231. "输入": {
  232. "帖子信息": post_info,
  233. "prompt": prompt
  234. },
  235. "输出": None,
  236. "错误": str(e),
  237. "原始输出": output
  238. }
  239. # ===== 主函数 =====
  240. async def main(
  241. post_id: str = None,
  242. target_feature: str = None,
  243. current_time: str = None,
  244. log_url: str = None
  245. ):
  246. """
  247. 主函数
  248. Args:
  249. post_id: 帖子ID,可选(默认使用第一个)
  250. target_feature: 目标特征名称,可选(默认使用关键点第一个)
  251. current_time: 当前时间戳(从外部传入)
  252. log_url: 日志链接(从外部传入)
  253. """
  254. config = PathConfig()
  255. # 获取输入目录
  256. input_dir = config.intermediate_dir / "filtered_results"
  257. output_dir = config.intermediate_dir / "feature_origin_analysis"
  258. output_dir.mkdir(parents=True, exist_ok=True)
  259. print(f"账号: {config.account_name}")
  260. print(f"输入目录: {input_dir}")
  261. print(f"输出目录: {output_dir}")
  262. print(f"使用模型: {MODEL_NAME}")
  263. if log_url:
  264. print(f"Trace URL: {log_url}")
  265. print()
  266. # 获取输入文件
  267. input_files = sorted(input_dir.glob("*_filtered.json"))
  268. if not input_files:
  269. print(f"错误: 在 {input_dir} 中没有找到任何 *_filtered.json 文件")
  270. return
  271. # 选择帖子
  272. if post_id:
  273. target_file = next(
  274. (f for f in input_files if post_id in f.name),
  275. None
  276. )
  277. if not target_file:
  278. print(f"错误: 未找到帖子 {post_id}")
  279. return
  280. else:
  281. target_file = input_files[0] # 默认第一个
  282. # 读取文件
  283. with open(target_file, "r", encoding="utf-8") as f:
  284. post_data = json.load(f)
  285. actual_post_id = post_data.get("帖子id", "unknown")
  286. print(f"帖子ID: {actual_post_id}")
  287. print(f"目标特征: {target_feature or '(默认关键点第一个)'}")
  288. print()
  289. # 分析
  290. result = await analyze_feature_origin(post_data, target_feature)
  291. # 显示结果
  292. output = result.get("输出")
  293. if output:
  294. print("=" * 60)
  295. print("分析结果")
  296. print("=" * 60)
  297. print(f"\n目标关键特征: {output.get('目标关键特征', 'N/A')}\n")
  298. reasoning = output.get("推理类型分类", {})
  299. # 显示单独推理
  300. single = reasoning.get("单独推理", [])
  301. if single:
  302. print("【单独推理】")
  303. for item in single:
  304. print(f" #{item.get('排名', '-')} [{item.get('可能性', 0):.2f}] {item.get('特征名称', '')} ({item.get('特征类型', '')})")
  305. print(f" {item.get('推理说明', '')}")
  306. # 显示组合推理
  307. combo = reasoning.get("组合推理", [])
  308. if combo:
  309. print("\n【组合推理】")
  310. for item in combo:
  311. members = " + ".join(item.get("组合成员", []))
  312. prob = item.get("可能性", 0)
  313. synergy = item.get("协同效应分析", {})
  314. gain = synergy.get("协同增益", 0)
  315. print(f" 组合{item.get('组合编号', '-')}: [{prob:.2f}] {members}")
  316. print(f" 协同增益: {gain:+.2f}")
  317. print(f" {item.get('推理说明', '')}")
  318. else:
  319. print(f"分析失败: {result.get('错误', 'N/A')}")
  320. # 保存结果
  321. target_name = result.get("目标特征", "unknown")
  322. output_file = output_dir / f"{actual_post_id}_{target_name}_来源分析.json"
  323. save_data = {
  324. "元数据": {
  325. "current_time": current_time,
  326. "log_url": log_url,
  327. "model": MODEL_NAME
  328. },
  329. **result
  330. }
  331. with open(output_file, "w", encoding="utf-8") as f:
  332. json.dump(save_data, f, ensure_ascii=False, indent=2)
  333. print(f"\n结果已保存到: {output_file}")
  334. if log_url:
  335. print(f"Trace: {log_url}")
  336. if __name__ == "__main__":
  337. import argparse
  338. parser = argparse.ArgumentParser(description="分析特征来源 V2")
  339. parser.add_argument("--post-id", type=str, help="帖子ID")
  340. parser.add_argument("--target", type=str, help="目标特征名称")
  341. args = parser.parse_args()
  342. # 设置 trace
  343. current_time, log_url = set_trace()
  344. # 使用 trace 上下文包裹整个执行流程
  345. with trace("特征来源分析V2"):
  346. asyncio.run(main(
  347. post_id=args.post_id,
  348. target_feature=args.target,
  349. current_time=current_time,
  350. log_url=log_url
  351. ))