#!/usr/bin/env python3 """ 修复脚本:根据知识的 types 标签,将 tool_knowledge 中的知识重新分类到正确的字段 映射规则: - types 包含 "plan" → process_knowledge(工序知识) - types 包含 "usecase" → case_knowledge(用例知识) - 其他 → 保留在 tool_knowledge """ import os import sys import json import psycopg2 from psycopg2.extras import RealDictCursor from dotenv import load_dotenv _script_dir = os.path.dirname(os.path.abspath(__file__)) _project_root = os.path.normpath(os.path.join(_script_dir, '..', '..')) load_dotenv(os.path.join(_project_root, '.env')) def get_connection(): host = os.getenv('KNOWHUB_DB') port = int(os.getenv('KNOWHUB_PORT', 5432)) user = os.getenv('KNOWHUB_USER') password = os.getenv('KNOWHUB_PASSWORD') dbname = os.getenv('KNOWHUB_DB_NAME') print(f"连接到 {host}:{port}/{dbname} as {user} ...") conn = psycopg2.connect( host=host, port=port, user=user, password=password, database=dbname, connect_timeout=10 ) conn.autocommit = False print("连接成功。\n") return conn def parse_jsonb(val): """安全解析 JSONB 字段""" if val is None: return [] if isinstance(val, list): return val if isinstance(val, str): try: return json.loads(val) except json.JSONDecodeError: return [] return [] def main(): dry_run = '--dry-run' in sys.argv if dry_run: print("=== DRY RUN 模式(不会实际修改数据)===\n") conn = get_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # 1. 获取所有有 tool_knowledge 的工具 cursor.execute(""" SELECT id, tool_knowledge, case_knowledge, process_knowledge FROM tool_table WHERE tool_knowledge IS NOT NULL AND tool_knowledge != '[]'::jsonb """) tools = cursor.fetchall() print(f"找到 {len(tools)} 个有关联知识的工具\n") if not tools: print("没有需要处理的工具,退出。") cursor.close() conn.close() return # 2. 收集所有涉及的知识 ID all_knowledge_ids = set() for tool in tools: all_knowledge_ids.update(parse_jsonb(tool['tool_knowledge'])) print(f"涉及 {len(all_knowledge_ids)} 条知识,正在查询类型...\n") # 3. 批量查询知识的 types knowledge_types = {} if all_knowledge_ids: id_list = list(all_knowledge_ids) # 分批查询,避免参数过多 batch_size = 100 for i in range(0, len(id_list), batch_size): batch = id_list[i:i + batch_size] placeholders = ','.join(['%s'] * len(batch)) cursor.execute(f""" SELECT id, types FROM knowledge WHERE id IN ({placeholders}) """, batch) for row in cursor.fetchall(): types_val = row['types'] if isinstance(types_val, str): try: types_val = json.loads(types_val) except json.JSONDecodeError: types_val = [] knowledge_types[row['id']] = types_val or [] print(f"成功查询到 {len(knowledge_types)} 条知识的类型信息\n") # 统计 plan_count = sum(1 for t in knowledge_types.values() if 'plan' in t) usecase_count = sum(1 for t in knowledge_types.values() if 'usecase' in t) print(f"类型分布:plan={plan_count}, usecase={usecase_count}, " f"其他={len(knowledge_types) - plan_count - usecase_count}\n") # 4. 重新分类每个工具的知识 updated_count = 0 for tool in tools: tool_id = tool['id'] old_tool_knowledge = parse_jsonb(tool['tool_knowledge']) old_case_knowledge = parse_jsonb(tool['case_knowledge']) old_process_knowledge = parse_jsonb(tool['process_knowledge']) new_tool_knowledge = [] new_case_knowledge = list(old_case_knowledge) # 保留已有的 new_process_knowledge = list(old_process_knowledge) # 保留已有的 for kid in old_tool_knowledge: types = knowledge_types.get(kid, []) if 'plan' in types: if kid not in new_process_knowledge: new_process_knowledge.append(kid) elif 'usecase' in types: if kid not in new_case_knowledge: new_case_knowledge.append(kid) else: new_tool_knowledge.append(kid) # 检查是否有变化 changed = ( set(new_tool_knowledge) != set(old_tool_knowledge) or set(new_case_knowledge) != set(old_case_knowledge) or set(new_process_knowledge) != set(old_process_knowledge) ) if changed: moved_to_case = [k for k in old_tool_knowledge if k in new_case_knowledge and k not in old_case_knowledge] moved_to_process = [k for k in old_tool_knowledge if k in new_process_knowledge and k not in old_process_knowledge] print(f"工具: {tool_id}") if moved_to_process: print(f" → process_knowledge (plan): +{len(moved_to_process)} 条") for kid in moved_to_process: print(f" {kid}") if moved_to_case: print(f" → case_knowledge (usecase): +{len(moved_to_case)} 条") for kid in moved_to_case: print(f" {kid}") print(f" tool_knowledge: {len(old_tool_knowledge)} → {len(new_tool_knowledge)}") print() if not dry_run: cursor.execute(""" UPDATE tool_table SET tool_knowledge = %s, case_knowledge = %s, process_knowledge = %s WHERE id = %s """, ( json.dumps(new_tool_knowledge), json.dumps(new_case_knowledge), json.dumps(new_process_knowledge), tool_id )) updated_count += 1 if not dry_run and updated_count > 0: conn.commit() print("=" * 60) print(f"处理完成:共 {len(tools)} 个工具,{updated_count} 个需要更新") if dry_run: print("(DRY RUN 模式,未实际修改数据)") print("确认无误后,去掉 --dry-run 参数重新运行即可。") else: print(f"已成功更新 {updated_count} 个工具的知识分类。") cursor.close() conn.close() if __name__ == '__main__': print("知识分类修复工具") print("用法:") print(" python reclassify_tool_knowledge.py --dry-run # 预览变更") print(" python reclassify_tool_knowledge.py # 执行变更") print() main()