taodev_fill_capability_tool.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python3
  2. """
  3. 补丁:为 tao_dev 的 capability 从源 JSON 的 implements 字段构造 capability_tool junction。
  4. 策略(与 dev_abstract 现有数据一致):
  5. - 把 implements 的 key 原样写入 tool_id(可能是路径 "tools/workflow/comfyui"、
  6. 下划线名 "ji_meng_add_task"、或人类可读名 "ComfyUI")
  7. - value(描述字符串)写入 capability_tool.description
  8. - 不做任何 canonical 映射,不改 tool 表
  9. 每 folder 的 cap 用 {orig_req_id}::{raw_cap_id or 'NEW-<idx>'} 重建 mapping。
  10. """
  11. import json, sys, time
  12. from pathlib import Path
  13. import psycopg2.extras
  14. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  15. from knowhub.knowhub_db.pg_capability_store import PostgreSQLCapabilityStore
  16. OUTPUT = Path('/Users/sunlit/Downloads/output-new')
  17. def main():
  18. s = PostgreSQLCapabilityStore()
  19. cur = s._get_cursor()
  20. try:
  21. cur.execute("SET statement_timeout = '120s'")
  22. cur.execute("""SELECT pid FROM pg_stat_activity WHERE state='idle in transaction'
  23. AND pid!=pg_backend_pid() AND datname=current_database()""")
  24. for r in cur.fetchall():
  25. cur.execute('SELECT pg_terminate_backend(%s)', (r['pid'],))
  26. # v0 req 文本 → orig_req_id 映射
  27. cur.execute("SELECT id, description FROM requirement WHERE version='v0'")
  28. req_map = {r['description']: r['id'] for r in cur.fetchall()}
  29. print(f'v0 req 映射: {len(req_map)}', flush=True)
  30. # 建立 tao_dev cap 集合(快速校验)
  31. cur.execute("SELECT id FROM capability WHERE version='tao_dev'")
  32. valid_caps = {r['id'] for r in cur.fetchall()}
  33. print(f'tao_dev capability 现有: {len(valid_caps)}', flush=True)
  34. folders = sorted([f for f in OUTPUT.iterdir() if f.is_dir()])
  35. stats = {'inserted': 0, 'skipped_no_cap': 0, 'total_implements': 0}
  36. for folder in folders:
  37. t0 = time.time()
  38. cd = json.loads((folder / 'capabilities_extracted.json').read_text(encoding='utf-8'))
  39. req_text = cd.get('requirement')
  40. orig_req = req_map.get(req_text)
  41. if not orig_req:
  42. print(f' [{folder.name}] 无法匹配 req', flush=True); continue
  43. for idx, c in enumerate(cd.get('extracted_capabilities', [])):
  44. if not isinstance(c, dict): continue
  45. raw_id = (c.get('id') or '').strip()
  46. cap_id = f'{orig_req}::{raw_id}' if raw_id else f'{orig_req}::NEW-{idx}'
  47. if cap_id not in valid_caps:
  48. stats['skipped_no_cap'] += 1; continue
  49. implements = c.get('implements') or {}
  50. if not isinstance(implements, dict): continue
  51. for tool_key, desc in implements.items():
  52. stats['total_implements'] += 1
  53. cur.execute("""INSERT INTO capability_tool (capability_id, tool_id, description)
  54. VALUES (%s,%s,%s) ON CONFLICT DO NOTHING""",
  55. (cap_id, tool_key, str(desc) if desc is not None else ''))
  56. stats['inserted'] += cur.rowcount or 0
  57. print(f' [{folder.name}] {orig_req}: {time.time()-t0:.1f}s', flush=True)
  58. print('\n=== cap_tool 补丁统计 ===', flush=True)
  59. for k, v in stats.items():
  60. print(f' {k}: {v}', flush=True)
  61. cur.execute("""SELECT COUNT(*) c FROM capability_tool ct
  62. JOIN capability c ON c.id=ct.capability_id
  63. WHERE c.version='tao_dev'""")
  64. print(f' tao_dev cap_tool 总计: {cur.fetchone()["c"]}', flush=True)
  65. # 独特 tool_key 数量
  66. cur.execute("""SELECT COUNT(DISTINCT ct.tool_id) c FROM capability_tool ct
  67. JOIN capability c ON c.id=ct.capability_id
  68. WHERE c.version='tao_dev'""")
  69. print(f' tao_dev 独特 tool_key: {cur.fetchone()["c"]}', flush=True)
  70. finally:
  71. cur.close(); s.close()
  72. if __name__ == '__main__':
  73. main()