test_knowledge_requirement_agent.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. """
  2. 知识需求生成Agent的完整测试脚本
  3. 测试KnowledgeRequirementGenerationAgent的新输出格式(PRD 1.4)
  4. 新格式包含:
  5. - 整体项目目标
  6. - 本次任务目标
  7. - 上下文
  8. - 待解构帖子信息
  9. - 需求(内容知识需求 + 工具知识需求)
  10. """
  11. import sys
  12. import json
  13. import os
  14. from pathlib import Path
  15. from dotenv import load_dotenv
  16. # 添加项目根目录到路径
  17. project_root = Path(__file__).parent.parent.parent
  18. sys.path.insert(0, str(project_root))
  19. # 加载环境变量
  20. load_dotenv(project_root / ".env")
  21. from langchain.chat_models import init_chat_model
  22. from src.components.agents.knowledge_requirement_agent import generate_knowledge_requirements
  23. from src.utils.logger import get_logger
  24. logger = get_logger(__name__)
  25. def load_test_data(file_path: str) -> dict:
  26. """加载测试数据
  27. Args:
  28. file_path: JSON文件路径
  29. Returns:
  30. 解析后的JSON数据
  31. """
  32. with open(file_path, 'r', encoding='utf-8') as f:
  33. return json.load(f)
  34. def read_prd_content(pdf_path: str) -> str:
  35. """读取PRD PDF内容
  36. Args:
  37. pdf_path: PDF文件路径
  38. Returns:
  39. PRD文本内容
  40. """
  41. try:
  42. import pymupdf # PyMuPDF
  43. # 打开PDF文件
  44. doc = pymupdf.open(pdf_path)
  45. # 提取所有页面的文本
  46. text_content = []
  47. for page_num in range(len(doc)):
  48. page = doc[page_num]
  49. text_content.append(page.get_text())
  50. doc.close()
  51. # 合并所有页面的文本
  52. full_text = "\n".join(text_content)
  53. logger.info(f"成功读取PDF文件: {pdf_path}, 内容长度: {len(full_text)} 字符")
  54. return full_text
  55. except ImportError:
  56. logger.warning("未安装 pymupdf 库,使用 PyPDF2 作为备选方案")
  57. try:
  58. from PyPDF2 import PdfReader
  59. reader = PdfReader(pdf_path)
  60. text_content = []
  61. for page in reader.pages:
  62. text_content.append(page.extract_text())
  63. full_text = "\n".join(text_content)
  64. logger.info(f"成功读取PDF文件 (PyPDF2): {pdf_path}, 内容长度: {len(full_text)} 字符")
  65. return full_text
  66. except ImportError:
  67. logger.error("未安装 PDF 读取库 (pymupdf 或 PyPDF2),无法读取PDF文件")
  68. raise ImportError("请安装 pymupdf (推荐) 或 PyPDF2: pip install pymupdf 或 pip install PyPDF2")
  69. except Exception as e:
  70. logger.error(f"读取PDF文件失败: {e}", exc_info=True)
  71. raise
  72. def format_post_content(raw_data: dict) -> dict:
  73. """格式化帖子内容为Agent所需格式
  74. Args:
  75. raw_data: 原始帖子数据
  76. Returns:
  77. 格式化后的帖子内容
  78. """
  79. return {
  80. "text": {
  81. "title": raw_data.get("title", ""),
  82. "body": raw_data.get("body_text", ""),
  83. "hashtags": [] # 可以从body_text中提取
  84. },
  85. "images": raw_data.get("images", []),
  86. "metadata": {
  87. "link": raw_data.get("link", ""),
  88. "content_id": raw_data.get("channel_content_id", ""),
  89. "account_name": raw_data.get("channel_account_name", ""),
  90. "content_type": raw_data.get("content_type", ""),
  91. "comment_count": raw_data.get("comment_count", 0),
  92. "like_count": raw_data.get("like_count", 0),
  93. "collect_count": raw_data.get("collect_count", 0)
  94. }
  95. }
  96. def test_complete_knowledge_requirement_generation():
  97. """完整测试:知识需求生成(新格式 - PRD 1.4)
  98. 测试场景:
  99. 1. 帖子整体解构
  100. 2. 验证输出格式包含所有必需章节
  101. 3. 验证知识需求正确分类(内容知识 vs 工具知识)
  102. 测试步骤:
  103. - Step 1: 初始化LLM
  104. - Step 2: 加载PRD内容
  105. - Step 3: 加载测试帖子数据
  106. - Step 4: 定义任务阶段
  107. - Step 5: 创建Agent并生成知识需求
  108. - Step 6: 展示结果
  109. - Step 7: 格式验证
  110. - Step 8: 保存Markdown文档
  111. - Step 9: 测试总结
  112. """
  113. print("\n" + "=" * 100)
  114. print("📝 完整测试:KnowledgeRequirementGenerationAgent(新格式 - PRD 1.4)")
  115. print(" - 输入:PRD文档 + 小红书帖子 + 解构上下文 + 任务阶段")
  116. print(" - 输出:结构化知识需求文档(包含项目目标、任务目标、上下文、需求等)")
  117. print("=" * 100)
  118. # Step 1: 初始化LLM
  119. print("\n🤖 Step 1: 初始化LLM (Gemini)...")
  120. # 确保 GOOGLE_API_KEY 环境变量已设置
  121. google_api_key = os.getenv("GEMINI_API_KEY")
  122. if google_api_key:
  123. os.environ["GOOGLE_API_KEY"] = google_api_key
  124. llm = init_chat_model("gemini-2.5-flash", model_provider="google_genai")
  125. print(" ✓ LLM初始化成功")
  126. # Step 2: 加载PRD内容
  127. print("\n📄 Step 2: 加载PRD内容...")
  128. prd_path = project_root / "prd1.4.pdf"
  129. prd_content = read_prd_content(str(prd_path))
  130. print(f" ✓ PRD加载成功,内容长度: {len(prd_content)} 字符")
  131. print(f" PRD摘要(前200字): {prd_content}...")
  132. # Step 3: 加载测试帖子数据
  133. print("\n📁 Step 3: 加载测试帖子数据...")
  134. test_data_path = project_root / "examples/测试数据/阿里多多酱/待解构帖子.json"
  135. raw_data = load_test_data(str(test_data_path))
  136. post_content = format_post_content(raw_data)
  137. print(f" ✓ 帖子数据加载成功")
  138. print(f" - 标题: {raw_data.get('title', 'N/A')}")
  139. print(f" - 图片数量: {len(raw_data.get('images', []))}张")
  140. print(f" - 点赞数: {raw_data.get('like_count', 0)}")
  141. # Step 4: 定义任务阶段
  142. print("\n🎯 Step 4: 定义任务阶段...")
  143. task_stage = "帖子整体解构"
  144. print(f" ✓ 任务阶段: {task_stage}")
  145. print(f" - 说明: 确定帖子的描述维度(品类、主题、脚本、内容亮点、情绪共鸣点等)")
  146. # Step 5: 创建Agent并生成知识需求
  147. print("\n⚙️ Step 5: 创建Agent并生成知识需求...")
  148. print(" 选项:启用知识检索 = False(加快测试速度)")
  149. result = generate_knowledge_requirements(
  150. llm=llm,
  151. prd_content=prd_content,
  152. post_content=post_content,
  153. task_stage=task_stage, # 新增参数
  154. enable_retrieval=False # 设为True可启用知识检索,但会较慢
  155. )
  156. # Step 6: 展示结果
  157. print("\n" + "=" * 100)
  158. print("📊 Step 6: 生成结果展示")
  159. print("=" * 100)
  160. print(f"\n📝 总结:")
  161. print(f" {result.summary}")
  162. print(f"\n📄 Markdown文档:")
  163. print("-" * 100)
  164. print(result.markdown_document)
  165. print("-" * 100)
  166. # Step 7: 格式验证
  167. print("\n" + "=" * 100)
  168. print("🔍 Step 7: 格式验证")
  169. print("=" * 100)
  170. required_sections = [
  171. "# 整体项目目标",
  172. "# 本次任务目标",
  173. "# 上下文",
  174. "# 待解构帖子信息",
  175. "# 需求",
  176. "## 内容知识需求",
  177. "### 需求约束",
  178. "### 需求描述",
  179. "## 工具知识需求"
  180. ]
  181. all_passed = True
  182. for section in required_sections:
  183. if section in result.markdown_document:
  184. print(f" ✓ 包含章节: {section}")
  185. else:
  186. print(f" ✗ 缺少章节: {section}")
  187. all_passed = False
  188. if all_passed:
  189. print(f"\n 🎉 格式验证通过!所有必需章节均存在。")
  190. else:
  191. print(f"\n ⚠️ 格式验证未通过,缺少部分章节。")
  192. # Step 8: 保存Markdown文档
  193. print("\n" + "=" * 100)
  194. print("💾 Step 8: 保存Markdown文档")
  195. print("=" * 100)
  196. output_dir = project_root / "test/outputs"
  197. output_dir.mkdir(parents=True, exist_ok=True)
  198. output_path = output_dir / "knowledge_requirement_complete.md"
  199. with open(output_path, 'w', encoding='utf-8') as f:
  200. f.write(result.markdown_document)
  201. print(f"\n ✓ Markdown文档已保存到: {output_path}")
  202. print(f" - 文件大小: {len(result.markdown_document)} 字符")
  203. # Step 9: 测试总结
  204. print("\n" + "=" * 100)
  205. print("✅ 测试完成总结")
  206. print("=" * 100)
  207. print(f" ✓ 成功读取PRD文档 (prd1.4.pdf)")
  208. print(f" ✓ 成功加载小红书帖子数据")
  209. print(f" ✓ 成功定义任务阶段: {task_stage}")
  210. print(f" ✓ 成功生成知识需求文档(新格式 - PRD 1.4)")
  211. print(f" ✓ 格式验证: {'通过' if all_passed else '未通过'}")
  212. print(f" ✓ 总结: {result.summary}")
  213. print(f" ✓ Markdown文档大小: {len(result.markdown_document)} 字符")
  214. print("\n" + "🎉" * 50)
  215. print("测试成功完成! 新输出格式符合 PRD 1.4 要求")
  216. print("🎉" * 50 + "\n")
  217. return result
  218. def main():
  219. """主测试函数"""
  220. print("\n" + "🚀" * 50)
  221. print("KnowledgeRequirementGenerationAgent 完整测试套件(PRD 1.4 新格式)")
  222. print("🚀" * 50)
  223. try:
  224. _ = test_complete_knowledge_requirement_generation() # noqa: F841
  225. print("\n✅ 所有测试通过!")
  226. return 0
  227. except Exception as e:
  228. logger.error(f"测试失败: {e}", exc_info=True)
  229. print(f"\n❌ 测试失败: {e}\n")
  230. import traceback
  231. traceback.print_exc()
  232. return 1
  233. if __name__ == "__main__":
  234. exit(main())