fill_cap_relations.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. #!/usr/bin/env python3
  2. """
  3. 补全 atomic_capability 的 implements 和 requirements 字段:
  4. 1. 从 atomic_capabilities.md 重新解析 implements(修复工具名解析)
  5. 2. 从 requirement_table.atomics 反向构建 requirements
  6. """
  7. import os, json, re, psycopg2
  8. from psycopg2.extras import RealDictCursor
  9. from dotenv import load_dotenv
  10. _dir = os.path.dirname(os.path.abspath(__file__))
  11. _root = os.path.normpath(os.path.join(_dir, '..', '..'))
  12. load_dotenv(os.path.join(_root, '.env'))
  13. def get_conn():
  14. conn = psycopg2.connect(
  15. host=os.getenv('KNOWHUB_DB'),
  16. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  17. user=os.getenv('KNOWHUB_USER'),
  18. password=os.getenv('KNOWHUB_PASSWORD'),
  19. database=os.getenv('KNOWHUB_DB_NAME'),
  20. connect_timeout=10
  21. )
  22. conn.autocommit = True
  23. return conn
  24. def parse_implements(md_path):
  25. """从 atomic_capabilities.md 解析每个 CAP 的 implements"""
  26. caps = {}
  27. current_id = None
  28. in_implements = False
  29. with open(md_path, 'r', encoding='utf-8') as f:
  30. lines = f.readlines()
  31. for line in lines:
  32. line = line.rstrip()
  33. # 匹配 ### CAP-XXX: 名称
  34. if line.startswith('### CAP-'):
  35. parts = line.split(':', 1)
  36. current_id = parts[0].replace('### ', '').strip()
  37. caps[current_id] = {}
  38. in_implements = False
  39. continue
  40. if not current_id:
  41. continue
  42. # 检测 **实现方式** 区块开始
  43. if '**实现方式**' in line:
  44. # 有些 CAP 的实现方式写在同一行:`- **实现方式**: ComfyUI: xxx`
  45. after_label = line.split('**实现方式**', 1)[1].lstrip(':').lstrip(':').strip()
  46. if after_label:
  47. # 同一行内容,可能包含多个工具用分号分隔,也可能只有一个
  48. # 例如 "ComfyUI: xxx 工作流;ReActor 节点用于换脸场景"
  49. # 先当作一整个工具条目处理
  50. split_char = ':' if ':' in after_label else ':'
  51. parts = after_label.split(split_char, 1)
  52. raw_name = parts[0].strip()
  53. desc = parts[1].strip() if len(parts) > 1 else ''
  54. tool_name = normalize_tool_name(raw_name)
  55. if tool_name and current_id:
  56. caps[current_id][tool_name] = desc
  57. in_implements = False # 单行格式,不进入多行模式
  58. else:
  59. in_implements = True
  60. continue
  61. # 检测其他 ** 区块开始(结束 implements 区块)
  62. if line.startswith('- **') and '实现方式' not in line:
  63. in_implements = False
  64. continue
  65. # 分隔线也结束区块
  66. if line.startswith('---'):
  67. in_implements = False
  68. current_id = None
  69. continue
  70. # 在 implements 区块内解析工具
  71. if in_implements and line.strip().startswith('- '):
  72. text = line.strip().lstrip('- ').strip()
  73. # 提取工具名:取冒号前的部分,但要清理掉多余内容
  74. # 例如 "ComfyUI: xxx" -> "ComfyUI"
  75. # 例如 "FLUX.2 [max]:xxx" -> "FLUX.2 [max]"
  76. # 例如 "Midjourney v8 `--cref`:xxx" -> "Midjourney v8"
  77. # 例如 "Nano Banana Pro (Gemini 3 Pro Image):xxx" -> "Nano Banana Pro"
  78. # 先用中文冒号或英文冒号分割
  79. split_char = ':' if ':' in text else ':'
  80. parts = text.split(split_char, 1)
  81. raw_name = parts[0].strip()
  82. desc = parts[1].strip() if len(parts) > 1 else ''
  83. # 规范化工具名:提取核心名称
  84. tool_name = normalize_tool_name(raw_name)
  85. if tool_name:
  86. # 如果同一个工具名已存在,追加描述
  87. if tool_name in caps[current_id]:
  88. caps[current_id][tool_name] += '; ' + desc
  89. else:
  90. caps[current_id][tool_name] = desc
  91. return caps
  92. def normalize_tool_name(raw):
  93. """规范化工具名"""
  94. # 已知的工具名映射
  95. known_tools = {
  96. 'ComfyUI': 'ComfyUI',
  97. 'Midjourney': 'Midjourney v8',
  98. 'FLUX': 'FLUX.2 [max]',
  99. 'Nano Banana': 'Nano Banana Pro',
  100. 'Seedream': 'Seedream 5.0 Lite',
  101. }
  102. for prefix, canonical in known_tools.items():
  103. if raw.startswith(prefix):
  104. return canonical
  105. # 兜底:返回原始名(去掉 markdown 语法)
  106. cleaned = re.sub(r'`[^`]*`', '', raw).strip()
  107. cleaned = re.sub(r'\([^)]*\)', '', cleaned).strip()
  108. return cleaned if len(cleaned) < 60 else cleaned[:60]
  109. def build_requirements_map(cur):
  110. """从 requirement_table.atomics 构建 CAP -> [REQ_IDs] 的反向映射"""
  111. cur.execute("SELECT id, atomics FROM requirement_table")
  112. rows = cur.fetchall()
  113. cap_to_reqs = {}
  114. for r in rows:
  115. atomics = r['atomics']
  116. if isinstance(atomics, str):
  117. atomics = json.loads(atomics)
  118. if not atomics:
  119. continue
  120. for cap_id in atomics:
  121. if cap_id not in cap_to_reqs:
  122. cap_to_reqs[cap_id] = []
  123. cap_to_reqs[cap_id].append(r['id'])
  124. return cap_to_reqs
  125. def main():
  126. conn = get_conn()
  127. cur = conn.cursor(cursor_factory=RealDictCursor)
  128. print("Connected.\n")
  129. # 1. 解析 implements
  130. md_path = os.path.join(_root, 'examples', 'tool_research', 'atomic_cap', '1', 'atomic_capabilities.md')
  131. print("=== [1] Parsing implements from MD ===")
  132. if not os.path.exists(md_path):
  133. print(f" File not found: {md_path}")
  134. return
  135. cap_implements = parse_implements(md_path)
  136. for cap_id, impl in cap_implements.items():
  137. print(f" {cap_id}: {list(impl.keys())}")
  138. # 2. 构建 requirements 反向映射
  139. print("\n=== [2] Building requirements from requirement_table ===")
  140. cap_to_reqs = build_requirements_map(cur)
  141. for cap_id, reqs in sorted(cap_to_reqs.items()):
  142. print(f" {cap_id}: {len(reqs)} requirements")
  143. # 3. 更新数据库
  144. print("\n=== [3] Updating atomic_capability ===")
  145. cur.execute("SELECT id FROM atomic_capability")
  146. all_caps = [r['id'] for r in cur.fetchall()]
  147. updated = 0
  148. for cap_id in all_caps:
  149. impl = cap_implements.get(cap_id, {})
  150. reqs = cap_to_reqs.get(cap_id, [])
  151. cur.execute("""
  152. UPDATE atomic_capability
  153. SET implements = %s, requirements = %s
  154. WHERE id = %s
  155. """, (json.dumps(impl, ensure_ascii=False), json.dumps(reqs), cap_id))
  156. updated += 1
  157. print(f" {cap_id}: {len(impl)} tools, {len(reqs)} requirements")
  158. # 4. 验证
  159. print(f"\n=== Updated {updated} capabilities ===")
  160. print("\n=== Verify ===")
  161. cur.execute("SELECT id, name, implements, requirements FROM atomic_capability ORDER BY id LIMIT 5")
  162. for r in cur.fetchall():
  163. impl = r['implements'] if isinstance(r['implements'], dict) else json.loads(r['implements'] or '{}')
  164. reqs = r['requirements'] if isinstance(r['requirements'], list) else json.loads(r['requirements'] or '[]')
  165. print(f" {r['id']}: {r['name']}")
  166. print(f" tools: {list(impl.keys())}")
  167. print(f" reqs: {reqs[:5]}{'...' if len(reqs) > 5 else ''}")
  168. cur.close()
  169. conn.close()
  170. print("\nDone.")
  171. if __name__ == '__main__':
  172. main()