| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- #!/usr/bin/env python3
- """
- 为 atomic_capability、tool_table、requirement_table 生成 embedding 向量
- - atomic_capability: name + description
- - tool_table: name + introduction
- - requirement_table: description
- """
- import os, sys, json, asyncio
- import psycopg2
- from psycopg2.extras import RealDictCursor
- from dotenv import load_dotenv
- _dir = os.path.dirname(os.path.abspath(__file__))
- _root = os.path.normpath(os.path.join(_dir, '..', '..'))
- sys.path.insert(0, _root)
- load_dotenv(os.path.join(_root, '.env'))
- from knowhub.embeddings import get_embeddings_batch
- def get_conn():
- conn = psycopg2.connect(
- host=os.getenv('KNOWHUB_DB'),
- port=int(os.getenv('KNOWHUB_PORT', 5432)),
- user=os.getenv('KNOWHUB_USER'),
- password=os.getenv('KNOWHUB_PASSWORD'),
- database=os.getenv('KNOWHUB_DB_NAME'),
- connect_timeout=10
- )
- conn.autocommit = True
- return conn
- async def fill_embeddings():
- conn = get_conn()
- cur = conn.cursor(cursor_factory=RealDictCursor)
- print("Connected.\n")
- # ── 1. atomic_capability: name + description ──
- print("=== [1/3] atomic_capability ===")
- cur.execute("SELECT id, name, description, embedding FROM atomic_capability ORDER BY id")
- rows = cur.fetchall()
- need_embed = [r for r in rows if not r['embedding']]
- print(f" Total: {len(rows)}, need embedding: {len(need_embed)}")
- if need_embed:
- texts = [f"{r['name']}. {r['description'] or ''}" for r in need_embed]
- print(f" Generating {len(texts)} embeddings...")
- embeddings = await get_embeddings_batch(texts)
- for r, emb in zip(need_embed, embeddings):
- cur.execute("UPDATE atomic_capability SET embedding = %s WHERE id = %s", (emb, r['id']))
- print(f" Done: {len(need_embed)} updated.")
- else:
- print(" All have embeddings, skip.")
- # ── 2. tool_table: name + introduction ──
- print("\n=== [2/3] tool_table ===")
- cur.execute("SELECT id, name, introduction, embedding FROM tool_table ORDER BY id")
- rows = cur.fetchall()
- need_embed = [r for r in rows if not r['embedding']]
- print(f" Total: {len(rows)}, need embedding: {len(need_embed)}")
- if need_embed:
- texts = [f"{r['name'] or ''}. {r['introduction'] or ''}" for r in need_embed]
- # 分批处理(tool_table 有 291 条)
- batch_size = 50
- done = 0
- for i in range(0, len(need_embed), batch_size):
- batch_rows = need_embed[i:i+batch_size]
- batch_texts = texts[i:i+batch_size]
- print(f" Batch {i//batch_size + 1}: generating {len(batch_texts)} embeddings...")
- embeddings = await get_embeddings_batch(batch_texts)
- for r, emb in zip(batch_rows, embeddings):
- cur.execute("UPDATE tool_table SET embedding = %s WHERE id = %s", (emb, r['id']))
- done += len(batch_rows)
- print(f" Progress: {done}/{len(need_embed)}")
- print(f" Done: {len(need_embed)} updated.")
- else:
- print(" All have embeddings, skip.")
- # ── 3. requirement_table: description ──
- print("\n=== [3/3] requirement_table ===")
- cur.execute("SELECT id, description, embedding FROM requirement_table ORDER BY id")
- rows = cur.fetchall()
- need_embed = [r for r in rows if not r['embedding']]
- print(f" Total: {len(rows)}, need embedding: {len(need_embed)}")
- if need_embed:
- texts = [r['description'] or '' for r in need_embed]
- print(f" Generating {len(texts)} embeddings...")
- embeddings = await get_embeddings_batch(texts)
- for r, emb in zip(need_embed, embeddings):
- cur.execute("UPDATE requirement_table SET embedding = %s WHERE id = %s", (emb, r['id']))
- print(f" Done: {len(need_embed)} updated.")
- else:
- print(" All have embeddings, skip.")
- # ── 4. knowledge: task_embedding (task) + content_embedding (content) ──
- print("\n=== [4/4] knowledge ===")
- cur.execute("SELECT id, task, content, task_embedding, content_embedding FROM knowledge ORDER BY id")
- rows = cur.fetchall()
- need_task = [r for r in rows if not r['task_embedding']]
- need_content = [r for r in rows if not r['content_embedding']]
- print(f" Total: {len(rows)}, need task_embedding: {len(need_task)}, need content_embedding: {len(need_content)}")
- batch_size = 50
- if need_task:
- texts = [r['task'] or '' for r in need_task]
- done = 0
- for i in range(0, len(need_task), batch_size):
- batch_rows = need_task[i:i+batch_size]
- batch_texts = texts[i:i+batch_size]
- print(f" task_embedding batch {i//batch_size + 1}: {len(batch_texts)} items...")
- embeddings = await get_embeddings_batch(batch_texts)
- for r, emb in zip(batch_rows, embeddings):
- cur.execute("UPDATE knowledge SET task_embedding = %s WHERE id = %s", (emb, r['id']))
- done += len(batch_rows)
- print(f" Progress: {done}/{len(need_task)}")
- print(f" task_embedding done: {len(need_task)} updated.")
- else:
- print(" All have task_embedding, skip.")
- if need_content:
- texts = [r['content'] or '' for r in need_content]
- done = 0
- for i in range(0, len(need_content), batch_size):
- batch_rows = need_content[i:i+batch_size]
- batch_texts = texts[i:i+batch_size]
- print(f" content_embedding batch {i//batch_size + 1}: {len(batch_texts)} items...")
- embeddings = await get_embeddings_batch(batch_texts)
- for r, emb in zip(batch_rows, embeddings):
- cur.execute("UPDATE knowledge SET content_embedding = %s WHERE id = %s", (emb, r['id']))
- done += len(batch_rows)
- print(f" Progress: {done}/{len(need_content)}")
- print(f" content_embedding done: {len(need_content)} updated.")
- else:
- print(" All have content_embedding, skip.")
- # ── 验证 ──
- print("\n=== Verify ===")
- for table in ['atomic_capability', 'tool_table', 'requirement_table']:
- cur.execute(f"SELECT COUNT(*) as total FROM {table}")
- total = cur.fetchone()['total']
- cur.execute(f"SELECT COUNT(*) as cnt FROM {table} WHERE embedding IS NOT NULL")
- has_emb = cur.fetchone()['cnt']
- print(f" {table}: {has_emb}/{total} have embedding")
- # knowledge 双向量
- cur.execute("SELECT COUNT(*) as total FROM knowledge")
- total = cur.fetchone()['total']
- cur.execute("SELECT COUNT(*) as cnt FROM knowledge WHERE task_embedding IS NOT NULL")
- has_task = cur.fetchone()['cnt']
- cur.execute("SELECT COUNT(*) as cnt FROM knowledge WHERE content_embedding IS NOT NULL")
- has_content = cur.fetchone()['cnt']
- print(f" knowledge: task_embedding {has_task}/{total}, content_embedding {has_content}/{total}")
- cur.close()
- conn.close()
- print("\nDone.")
- if __name__ == '__main__':
- asyncio.run(fill_embeddings())
|