fill_cap_tools.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #!/usr/bin/env python3
  2. """
  3. 补全 atomic_capability.tools 与 tool_table.capabilities 的双向映射:
  4. 1. 从 atomic_capability.implements(已有的工具名→实现描述字典)提取工具名
  5. 2. 在 tool_table 中模糊匹配,找到对应的 tool_id
  6. 3. 将 tool_id 列表写入 atomic_capability.tools
  7. 4. 反向构建映射,将 capability_id 列表写入 tool_table.capabilities
  8. 用法:
  9. python fill_cap_tools.py # 正常执行
  10. python fill_cap_tools.py --dry-run # 仅预览,不写入数据库
  11. """
  12. import os
  13. import sys
  14. import json
  15. import re
  16. import io
  17. import psycopg2
  18. from psycopg2.extras import RealDictCursor
  19. from dotenv import load_dotenv
  20. # 解决 Windows 终端编码问题
  21. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
  22. sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
  23. _dir = os.path.dirname(os.path.abspath(__file__))
  24. _root = os.path.normpath(os.path.join(_dir, '..', '..'))
  25. load_dotenv(os.path.join(_root, '.env'))
  26. # ─── 数据库连接 ──────────────────────────────────────────────────────────────
  27. def get_conn():
  28. conn = psycopg2.connect(
  29. host=os.getenv('KNOWHUB_DB'),
  30. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  31. user=os.getenv('KNOWHUB_USER'),
  32. password=os.getenv('KNOWHUB_PASSWORD'),
  33. database=os.getenv('KNOWHUB_DB_NAME'),
  34. connect_timeout=10
  35. )
  36. conn.autocommit = True
  37. return conn
  38. # ─── 工具名模糊匹配 ──────────────────────────────────────────────────────────
  39. # 已知的工具名别名映射(implements 中的名字 -> 可能出现在数据库中的名字前缀)
  40. TOOL_NAME_ALIASES = {
  41. 'ComfyUI': ['ComfyUI', 'comfyui'],
  42. 'FLUX.2 [max]': ['FLUX', 'flux', 'FLUX.2', 'Flux.2'],
  43. 'Midjourney v8': ['Midjourney', 'midjourney', 'MJ'],
  44. 'Nano Banana Pro': ['Nano Banana', 'nano banana', 'Gemini 3 Pro Image', 'Nano'],
  45. 'Seedream 5.0 Lite': ['Seedream', 'seedream'],
  46. }
  47. def normalize(s):
  48. """将字符串转为小写,去除特殊符号,用于模糊比较"""
  49. return re.sub(r'[^a-z0-9]', '', s.lower())
  50. def match_tool(impl_tool_name, db_tools):
  51. """
  52. 将 implements 中的工具名匹配到 tool_table 中的记录。
  53. 匹配策略(按优先级):
  54. 1. 精确匹配 tool.name
  55. 2. tool.name 包含 impl_tool_name(或反之)
  56. 3. 归一化后的子串匹配
  57. 4. 通过别名表匹配
  58. 返回匹配到的 tool_id,或 None
  59. """
  60. # 1. 精确匹配
  61. for tool in db_tools:
  62. if tool['name'] == impl_tool_name:
  63. return tool['id']
  64. # 2. 包含匹配
  65. for tool in db_tools:
  66. if impl_tool_name in tool['name'] or tool['name'] in impl_tool_name:
  67. return tool['id']
  68. # 3. 归一化子串匹配
  69. norm_impl = normalize(impl_tool_name)
  70. for tool in db_tools:
  71. norm_db = normalize(tool['name'])
  72. if norm_impl in norm_db or norm_db in norm_impl:
  73. return tool['id']
  74. # 4. 别名匹配
  75. for canonical, aliases in TOOL_NAME_ALIASES.items():
  76. if impl_tool_name == canonical or impl_tool_name in aliases:
  77. # 找到别名组,用所有别名去匹配数据库
  78. for alias in [canonical] + aliases:
  79. norm_alias = normalize(alias)
  80. for tool in db_tools:
  81. norm_db = normalize(tool['name'])
  82. if norm_alias in norm_db or norm_db in norm_alias:
  83. return tool['id']
  84. return None
  85. # ─── 主逻辑 ──────────────────────────────────────────────────────────────────
  86. def main():
  87. dry_run = '--dry-run' in sys.argv
  88. conn = get_conn()
  89. cur = conn.cursor(cursor_factory=RealDictCursor)
  90. print("Connected.\n")
  91. # ── Step 1: 加载 tool_table 全量数据 ──
  92. print("=== [1] Loading tool_table ===")
  93. cur.execute("SELECT id, name FROM tool_table ORDER BY id")
  94. db_tools = cur.fetchall()
  95. print(f" Found {len(db_tools)} tools:")
  96. for t in db_tools:
  97. print(f" {t['id']}: {t['name']}")
  98. # ── Step 2: 加载 atomic_capability 及其 implements ──
  99. print("\n=== [2] Loading atomic_capability.implements ===")
  100. cur.execute("SELECT id, name, implements FROM atomic_capability ORDER BY id")
  101. caps = cur.fetchall()
  102. print(f" Found {len(caps)} capabilities")
  103. # ── Step 3: 逐个 capability 匹配工具 ──
  104. print("\n=== [3] Matching capability -> tools ===")
  105. # cap_id -> [tool_ids]
  106. cap_to_tools = {}
  107. # tool_id -> [cap_ids] (反向映射)
  108. tool_to_caps = {}
  109. # 未匹配的工具名
  110. unmatched = []
  111. for cap in caps:
  112. cap_id = cap['id']
  113. implements = cap['implements']
  114. # implements 可能是 str 或 dict
  115. if isinstance(implements, str):
  116. try:
  117. implements = json.loads(implements)
  118. except json.JSONDecodeError:
  119. implements = {}
  120. if not implements:
  121. implements = {}
  122. matched_tool_ids = []
  123. for impl_tool_name in implements.keys():
  124. tool_id = match_tool(impl_tool_name, db_tools)
  125. if tool_id:
  126. matched_tool_ids.append(tool_id)
  127. # 反向映射
  128. if tool_id not in tool_to_caps:
  129. tool_to_caps[tool_id] = []
  130. if cap_id not in tool_to_caps[tool_id]:
  131. tool_to_caps[tool_id].append(cap_id)
  132. else:
  133. unmatched.append((cap_id, impl_tool_name))
  134. cap_to_tools[cap_id] = matched_tool_ids
  135. print(f" {cap_id} ({cap['name']}): {list(implements.keys())} -> {matched_tool_ids}")
  136. if unmatched:
  137. print(f"\n [!] {len(unmatched)} unmatched tool names:")
  138. for cap_id, name in unmatched:
  139. print(f" {cap_id}: \"{name}\"")
  140. # ── Step 4: 写入 atomic_capability.tools ──
  141. print(f"\n=== [4] Updating atomic_capability.tools {'(DRY RUN)' if dry_run else ''} ===")
  142. cap_updated = 0
  143. for cap_id, tool_ids in cap_to_tools.items():
  144. print(f" {cap_id}: tools = {tool_ids}")
  145. if not dry_run:
  146. cur.execute(
  147. "UPDATE atomic_capability SET tools = %s WHERE id = %s",
  148. (json.dumps(tool_ids), cap_id)
  149. )
  150. cap_updated += 1
  151. print(f" -> {cap_updated} capabilities updated")
  152. # ── Step 5: 写入 tool_table.capabilities ──
  153. print(f"\n=== [5] Updating tool_table.capabilities {'(DRY RUN)' if dry_run else ''} ===")
  154. tool_updated = 0
  155. for tool_id, cap_ids in sorted(tool_to_caps.items()):
  156. cap_ids_sorted = sorted(cap_ids)
  157. print(f" {tool_id}: capabilities = {cap_ids_sorted}")
  158. if not dry_run:
  159. cur.execute(
  160. "UPDATE tool_table SET capabilities = %s WHERE id = %s",
  161. (json.dumps(cap_ids_sorted), tool_id)
  162. )
  163. tool_updated += 1
  164. print(f" -> {tool_updated} tools updated")
  165. # ── Step 6: 验证 ──
  166. if not dry_run:
  167. print("\n=== [6] Verification ===")
  168. print("\n -- atomic_capability.tools (sample) --")
  169. cur.execute("""
  170. SELECT id, name, tools
  171. FROM atomic_capability
  172. ORDER BY id LIMIT 5
  173. """)
  174. for r in cur.fetchall():
  175. tools = r['tools'] if isinstance(r['tools'], list) else json.loads(r['tools'] or '[]')
  176. print(f" {r['id']}: {r['name']} -> tools={tools}")
  177. print("\n -- tool_table.capabilities (all with mappings) --")
  178. cur.execute("""
  179. SELECT id, name, capabilities
  180. FROM tool_table
  181. WHERE capabilities IS NOT NULL AND capabilities != '[]'::jsonb
  182. ORDER BY id
  183. """)
  184. for r in cur.fetchall():
  185. caps_list = r['capabilities'] if isinstance(r['capabilities'], list) else json.loads(r['capabilities'] or '[]')
  186. print(f" {r['id']}: {r['name']} -> caps={caps_list}")
  187. # ── 统计 ──
  188. print(f"\n=== Summary ===")
  189. print(f" Capabilities with tools: {sum(1 for v in cap_to_tools.values() if v)}/{len(cap_to_tools)}")
  190. print(f" Tools with capabilities: {len(tool_to_caps)}/{len(db_tools)}")
  191. print(f" Unmatched tool names: {len(unmatched)}")
  192. if dry_run:
  193. print(f"\n (DRY RUN mode - no changes written to database)")
  194. cur.close()
  195. conn.close()
  196. print("\nDone.")
  197. if __name__ == '__main__':
  198. main()