#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 从 tool_research 输出导入需求和 case 到数据库 - 需求存入 requirement_table - case 存入 knowledge 表 """ import json import os import sys import time import asyncio from pathlib import Path from typing import List, Dict, Any # 设置 Windows 控制台编码 if sys.platform == 'win32': import codecs sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict') sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict') # 添加父目录到路径以导入模块 sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # 加载环境变量(从项目根目录) from dotenv import load_dotenv project_root = Path(__file__).parent.parent.parent env_path = project_root / '.env' load_dotenv(env_path) print(f"加载环境变量: {env_path}") from knowhub.knowhub_db.pg_store import PostgreSQLStore from knowhub.knowhub_db.pg_requirement_store import PostgreSQLRequirementStore from knowhub.embeddings import get_embedding async def get_embedding_with_retry(text: str, max_retries: int = 3) -> List[float]: """带重试的 embedding 生成""" for attempt in range(max_retries): try: return await get_embedding(text) except Exception as e: if attempt < max_retries - 1: wait_time = (attempt + 1) * 2 # 2, 4, 6 秒 print(f" ⚠ Embedding 生成失败,{wait_time}秒后重试... ({attempt + 1}/{max_retries})") await asyncio.sleep(wait_time) else: raise e def load_json_file(file_path: str) -> Dict: """加载 JSON 文件""" with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) async def import_requirements( req_store: PostgreSQLRequirementStore, tool_name: str, match_result_path: str, case_id_to_knowledge_id: Dict[str, str] ): """导入需求到 requirement_table Args: case_id_to_knowledge_id: case_id 到 knowledge_id 的映射 """ print(f"\n=== 导入 {tool_name} 的需求 ===") # 加载 match_nodes_result.json data = load_json_file(match_result_path) merged_demands = data.get('merged_demands', []) print(f"找到 {len(merged_demands)} 个合并后的需求") imported_count = 0 for demand in merged_demands: demand_name = demand.get('demand_name', '') description = demand.get('description', '') source_case_ids = demand.get('source_case_ids', []) mount_decision = demand.get('mount_decision', {}) print(f"\n[{imported_count + 1}/{len(merged_demands)}] 处理需求: {demand_name[:50]}...") # 构建需求 ID req_id = f"req_{tool_name}_{demand_name[:20].replace(' ', '_')}" print(f" - 需求 ID: {req_id}") # 生成 embedding (异步调用,带重试) print(f" - 生成 embedding...") embedding_text = f"{demand_name} {description}" embedding = await get_embedding_with_retry(embedding_text) print(f" - Embedding 维度: {len(embedding)}") # 提取挂载的节点信息 mounted_nodes = mount_decision.get('mounted_nodes', []) source_items = [ { 'entity_id': node['entity_id'], 'name': node['name'], 'source_type': node['source_type'] } for node in mounted_nodes ] # 将 case_id 转换为 knowledge_id case_knowledge_ids = [ case_id_to_knowledge_id.get(str(case_id), str(case_id)) for case_id in source_case_ids ] print(f" - 关联 {len(case_knowledge_ids)} 个 case") # 构建需求记录 requirement = { 'id': req_id, 'task': demand_name, 'type': '制作', 'source_type': 'itemset', 'source_itemset_id': f"{tool_name}_demands", 'source_items': source_items, 'tools': [{'id': f'tools/image_gen/{tool_name}', 'name': tool_name}], 'knowledge': [], 'case_knowledge': case_knowledge_ids, # 使用 knowledge_id 'process_knowledge': [], 'trace': {}, 'body': description, 'embedding': embedding } try: print(f" - 插入数据库...") req_store.insert_or_update(requirement) imported_count += 1 print(f" ✓ 成功") except Exception as e: print(f" ✗ 失败: {e}") print(f"成功导入 {imported_count}/{len(merged_demands)} 个需求\n") return imported_count async def import_cases( knowledge_store: PostgreSQLStore, tool_name: str, cases_json_path: str ) -> Dict[str, str]: """导入 case 到 knowledge 表 Returns: case_id 到 knowledge_id 的映射字典 """ print(f"\n=== 导入 {tool_name} 的 cases ===") # 加载 cases.json data = load_json_file(cases_json_path) cases = data.get('cases', []) print(f"找到 {len(cases)} 个 cases") imported_count = 0 case_id_mapping = {} # case_id -> knowledge_id for i, case in enumerate(cases): case_id = case.get('case_id', '') title = case.get('title', '') source = case.get('source', '') source_link = case.get('source_link', '') output_description = case.get('output_description', '') key_findings = case.get('key_findings', '') images = case.get('images', []) print(f"\n[{i + 1}/{len(cases)}] 处理 case: {title[:50]}...") print(f" - Case ID: {case_id}") # 构建知识内容 content = f"""# {title} ## 来源 {source} 链接: {source_link} ## 输出效果 {output_description} ## 关键发现 {key_findings} """ # 生成 embedding (异步调用,带重试) print(f" - 生成 embedding...") embedding_text = f"{title} {output_description} {key_findings}" embedding = await get_embedding_with_retry(embedding_text) print(f" - Embedding 维度: {len(embedding)}") # 生成知识 ID timestamp = int(time.time()) knowledge_id = f"knowledge-case-{tool_name}-{case_id}" print(f" - Knowledge ID: {knowledge_id}") # 构建知识记录 knowledge = { 'id': knowledge_id, 'embedding': embedding, 'message_id': '', 'task': title, # 使用帖子标题作为 task 'content': content, 'types': ['case', 'tool_usage'], 'tags': { 'tool': tool_name, 'case_id': str(case_id), 'source': source, 'has_images': len(images) > 0 }, 'tag_keys': ['tool', 'case_id', 'source', 'has_images'], 'scopes': ['org:cybertogether'], 'owner': 'tool_research_agent', 'resource_ids': [f'tools/image_gen/{tool_name}'], 'source': { 'agent_id': 'tool_research_agent', 'category': 'case', 'timestamp': timestamp, 'url': source_link }, 'eval': {'score': 5, 'helpful': 1, 'confidence': 0.9}, 'created_at': timestamp, 'updated_at': timestamp, 'status': 'approved', 'relationships': [] } try: print(f" - 插入数据库...") knowledge_store.insert(knowledge) imported_count += 1 case_id_mapping[str(case_id)] = knowledge_id # 记录映射关系 print(f" ✓ 成功") except Exception as e: print(f" ✗ 失败: {e}") print(f"成功导入 {imported_count}/{len(cases)} 个 cases\n") return case_id_mapping async def main(): """主函数""" print("开始导入 tool_research 数据到数据库...") # 初始化数据库连接 knowledge_store = PostgreSQLStore() req_store = PostgreSQLRequirementStore() # 定义数据路径 base_path = Path(__file__).parent.parent.parent / 'examples' / 'tool_research' / 'outputs' tools_data = [ { 'name': 'midjourney', 'match_result': base_path / 'nodes_output' / 'midjourney' / 'match_nodes_result.json', 'cases': base_path / 'midjourney_0' / '02_cases.json' }, { 'name': 'seedream', 'match_result': base_path / 'nodes_output' / 'seedream' / 'match_nodes_result.json', 'cases': base_path / 'seedream_1' / '02_cases.json' } ] total_requirements = 0 total_cases = 0 for tool_data in tools_data: tool_name = tool_data['name'] case_id_mapping = {} # 先导入 cases,获取 case_id 到 knowledge_id 的映射 if tool_data['cases'].exists(): case_id_mapping = await import_cases( knowledge_store, tool_name, str(tool_data['cases']) ) total_cases += len(case_id_mapping) else: print(f"⚠ 文件不存在: {tool_data['cases']}") # 再导入需求,使用 case_id_mapping if tool_data['match_result'].exists(): count = await import_requirements( req_store, tool_name, str(tool_data['match_result']), case_id_mapping ) total_requirements += count else: print(f"⚠ 文件不存在: {tool_data['match_result']}") print("\n" + "="*50) print(f"导入完成!") print(f" - 需求: {total_requirements} 条") print(f" - Cases: {total_cases} 条") print("="*50) if __name__ == '__main__': asyncio.run(main())