import_tool_research_data.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 从 tool_research 输出导入需求和 case 到数据库
  5. - 需求存入 requirement_table
  6. - case 存入 knowledge 表
  7. """
  8. import json
  9. import os
  10. import sys
  11. import time
  12. import asyncio
  13. from pathlib import Path
  14. from typing import List, Dict, Any
  15. # 设置 Windows 控制台编码
  16. if sys.platform == 'win32':
  17. import codecs
  18. sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
  19. sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')
  20. # 添加父目录到路径以导入模块
  21. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  22. # 加载环境变量(从项目根目录)
  23. from dotenv import load_dotenv
  24. project_root = Path(__file__).parent.parent.parent
  25. env_path = project_root / '.env'
  26. load_dotenv(env_path)
  27. print(f"加载环境变量: {env_path}")
  28. from knowhub.knowhub_db.pg_store import PostgreSQLStore
  29. from knowhub.knowhub_db.pg_requirement_store import PostgreSQLRequirementStore
  30. from knowhub.embeddings import get_embedding
  31. async def get_embedding_with_retry(text: str, max_retries: int = 3) -> List[float]:
  32. """带重试的 embedding 生成"""
  33. for attempt in range(max_retries):
  34. try:
  35. return await get_embedding(text)
  36. except Exception as e:
  37. if attempt < max_retries - 1:
  38. wait_time = (attempt + 1) * 2 # 2, 4, 6 秒
  39. print(f" ⚠ Embedding 生成失败,{wait_time}秒后重试... ({attempt + 1}/{max_retries})")
  40. await asyncio.sleep(wait_time)
  41. else:
  42. raise e
  43. def load_json_file(file_path: str) -> Dict:
  44. """加载 JSON 文件"""
  45. with open(file_path, 'r', encoding='utf-8') as f:
  46. return json.load(f)
  47. async def import_requirements(
  48. req_store: PostgreSQLRequirementStore,
  49. tool_name: str,
  50. match_result_path: str,
  51. case_id_to_knowledge_id: Dict[str, str]
  52. ):
  53. """导入需求到 requirement_table
  54. Args:
  55. case_id_to_knowledge_id: case_id 到 knowledge_id 的映射
  56. """
  57. print(f"\n=== 导入 {tool_name} 的需求 ===")
  58. # 加载 match_nodes_result.json
  59. data = load_json_file(match_result_path)
  60. merged_demands = data.get('merged_demands', [])
  61. print(f"找到 {len(merged_demands)} 个合并后的需求")
  62. imported_count = 0
  63. for demand in merged_demands:
  64. demand_name = demand.get('demand_name', '')
  65. description = demand.get('description', '')
  66. source_case_ids = demand.get('source_case_ids', [])
  67. mount_decision = demand.get('mount_decision', {})
  68. print(f"\n[{imported_count + 1}/{len(merged_demands)}] 处理需求: {demand_name[:50]}...")
  69. # 构建需求 ID
  70. req_id = f"req_{tool_name}_{demand_name[:20].replace(' ', '_')}"
  71. print(f" - 需求 ID: {req_id}")
  72. # 生成 embedding (异步调用,带重试)
  73. print(f" - 生成 embedding...")
  74. embedding_text = f"{demand_name} {description}"
  75. embedding = await get_embedding_with_retry(embedding_text)
  76. print(f" - Embedding 维度: {len(embedding)}")
  77. # 提取挂载的节点信息
  78. mounted_nodes = mount_decision.get('mounted_nodes', [])
  79. source_items = [
  80. {
  81. 'entity_id': node['entity_id'],
  82. 'name': node['name'],
  83. 'source_type': node['source_type']
  84. }
  85. for node in mounted_nodes
  86. ]
  87. # 将 case_id 转换为 knowledge_id
  88. case_knowledge_ids = [
  89. case_id_to_knowledge_id.get(str(case_id), str(case_id))
  90. for case_id in source_case_ids
  91. ]
  92. print(f" - 关联 {len(case_knowledge_ids)} 个 case")
  93. # 构建需求记录
  94. requirement = {
  95. 'id': req_id,
  96. 'task': demand_name,
  97. 'type': '制作',
  98. 'source_type': 'itemset',
  99. 'source_itemset_id': f"{tool_name}_demands",
  100. 'source_items': source_items,
  101. 'tools': [{'id': f'tools/image_gen/{tool_name}', 'name': tool_name}],
  102. 'knowledge': [],
  103. 'case_knowledge': case_knowledge_ids, # 使用 knowledge_id
  104. 'process_knowledge': [],
  105. 'trace': {},
  106. 'body': description,
  107. 'embedding': embedding
  108. }
  109. try:
  110. print(f" - 插入数据库...")
  111. req_store.insert_or_update(requirement)
  112. imported_count += 1
  113. print(f" ✓ 成功")
  114. except Exception as e:
  115. print(f" ✗ 失败: {e}")
  116. print(f"成功导入 {imported_count}/{len(merged_demands)} 个需求\n")
  117. return imported_count
  118. async def import_cases(
  119. knowledge_store: PostgreSQLStore,
  120. tool_name: str,
  121. cases_json_path: str
  122. ) -> Dict[str, str]:
  123. """导入 case 到 knowledge 表
  124. Returns:
  125. case_id 到 knowledge_id 的映射字典
  126. """
  127. print(f"\n=== 导入 {tool_name} 的 cases ===")
  128. # 加载 cases.json
  129. data = load_json_file(cases_json_path)
  130. cases = data.get('cases', [])
  131. print(f"找到 {len(cases)} 个 cases")
  132. imported_count = 0
  133. case_id_mapping = {} # case_id -> knowledge_id
  134. for i, case in enumerate(cases):
  135. case_id = case.get('case_id', '')
  136. title = case.get('title', '')
  137. source = case.get('source', '')
  138. source_link = case.get('source_link', '')
  139. output_description = case.get('output_description', '')
  140. key_findings = case.get('key_findings', '')
  141. images = case.get('images', [])
  142. print(f"\n[{i + 1}/{len(cases)}] 处理 case: {title[:50]}...")
  143. print(f" - Case ID: {case_id}")
  144. # 构建知识内容
  145. content = f"""# {title}
  146. ## 来源
  147. {source}
  148. 链接: {source_link}
  149. ## 输出效果
  150. {output_description}
  151. ## 关键发现
  152. {key_findings}
  153. """
  154. # 生成 embedding (异步调用,带重试)
  155. print(f" - 生成 embedding...")
  156. embedding_text = f"{title} {output_description} {key_findings}"
  157. embedding = await get_embedding_with_retry(embedding_text)
  158. print(f" - Embedding 维度: {len(embedding)}")
  159. # 生成知识 ID
  160. timestamp = int(time.time())
  161. knowledge_id = f"knowledge-case-{tool_name}-{case_id}"
  162. print(f" - Knowledge ID: {knowledge_id}")
  163. # 构建知识记录
  164. knowledge = {
  165. 'id': knowledge_id,
  166. 'embedding': embedding,
  167. 'message_id': '',
  168. 'task': title, # 使用帖子标题作为 task
  169. 'content': content,
  170. 'types': ['case', 'tool_usage'],
  171. 'tags': {
  172. 'tool': tool_name,
  173. 'case_id': str(case_id),
  174. 'source': source,
  175. 'has_images': len(images) > 0
  176. },
  177. 'tag_keys': ['tool', 'case_id', 'source', 'has_images'],
  178. 'scopes': ['org:cybertogether'],
  179. 'owner': 'tool_research_agent',
  180. 'resource_ids': [f'tools/image_gen/{tool_name}'],
  181. 'source': {
  182. 'agent_id': 'tool_research_agent',
  183. 'category': 'case',
  184. 'timestamp': timestamp,
  185. 'url': source_link
  186. },
  187. 'eval': {'score': 5, 'helpful': 1, 'confidence': 0.9},
  188. 'created_at': timestamp,
  189. 'updated_at': timestamp,
  190. 'status': 'approved',
  191. 'relationships': []
  192. }
  193. try:
  194. print(f" - 插入数据库...")
  195. knowledge_store.insert(knowledge)
  196. imported_count += 1
  197. case_id_mapping[str(case_id)] = knowledge_id # 记录映射关系
  198. print(f" ✓ 成功")
  199. except Exception as e:
  200. print(f" ✗ 失败: {e}")
  201. print(f"成功导入 {imported_count}/{len(cases)} 个 cases\n")
  202. return case_id_mapping
  203. async def main():
  204. """主函数"""
  205. print("开始导入 tool_research 数据到数据库...")
  206. # 初始化数据库连接
  207. knowledge_store = PostgreSQLStore()
  208. req_store = PostgreSQLRequirementStore()
  209. # 定义数据路径
  210. base_path = Path(__file__).parent.parent.parent / 'examples' / 'tool_research' / 'outputs'
  211. tools_data = [
  212. {
  213. 'name': 'midjourney',
  214. 'match_result': base_path / 'nodes_output' / 'midjourney' / 'match_nodes_result.json',
  215. 'cases': base_path / 'midjourney_0' / '02_cases.json'
  216. },
  217. {
  218. 'name': 'seedream',
  219. 'match_result': base_path / 'nodes_output' / 'seedream' / 'match_nodes_result.json',
  220. 'cases': base_path / 'seedream_1' / '02_cases.json'
  221. }
  222. ]
  223. total_requirements = 0
  224. total_cases = 0
  225. for tool_data in tools_data:
  226. tool_name = tool_data['name']
  227. case_id_mapping = {}
  228. # 先导入 cases,获取 case_id 到 knowledge_id 的映射
  229. if tool_data['cases'].exists():
  230. case_id_mapping = await import_cases(
  231. knowledge_store,
  232. tool_name,
  233. str(tool_data['cases'])
  234. )
  235. total_cases += len(case_id_mapping)
  236. else:
  237. print(f"⚠ 文件不存在: {tool_data['cases']}")
  238. # 再导入需求,使用 case_id_mapping
  239. if tool_data['match_result'].exists():
  240. count = await import_requirements(
  241. req_store,
  242. tool_name,
  243. str(tool_data['match_result']),
  244. case_id_mapping
  245. )
  246. total_requirements += count
  247. else:
  248. print(f"⚠ 文件不存在: {tool_data['match_result']}")
  249. print("\n" + "="*50)
  250. print(f"导入完成!")
  251. print(f" - 需求: {total_requirements} 条")
  252. print(f" - Cases: {total_cases} 条")
  253. print("="*50)
  254. if __name__ == '__main__':
  255. asyncio.run(main())