| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- #!/usr/bin/env python3
- """
- 补全 atomic_capability 的 implements 和 requirements 字段:
- 1. 从 atomic_capabilities.md 重新解析 implements(修复工具名解析)
- 2. 从 requirement_table.atomics 反向构建 requirements
- """
- import os, json, re, psycopg2
- from psycopg2.extras import RealDictCursor
- from dotenv import load_dotenv
- _dir = os.path.dirname(os.path.abspath(__file__))
- _root = os.path.normpath(os.path.join(_dir, '..', '..'))
- load_dotenv(os.path.join(_root, '.env'))
- def get_conn():
- conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME'),
- connect_timeout=10
- )
- conn.autocommit = True
- return conn
- def parse_implements(md_path):
- """从 atomic_capabilities.md 解析每个 CAP 的 implements"""
- caps = {}
- current_id = None
- in_implements = False
- with open(md_path, 'r', encoding='utf-8') as f:
- lines = f.readlines()
- for line in lines:
- line = line.rstrip()
- # 匹配 ### CAP-XXX: 名称
- if line.startswith('### CAP-'):
- parts = line.split(':', 1)
- current_id = parts[0].replace('### ', '').strip()
- caps[current_id] = {}
- in_implements = False
- continue
- if not current_id:
- continue
- # 检测 **实现方式** 区块开始
- if '**实现方式**' in line:
- # 有些 CAP 的实现方式写在同一行:`- **实现方式**: ComfyUI: xxx`
- after_label = line.split('**实现方式**', 1)[1].lstrip(':').lstrip(':').strip()
- if after_label:
- # 同一行内容,可能包含多个工具用分号分隔,也可能只有一个
- # 例如 "ComfyUI: xxx 工作流;ReActor 节点用于换脸场景"
- # 先当作一整个工具条目处理
- split_char = ':' if ':' in after_label else ':'
- parts = after_label.split(split_char, 1)
- raw_name = parts[0].strip()
- desc = parts[1].strip() if len(parts) > 1 else ''
- tool_name = normalize_tool_name(raw_name)
- if tool_name and current_id:
- caps[current_id][tool_name] = desc
- in_implements = False # 单行格式,不进入多行模式
- else:
- in_implements = True
- continue
- # 检测其他 ** 区块开始(结束 implements 区块)
- if line.startswith('- **') and '实现方式' not in line:
- in_implements = False
- continue
- # 分隔线也结束区块
- if line.startswith('---'):
- in_implements = False
- current_id = None
- continue
- # 在 implements 区块内解析工具
- if in_implements and line.strip().startswith('- '):
- text = line.strip().lstrip('- ').strip()
- # 提取工具名:取冒号前的部分,但要清理掉多余内容
- # 例如 "ComfyUI: xxx" -> "ComfyUI"
- # 例如 "FLUX.2 [max]:xxx" -> "FLUX.2 [max]"
- # 例如 "Midjourney v8 `--cref`:xxx" -> "Midjourney v8"
- # 例如 "Nano Banana Pro (Gemini 3 Pro Image):xxx" -> "Nano Banana Pro"
- # 先用中文冒号或英文冒号分割
- split_char = ':' if ':' in text else ':'
- parts = text.split(split_char, 1)
- raw_name = parts[0].strip()
- desc = parts[1].strip() if len(parts) > 1 else ''
- # 规范化工具名:提取核心名称
- tool_name = normalize_tool_name(raw_name)
- if tool_name:
- # 如果同一个工具名已存在,追加描述
- if tool_name in caps[current_id]:
- caps[current_id][tool_name] += '; ' + desc
- else:
- caps[current_id][tool_name] = desc
- return caps
- def normalize_tool_name(raw):
- """规范化工具名"""
- # 已知的工具名映射
- known_tools = {
- 'ComfyUI': 'ComfyUI',
- 'Midjourney': 'Midjourney v8',
- 'FLUX': 'FLUX.2 [max]',
- 'Nano Banana': 'Nano Banana Pro',
- 'Seedream': 'Seedream 5.0 Lite',
- }
- for prefix, canonical in known_tools.items():
- if raw.startswith(prefix):
- return canonical
- # 兜底:返回原始名(去掉 markdown 语法)
- cleaned = re.sub(r'`[^`]*`', '', raw).strip()
- cleaned = re.sub(r'\([^)]*\)', '', cleaned).strip()
- return cleaned if len(cleaned) < 60 else cleaned[:60]
- def build_requirements_map(cur):
- """从 requirement_table.atomics 构建 CAP -> [REQ_IDs] 的反向映射"""
- cur.execute("SELECT id, atomics FROM requirement_table")
- rows = cur.fetchall()
- cap_to_reqs = {}
- for r in rows:
- atomics = r['atomics']
- if isinstance(atomics, str):
- atomics = json.loads(atomics)
- if not atomics:
- continue
- for cap_id in atomics:
- if cap_id not in cap_to_reqs:
- cap_to_reqs[cap_id] = []
- cap_to_reqs[cap_id].append(r['id'])
- return cap_to_reqs
- def main():
- conn = get_conn()
- cur = conn.cursor(cursor_factory=RealDictCursor)
- print("Connected.\n")
- # 1. 解析 implements
- md_path = os.path.join(_root, 'examples', 'tool_research', 'atomic_cap', '1', 'atomic_capabilities.md')
- print("=== [1] Parsing implements from MD ===")
- if not os.path.exists(md_path):
- print(f" File not found: {md_path}")
- return
- cap_implements = parse_implements(md_path)
- for cap_id, impl in cap_implements.items():
- print(f" {cap_id}: {list(impl.keys())}")
- # 2. 构建 requirements 反向映射
- print("\n=== [2] Building requirements from requirement_table ===")
- cap_to_reqs = build_requirements_map(cur)
- for cap_id, reqs in sorted(cap_to_reqs.items()):
- print(f" {cap_id}: {len(reqs)} requirements")
- # 3. 更新数据库
- print("\n=== [3] Updating atomic_capability ===")
- cur.execute("SELECT id FROM atomic_capability")
- all_caps = [r['id'] for r in cur.fetchall()]
- updated = 0
- for cap_id in all_caps:
- impl = cap_implements.get(cap_id, {})
- reqs = cap_to_reqs.get(cap_id, [])
- cur.execute("""
- UPDATE atomic_capability
- SET implements = %s, requirements = %s
- WHERE id = %s
- """, (json.dumps(impl, ensure_ascii=False), json.dumps(reqs), cap_id))
- updated += 1
- print(f" {cap_id}: {len(impl)} tools, {len(reqs)} requirements")
- # 4. 验证
- print(f"\n=== Updated {updated} capabilities ===")
- print("\n=== Verify ===")
- cur.execute("SELECT id, name, implements, requirements FROM atomic_capability ORDER BY id LIMIT 5")
- for r in cur.fetchall():
- impl = r['implements'] if isinstance(r['implements'], dict) else json.loads(r['implements'] or '{}')
- reqs = r['requirements'] if isinstance(r['requirements'], list) else json.loads(r['requirements'] or '[]')
- print(f" {r['id']}: {r['name']}")
- print(f" tools: {list(impl.keys())}")
- print(f" reqs: {reqs[:5]}{'...' if len(reqs) > 5 else ''}")
- cur.close()
- conn.close()
- print("\nDone.")
- if __name__ == '__main__':
- main()
|