vector_store.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. """
  2. Milvus Lite 存储封装
  3. 单一存储架构,存储完整知识数据 + 向量。
  4. """
  5. from milvus import default_server
  6. from pymilvus import (
  7. connections, Collection, FieldSchema,
  8. CollectionSchema, DataType, utility
  9. )
  10. from typing import List, Dict, Optional
  11. import json
  12. import time
  13. class MilvusStore:
  14. def __init__(self, data_dir: str = "./milvus_data"):
  15. """
  16. 初始化 Milvus Lite 存储
  17. Args:
  18. data_dir: 数据存储目录
  19. """
  20. # 启动内嵌服务器
  21. default_server.set_base_dir(data_dir)
  22. # 检查是否已经有 Milvus 实例在运行
  23. try:
  24. # 尝试连接到可能已存在的实例
  25. connections.connect(
  26. alias="default",
  27. host='127.0.0.1',
  28. port=default_server.listen_port,
  29. timeout=5
  30. )
  31. print(f"[Milvus] 连接到已存在的 Milvus 实例 (端口 {default_server.listen_port})")
  32. except Exception:
  33. # 没有运行的实例,启动新的
  34. print(f"[Milvus] 启动新的 Milvus Lite 实例...")
  35. try:
  36. default_server.start()
  37. print(f"[Milvus] Milvus Lite 启动成功 (端口 {default_server.listen_port})")
  38. # 启动后建立连接
  39. connections.connect(
  40. alias="default",
  41. host='127.0.0.1',
  42. port=default_server.listen_port,
  43. timeout=5
  44. )
  45. print(f"[Milvus] 已连接到新启动的实例")
  46. except Exception as e:
  47. print(f"[Milvus] 启动失败: {e}")
  48. # 尝试连接到可能已经在运行的实例
  49. try:
  50. connections.connect(
  51. alias="default",
  52. host='127.0.0.1',
  53. port=default_server.listen_port,
  54. timeout=5
  55. )
  56. print(f"[Milvus] 连接到已存在的实例")
  57. except Exception as e2:
  58. raise RuntimeError(f"无法启动或连接到 Milvus: {e}, {e2}")
  59. self._init_collection()
  60. def _init_collection(self):
  61. """初始化 collection"""
  62. collection_name = "knowledge"
  63. if utility.has_collection(collection_name):
  64. self.collection = Collection(collection_name)
  65. else:
  66. # 定义 schema
  67. fields = [
  68. FieldSchema(name="id", dtype=DataType.VARCHAR,
  69. max_length=100, is_primary=True),
  70. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR,
  71. dim=1536),
  72. FieldSchema(name="message_id", dtype=DataType.VARCHAR,
  73. max_length=100),
  74. FieldSchema(name="task", dtype=DataType.VARCHAR,
  75. max_length=2000),
  76. FieldSchema(name="content", dtype=DataType.VARCHAR,
  77. max_length=50000),
  78. FieldSchema(name="types", dtype=DataType.ARRAY,
  79. element_type=DataType.VARCHAR, max_capacity=20, max_length=50),
  80. FieldSchema(name="tags", dtype=DataType.JSON),
  81. FieldSchema(name="tag_keys", dtype=DataType.ARRAY,
  82. element_type=DataType.VARCHAR, max_capacity=50, max_length=100),
  83. FieldSchema(name="scopes", dtype=DataType.ARRAY,
  84. element_type=DataType.VARCHAR, max_capacity=20, max_length=100),
  85. FieldSchema(name="owner", dtype=DataType.VARCHAR,
  86. max_length=200),
  87. FieldSchema(name="resource_ids", dtype=DataType.ARRAY,
  88. element_type=DataType.VARCHAR, max_capacity=50, max_length=200),
  89. FieldSchema(name="source", dtype=DataType.JSON),
  90. FieldSchema(name="eval", dtype=DataType.JSON),
  91. FieldSchema(name="created_at", dtype=DataType.INT64),
  92. FieldSchema(name="updated_at", dtype=DataType.INT64),
  93. FieldSchema(name="status", dtype=DataType.VARCHAR,
  94. max_length=20, default_value="approved"),
  95. FieldSchema(name="relationships", dtype=DataType.VARCHAR,
  96. max_length=10000, default_value="[]"),
  97. ]
  98. schema = CollectionSchema(fields, description="KnowHub Knowledge")
  99. self.collection = Collection(collection_name, schema)
  100. # 创建向量索引
  101. index_params = {
  102. "metric_type": "COSINE",
  103. "index_type": "HNSW",
  104. "params": {"M": 16, "efConstruction": 200}
  105. }
  106. self.collection.create_index("embedding", index_params)
  107. # 为 status 创建 Trie 标量索引(加速过滤)
  108. try:
  109. self.collection.create_index("status", {"index_type": "Trie"})
  110. except Exception:
  111. pass
  112. self.collection.load()
  113. def insert(self, knowledge: Dict):
  114. """
  115. 插入单条知识
  116. Args:
  117. knowledge: 知识数据(包含 embedding)
  118. """
  119. self.collection.insert([knowledge])
  120. self.collection.flush()
  121. def insert_batch(self, knowledge_list: List[Dict]):
  122. """
  123. 批量插入知识
  124. Args:
  125. knowledge_list: 知识列表
  126. """
  127. if not knowledge_list:
  128. return
  129. self.collection.insert(knowledge_list)
  130. self.collection.flush()
  131. def search(self,
  132. query_embedding: List[float],
  133. filters: Optional[str] = None,
  134. limit: int = 10) -> List[Dict]:
  135. """
  136. 向量检索 + 标量过滤
  137. Args:
  138. query_embedding: 查询向量
  139. filters: 过滤表达式(如: 'owner == "agent"')
  140. limit: 返回数量
  141. Returns:
  142. 知识列表
  143. """
  144. search_params = {"metric_type": "COSINE", "params": {"ef": 100}}
  145. results = self.collection.search(
  146. data=[query_embedding],
  147. anns_field="embedding",
  148. param=search_params,
  149. limit=limit,
  150. expr=filters,
  151. output_fields=["id", "message_id", "task", "content", "types",
  152. "tags", "tag_keys", "scopes", "owner", "resource_ids",
  153. "source", "eval", "created_at", "updated_at",
  154. "status", "relationships"]
  155. )
  156. if not results or not results[0]:
  157. return []
  158. # 返回实体字典,包含所有字段
  159. # 注意:时间戳需要转换为毫秒(JavaScript Date 需要)
  160. return [
  161. {
  162. "id": hit.entity.get("id"),
  163. "message_id": hit.entity.get("message_id"),
  164. "task": hit.entity.get("task"),
  165. "content": hit.entity.get("content"),
  166. "types": list(hit.entity.get("types")) if hit.entity.get("types") else [],
  167. "tags": hit.entity.get("tags"),
  168. "tag_keys": list(hit.entity.get("tag_keys")) if hit.entity.get("tag_keys") else [],
  169. "scopes": list(hit.entity.get("scopes")) if hit.entity.get("scopes") else [],
  170. "owner": hit.entity.get("owner"),
  171. "resource_ids": list(hit.entity.get("resource_ids")) if hit.entity.get("resource_ids") else [],
  172. "source": hit.entity.get("source"),
  173. "eval": hit.entity.get("eval"),
  174. "created_at": hit.entity.get("created_at") * 1000 if hit.entity.get("created_at") else None,
  175. "updated_at": hit.entity.get("updated_at") * 1000 if hit.entity.get("updated_at") else None,
  176. "status": hit.entity.get("status", "approved"),
  177. "relationships": json.loads(hit.entity.get("relationships") or "[]"),
  178. "score": hit.score,
  179. }
  180. for hit in results[0]
  181. ]
  182. def query(self, filters: str, limit: int = 100) -> List[Dict]:
  183. """
  184. 纯标量查询(不使用向量)
  185. Args:
  186. filters: 过滤表达式
  187. limit: 返回数量
  188. Returns:
  189. 知识列表
  190. """
  191. results = self.collection.query(
  192. expr=filters,
  193. output_fields=["id", "message_id", "task", "content", "types",
  194. "tags", "tag_keys", "scopes", "owner", "resource_ids",
  195. "source", "eval", "created_at", "updated_at",
  196. "status", "relationships"],
  197. limit=limit
  198. )
  199. # 转换时间戳为毫秒,确保数组字段格式正确
  200. for r in results:
  201. if r.get("created_at"):
  202. r["created_at"] = r["created_at"] * 1000
  203. if r.get("updated_at"):
  204. r["updated_at"] = r["updated_at"] * 1000
  205. # 确保数组字段是列表格式
  206. if r.get("types") and not isinstance(r["types"], list):
  207. r["types"] = list(r["types"])
  208. if r.get("tag_keys") and not isinstance(r["tag_keys"], list):
  209. r["tag_keys"] = list(r["tag_keys"])
  210. if r.get("scopes") and not isinstance(r["scopes"], list):
  211. r["scopes"] = list(r["scopes"])
  212. if r.get("resource_ids") and not isinstance(r["resource_ids"], list):
  213. r["resource_ids"] = list(r["resource_ids"])
  214. # 兼容旧数据(无 status/relationships 字段)
  215. if "status" not in r:
  216. r["status"] = "approved"
  217. if "relationships" not in r or r["relationships"] is None:
  218. r["relationships"] = []
  219. else:
  220. r["relationships"] = json.loads(r["relationships"]) if isinstance(r["relationships"], str) else r["relationships"]
  221. return results
  222. def get_by_id(self, knowledge_id: str) -> Optional[Dict]:
  223. """
  224. 根据 ID 获取知识
  225. Args:
  226. knowledge_id: 知识 ID
  227. Returns:
  228. 知识数据,不存在返回 None
  229. """
  230. results = self.collection.query(
  231. expr=f'id == "{knowledge_id}"',
  232. output_fields=["id", "embedding", "message_id", "task", "content", "types",
  233. "tags", "tag_keys", "scopes", "owner", "resource_ids",
  234. "source", "eval", "created_at", "updated_at",
  235. "status", "relationships"]
  236. )
  237. if not results:
  238. return None
  239. # 转换时间戳和数组字段
  240. r = results[0]
  241. if r.get("created_at"):
  242. r["created_at"] = r["created_at"] * 1000
  243. if r.get("updated_at"):
  244. r["updated_at"] = r["updated_at"] * 1000
  245. if r.get("types") and not isinstance(r["types"], list):
  246. r["types"] = list(r["types"])
  247. if r.get("tag_keys") and not isinstance(r["tag_keys"], list):
  248. r["tag_keys"] = list(r["tag_keys"])
  249. if r.get("scopes") and not isinstance(r["scopes"], list):
  250. r["scopes"] = list(r["scopes"])
  251. if r.get("resource_ids") and not isinstance(r["resource_ids"], list):
  252. r["resource_ids"] = list(r["resource_ids"])
  253. # 兼容旧数据
  254. if "status" not in r:
  255. r["status"] = "approved"
  256. if "relationships" not in r or r["relationships"] is None:
  257. r["relationships"] = []
  258. else:
  259. r["relationships"] = json.loads(r["relationships"]) if isinstance(r["relationships"], str) else r["relationships"]
  260. return r
  261. def update(self, knowledge_id: str, updates: Dict):
  262. """
  263. 更新知识(先删除再插入)
  264. Args:
  265. knowledge_id: 知识 ID
  266. updates: 更新字段
  267. """
  268. # 1. 查询现有数据
  269. existing = self.get_by_id(knowledge_id)
  270. if not existing:
  271. raise ValueError(f"Knowledge not found: {knowledge_id}")
  272. # 2. 合并更新
  273. existing.update(updates)
  274. existing["updated_at"] = int(time.time())
  275. # 3. 还原 get_by_id 的展示层转换,确保存储格式正确
  276. # created_at 被 get_by_id 乘以 1000(毫秒),需还原为秒
  277. if existing.get("created_at") and existing["created_at"] > 1_000_000_000_000:
  278. existing["created_at"] = existing["created_at"] // 1000
  279. # relationships 被 get_by_id 反序列化为 list,需还原为 JSON 字符串
  280. if isinstance(existing.get("relationships"), list):
  281. existing["relationships"] = json.dumps(existing["relationships"])
  282. # 4. 删除旧数据
  283. self.delete(knowledge_id)
  284. # 5. 插入新数据
  285. self.insert(existing)
  286. def delete(self, knowledge_id: str):
  287. """
  288. 删除知识
  289. Args:
  290. knowledge_id: 知识 ID
  291. """
  292. self.collection.delete(f'id == "{knowledge_id}"')
  293. self.collection.flush()
  294. def count(self) -> int:
  295. """返回知识总数"""
  296. return self.collection.num_entities
  297. def drop_collection(self):
  298. """删除 collection(危险操作)"""
  299. utility.drop_collection("knowledge")
  300. def migrate_schema(self) -> int:
  301. """
  302. 将旧 collection(无 status/relationships 字段)迁移到新 schema。
  303. 采用中转 collection 模式,Step 3 之前数据始终有两份副本。
  304. 返回迁移的知识条数。
  305. """
  306. MIGRATION_NAME = "knowledge_migration"
  307. MAIN_NAME = "knowledge"
  308. # 如果中转 collection 已存在(上次迁移中途失败),先清理
  309. if utility.has_collection(MIGRATION_NAME):
  310. print(f"[Migrate] 检测到残留中转 collection,清理...")
  311. utility.drop_collection(MIGRATION_NAME)
  312. # Step 1: 创建中转 collection(新 schema)
  313. print(f"[Migrate] Step 1: 创建中转 collection {MIGRATION_NAME}...")
  314. fields = [
  315. FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=100, is_primary=True),
  316. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1536),
  317. FieldSchema(name="message_id", dtype=DataType.VARCHAR, max_length=100),
  318. FieldSchema(name="task", dtype=DataType.VARCHAR, max_length=2000),
  319. FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=50000),
  320. FieldSchema(name="types", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=20, max_length=50),
  321. FieldSchema(name="tags", dtype=DataType.JSON),
  322. FieldSchema(name="tag_keys", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=50, max_length=100),
  323. FieldSchema(name="scopes", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=20, max_length=100),
  324. FieldSchema(name="owner", dtype=DataType.VARCHAR, max_length=200),
  325. FieldSchema(name="resource_ids", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=50, max_length=200),
  326. FieldSchema(name="source", dtype=DataType.JSON),
  327. FieldSchema(name="eval", dtype=DataType.JSON),
  328. FieldSchema(name="created_at", dtype=DataType.INT64),
  329. FieldSchema(name="updated_at", dtype=DataType.INT64),
  330. FieldSchema(name="status", dtype=DataType.VARCHAR, max_length=20, default_value="approved"),
  331. FieldSchema(name="relationships", dtype=DataType.VARCHAR, max_length=10000, default_value="[]"),
  332. ]
  333. schema = CollectionSchema(fields, description="KnowHub Knowledge")
  334. migration_col = Collection(MIGRATION_NAME, schema)
  335. migration_col.create_index("embedding", {"metric_type": "COSINE", "index_type": "HNSW", "params": {"M": 16, "efConstruction": 200}})
  336. try:
  337. migration_col.create_index("status", {"index_type": "Trie"})
  338. except Exception:
  339. pass
  340. migration_col.load()
  341. # Step 2: 从旧 collection 逐批读取,补字段,插入中转
  342. print(f"[Migrate] Step 2: 读取旧数据并插入中转 collection...")
  343. batch_size = 200
  344. offset = 0
  345. total = 0
  346. while True:
  347. batch = self.collection.query(
  348. expr='id != ""',
  349. output_fields=["id", "embedding", "message_id", "task", "content", "types",
  350. "tags", "tag_keys", "scopes", "owner", "resource_ids",
  351. "source", "eval", "created_at", "updated_at"],
  352. limit=batch_size,
  353. offset=offset
  354. )
  355. if not batch:
  356. break
  357. for item in batch:
  358. item["status"] = item.get("status", "approved")
  359. item["relationships"] = item.get("relationships") or []
  360. # 时间戳已是秒级(query 返回原始值,未乘 1000)
  361. migration_col.insert(batch)
  362. migration_col.flush()
  363. total += len(batch)
  364. offset += len(batch)
  365. print(f"[Migrate] 已迁移 {total} 条...")
  366. if len(batch) < batch_size:
  367. break
  368. # Step 3: drop 旧 collection
  369. print(f"[Migrate] Step 3: drop 旧 collection {MAIN_NAME}...")
  370. self.collection.release()
  371. utility.drop_collection(MAIN_NAME)
  372. # Step 4: 创建新 collection(同名,新 schema)
  373. print(f"[Migrate] Step 4: 创建新 collection {MAIN_NAME}...")
  374. new_col = Collection(MAIN_NAME, schema)
  375. new_col.create_index("embedding", {"metric_type": "COSINE", "index_type": "HNSW", "params": {"M": 16, "efConstruction": 200}})
  376. try:
  377. new_col.create_index("status", {"index_type": "Trie"})
  378. except Exception:
  379. pass
  380. new_col.load()
  381. # Step 5: 从中转 collection 读取,插入新 collection
  382. print(f"[Migrate] Step 5: 从中转 collection 回写到新 collection...")
  383. offset = 0
  384. while True:
  385. batch = migration_col.query(
  386. expr='id != ""',
  387. output_fields=["id", "embedding", "message_id", "task", "content", "types",
  388. "tags", "tag_keys", "scopes", "owner", "resource_ids",
  389. "source", "eval", "created_at", "updated_at",
  390. "status", "relationships"],
  391. limit=batch_size,
  392. offset=offset
  393. )
  394. if not batch:
  395. break
  396. new_col.insert(batch)
  397. new_col.flush()
  398. offset += len(batch)
  399. if len(batch) < batch_size:
  400. break
  401. # Step 6: drop 中转 collection
  402. print(f"[Migrate] Step 6: drop 中转 collection {MIGRATION_NAME}...")
  403. migration_col.release()
  404. utility.drop_collection(MIGRATION_NAME)
  405. # Step 7: 更新 self.collection 引用
  406. print(f"[Migrate] Step 7: 更新 collection 引用...")
  407. self.collection = new_col
  408. print(f"[Migrate] 迁移完成,共迁移 {total} 条知识。")
  409. return total