| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 从 tool_research 输出导入需求和 case 到数据库
- - 需求存入 requirement_table
- - case 存入 knowledge 表
- """
- import json
- import os
- import sys
- import time
- import asyncio
- from pathlib import Path
- from typing import List, Dict, Any
- # 设置 Windows 控制台编码
- if sys.platform == 'win32':
- import codecs
- sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
- sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')
- # 添加父目录到路径以导入模块
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- # 加载环境变量(从项目根目录)
- from dotenv import load_dotenv
- project_root = Path(__file__).parent.parent.parent
- env_path = project_root / '.env'
- load_dotenv(env_path)
- print(f"加载环境变量: {env_path}")
- from knowhub.knowhub_db.pg_store import PostgreSQLStore
- from knowhub.knowhub_db.pg_requirement_store import PostgreSQLRequirementStore
- from knowhub.embeddings import get_embedding
- async def get_embedding_with_retry(text: str, max_retries: int = 3) -> List[float]:
- """带重试的 embedding 生成"""
- for attempt in range(max_retries):
- try:
- return await get_embedding(text)
- except Exception as e:
- if attempt < max_retries - 1:
- wait_time = (attempt + 1) * 2 # 2, 4, 6 秒
- print(f" ⚠ Embedding 生成失败,{wait_time}秒后重试... ({attempt + 1}/{max_retries})")
- await asyncio.sleep(wait_time)
- else:
- raise e
- def load_json_file(file_path: str) -> Dict:
- """加载 JSON 文件"""
- with open(file_path, 'r', encoding='utf-8') as f:
- return json.load(f)
- async def import_requirements(
- req_store: PostgreSQLRequirementStore,
- tool_name: str,
- match_result_path: str,
- case_id_to_knowledge_id: Dict[str, str]
- ):
- """导入需求到 requirement_table
- Args:
- case_id_to_knowledge_id: case_id 到 knowledge_id 的映射
- """
- print(f"\n=== 导入 {tool_name} 的需求 ===")
- # 加载 match_nodes_result.json
- data = load_json_file(match_result_path)
- merged_demands = data.get('merged_demands', [])
- print(f"找到 {len(merged_demands)} 个合并后的需求")
- imported_count = 0
- for demand in merged_demands:
- demand_name = demand.get('demand_name', '')
- description = demand.get('description', '')
- source_case_ids = demand.get('source_case_ids', [])
- mount_decision = demand.get('mount_decision', {})
- print(f"\n[{imported_count + 1}/{len(merged_demands)}] 处理需求: {demand_name[:50]}...")
- # 构建需求 ID
- req_id = f"req_{tool_name}_{demand_name[:20].replace(' ', '_')}"
- print(f" - 需求 ID: {req_id}")
- # 生成 embedding (异步调用,带重试)
- print(f" - 生成 embedding...")
- embedding_text = f"{demand_name} {description}"
- embedding = await get_embedding_with_retry(embedding_text)
- print(f" - Embedding 维度: {len(embedding)}")
- # 提取挂载的节点信息
- mounted_nodes = mount_decision.get('mounted_nodes', [])
- source_items = [
- {
- 'entity_id': node['entity_id'],
- 'name': node['name'],
- 'source_type': node['source_type']
- }
- for node in mounted_nodes
- ]
- # 将 case_id 转换为 knowledge_id
- case_knowledge_ids = [
- case_id_to_knowledge_id.get(str(case_id), str(case_id))
- for case_id in source_case_ids
- ]
- print(f" - 关联 {len(case_knowledge_ids)} 个 case")
- # 构建需求记录
- requirement = {
- 'id': req_id,
- 'task': demand_name,
- 'type': '制作',
- 'source_type': 'itemset',
- 'source_itemset_id': f"{tool_name}_demands",
- 'source_items': source_items,
- 'tools': [{'id': f'tools/image_gen/{tool_name}', 'name': tool_name}],
- 'knowledge': [],
- 'case_knowledge': case_knowledge_ids, # 使用 knowledge_id
- 'process_knowledge': [],
- 'trace': {},
- 'body': description,
- 'embedding': embedding
- }
- try:
- print(f" - 插入数据库...")
- req_store.insert_or_update(requirement)
- imported_count += 1
- print(f" ✓ 成功")
- except Exception as e:
- print(f" ✗ 失败: {e}")
- print(f"成功导入 {imported_count}/{len(merged_demands)} 个需求\n")
- return imported_count
- async def import_cases(
- knowledge_store: PostgreSQLStore,
- tool_name: str,
- cases_json_path: str
- ) -> Dict[str, str]:
- """导入 case 到 knowledge 表
- Returns:
- case_id 到 knowledge_id 的映射字典
- """
- print(f"\n=== 导入 {tool_name} 的 cases ===")
- # 加载 cases.json
- data = load_json_file(cases_json_path)
- cases = data.get('cases', [])
- print(f"找到 {len(cases)} 个 cases")
- imported_count = 0
- case_id_mapping = {} # case_id -> knowledge_id
- for i, case in enumerate(cases):
- case_id = case.get('case_id', '')
- title = case.get('title', '')
- source = case.get('source', '')
- source_link = case.get('source_link', '')
- output_description = case.get('output_description', '')
- key_findings = case.get('key_findings', '')
- images = case.get('images', [])
- print(f"\n[{i + 1}/{len(cases)}] 处理 case: {title[:50]}...")
- print(f" - Case ID: {case_id}")
- # 构建知识内容
- content = f"""# {title}
- ## 来源
- {source}
- 链接: {source_link}
- ## 输出效果
- {output_description}
- ## 关键发现
- {key_findings}
- """
- # 生成 embedding (异步调用,带重试)
- print(f" - 生成 embedding...")
- embedding_text = f"{title} {output_description} {key_findings}"
- embedding = await get_embedding_with_retry(embedding_text)
- print(f" - Embedding 维度: {len(embedding)}")
- # 生成知识 ID
- timestamp = int(time.time())
- knowledge_id = f"knowledge-case-{tool_name}-{case_id}"
- print(f" - Knowledge ID: {knowledge_id}")
- # 构建知识记录
- knowledge = {
- 'id': knowledge_id,
- 'embedding': embedding,
- 'message_id': '',
- 'task': title, # 使用帖子标题作为 task
- 'content': content,
- 'types': ['case', 'tool_usage'],
- 'tags': {
- 'tool': tool_name,
- 'case_id': str(case_id),
- 'source': source,
- 'has_images': len(images) > 0
- },
- 'tag_keys': ['tool', 'case_id', 'source', 'has_images'],
- 'scopes': ['org:cybertogether'],
- 'owner': 'tool_research_agent',
- 'resource_ids': [f'tools/image_gen/{tool_name}'],
- 'source': {
- 'agent_id': 'tool_research_agent',
- 'category': 'case',
- 'timestamp': timestamp,
- 'url': source_link
- },
- 'eval': {'score': 5, 'helpful': 1, 'confidence': 0.9},
- 'created_at': timestamp,
- 'updated_at': timestamp,
- 'status': 'approved',
- 'relationships': []
- }
- try:
- print(f" - 插入数据库...")
- knowledge_store.insert(knowledge)
- imported_count += 1
- case_id_mapping[str(case_id)] = knowledge_id # 记录映射关系
- print(f" ✓ 成功")
- except Exception as e:
- print(f" ✗ 失败: {e}")
- print(f"成功导入 {imported_count}/{len(cases)} 个 cases\n")
- return case_id_mapping
- async def main():
- """主函数"""
- print("开始导入 tool_research 数据到数据库...")
- # 初始化数据库连接
- knowledge_store = PostgreSQLStore()
- req_store = PostgreSQLRequirementStore()
- # 定义数据路径
- base_path = Path(__file__).parent.parent.parent / 'examples' / 'tool_research' / 'outputs'
- tools_data = [
- {
- 'name': 'midjourney',
- 'match_result': base_path / 'nodes_output' / 'midjourney' / 'match_nodes_result.json',
- 'cases': base_path / 'midjourney_0' / '02_cases.json'
- },
- {
- 'name': 'seedream',
- 'match_result': base_path / 'nodes_output' / 'seedream' / 'match_nodes_result.json',
- 'cases': base_path / 'seedream_1' / '02_cases.json'
- }
- ]
- total_requirements = 0
- total_cases = 0
- for tool_data in tools_data:
- tool_name = tool_data['name']
- case_id_mapping = {}
- # 先导入 cases,获取 case_id 到 knowledge_id 的映射
- if tool_data['cases'].exists():
- case_id_mapping = await import_cases(
- knowledge_store,
- tool_name,
- str(tool_data['cases'])
- )
- total_cases += len(case_id_mapping)
- else:
- print(f"⚠ 文件不存在: {tool_data['cases']}")
- # 再导入需求,使用 case_id_mapping
- if tool_data['match_result'].exists():
- count = await import_requirements(
- req_store,
- tool_name,
- str(tool_data['match_result']),
- case_id_mapping
- )
- total_requirements += count
- else:
- print(f"⚠ 文件不存在: {tool_data['match_result']}")
- print("\n" + "="*50)
- print(f"导入完成!")
- print(f" - 需求: {total_requirements} 条")
- print(f" - Cases: {total_cases} 条")
- print("="*50)
- if __name__ == '__main__':
- asyncio.run(main())
|