vector_store.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. """
  2. Milvus Lite 存储封装
  3. 单一存储架构,存储完整知识数据 + 向量。
  4. """
  5. from milvus import default_server
  6. from pymilvus import (
  7. connections, Collection, FieldSchema,
  8. CollectionSchema, DataType, utility
  9. )
  10. from typing import List, Dict, Optional
  11. import json
  12. import time
  13. class MilvusStore:
  14. def __init__(self, data_dir: str = "./milvus_data"):
  15. """
  16. 初始化 Milvus Lite 存储
  17. Args:
  18. data_dir: 数据存储目录
  19. """
  20. # 启动内嵌服务器
  21. default_server.set_base_dir(data_dir)
  22. default_server.start()
  23. # 连接
  24. connections.connect(
  25. host='127.0.0.1',
  26. port=default_server.listen_port
  27. )
  28. self._init_collection()
  29. def _init_collection(self):
  30. """初始化 collection"""
  31. collection_name = "knowledge"
  32. if utility.has_collection(collection_name):
  33. self.collection = Collection(collection_name)
  34. else:
  35. # 定义 schema
  36. fields = [
  37. FieldSchema(name="id", dtype=DataType.VARCHAR,
  38. max_length=100, is_primary=True),
  39. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR,
  40. dim=1536),
  41. FieldSchema(name="message_id", dtype=DataType.VARCHAR,
  42. max_length=100),
  43. FieldSchema(name="task", dtype=DataType.VARCHAR,
  44. max_length=2000),
  45. FieldSchema(name="content", dtype=DataType.VARCHAR,
  46. max_length=50000),
  47. FieldSchema(name="types", dtype=DataType.JSON),
  48. FieldSchema(name="tags", dtype=DataType.JSON),
  49. FieldSchema(name="scopes", dtype=DataType.JSON),
  50. FieldSchema(name="owner", dtype=DataType.VARCHAR,
  51. max_length=200),
  52. FieldSchema(name="resource_ids", dtype=DataType.JSON),
  53. FieldSchema(name="source", dtype=DataType.JSON),
  54. FieldSchema(name="eval", dtype=DataType.JSON),
  55. FieldSchema(name="created_at", dtype=DataType.INT64),
  56. FieldSchema(name="updated_at", dtype=DataType.INT64),
  57. ]
  58. schema = CollectionSchema(fields, description="KnowHub Knowledge")
  59. self.collection = Collection(collection_name, schema)
  60. # 创建向量索引
  61. index_params = {
  62. "metric_type": "COSINE",
  63. "index_type": "HNSW",
  64. "params": {"M": 16, "efConstruction": 200}
  65. }
  66. self.collection.create_index("embedding", index_params)
  67. self.collection.load()
  68. def insert(self, knowledge: Dict):
  69. """
  70. 插入单条知识
  71. Args:
  72. knowledge: 知识数据(包含 embedding)
  73. """
  74. self.collection.insert([knowledge])
  75. self.collection.flush()
  76. def insert_batch(self, knowledge_list: List[Dict]):
  77. """
  78. 批量插入知识
  79. Args:
  80. knowledge_list: 知识列表
  81. """
  82. if not knowledge_list:
  83. return
  84. self.collection.insert(knowledge_list)
  85. self.collection.flush()
  86. def search(self,
  87. query_embedding: List[float],
  88. filters: Optional[str] = None,
  89. limit: int = 10) -> List[Dict]:
  90. """
  91. 向量检索 + 标量过滤
  92. Args:
  93. query_embedding: 查询向量
  94. filters: 过滤表达式(如: 'owner == "agent"')
  95. limit: 返回数量
  96. Returns:
  97. 知识列表
  98. """
  99. search_params = {"metric_type": "COSINE", "params": {"ef": 100}}
  100. results = self.collection.search(
  101. data=[query_embedding],
  102. anns_field="embedding",
  103. param=search_params,
  104. limit=limit,
  105. expr=filters,
  106. output_fields=["id", "message_id", "task", "content", "types",
  107. "tags", "scopes", "owner", "resource_ids",
  108. "source", "eval", "created_at", "updated_at"]
  109. )
  110. if not results or not results[0]:
  111. return []
  112. return [hit.entity.to_dict() for hit in results[0]]
  113. def query(self, filters: str, limit: int = 100) -> List[Dict]:
  114. """
  115. 纯标量查询(不使用向量)
  116. Args:
  117. filters: 过滤表达式
  118. limit: 返回数量
  119. Returns:
  120. 知识列表
  121. """
  122. results = self.collection.query(
  123. expr=filters,
  124. output_fields=["id", "message_id", "task", "content", "types",
  125. "tags", "scopes", "owner", "resource_ids",
  126. "source", "eval", "created_at", "updated_at"],
  127. limit=limit
  128. )
  129. return results
  130. def get_by_id(self, knowledge_id: str) -> Optional[Dict]:
  131. """
  132. 根据 ID 获取知识
  133. Args:
  134. knowledge_id: 知识 ID
  135. Returns:
  136. 知识数据,不存在返回 None
  137. """
  138. results = self.collection.query(
  139. expr=f'id == "{knowledge_id}"',
  140. output_fields=["id", "message_id", "task", "content", "types",
  141. "tags", "scopes", "owner", "resource_ids",
  142. "source", "eval", "created_at", "updated_at"]
  143. )
  144. return results[0] if results else None
  145. def update(self, knowledge_id: str, updates: Dict):
  146. """
  147. 更新知识(先删除再插入)
  148. Args:
  149. knowledge_id: 知识 ID
  150. updates: 更新字段
  151. """
  152. # 1. 查询现有数据
  153. existing = self.get_by_id(knowledge_id)
  154. if not existing:
  155. raise ValueError(f"Knowledge not found: {knowledge_id}")
  156. # 2. 合并更新
  157. existing.update(updates)
  158. existing["updated_at"] = int(time.time())
  159. # 3. 删除旧数据
  160. self.delete(knowledge_id)
  161. # 4. 插入新数据
  162. self.insert(existing)
  163. def delete(self, knowledge_id: str):
  164. """
  165. 删除知识
  166. Args:
  167. knowledge_id: 知识 ID
  168. """
  169. self.collection.delete(f'id == "{knowledge_id}"')
  170. self.collection.flush()
  171. def count(self) -> int:
  172. """返回知识总数"""
  173. return self.collection.num_entities
  174. def drop_collection(self):
  175. """删除 collection(危险操作)"""
  176. utility.drop_collection("knowledge")