test_script_orthogonal_with_real_data.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 使用真实数据测试脚本正交分析Agent
  5. """
  6. import sys
  7. import os
  8. import json
  9. from pathlib import Path
  10. from dotenv import load_dotenv
  11. # 添加项目根目录到 Python 路径
  12. project_root = Path(__file__).parent.parent.parent
  13. sys.path.insert(0, str(project_root))
  14. # 加载环境变量
  15. load_dotenv(project_root / ".env")
  16. from src.components.agents.script_orthogonal_analysis_agent import ScriptOrthogonalAnalysisAgent
  17. def load_script_result(file_path: str) -> dict:
  18. """加载脚本解析结果文件
  19. Args:
  20. file_path: JSON文件路径
  21. Returns:
  22. 解析后的JSON数据
  23. """
  24. with open(file_path, 'r', encoding='utf-8') as f:
  25. return json.load(f)
  26. def test_with_real_data():
  27. """使用真实数据测试脚本正交分析Agent"""
  28. print("=" * 100)
  29. print("使用真实数据测试脚本正交分析Agent")
  30. print("=" * 100)
  31. # 加载真实数据
  32. script_result_path = project_root / "examples/阿里多多酱/output/script_result_20251118_152626.json"
  33. print(f"\n📁 加载脚本解析结果: {script_result_path}")
  34. script_result = load_script_result(str(script_result_path))
  35. # 提取所需数据
  36. topic_description = script_result.get("选题描述", {})
  37. content_weight = script_result.get("图文权重", {})
  38. script_understanding = script_result.get("脚本理解", {})
  39. # 构建state
  40. state = {
  41. "text": {
  42. "title": "当代年轻人对食物的双标日常",
  43. "body": "#讨好型水果[话题]#"
  44. },
  45. "images": script_understanding.get("图片列表", []),
  46. "topic_selection_understanding": topic_description,
  47. "content_weight": content_weight,
  48. "script_sections": {
  49. "内容品类": script_understanding.get("内容品类", ""),
  50. "段落列表": script_understanding.get("段落列表", [])
  51. },
  52. "script_elements": {
  53. "元素列表": script_understanding.get("元素列表", [])
  54. }
  55. }
  56. print(f"\n✓ 数据加载成功")
  57. print(f" - 选题主题: {topic_description.get('主题', 'N/A')}")
  58. print(f" - 内容品类: {script_understanding.get('内容品类', 'N/A')}")
  59. print(f" - 段落数量: {len(script_understanding.get('段落列表', []))}")
  60. print(f" - 元素数量: {len(script_understanding.get('元素列表', []))}")
  61. print(f" - 图片数量: {len(script_understanding.get('图片列表', []))}")
  62. # 初始化Agent
  63. print("\n🤖 初始化ScriptOrthogonalAnalysisAgent...")
  64. agent = ScriptOrthogonalAnalysisAgent()
  65. print(" ✓ Agent初始化成功")
  66. # 执行正交分析
  67. print("\n⚙️ 开始执行正交分析...")
  68. print("-" * 100)
  69. result = agent.process(state)
  70. print("-" * 100)
  71. # 输出结果
  72. print("\n" + "=" * 100)
  73. print("📊 正交分析结果")
  74. print("=" * 100)
  75. orthogonal_matrix = result.get("正交矩阵", [])
  76. element_type_list = result.get("元素类型列表", [])
  77. print(f"\n✓ 正交矩阵生成成功")
  78. print(f" - 段落行数: {len(orthogonal_matrix)}")
  79. print(f" - 元素类型列数: {len(element_type_list)}")
  80. print(f"\n📋 元素类型列表 (共{len(element_type_list)}个):")
  81. for idx, element_type in enumerate(element_type_list, 1):
  82. print(f" {idx}. {element_type}")
  83. print(f"\n📋 正交矩阵预览 (前3个段落):")
  84. for idx, row in enumerate(orthogonal_matrix[:3], 1):
  85. print(f"\n 段落 {idx}: {row['段落']}")
  86. print(f" 内容范围: {row['内容范围']}")
  87. print(f" 元素类型分析:")
  88. for element_type, analysis in row.get('元素类型分析', {}).items():
  89. if analysis: # 只显示非空的分析
  90. print(f" - {element_type}: {analysis[:80]}{'...' if len(analysis) > 80 else ''}")
  91. # 保存结果
  92. output_dir = project_root / "examples/阿里多多酱/output"
  93. output_path = output_dir / "orthogonal_analysis_result.json"
  94. print(f"\n💾 保存结果到: {output_path}")
  95. with open(output_path, 'w', encoding='utf-8') as f:
  96. json.dump(result, f, ensure_ascii=False, indent=2)
  97. print(f" ✓ 结果保存成功")
  98. # 生成表格形式的输出(Markdown格式)
  99. markdown_output_path = output_dir / "orthogonal_analysis_table.md"
  100. print(f"\n📝 生成Markdown表格: {markdown_output_path}")
  101. with open(markdown_output_path, 'w', encoding='utf-8') as f:
  102. f.write("# 脚本正交分析矩阵\n\n")
  103. f.write(f"**选题主题**: {topic_description.get('主题', 'N/A')}\n\n")
  104. f.write(f"**内容品类**: {script_understanding.get('内容品类', 'N/A')}\n\n")
  105. # 生成表格
  106. f.write("## 正交矩阵表格\n\n")
  107. # 表头
  108. header = "| 段落 |"
  109. separator = "|------|"
  110. for element_type in element_type_list:
  111. header += f" {element_type} |"
  112. separator += "------|"
  113. f.write(header + "\n")
  114. f.write(separator + "\n")
  115. # 表格内容
  116. for row in orthogonal_matrix:
  117. line = f"| {row['段落']} |"
  118. for element_type in element_type_list:
  119. analysis = row.get('元素类型分析', {}).get(element_type, "")
  120. # 处理换行和特殊字符
  121. analysis = analysis.replace("\n", " ").replace("|", "\\|")
  122. line += f" {analysis} |"
  123. f.write(line + "\n")
  124. # 添加详细说明
  125. f.write("\n## 段落内容范围详情\n\n")
  126. for idx, row in enumerate(orthogonal_matrix, 1):
  127. f.write(f"### {idx}. {row['段落']}\n\n")
  128. f.write(f"**内容范围**:\n")
  129. for content in row['内容范围']:
  130. f.write(f"- {content}\n")
  131. f.write("\n")
  132. print(f" ✓ Markdown表格生成成功")
  133. # 测试总结
  134. print("\n" + "=" * 100)
  135. print("✅ 测试完成总结")
  136. print("=" * 100)
  137. print(f" ✓ 成功加载真实数据文件")
  138. print(f" ✓ 成功提取 {len(orthogonal_matrix)} 个段落")
  139. print(f" ✓ 成功提取 {len(element_type_list)} 个元素类型")
  140. print(f" ✓ 成功生成正交分析矩阵")
  141. print(f" ✓ 结果已保存为JSON和Markdown格式")
  142. print(f"\n 📂 输出文件:")
  143. print(f" - JSON: {output_path}")
  144. print(f" - Markdown: {markdown_output_path}")
  145. print("\n" + "🎉" * 50)
  146. print("测试成功完成!")
  147. print("🎉" * 50 + "\n")
  148. return result
  149. if __name__ == "__main__":
  150. try:
  151. test_with_real_data()
  152. except Exception as e:
  153. print(f"\n❌ 测试失败: {e}\n")
  154. import traceback
  155. traceback.print_exc()
  156. exit(1)