update_case_knowledge.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 更新已导入的 case knowledge 记录
  5. - 使用 qwen 生成更好的 task 描述
  6. - 补充完整的 case 信息(input, output, operation_process, images)
  7. """
  8. import json
  9. import os
  10. import sys
  11. import time
  12. import asyncio
  13. from pathlib import Path
  14. from typing import List, Dict, Any
  15. # 添加父目录到路径
  16. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  17. # 加载环境变量
  18. from dotenv import load_dotenv
  19. project_root = Path(__file__).parent.parent.parent
  20. env_path = project_root / '.env'
  21. load_dotenv(env_path)
  22. from knowhub.knowhub_db.pg_store import PostgreSQLStore
  23. from knowhub.embeddings import get_embedding
  24. from agent.llm.qwen import qwen_llm_call
  25. async def get_embedding_with_retry(text: str, max_retries: int = 3) -> List[float]:
  26. """带重试的 embedding 生成"""
  27. for attempt in range(max_retries):
  28. try:
  29. return await get_embedding(text)
  30. except Exception as e:
  31. if attempt < max_retries - 1:
  32. wait_time = (attempt + 1) * 2
  33. print(f" ⚠ Embedding 生成失败,{wait_time}秒后重试... ({attempt + 1}/{max_retries})")
  34. await asyncio.sleep(wait_time)
  35. else:
  36. raise e
  37. async def generate_task_description(case_data: Dict) -> str:
  38. """使用 qwen 生成 task 描述"""
  39. title = case_data.get('title', '')
  40. user_input = case_data.get('user_input', {})
  41. output_desc = case_data.get('output_description', '')
  42. key_findings = case_data.get('key_findings', '')
  43. prompt = f"""请为以下 AI 图像生成案例生成一个简洁的任务描述(task),用于知识库检索。
  44. 案例标题:{title}
  45. 用户输入:{json.dumps(user_input, ensure_ascii=False)}
  46. 输出效果:{output_desc}
  47. 关键发现:{key_findings}
  48. 要求:
  49. 1. 用一句话概括这个案例的核心任务或目标
  50. 2. 突出关键的技术特点或应用场景
  51. 3. 不超过 30 个字
  52. 4. 直接输出描述,不要其他内容
  53. 示例格式:
  54. - 使用 sref 风格代码生成行星级尺度的巨构建筑场景
  55. - 测试不同 stylize 值对电影感角色渲染的影响
  56. - 对比 V7 和 V8 在科幻哥特风格上的表现差异
  57. """
  58. response = await qwen_llm_call(
  59. messages=[{"role": "user", "content": prompt}],
  60. temperature=0.3,
  61. max_tokens=100
  62. )
  63. task_desc = response["content"].strip()
  64. return task_desc
  65. def load_json_file(file_path: str) -> Dict:
  66. """加载 JSON 文件"""
  67. with open(file_path, 'r', encoding='utf-8') as f:
  68. return json.load(f)
  69. def build_enhanced_content(case_data: Dict) -> str:
  70. """构建增强的 content 内容"""
  71. title = case_data.get('title', '')
  72. source = case_data.get('source', '')
  73. source_link = case_data.get('source_link', '')
  74. user_input = case_data.get('user_input', {})
  75. output_desc = case_data.get('output_description', '')
  76. key_findings = case_data.get('key_findings', '')
  77. images = case_data.get('images', [])
  78. operation_process = case_data.get('operation_process', '')
  79. # 构建用户输入部分
  80. input_section = "## 用户输入\n"
  81. if user_input:
  82. for key, value in user_input.items():
  83. if isinstance(value, list):
  84. input_section += f"- **{key}**: {', '.join(str(v) for v in value)}\n"
  85. else:
  86. input_section += f"- **{key}**: {value}\n"
  87. else:
  88. input_section += "(无详细输入信息)\n"
  89. # 构建操作流程部分
  90. process_section = ""
  91. if operation_process:
  92. process_section = f"\n## 操作流程\n{operation_process}\n"
  93. # 构建图片部分
  94. images_section = ""
  95. if images:
  96. images_section = "\n## 示例图片\n"
  97. for i, img_url in enumerate(images, 1):
  98. images_section += f"{i}. {img_url}\n"
  99. content = f"""# {title}
  100. ## 来源
  101. {source}
  102. 链接: {source_link}
  103. {input_section}
  104. ## 输出效果
  105. {output_desc}
  106. ## 关键发现
  107. {key_findings}
  108. {process_section}{images_section}
  109. """
  110. return content
  111. async def update_cases(
  112. knowledge_store: PostgreSQLStore,
  113. tool_name: str,
  114. cases_json_path: str
  115. ):
  116. """更新 case knowledge 记录"""
  117. print(f"\n=== 更新 {tool_name} 的 cases ===")
  118. # 加载 cases.json
  119. data = load_json_file(cases_json_path)
  120. cases = data.get('cases', [])
  121. print(f"找到 {len(cases)} 个 cases")
  122. updated_count = 0
  123. for i, case in enumerate(cases):
  124. case_id = case.get('case_id', '')
  125. title = case.get('title', '')
  126. knowledge_id = f"knowledge-case-{tool_name}-{case_id}"
  127. print(f"\n[{i + 1}/{len(cases)}] 更新 case: {title[:50]}...")
  128. print(f" - Knowledge ID: {knowledge_id}")
  129. # 生成新的 task 描述
  130. print(f" - 使用 qwen 生成 task 描述...")
  131. try:
  132. task_desc = await generate_task_description(case)
  133. print(f" - Task: {task_desc}")
  134. except Exception as e:
  135. print(f" ⚠ Task 生成失败,使用标题: {e}")
  136. task_desc = title
  137. # 构建增强的 content
  138. print(f" - 构建增强内容...")
  139. enhanced_content = build_enhanced_content(case)
  140. # 重新生成 embedding(只基于 task,与 save_knowledge 保持一致)
  141. print(f" - 重新生成 embedding(基于 task)...")
  142. try:
  143. new_embedding = await get_embedding_with_retry(task_desc)
  144. print(f" - Embedding 维度: {len(new_embedding)}")
  145. except Exception as e:
  146. print(f" ⚠ Embedding 生成失败: {e}")
  147. new_embedding = None
  148. # 更新数据库
  149. try:
  150. print(f" - 更新数据库...")
  151. updates = {
  152. 'task': task_desc,
  153. 'content': enhanced_content,
  154. 'updated_at': int(time.time())
  155. }
  156. if new_embedding:
  157. updates['embedding'] = new_embedding
  158. knowledge_store.update(knowledge_id, updates)
  159. updated_count += 1
  160. print(f" ✓ 成功")
  161. except Exception as e:
  162. print(f" ✗ 失败: {e}")
  163. # 避免 API 限流
  164. await asyncio.sleep(0.5)
  165. print(f"\n成功更新 {updated_count}/{len(cases)} 个 cases\n")
  166. return updated_count
  167. async def main():
  168. """主函数"""
  169. print("=" * 60)
  170. print("开始更新 case knowledge 记录...")
  171. print("=" * 60)
  172. # 初始化
  173. print("\n[1/4] 初始化数据库连接...")
  174. try:
  175. knowledge_store = PostgreSQLStore()
  176. print(" ✓ 数据库连接成功")
  177. except Exception as e:
  178. print(f" ✗ 数据库连接失败: {e}")
  179. return
  180. print("\n[2/4] 检查数据文件...")
  181. # 定义数据路径
  182. base_path = Path(__file__).parent.parent.parent / 'examples' / 'tool_research' / 'outputs'
  183. tools_data = [
  184. {
  185. 'name': 'midjourney',
  186. 'cases': base_path / 'midjourney_0' / '02_cases.json'
  187. },
  188. {
  189. 'name': 'seedream',
  190. 'cases': base_path / 'seedream_1' / '02_cases.json'
  191. }
  192. ]
  193. for tool_data in tools_data:
  194. if tool_data['cases'].exists():
  195. print(f" ✓ {tool_data['name']}: {tool_data['cases']}")
  196. else:
  197. print(f" ✗ {tool_data['name']}: 文件不存在")
  198. print("\n[3/4] 测试 qwen API 连接...")
  199. try:
  200. test_response = await qwen_llm_call(
  201. messages=[{"role": "user", "content": "测试"}],
  202. temperature=0.3,
  203. max_tokens=10
  204. )
  205. print(f" ✓ qwen API 连接成功")
  206. except Exception as e:
  207. print(f" ✗ qwen API 连接失败: {e}")
  208. print(" 提示:请检查 QWEN_API_KEY 环境变量")
  209. return
  210. print("\n[4/4] 开始处理 cases...")
  211. print("=" * 60)
  212. total_updated = 0
  213. for tool_data in tools_data:
  214. tool_name = tool_data['name']
  215. if tool_data['cases'].exists():
  216. count = await update_cases(
  217. knowledge_store,
  218. tool_name,
  219. str(tool_data['cases'])
  220. )
  221. total_updated += count
  222. else:
  223. print(f"⚠ 文件不存在: {tool_data['cases']}")
  224. print("\n" + "="*50)
  225. print(f"更新完成!")
  226. print(f" - 总计更新: {total_updated} 条")
  227. print("="*50)
  228. if __name__ == '__main__':
  229. asyncio.run(main())