#!/usr/bin/env python3 """ 补全 atomic_capability.tools 与 tool_table.capabilities 的双向映射: 1. 从 atomic_capability.implements(已有的工具名→实现描述字典)提取工具名 2. 在 tool_table 中模糊匹配,找到对应的 tool_id 3. 将 tool_id 列表写入 atomic_capability.tools 4. 反向构建映射,将 capability_id 列表写入 tool_table.capabilities 用法: python fill_cap_tools.py # 正常执行 python fill_cap_tools.py --dry-run # 仅预览,不写入数据库 """ import os import sys import json import re import io import psycopg2 from psycopg2.extras import RealDictCursor from dotenv import load_dotenv # 解决 Windows 终端编码问题 sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace') _dir = os.path.dirname(os.path.abspath(__file__)) _root = os.path.normpath(os.path.join(_dir, '..', '..')) load_dotenv(os.path.join(_root, '.env')) # ─── 数据库连接 ────────────────────────────────────────────────────────────── def get_conn(): conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME'), connect_timeout=10 ) conn.autocommit = True return conn # ─── 工具名模糊匹配 ────────────────────────────────────────────────────────── # 已知的工具名别名映射(implements 中的名字 -> 可能出现在数据库中的名字前缀) TOOL_NAME_ALIASES = { 'ComfyUI': ['ComfyUI', 'comfyui'], 'FLUX.2 [max]': ['FLUX', 'flux', 'FLUX.2', 'Flux.2'], 'Midjourney v8': ['Midjourney', 'midjourney', 'MJ'], 'Nano Banana Pro': ['Nano Banana', 'nano banana', 'Gemini 3 Pro Image', 'Nano'], 'Seedream 5.0 Lite': ['Seedream', 'seedream'], } def normalize(s): """将字符串转为小写,去除特殊符号,用于模糊比较""" return re.sub(r'[^a-z0-9]', '', s.lower()) def match_tool(impl_tool_name, db_tools): """ 将 implements 中的工具名匹配到 tool_table 中的记录。 匹配策略(按优先级): 1. 精确匹配 tool.name 2. tool.name 包含 impl_tool_name(或反之) 3. 归一化后的子串匹配 4. 通过别名表匹配 返回匹配到的 tool_id,或 None """ # 1. 精确匹配 for tool in db_tools: if tool['name'] == impl_tool_name: return tool['id'] # 2. 包含匹配 for tool in db_tools: if impl_tool_name in tool['name'] or tool['name'] in impl_tool_name: return tool['id'] # 3. 归一化子串匹配 norm_impl = normalize(impl_tool_name) for tool in db_tools: norm_db = normalize(tool['name']) if norm_impl in norm_db or norm_db in norm_impl: return tool['id'] # 4. 别名匹配 for canonical, aliases in TOOL_NAME_ALIASES.items(): if impl_tool_name == canonical or impl_tool_name in aliases: # 找到别名组,用所有别名去匹配数据库 for alias in [canonical] + aliases: norm_alias = normalize(alias) for tool in db_tools: norm_db = normalize(tool['name']) if norm_alias in norm_db or norm_db in norm_alias: return tool['id'] return None # ─── 主逻辑 ────────────────────────────────────────────────────────────────── def main(): dry_run = '--dry-run' in sys.argv conn = get_conn() cur = conn.cursor(cursor_factory=RealDictCursor) print("Connected.\n") # ── Step 1: 加载 tool_table 全量数据 ── print("=== [1] Loading tool_table ===") cur.execute("SELECT id, name FROM tool_table ORDER BY id") db_tools = cur.fetchall() print(f" Found {len(db_tools)} tools:") for t in db_tools: print(f" {t['id']}: {t['name']}") # ── Step 2: 加载 atomic_capability 及其 implements ── print("\n=== [2] Loading atomic_capability.implements ===") cur.execute("SELECT id, name, implements FROM atomic_capability ORDER BY id") caps = cur.fetchall() print(f" Found {len(caps)} capabilities") # ── Step 3: 逐个 capability 匹配工具 ── print("\n=== [3] Matching capability -> tools ===") # cap_id -> [tool_ids] cap_to_tools = {} # tool_id -> [cap_ids] (反向映射) tool_to_caps = {} # 未匹配的工具名 unmatched = [] for cap in caps: cap_id = cap['id'] implements = cap['implements'] # implements 可能是 str 或 dict if isinstance(implements, str): try: implements = json.loads(implements) except json.JSONDecodeError: implements = {} if not implements: implements = {} matched_tool_ids = [] for impl_tool_name in implements.keys(): tool_id = match_tool(impl_tool_name, db_tools) if tool_id: matched_tool_ids.append(tool_id) # 反向映射 if tool_id not in tool_to_caps: tool_to_caps[tool_id] = [] if cap_id not in tool_to_caps[tool_id]: tool_to_caps[tool_id].append(cap_id) else: unmatched.append((cap_id, impl_tool_name)) cap_to_tools[cap_id] = matched_tool_ids print(f" {cap_id} ({cap['name']}): {list(implements.keys())} -> {matched_tool_ids}") if unmatched: print(f"\n [!] {len(unmatched)} unmatched tool names:") for cap_id, name in unmatched: print(f" {cap_id}: \"{name}\"") # ── Step 4: 写入 atomic_capability.tools ── print(f"\n=== [4] Updating atomic_capability.tools {'(DRY RUN)' if dry_run else ''} ===") cap_updated = 0 for cap_id, tool_ids in cap_to_tools.items(): print(f" {cap_id}: tools = {tool_ids}") if not dry_run: cur.execute( "UPDATE atomic_capability SET tools = %s WHERE id = %s", (json.dumps(tool_ids), cap_id) ) cap_updated += 1 print(f" -> {cap_updated} capabilities updated") # ── Step 5: 写入 tool_table.capabilities ── print(f"\n=== [5] Updating tool_table.capabilities {'(DRY RUN)' if dry_run else ''} ===") tool_updated = 0 for tool_id, cap_ids in sorted(tool_to_caps.items()): cap_ids_sorted = sorted(cap_ids) print(f" {tool_id}: capabilities = {cap_ids_sorted}") if not dry_run: cur.execute( "UPDATE tool_table SET capabilities = %s WHERE id = %s", (json.dumps(cap_ids_sorted), tool_id) ) tool_updated += 1 print(f" -> {tool_updated} tools updated") # ── Step 6: 验证 ── if not dry_run: print("\n=== [6] Verification ===") print("\n -- atomic_capability.tools (sample) --") cur.execute(""" SELECT id, name, tools FROM atomic_capability ORDER BY id LIMIT 5 """) for r in cur.fetchall(): tools = r['tools'] if isinstance(r['tools'], list) else json.loads(r['tools'] or '[]') print(f" {r['id']}: {r['name']} -> tools={tools}") print("\n -- tool_table.capabilities (all with mappings) --") cur.execute(""" SELECT id, name, capabilities FROM tool_table WHERE capabilities IS NOT NULL AND capabilities != '[]'::jsonb ORDER BY id """) for r in cur.fetchall(): caps_list = r['capabilities'] if isinstance(r['capabilities'], list) else json.loads(r['capabilities'] or '[]') print(f" {r['id']}: {r['name']} -> caps={caps_list}") # ── 统计 ── print(f"\n=== Summary ===") print(f" Capabilities with tools: {sum(1 for v in cap_to_tools.values() if v)}/{len(cap_to_tools)}") print(f" Tools with capabilities: {len(tool_to_caps)}/{len(db_tools)}") print(f" Unmatched tool names: {len(unmatched)}") if dry_run: print(f"\n (DRY RUN mode - no changes written to database)") cur.close() conn.close() print("\nDone.") if __name__ == '__main__': main()