| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 更新已导入的 case knowledge 记录
- - 使用 qwen 生成更好的 task 描述
- - 补充完整的 case 信息(input, output, operation_process, images)
- """
- import json
- import os
- import sys
- import time
- import asyncio
- from pathlib import Path
- from typing import List, Dict, Any
- # 添加父目录到路径
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- # 加载环境变量
- from dotenv import load_dotenv
- project_root = Path(__file__).parent.parent.parent
- env_path = project_root / '.env'
- load_dotenv(env_path)
- from knowhub.knowhub_db.pg_store import PostgreSQLStore
- from knowhub.embeddings import get_embedding
- from agent.llm.qwen import qwen_llm_call
- async def get_embedding_with_retry(text: str, max_retries: int = 3) -> List[float]:
- """带重试的 embedding 生成"""
- for attempt in range(max_retries):
- try:
- return await get_embedding(text)
- except Exception as e:
- if attempt < max_retries - 1:
- wait_time = (attempt + 1) * 2
- print(f" ⚠ Embedding 生成失败,{wait_time}秒后重试... ({attempt + 1}/{max_retries})")
- await asyncio.sleep(wait_time)
- else:
- raise e
- async def generate_task_description(case_data: Dict) -> str:
- """使用 qwen 生成 task 描述"""
- title = case_data.get('title', '')
- user_input = case_data.get('user_input', {})
- output_desc = case_data.get('output_description', '')
- key_findings = case_data.get('key_findings', '')
- prompt = f"""请为以下 AI 图像生成案例生成一个简洁的任务描述(task),用于知识库检索。
- 案例标题:{title}
- 用户输入:{json.dumps(user_input, ensure_ascii=False)}
- 输出效果:{output_desc}
- 关键发现:{key_findings}
- 要求:
- 1. 用一句话概括这个案例的核心任务或目标
- 2. 突出关键的技术特点或应用场景
- 3. 不超过 30 个字
- 4. 直接输出描述,不要其他内容
- 示例格式:
- - 使用 sref 风格代码生成行星级尺度的巨构建筑场景
- - 测试不同 stylize 值对电影感角色渲染的影响
- - 对比 V7 和 V8 在科幻哥特风格上的表现差异
- """
- response = await qwen_llm_call(
- messages=[{"role": "user", "content": prompt}],
- temperature=0.3,
- max_tokens=100
- )
- task_desc = response["content"].strip()
- return task_desc
- def load_json_file(file_path: str) -> Dict:
- """加载 JSON 文件"""
- with open(file_path, 'r', encoding='utf-8') as f:
- return json.load(f)
- def build_enhanced_content(case_data: Dict) -> str:
- """构建增强的 content 内容"""
- title = case_data.get('title', '')
- source = case_data.get('source', '')
- source_link = case_data.get('source_link', '')
- user_input = case_data.get('user_input', {})
- output_desc = case_data.get('output_description', '')
- key_findings = case_data.get('key_findings', '')
- images = case_data.get('images', [])
- operation_process = case_data.get('operation_process', '')
- # 构建用户输入部分
- input_section = "## 用户输入\n"
- if user_input:
- for key, value in user_input.items():
- if isinstance(value, list):
- input_section += f"- **{key}**: {', '.join(str(v) for v in value)}\n"
- else:
- input_section += f"- **{key}**: {value}\n"
- else:
- input_section += "(无详细输入信息)\n"
- # 构建操作流程部分
- process_section = ""
- if operation_process:
- process_section = f"\n## 操作流程\n{operation_process}\n"
- # 构建图片部分
- images_section = ""
- if images:
- images_section = "\n## 示例图片\n"
- for i, img_url in enumerate(images, 1):
- images_section += f"{i}. {img_url}\n"
- content = f"""# {title}
- ## 来源
- {source}
- 链接: {source_link}
- {input_section}
- ## 输出效果
- {output_desc}
- ## 关键发现
- {key_findings}
- {process_section}{images_section}
- """
- return content
- async def update_cases(
- knowledge_store: PostgreSQLStore,
- tool_name: str,
- cases_json_path: str
- ):
- """更新 case knowledge 记录"""
- print(f"\n=== 更新 {tool_name} 的 cases ===")
- # 加载 cases.json
- data = load_json_file(cases_json_path)
- cases = data.get('cases', [])
- print(f"找到 {len(cases)} 个 cases")
- updated_count = 0
- for i, case in enumerate(cases):
- case_id = case.get('case_id', '')
- title = case.get('title', '')
- knowledge_id = f"knowledge-case-{tool_name}-{case_id}"
- print(f"\n[{i + 1}/{len(cases)}] 更新 case: {title[:50]}...")
- print(f" - Knowledge ID: {knowledge_id}")
- # 生成新的 task 描述
- print(f" - 使用 qwen 生成 task 描述...")
- try:
- task_desc = await generate_task_description(case)
- print(f" - Task: {task_desc}")
- except Exception as e:
- print(f" ⚠ Task 生成失败,使用标题: {e}")
- task_desc = title
- # 构建增强的 content
- print(f" - 构建增强内容...")
- enhanced_content = build_enhanced_content(case)
- # 重新生成 embedding(只基于 task,与 save_knowledge 保持一致)
- print(f" - 重新生成 embedding(基于 task)...")
- try:
- new_embedding = await get_embedding_with_retry(task_desc)
- print(f" - Embedding 维度: {len(new_embedding)}")
- except Exception as e:
- print(f" ⚠ Embedding 生成失败: {e}")
- new_embedding = None
- # 更新数据库
- try:
- print(f" - 更新数据库...")
- updates = {
- 'task': task_desc,
- 'content': enhanced_content,
- 'updated_at': int(time.time())
- }
- if new_embedding:
- updates['embedding'] = new_embedding
- knowledge_store.update(knowledge_id, updates)
- updated_count += 1
- print(f" ✓ 成功")
- except Exception as e:
- print(f" ✗ 失败: {e}")
- # 避免 API 限流
- await asyncio.sleep(0.5)
- print(f"\n成功更新 {updated_count}/{len(cases)} 个 cases\n")
- return updated_count
- async def main():
- """主函数"""
- print("=" * 60)
- print("开始更新 case knowledge 记录...")
- print("=" * 60)
- # 初始化
- print("\n[1/4] 初始化数据库连接...")
- try:
- knowledge_store = PostgreSQLStore()
- print(" ✓ 数据库连接成功")
- except Exception as e:
- print(f" ✗ 数据库连接失败: {e}")
- return
- print("\n[2/4] 检查数据文件...")
- # 定义数据路径
- base_path = Path(__file__).parent.parent.parent / 'examples' / 'tool_research' / 'outputs'
- tools_data = [
- {
- 'name': 'midjourney',
- 'cases': base_path / 'midjourney_0' / '02_cases.json'
- },
- {
- 'name': 'seedream',
- 'cases': base_path / 'seedream_1' / '02_cases.json'
- }
- ]
- for tool_data in tools_data:
- if tool_data['cases'].exists():
- print(f" ✓ {tool_data['name']}: {tool_data['cases']}")
- else:
- print(f" ✗ {tool_data['name']}: 文件不存在")
- print("\n[3/4] 测试 qwen API 连接...")
- try:
- test_response = await qwen_llm_call(
- messages=[{"role": "user", "content": "测试"}],
- temperature=0.3,
- max_tokens=10
- )
- print(f" ✓ qwen API 连接成功")
- except Exception as e:
- print(f" ✗ qwen API 连接失败: {e}")
- print(" 提示:请检查 QWEN_API_KEY 环境变量")
- return
- print("\n[4/4] 开始处理 cases...")
- print("=" * 60)
- total_updated = 0
- for tool_data in tools_data:
- tool_name = tool_data['name']
- if tool_data['cases'].exists():
- count = await update_cases(
- knowledge_store,
- tool_name,
- str(tool_data['cases'])
- )
- total_updated += count
- else:
- print(f"⚠ 文件不存在: {tool_data['cases']}")
- print("\n" + "="*50)
- print(f"更新完成!")
- print(f" - 总计更新: {total_updated} 条")
- print("="*50)
- if __name__ == '__main__':
- asyncio.run(main())
|