#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 更新已导入的 case knowledge 记录 - 使用 qwen 生成更好的 task 描述 - 补充完整的 case 信息(input, output, operation_process, images) """ import json import os import sys import time import asyncio from pathlib import Path from typing import List, Dict, Any # 添加父目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # 加载环境变量 from dotenv import load_dotenv project_root = Path(__file__).parent.parent.parent env_path = project_root / '.env' load_dotenv(env_path) from knowhub.knowhub_db.pg_store import PostgreSQLStore from knowhub.embeddings import get_embedding from agent.llm.qwen import qwen_llm_call async def get_embedding_with_retry(text: str, max_retries: int = 3) -> List[float]: """带重试的 embedding 生成""" for attempt in range(max_retries): try: return await get_embedding(text) except Exception as e: if attempt < max_retries - 1: wait_time = (attempt + 1) * 2 print(f" ⚠ Embedding 生成失败,{wait_time}秒后重试... ({attempt + 1}/{max_retries})") await asyncio.sleep(wait_time) else: raise e async def generate_task_description(case_data: Dict) -> str: """使用 qwen 生成 task 描述""" title = case_data.get('title', '') user_input = case_data.get('user_input', {}) output_desc = case_data.get('output_description', '') key_findings = case_data.get('key_findings', '') prompt = f"""请为以下 AI 图像生成案例生成一个简洁的任务描述(task),用于知识库检索。 案例标题:{title} 用户输入:{json.dumps(user_input, ensure_ascii=False)} 输出效果:{output_desc} 关键发现:{key_findings} 要求: 1. 用一句话概括这个案例的核心任务或目标 2. 突出关键的技术特点或应用场景 3. 不超过 30 个字 4. 直接输出描述,不要其他内容 示例格式: - 使用 sref 风格代码生成行星级尺度的巨构建筑场景 - 测试不同 stylize 值对电影感角色渲染的影响 - 对比 V7 和 V8 在科幻哥特风格上的表现差异 """ response = await qwen_llm_call( messages=[{"role": "user", "content": prompt}], temperature=0.3, max_tokens=100 ) task_desc = response["content"].strip() return task_desc def load_json_file(file_path: str) -> Dict: """加载 JSON 文件""" with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) def build_enhanced_content(case_data: Dict) -> str: """构建增强的 content 内容""" title = case_data.get('title', '') source = case_data.get('source', '') source_link = case_data.get('source_link', '') user_input = case_data.get('user_input', {}) output_desc = case_data.get('output_description', '') key_findings = case_data.get('key_findings', '') images = case_data.get('images', []) operation_process = case_data.get('operation_process', '') # 构建用户输入部分 input_section = "## 用户输入\n" if user_input: for key, value in user_input.items(): if isinstance(value, list): input_section += f"- **{key}**: {', '.join(str(v) for v in value)}\n" else: input_section += f"- **{key}**: {value}\n" else: input_section += "(无详细输入信息)\n" # 构建操作流程部分 process_section = "" if operation_process: process_section = f"\n## 操作流程\n{operation_process}\n" # 构建图片部分 images_section = "" if images: images_section = "\n## 示例图片\n" for i, img_url in enumerate(images, 1): images_section += f"{i}. {img_url}\n" content = f"""# {title} ## 来源 {source} 链接: {source_link} {input_section} ## 输出效果 {output_desc} ## 关键发现 {key_findings} {process_section}{images_section} """ return content async def update_cases( knowledge_store: PostgreSQLStore, tool_name: str, cases_json_path: str ): """更新 case knowledge 记录""" print(f"\n=== 更新 {tool_name} 的 cases ===") # 加载 cases.json data = load_json_file(cases_json_path) cases = data.get('cases', []) print(f"找到 {len(cases)} 个 cases") updated_count = 0 for i, case in enumerate(cases): case_id = case.get('case_id', '') title = case.get('title', '') knowledge_id = f"knowledge-case-{tool_name}-{case_id}" print(f"\n[{i + 1}/{len(cases)}] 更新 case: {title[:50]}...") print(f" - Knowledge ID: {knowledge_id}") # 生成新的 task 描述 print(f" - 使用 qwen 生成 task 描述...") try: task_desc = await generate_task_description(case) print(f" - Task: {task_desc}") except Exception as e: print(f" ⚠ Task 生成失败,使用标题: {e}") task_desc = title # 构建增强的 content print(f" - 构建增强内容...") enhanced_content = build_enhanced_content(case) # 重新生成 embedding(只基于 task,与 save_knowledge 保持一致) print(f" - 重新生成 embedding(基于 task)...") try: new_embedding = await get_embedding_with_retry(task_desc) print(f" - Embedding 维度: {len(new_embedding)}") except Exception as e: print(f" ⚠ Embedding 生成失败: {e}") new_embedding = None # 更新数据库 try: print(f" - 更新数据库...") updates = { 'task': task_desc, 'content': enhanced_content, 'updated_at': int(time.time()) } if new_embedding: updates['embedding'] = new_embedding knowledge_store.update(knowledge_id, updates) updated_count += 1 print(f" ✓ 成功") except Exception as e: print(f" ✗ 失败: {e}") # 避免 API 限流 await asyncio.sleep(0.5) print(f"\n成功更新 {updated_count}/{len(cases)} 个 cases\n") return updated_count async def main(): """主函数""" print("=" * 60) print("开始更新 case knowledge 记录...") print("=" * 60) # 初始化 print("\n[1/4] 初始化数据库连接...") try: knowledge_store = PostgreSQLStore() print(" ✓ 数据库连接成功") except Exception as e: print(f" ✗ 数据库连接失败: {e}") return print("\n[2/4] 检查数据文件...") # 定义数据路径 base_path = Path(__file__).parent.parent.parent / 'examples' / 'tool_research' / 'outputs' tools_data = [ { 'name': 'midjourney', 'cases': base_path / 'midjourney_0' / '02_cases.json' }, { 'name': 'seedream', 'cases': base_path / 'seedream_1' / '02_cases.json' } ] for tool_data in tools_data: if tool_data['cases'].exists(): print(f" ✓ {tool_data['name']}: {tool_data['cases']}") else: print(f" ✗ {tool_data['name']}: 文件不存在") print("\n[3/4] 测试 qwen API 连接...") try: test_response = await qwen_llm_call( messages=[{"role": "user", "content": "测试"}], temperature=0.3, max_tokens=10 ) print(f" ✓ qwen API 连接成功") except Exception as e: print(f" ✗ qwen API 连接失败: {e}") print(" 提示:请检查 QWEN_API_KEY 环境变量") return print("\n[4/4] 开始处理 cases...") print("=" * 60) total_updated = 0 for tool_data in tools_data: tool_name = tool_data['name'] if tool_data['cases'].exists(): count = await update_cases( knowledge_store, tool_name, str(tool_data['cases']) ) total_updated += count else: print(f"⚠ 文件不存在: {tool_data['cases']}") print("\n" + "="*50) print(f"更新完成!") print(f" - 总计更新: {total_updated} 条") print("="*50) if __name__ == '__main__': asyncio.run(main())