| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- #!/usr/bin/env python3
- """
- 补全 atomic_capability.tools 与 tool_table.capabilities 的双向映射:
- 1. 从 atomic_capability.implements(已有的工具名→实现描述字典)提取工具名
- 2. 在 tool_table 中模糊匹配,找到对应的 tool_id
- 3. 将 tool_id 列表写入 atomic_capability.tools
- 4. 反向构建映射,将 capability_id 列表写入 tool_table.capabilities
- 用法:
- python fill_cap_tools.py # 正常执行
- python fill_cap_tools.py --dry-run # 仅预览,不写入数据库
- """
- import os
- import sys
- import json
- import re
- import io
- import psycopg2
- from psycopg2.extras import RealDictCursor
- from dotenv import load_dotenv
- # 解决 Windows 终端编码问题
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
- sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
- _dir = os.path.dirname(os.path.abspath(__file__))
- _root = os.path.normpath(os.path.join(_dir, '..', '..'))
- load_dotenv(os.path.join(_root, '.env'))
- # ─── 数据库连接 ──────────────────────────────────────────────────────────────
- def get_conn():
- conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME'),
- connect_timeout=10
- )
- conn.autocommit = True
- return conn
- # ─── 工具名模糊匹配 ──────────────────────────────────────────────────────────
- # 已知的工具名别名映射(implements 中的名字 -> 可能出现在数据库中的名字前缀)
- TOOL_NAME_ALIASES = {
- 'ComfyUI': ['ComfyUI', 'comfyui'],
- 'FLUX.2 [max]': ['FLUX', 'flux', 'FLUX.2', 'Flux.2'],
- 'Midjourney v8': ['Midjourney', 'midjourney', 'MJ'],
- 'Nano Banana Pro': ['Nano Banana', 'nano banana', 'Gemini 3 Pro Image', 'Nano'],
- 'Seedream 5.0 Lite': ['Seedream', 'seedream'],
- }
- def normalize(s):
- """将字符串转为小写,去除特殊符号,用于模糊比较"""
- return re.sub(r'[^a-z0-9]', '', s.lower())
- def match_tool(impl_tool_name, db_tools):
- """
- 将 implements 中的工具名匹配到 tool_table 中的记录。
- 匹配策略(按优先级):
- 1. 精确匹配 tool.name
- 2. tool.name 包含 impl_tool_name(或反之)
- 3. 归一化后的子串匹配
- 4. 通过别名表匹配
- 返回匹配到的 tool_id,或 None
- """
- # 1. 精确匹配
- for tool in db_tools:
- if tool['name'] == impl_tool_name:
- return tool['id']
- # 2. 包含匹配
- for tool in db_tools:
- if impl_tool_name in tool['name'] or tool['name'] in impl_tool_name:
- return tool['id']
- # 3. 归一化子串匹配
- norm_impl = normalize(impl_tool_name)
- for tool in db_tools:
- norm_db = normalize(tool['name'])
- if norm_impl in norm_db or norm_db in norm_impl:
- return tool['id']
- # 4. 别名匹配
- for canonical, aliases in TOOL_NAME_ALIASES.items():
- if impl_tool_name == canonical or impl_tool_name in aliases:
- # 找到别名组,用所有别名去匹配数据库
- for alias in [canonical] + aliases:
- norm_alias = normalize(alias)
- for tool in db_tools:
- norm_db = normalize(tool['name'])
- if norm_alias in norm_db or norm_db in norm_alias:
- return tool['id']
- return None
- # ─── 主逻辑 ──────────────────────────────────────────────────────────────────
- def main():
- dry_run = '--dry-run' in sys.argv
- conn = get_conn()
- cur = conn.cursor(cursor_factory=RealDictCursor)
- print("Connected.\n")
- # ── Step 1: 加载 tool_table 全量数据 ──
- print("=== [1] Loading tool_table ===")
- cur.execute("SELECT id, name FROM tool_table ORDER BY id")
- db_tools = cur.fetchall()
- print(f" Found {len(db_tools)} tools:")
- for t in db_tools:
- print(f" {t['id']}: {t['name']}")
- # ── Step 2: 加载 atomic_capability 及其 implements ──
- print("\n=== [2] Loading atomic_capability.implements ===")
- cur.execute("SELECT id, name, implements FROM atomic_capability ORDER BY id")
- caps = cur.fetchall()
- print(f" Found {len(caps)} capabilities")
- # ── Step 3: 逐个 capability 匹配工具 ──
- print("\n=== [3] Matching capability -> tools ===")
- # cap_id -> [tool_ids]
- cap_to_tools = {}
- # tool_id -> [cap_ids] (反向映射)
- tool_to_caps = {}
- # 未匹配的工具名
- unmatched = []
- for cap in caps:
- cap_id = cap['id']
- implements = cap['implements']
- # implements 可能是 str 或 dict
- if isinstance(implements, str):
- try:
- implements = json.loads(implements)
- except json.JSONDecodeError:
- implements = {}
- if not implements:
- implements = {}
- matched_tool_ids = []
- for impl_tool_name in implements.keys():
- tool_id = match_tool(impl_tool_name, db_tools)
- if tool_id:
- matched_tool_ids.append(tool_id)
- # 反向映射
- if tool_id not in tool_to_caps:
- tool_to_caps[tool_id] = []
- if cap_id not in tool_to_caps[tool_id]:
- tool_to_caps[tool_id].append(cap_id)
- else:
- unmatched.append((cap_id, impl_tool_name))
- cap_to_tools[cap_id] = matched_tool_ids
- print(f" {cap_id} ({cap['name']}): {list(implements.keys())} -> {matched_tool_ids}")
- if unmatched:
- print(f"\n [!] {len(unmatched)} unmatched tool names:")
- for cap_id, name in unmatched:
- print(f" {cap_id}: \"{name}\"")
- # ── Step 4: 写入 atomic_capability.tools ──
- print(f"\n=== [4] Updating atomic_capability.tools {'(DRY RUN)' if dry_run else ''} ===")
- cap_updated = 0
- for cap_id, tool_ids in cap_to_tools.items():
- print(f" {cap_id}: tools = {tool_ids}")
- if not dry_run:
- cur.execute(
- "UPDATE atomic_capability SET tools = %s WHERE id = %s",
- (json.dumps(tool_ids), cap_id)
- )
- cap_updated += 1
- print(f" -> {cap_updated} capabilities updated")
- # ── Step 5: 写入 tool_table.capabilities ──
- print(f"\n=== [5] Updating tool_table.capabilities {'(DRY RUN)' if dry_run else ''} ===")
- tool_updated = 0
- for tool_id, cap_ids in sorted(tool_to_caps.items()):
- cap_ids_sorted = sorted(cap_ids)
- print(f" {tool_id}: capabilities = {cap_ids_sorted}")
- if not dry_run:
- cur.execute(
- "UPDATE tool_table SET capabilities = %s WHERE id = %s",
- (json.dumps(cap_ids_sorted), tool_id)
- )
- tool_updated += 1
- print(f" -> {tool_updated} tools updated")
- # ── Step 6: 验证 ──
- if not dry_run:
- print("\n=== [6] Verification ===")
- print("\n -- atomic_capability.tools (sample) --")
- cur.execute("""
- SELECT id, name, tools
- FROM atomic_capability
- ORDER BY id LIMIT 5
- """)
- for r in cur.fetchall():
- tools = r['tools'] if isinstance(r['tools'], list) else json.loads(r['tools'] or '[]')
- print(f" {r['id']}: {r['name']} -> tools={tools}")
- print("\n -- tool_table.capabilities (all with mappings) --")
- cur.execute("""
- SELECT id, name, capabilities
- FROM tool_table
- WHERE capabilities IS NOT NULL AND capabilities != '[]'::jsonb
- ORDER BY id
- """)
- for r in cur.fetchall():
- caps_list = r['capabilities'] if isinstance(r['capabilities'], list) else json.loads(r['capabilities'] or '[]')
- print(f" {r['id']}: {r['name']} -> caps={caps_list}")
- # ── 统计 ──
- print(f"\n=== Summary ===")
- print(f" Capabilities with tools: {sum(1 for v in cap_to_tools.values() if v)}/{len(cap_to_tools)}")
- print(f" Tools with capabilities: {len(tool_to_caps)}/{len(db_tools)}")
- print(f" Unmatched tool names: {len(unmatched)}")
- if dry_run:
- print(f"\n (DRY RUN mode - no changes written to database)")
- cur.close()
- conn.close()
- print("\nDone.")
- if __name__ == '__main__':
- main()
|