#!/usr/bin/env python3 """ 为 atomic_capability、tool_table、requirement_table 生成 embedding 向量 - atomic_capability: name + description - tool_table: name + introduction - requirement_table: description """ import os, sys, json, asyncio import psycopg2 from psycopg2.extras import RealDictCursor from dotenv import load_dotenv _dir = os.path.dirname(os.path.abspath(__file__)) _root = os.path.normpath(os.path.join(_dir, '..', '..')) sys.path.insert(0, _root) load_dotenv(os.path.join(_root, '.env')) from knowhub.embeddings import get_embeddings_batch def get_conn(): conn = psycopg2.connect( host=os.getenv('KNOWHUB_DB'), port=int(os.getenv('KNOWHUB_PORT', 5432)), user=os.getenv('KNOWHUB_USER'), password=os.getenv('KNOWHUB_PASSWORD'), database=os.getenv('KNOWHUB_DB_NAME'), connect_timeout=10 ) conn.autocommit = True return conn async def fill_embeddings(): conn = get_conn() cur = conn.cursor(cursor_factory=RealDictCursor) print("Connected.\n") # ── 1. atomic_capability: name + description ── print("=== [1/3] atomic_capability ===") cur.execute("SELECT id, name, description, embedding FROM atomic_capability ORDER BY id") rows = cur.fetchall() need_embed = [r for r in rows if not r['embedding']] print(f" Total: {len(rows)}, need embedding: {len(need_embed)}") if need_embed: texts = [f"{r['name']}. {r['description'] or ''}" for r in need_embed] print(f" Generating {len(texts)} embeddings...") embeddings = await get_embeddings_batch(texts) for r, emb in zip(need_embed, embeddings): cur.execute("UPDATE atomic_capability SET embedding = %s WHERE id = %s", (emb, r['id'])) print(f" Done: {len(need_embed)} updated.") else: print(" All have embeddings, skip.") # ── 2. tool_table: name + introduction ── print("\n=== [2/3] tool_table ===") cur.execute("SELECT id, name, introduction, embedding FROM tool_table ORDER BY id") rows = cur.fetchall() need_embed = [r for r in rows if not r['embedding']] print(f" Total: {len(rows)}, need embedding: {len(need_embed)}") if need_embed: texts = [f"{r['name'] or ''}. {r['introduction'] or ''}" for r in need_embed] # 分批处理(tool_table 有 291 条) batch_size = 50 done = 0 for i in range(0, len(need_embed), batch_size): batch_rows = need_embed[i:i+batch_size] batch_texts = texts[i:i+batch_size] print(f" Batch {i//batch_size + 1}: generating {len(batch_texts)} embeddings...") embeddings = await get_embeddings_batch(batch_texts) for r, emb in zip(batch_rows, embeddings): cur.execute("UPDATE tool_table SET embedding = %s WHERE id = %s", (emb, r['id'])) done += len(batch_rows) print(f" Progress: {done}/{len(need_embed)}") print(f" Done: {len(need_embed)} updated.") else: print(" All have embeddings, skip.") # ── 3. requirement_table: description ── print("\n=== [3/3] requirement_table ===") cur.execute("SELECT id, description, embedding FROM requirement_table ORDER BY id") rows = cur.fetchall() need_embed = [r for r in rows if not r['embedding']] print(f" Total: {len(rows)}, need embedding: {len(need_embed)}") if need_embed: texts = [r['description'] or '' for r in need_embed] print(f" Generating {len(texts)} embeddings...") embeddings = await get_embeddings_batch(texts) for r, emb in zip(need_embed, embeddings): cur.execute("UPDATE requirement_table SET embedding = %s WHERE id = %s", (emb, r['id'])) print(f" Done: {len(need_embed)} updated.") else: print(" All have embeddings, skip.") # ── 4. knowledge: task_embedding (task) + content_embedding (content) ── print("\n=== [4/4] knowledge ===") cur.execute("SELECT id, task, content, task_embedding, content_embedding FROM knowledge ORDER BY id") rows = cur.fetchall() need_task = [r for r in rows if not r['task_embedding']] need_content = [r for r in rows if not r['content_embedding']] print(f" Total: {len(rows)}, need task_embedding: {len(need_task)}, need content_embedding: {len(need_content)}") batch_size = 50 if need_task: texts = [r['task'] or '' for r in need_task] done = 0 for i in range(0, len(need_task), batch_size): batch_rows = need_task[i:i+batch_size] batch_texts = texts[i:i+batch_size] print(f" task_embedding batch {i//batch_size + 1}: {len(batch_texts)} items...") embeddings = await get_embeddings_batch(batch_texts) for r, emb in zip(batch_rows, embeddings): cur.execute("UPDATE knowledge SET task_embedding = %s WHERE id = %s", (emb, r['id'])) done += len(batch_rows) print(f" Progress: {done}/{len(need_task)}") print(f" task_embedding done: {len(need_task)} updated.") else: print(" All have task_embedding, skip.") if need_content: texts = [r['content'] or '' for r in need_content] done = 0 for i in range(0, len(need_content), batch_size): batch_rows = need_content[i:i+batch_size] batch_texts = texts[i:i+batch_size] print(f" content_embedding batch {i//batch_size + 1}: {len(batch_texts)} items...") embeddings = await get_embeddings_batch(batch_texts) for r, emb in zip(batch_rows, embeddings): cur.execute("UPDATE knowledge SET content_embedding = %s WHERE id = %s", (emb, r['id'])) done += len(batch_rows) print(f" Progress: {done}/{len(need_content)}") print(f" content_embedding done: {len(need_content)} updated.") else: print(" All have content_embedding, skip.") # ── 验证 ── print("\n=== Verify ===") for table in ['atomic_capability', 'tool_table', 'requirement_table']: cur.execute(f"SELECT COUNT(*) as total FROM {table}") total = cur.fetchone()['total'] cur.execute(f"SELECT COUNT(*) as cnt FROM {table} WHERE embedding IS NOT NULL") has_emb = cur.fetchone()['cnt'] print(f" {table}: {has_emb}/{total} have embedding") # knowledge 双向量 cur.execute("SELECT COUNT(*) as total FROM knowledge") total = cur.fetchone()['total'] cur.execute("SELECT COUNT(*) as cnt FROM knowledge WHERE task_embedding IS NOT NULL") has_task = cur.fetchone()['cnt'] cur.execute("SELECT COUNT(*) as cnt FROM knowledge WHERE content_embedding IS NOT NULL") has_content = cur.fetchone()['cnt'] print(f" knowledge: task_embedding {has_task}/{total}, content_embedding {has_content}/{total}") cur.close() conn.close() print("\nDone.") if __name__ == '__main__': asyncio.run(fill_embeddings())