fill_embeddings.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #!/usr/bin/env python3
  2. """
  3. 为 atomic_capability、tool_table、requirement_table 生成 embedding 向量
  4. - atomic_capability: name + description
  5. - tool_table: name + introduction
  6. - requirement_table: description
  7. """
  8. import os, sys, json, asyncio
  9. import psycopg2
  10. from psycopg2.extras import RealDictCursor
  11. from dotenv import load_dotenv
  12. _dir = os.path.dirname(os.path.abspath(__file__))
  13. _root = os.path.normpath(os.path.join(_dir, '..', '..'))
  14. sys.path.insert(0, _root)
  15. load_dotenv(os.path.join(_root, '.env'))
  16. from knowhub.embeddings import get_embeddings_batch
  17. def get_conn():
  18. conn = psycopg2.connect(
  19. host=os.getenv('KNOWHUB_DB'),
  20. port=int(os.getenv('KNOWHUB_PORT', 5432)),
  21. user=os.getenv('KNOWHUB_USER'),
  22. password=os.getenv('KNOWHUB_PASSWORD'),
  23. database=os.getenv('KNOWHUB_DB_NAME'),
  24. connect_timeout=10
  25. )
  26. conn.autocommit = True
  27. return conn
  28. async def fill_embeddings():
  29. conn = get_conn()
  30. cur = conn.cursor(cursor_factory=RealDictCursor)
  31. print("Connected.\n")
  32. # ── 1. atomic_capability: name + description ──
  33. print("=== [1/3] atomic_capability ===")
  34. cur.execute("SELECT id, name, description, embedding FROM atomic_capability ORDER BY id")
  35. rows = cur.fetchall()
  36. need_embed = [r for r in rows if not r['embedding']]
  37. print(f" Total: {len(rows)}, need embedding: {len(need_embed)}")
  38. if need_embed:
  39. texts = [f"{r['name']}. {r['description'] or ''}" for r in need_embed]
  40. print(f" Generating {len(texts)} embeddings...")
  41. embeddings = await get_embeddings_batch(texts)
  42. for r, emb in zip(need_embed, embeddings):
  43. cur.execute("UPDATE atomic_capability SET embedding = %s WHERE id = %s", (emb, r['id']))
  44. print(f" Done: {len(need_embed)} updated.")
  45. else:
  46. print(" All have embeddings, skip.")
  47. # ── 2. tool_table: name + introduction ──
  48. print("\n=== [2/3] tool_table ===")
  49. cur.execute("SELECT id, name, introduction, embedding FROM tool_table ORDER BY id")
  50. rows = cur.fetchall()
  51. need_embed = [r for r in rows if not r['embedding']]
  52. print(f" Total: {len(rows)}, need embedding: {len(need_embed)}")
  53. if need_embed:
  54. texts = [f"{r['name'] or ''}. {r['introduction'] or ''}" for r in need_embed]
  55. # 分批处理(tool_table 有 291 条)
  56. batch_size = 50
  57. done = 0
  58. for i in range(0, len(need_embed), batch_size):
  59. batch_rows = need_embed[i:i+batch_size]
  60. batch_texts = texts[i:i+batch_size]
  61. print(f" Batch {i//batch_size + 1}: generating {len(batch_texts)} embeddings...")
  62. embeddings = await get_embeddings_batch(batch_texts)
  63. for r, emb in zip(batch_rows, embeddings):
  64. cur.execute("UPDATE tool_table SET embedding = %s WHERE id = %s", (emb, r['id']))
  65. done += len(batch_rows)
  66. print(f" Progress: {done}/{len(need_embed)}")
  67. print(f" Done: {len(need_embed)} updated.")
  68. else:
  69. print(" All have embeddings, skip.")
  70. # ── 3. requirement_table: description ──
  71. print("\n=== [3/3] requirement_table ===")
  72. cur.execute("SELECT id, description, embedding FROM requirement_table ORDER BY id")
  73. rows = cur.fetchall()
  74. need_embed = [r for r in rows if not r['embedding']]
  75. print(f" Total: {len(rows)}, need embedding: {len(need_embed)}")
  76. if need_embed:
  77. texts = [r['description'] or '' for r in need_embed]
  78. print(f" Generating {len(texts)} embeddings...")
  79. embeddings = await get_embeddings_batch(texts)
  80. for r, emb in zip(need_embed, embeddings):
  81. cur.execute("UPDATE requirement_table SET embedding = %s WHERE id = %s", (emb, r['id']))
  82. print(f" Done: {len(need_embed)} updated.")
  83. else:
  84. print(" All have embeddings, skip.")
  85. # ── 4. knowledge: task_embedding (task) + content_embedding (content) ──
  86. print("\n=== [4/4] knowledge ===")
  87. cur.execute("SELECT id, task, content, task_embedding, content_embedding FROM knowledge ORDER BY id")
  88. rows = cur.fetchall()
  89. need_task = [r for r in rows if not r['task_embedding']]
  90. need_content = [r for r in rows if not r['content_embedding']]
  91. print(f" Total: {len(rows)}, need task_embedding: {len(need_task)}, need content_embedding: {len(need_content)}")
  92. batch_size = 50
  93. if need_task:
  94. texts = [r['task'] or '' for r in need_task]
  95. done = 0
  96. for i in range(0, len(need_task), batch_size):
  97. batch_rows = need_task[i:i+batch_size]
  98. batch_texts = texts[i:i+batch_size]
  99. print(f" task_embedding batch {i//batch_size + 1}: {len(batch_texts)} items...")
  100. embeddings = await get_embeddings_batch(batch_texts)
  101. for r, emb in zip(batch_rows, embeddings):
  102. cur.execute("UPDATE knowledge SET task_embedding = %s WHERE id = %s", (emb, r['id']))
  103. done += len(batch_rows)
  104. print(f" Progress: {done}/{len(need_task)}")
  105. print(f" task_embedding done: {len(need_task)} updated.")
  106. else:
  107. print(" All have task_embedding, skip.")
  108. if need_content:
  109. texts = [r['content'] or '' for r in need_content]
  110. done = 0
  111. for i in range(0, len(need_content), batch_size):
  112. batch_rows = need_content[i:i+batch_size]
  113. batch_texts = texts[i:i+batch_size]
  114. print(f" content_embedding batch {i//batch_size + 1}: {len(batch_texts)} items...")
  115. embeddings = await get_embeddings_batch(batch_texts)
  116. for r, emb in zip(batch_rows, embeddings):
  117. cur.execute("UPDATE knowledge SET content_embedding = %s WHERE id = %s", (emb, r['id']))
  118. done += len(batch_rows)
  119. print(f" Progress: {done}/{len(need_content)}")
  120. print(f" content_embedding done: {len(need_content)} updated.")
  121. else:
  122. print(" All have content_embedding, skip.")
  123. # ── 验证 ──
  124. print("\n=== Verify ===")
  125. for table in ['atomic_capability', 'tool_table', 'requirement_table']:
  126. cur.execute(f"SELECT COUNT(*) as total FROM {table}")
  127. total = cur.fetchone()['total']
  128. cur.execute(f"SELECT COUNT(*) as cnt FROM {table} WHERE embedding IS NOT NULL")
  129. has_emb = cur.fetchone()['cnt']
  130. print(f" {table}: {has_emb}/{total} have embedding")
  131. # knowledge 双向量
  132. cur.execute("SELECT COUNT(*) as total FROM knowledge")
  133. total = cur.fetchone()['total']
  134. cur.execute("SELECT COUNT(*) as cnt FROM knowledge WHERE task_embedding IS NOT NULL")
  135. has_task = cur.fetchone()['cnt']
  136. cur.execute("SELECT COUNT(*) as cnt FROM knowledge WHERE content_embedding IS NOT NULL")
  137. has_content = cur.fetchone()['cnt']
  138. print(f" knowledge: task_embedding {has_task}/{total}, content_embedding {has_content}/{total}")
  139. cur.close()
  140. conn.close()
  141. print("\nDone.")
  142. if __name__ == '__main__':
  143. asyncio.run(fill_embeddings())