浏览代码

migrate: move knowhub data to milvus

guantao 2 小时之前
父节点
当前提交
40a5a77318
共有 4 个文件被更改,包括 242 次插入33 次删除
  1. 4 1
      .gitignore
  2. 64 29
      knowhub/server.py
  3. 6 3
      knowhub/vector_store.py
  4. 168 0
      migrate_knowledge.py

+ 4 - 1
.gitignore

@@ -69,4 +69,7 @@ frontend/react-template/yarn.lock
 # data
 knowhub/knowhub.db
 knowhub/knowhub.db-shm
-knowhub/knowhub.db-wal
+knowhub/knowhub.db-wal
+
+# Milvus data
+knowhub/milvus_data/

+ 64 - 29
knowhub/server.py

@@ -136,6 +136,31 @@ def decrypt_content(resource_id: str, encrypted_text: str, provided_key: Optiona
         return "[ENCRYPTED]"
 
 
+def serialize_milvus_result(data):
+    """将 Milvus 返回的数据转换为可序列化的字典"""
+    if data is None:
+        return None
+    elif isinstance(data, (str, int, float, bool)):
+        return data
+    elif isinstance(data, dict):
+        return {k: serialize_milvus_result(v) for k, v in data.items()}
+    elif isinstance(data, (list, tuple)):
+        return [serialize_milvus_result(item) for item in data]
+    elif hasattr(data, '__iter__') and not isinstance(data, (str, bytes)):
+        # 处理 RepeatedScalarContainer 等可迭代对象
+        return [serialize_milvus_result(item) for item in data]
+    elif hasattr(data, 'to_dict'):
+        return serialize_milvus_result(data.to_dict())
+    elif hasattr(data, '__dict__'):
+        return serialize_milvus_result(vars(data))
+    else:
+        # 尝试转换为字符串
+        try:
+            return str(data)
+        except:
+            return None
+
+
 def init_db():
     """初始化 SQLite(仅用于 resources)"""
     conn = get_db()
@@ -566,7 +591,7 @@ async def search_knowledge_api(
         if types:
             type_list = [t.strip() for t in types.split(',') if t.strip()]
             for t in type_list:
-                filters.append(f'JSON_CONTAINS(types, "{t}")')
+                filters.append(f'ARRAY_CONTAINS(types, "{t}")')
         if owner:
             filters.append(f'owner == "{owner}"')
 
@@ -586,18 +611,21 @@ async def search_knowledge_api(
         if not candidates:
             return {"results": [], "count": 0, "reranked": False}
 
+        # 转换为可序列化的格式
+        serialized_candidates = [serialize_milvus_result(c) for c in candidates]
+
         # 4. LLM 精排
-        reranked_ids = await _llm_rerank(q, candidates, top_k)
+        reranked_ids = await _llm_rerank(q, serialized_candidates, top_k)
 
         if reranked_ids:
             # 按 LLM 排序返回
-            id_to_candidate = {c["id"]: c for c in candidates}
+            id_to_candidate = {c["id"]: c for c in serialized_candidates}
             results = [id_to_candidate[id] for id in reranked_ids if id in id_to_candidate]
             return {"results": results, "count": len(results), "reranked": True}
         else:
             # Fallback:直接返回向量召回的 top k
             print(f"[Knowledge Search] LLM 精排失败,fallback 到向量 top-{top_k}")
-            return {"results": candidates[:top_k], "count": len(candidates[:top_k]), "reranked": False}
+            return {"results": serialized_candidates[:top_k], "count": len(serialized_candidates[:top_k]), "reranked": False}
 
     except Exception as e:
         print(f"[Knowledge Search] 错误: {e}")
@@ -639,12 +667,11 @@ async def save_knowledge(knowledge: KnowledgeIn):
             "harmful_history": []
         }
 
-        # 生成向量
-        text = f"{knowledge.task}\n{knowledge.content}"
-        embedding = await get_embedding(text)
+        # 生成向量(只基于 task,因为搜索时用户描述的是任务场景)
+        embedding = await get_embedding(knowledge.task)
 
-        # 插入 Milvus
-        milvus_store.insert({
+        # 准备插入数据
+        insert_data = {
             "id": knowledge_id,
             "embedding": embedding,
             "message_id": knowledge.message_id,
@@ -659,7 +686,12 @@ async def save_knowledge(knowledge: KnowledgeIn):
             "eval": eval_data,
             "created_at": now,
             "updated_at": now,
-        })
+        }
+
+        print(f"[Save Knowledge] 插入数据: {json.dumps({k: v for k, v in insert_data.items() if k != 'embedding'}, ensure_ascii=False)}")
+
+        # 插入 Milvus
+        milvus_store.insert(insert_data)
 
         return {"status": "ok", "knowledge_id": knowledge_id}
 
@@ -685,10 +717,10 @@ def list_knowledge(
         if types:
             type_list = [t.strip() for t in types.split(',') if t.strip()]
             for t in type_list:
-                filters.append(f'JSON_CONTAINS(types, "{t}")')
+                filters.append(f'ARRAY_CONTAINS(types, "{t}")')
 
         if scopes:
-            filters.append(f'JSON_CONTAINS(scopes, "{scopes}")')
+            filters.append(f'ARRAY_CONTAINS(scopes, "{scopes}")')
 
         if owner:
             filters.append(f'owner like "%{owner}%"')
@@ -705,7 +737,10 @@ def list_knowledge(
         # 查询 Milvus
         results = milvus_store.query(filter_expr, limit=limit)
 
-        return {"results": results, "count": len(results)}
+        # 转换为可序列化的格式
+        serialized_results = [serialize_milvus_result(r) for r in results]
+
+        return {"results": serialized_results, "count": len(serialized_results)}
 
     except Exception as e:
         print(f"[List Knowledge] 错误: {e}")
@@ -721,7 +756,9 @@ def get_all_tags():
 
         all_tags = set()
         for item in results:
-            tags_dict = item.get("tags", {})
+            # 转换为标准字典
+            serialized_item = serialize_milvus_result(item)
+            tags_dict = serialized_item.get("tags", {})
             if isinstance(tags_dict, dict):
                 for key in tags_dict.keys():
                     all_tags.add(key)
@@ -742,7 +779,7 @@ def get_knowledge(knowledge_id: str):
         if not result:
             raise HTTPException(status_code=404, detail=f"Knowledge not found: {knowledge_id}")
 
-        return result
+        return serialize_milvus_result(result)
 
     except HTTPException:
         raise
@@ -827,8 +864,7 @@ async def update_knowledge(knowledge_id: str, update: KnowledgeUpdateIn):
 
         # 如果内容变化,重新生成向量
         if need_reembed:
-            text = f"{existing['task']}\n{content}"
-            embedding = await get_embedding(text)
+            embedding = await get_embedding(existing['task'])
             updates["embedding"] = embedding
 
         # 更新 Milvus
@@ -861,7 +897,7 @@ async def patch_knowledge(knowledge_id: str, patch: KnowledgePatchIn):
 
         if patch.content is not None:
             updates["content"] = patch.content
-            need_reembed = True
+            # content 变化不需要重新生成 embedding(只基于 task)
 
         if patch.types is not None:
             updates["types"] = patch.types
@@ -878,12 +914,10 @@ async def patch_knowledge(knowledge_id: str, patch: KnowledgePatchIn):
         if not updates:
             return {"status": "ok", "knowledge_id": knowledge_id}
 
-        # 如果 task 或 content 变化,重新生成向量
+        # 如果 task 变化,重新生成向量
         if need_reembed:
             task = updates.get("task", existing["task"])
-            content = updates.get("content", existing["content"])
-            text = f"{task}\n{content}"
-            embedding = await get_embedding(text)
+            embedding = await get_embedding(task)
             updates["embedding"] = embedding
 
         # 更新 Milvus
@@ -947,9 +981,8 @@ async def batch_update_knowledge(batch: KnowledgeBatchUpdateIn):
             for (knowledge_id, _, _, eval_data, task), evolved_content in zip(evolution_tasks, evolved_results):
                 eval_data["helpful"] = eval_data.get("helpful", 0) + 1
 
-                # 重新生成向量
-                text = f"{task}\n{evolved_content}"
-                embedding = await get_embedding(text)
+                # 重新生成向量(只基于 task)
+                embedding = await get_embedding(task)
 
                 milvus_store.update(knowledge_id, {
                     "content": evolved_content,
@@ -970,6 +1003,8 @@ async def slim_knowledge(model: str = "google/gemini-2.5-flash-lite"):
     try:
         # 获取所有知识
         all_knowledge = milvus_store.query('id != ""', limit=10000)
+        # 转换为可序列化的格式
+        all_knowledge = [serialize_milvus_result(item) for item in all_knowledge]
 
         if len(all_knowledge) < 2:
             return {"status": "ok", "message": f"知识库仅有 {len(all_knowledge)} 条,无需瘦身"}
@@ -1087,8 +1122,8 @@ REPORT: 原有 X 条,合并后 Y 条,精简了 Z 条。
         # 生成向量并重建知识库
         print(f"[知识瘦身] 正在为 {len(new_entries)} 条知识生成向量...")
 
-        # 批量生成向量
-        texts = [f"{e['task']}\n{e['content']}" for e in new_entries]
+        # 批量生成向量(只基于 task)
+        texts = [e['task'] for e in new_entries]
         embeddings = await get_embeddings_batch(texts)
 
         # 清空并重建
@@ -1231,8 +1266,8 @@ async def extract_knowledge_from_messages(extract_req: MessageExtractIn):
         if not extracted_knowledge:
             return {"status": "ok", "extracted_count": 0, "knowledge_ids": []}
 
-        # 批量生成向量
-        texts = [f"{item.get('task', '')}\n{item.get('content', '')}" for item in extracted_knowledge]
+        # 批量生成向量(只基于 task)
+        texts = [item.get('task', '') for item in extracted_knowledge]
         embeddings = await get_embeddings_batch(texts)
 
         # 保存提取的知识

+ 6 - 3
knowhub/vector_store.py

@@ -76,12 +76,15 @@ class MilvusStore:
                            max_length=2000),
                 FieldSchema(name="content", dtype=DataType.VARCHAR,
                            max_length=50000),
-                FieldSchema(name="types", dtype=DataType.JSON),
+                FieldSchema(name="types", dtype=DataType.ARRAY,
+                           element_type=DataType.VARCHAR, max_capacity=20, max_length=50),
                 FieldSchema(name="tags", dtype=DataType.JSON),
-                FieldSchema(name="scopes", dtype=DataType.JSON),
+                FieldSchema(name="scopes", dtype=DataType.ARRAY,
+                           element_type=DataType.VARCHAR, max_capacity=20, max_length=100),
                 FieldSchema(name="owner", dtype=DataType.VARCHAR,
                            max_length=200),
-                FieldSchema(name="resource_ids", dtype=DataType.JSON),
+                FieldSchema(name="resource_ids", dtype=DataType.ARRAY,
+                           element_type=DataType.VARCHAR, max_capacity=50, max_length=200),
                 FieldSchema(name="source", dtype=DataType.JSON),
                 FieldSchema(name="eval", dtype=DataType.JSON),
                 FieldSchema(name="created_at", dtype=DataType.INT64),

+ 168 - 0
migrate_knowledge.py

@@ -0,0 +1,168 @@
+#!/usr/bin/env python
+"""
+知识库迁移脚本: SQLite -> Milvus
+从旧的 SQLite 数据库迁移知识数据到新的 Milvus 向量数据库
+"""
+
+import sys
+import json
+import sqlite3
+import asyncio
+from pathlib import Path
+from datetime import datetime
+
+# 添加项目路径
+sys.path.insert(0, str(Path(__file__).parent))
+
+from knowhub.vector_store import MilvusStore
+from knowhub.embeddings import get_embeddings_batch
+
+
+async def migrate_knowledge():
+    """迁移知识数据"""
+
+    # 源数据库路径
+    source_db = Path.home() / "main_agent/knowhub/knowhub.db"
+    if not source_db.exists():
+        print(f"❌ 源数据库不存在: {source_db}")
+        return
+
+    # 目标 Milvus 存储
+    milvus_data_dir = Path(__file__).parent / "knowhub/milvus_data"
+    target_store = MilvusStore(str(milvus_data_dir))
+
+    print(f"📂 源数据库: {source_db}")
+    print(f"📂 目标 Milvus: {milvus_data_dir}")
+    print(f"📊 当前 Milvus 中的知识数量: {target_store.count()}")
+
+    # 读取源数据
+    print("\n📖 正在读取源数据...")
+    conn = sqlite3.connect(str(source_db))
+    conn.row_factory = sqlite3.Row
+    cursor = conn.cursor()
+
+    cursor.execute("SELECT * FROM knowledge ORDER BY created_at")
+    rows = cursor.fetchall()
+    conn.close()
+
+    print(f"✅ 读取到 {len(rows)} 条知识数据")
+
+    if len(rows) == 0:
+        print("⚠️  没有数据需要迁移")
+        return
+
+    # 显示迁移信息
+    print(f"\n⚠️  即将迁移 {len(rows)} 条知识到 Milvus")
+    print(f"   当前 Milvus 中已有 {target_store.count()} 条知识")
+    print("   开始迁移...")
+
+    # 转换数据格式
+    print("\n🔄 正在转换数据格式...")
+    knowledge_list = []
+    tasks = []  # 用于批量生成 embedding
+
+    for row in rows:
+        try:
+            # 解析 JSON 字段
+            types = json.loads(row['types']) if row['types'] else ["strategy"]
+            tags = json.loads(row['tags']) if row['tags'] else {}
+            scopes = json.loads(row['scopes']) if row['scopes'] else ["org:cybertogether"]
+            source = json.loads(row['source']) if row['source'] else {}
+            eval_data = json.loads(row['eval']) if row['eval'] else {
+                "score": 3, "helpful": 1, "harmful": 0, "confidence": 0.5,
+                "helpful_history": [], "harmful_history": []
+            }
+            resource_ids = json.loads(row['resource_ids']) if row['resource_ids'] else []
+
+            # 解析时间戳
+            created_at = row['created_at']
+            updated_at = row['updated_at'] if row['updated_at'] else created_at
+
+            # 转换为时间戳(如果是 ISO 格式字符串)
+            if isinstance(created_at, str):
+                try:
+                    created_at = int(datetime.fromisoformat(created_at.replace('Z', '+00:00')).timestamp())
+                except:
+                    created_at = int(datetime.now().timestamp())
+
+            if isinstance(updated_at, str):
+                try:
+                    updated_at = int(datetime.fromisoformat(updated_at.replace('Z', '+00:00')).timestamp())
+                except:
+                    updated_at = created_at
+
+            knowledge_list.append({
+                "id": row['id'],
+                "message_id": row['message_id'] or "",
+                "task": row['task'],
+                "content": row['content'],
+                "types": types,
+                "tags": tags,
+                "scopes": scopes,
+                "owner": row['owner'] or "agent:unknown",
+                "resource_ids": resource_ids,
+                "source": source,
+                "eval": eval_data,
+                "created_at": created_at,
+                "updated_at": updated_at,
+            })
+
+            # 收集 task 用于生成 embedding(只基于 task)
+            tasks.append(row['task'])
+
+        except Exception as e:
+            print(f"⚠️  跳过无效数据 {row['id']}: {e}")
+            continue
+
+    print(f"✅ 成功转换 {len(knowledge_list)} 条知识")
+
+    # 批量生成 embeddings
+    print(f"\n🧮 正在生成 embeddings (只基于 task 字段)...")
+    batch_size = 100
+    all_embeddings = []
+
+    for i in range(0, len(tasks), batch_size):
+        batch_tasks = tasks[i:i+batch_size]
+        print(f"   处理 {i+1}-{min(i+batch_size, len(tasks))}/{len(tasks)}...")
+
+        try:
+            embeddings = await get_embeddings_batch(batch_tasks)
+            all_embeddings.extend(embeddings)
+        except Exception as e:
+            print(f"❌ 生成 embeddings 失败: {e}")
+            return
+
+    print(f"✅ 成功生成 {len(all_embeddings)} 个 embeddings")
+
+    # 添加 embeddings 到知识数据
+    for knowledge, embedding in zip(knowledge_list, all_embeddings):
+        knowledge["embedding"] = embedding
+
+    # 批量插入到 Milvus
+    print(f"\n💾 正在插入数据到 Milvus...")
+    batch_size = 100
+
+    for i in range(0, len(knowledge_list), batch_size):
+        batch = knowledge_list[i:i+batch_size]
+        try:
+            target_store.insert_batch(batch)
+            print(f"   已插入 {min(i+batch_size, len(knowledge_list))}/{len(knowledge_list)}")
+        except Exception as e:
+            print(f"❌ 插入失败: {e}")
+            print(f"   失败的批次: {i}-{i+batch_size}")
+            # 尝试逐条插入
+            for j, item in enumerate(batch):
+                try:
+                    target_store.insert(item)
+                except Exception as e2:
+                    print(f"   ⚠️  跳过 {item['id']}: {e2}")
+
+    # 验证
+    final_count = target_store.count()
+    print(f"\n✅ 迁移完成!")
+    print(f"   Milvus 中的知识总数: {final_count}")
+    print(f"   新增: {final_count - target_store.count() + len(knowledge_list)}")
+
+
+if __name__ == "__main__":
+    asyncio.run(migrate_knowledge())