vector_store.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. """
  2. Milvus Lite 存储封装
  3. 单一存储架构,存储完整知识数据 + 向量。
  4. """
  5. from milvus import default_server
  6. from pymilvus import (
  7. connections, Collection, FieldSchema,
  8. CollectionSchema, DataType, utility
  9. )
  10. from typing import List, Dict, Optional
  11. import json
  12. import time
  13. class MilvusStore:
  14. def __init__(self, data_dir: str = "./milvus_data"):
  15. """
  16. 初始化 Milvus Lite 存储
  17. Args:
  18. data_dir: 数据存储目录
  19. """
  20. # 启动内嵌服务器
  21. default_server.set_base_dir(data_dir)
  22. # 检查是否已经有 Milvus 实例在运行
  23. try:
  24. # 尝试连接到可能已存在的实例
  25. connections.connect(
  26. alias="default",
  27. host='127.0.0.1',
  28. port=default_server.listen_port,
  29. timeout=5
  30. )
  31. print(f"[Milvus] 连接到已存在的 Milvus 实例 (端口 {default_server.listen_port})")
  32. except Exception:
  33. # 没有运行的实例,启动新的
  34. print(f"[Milvus] 启动新的 Milvus Lite 实例...")
  35. try:
  36. default_server.start()
  37. print(f"[Milvus] Milvus Lite 启动成功 (端口 {default_server.listen_port})")
  38. except Exception as e:
  39. print(f"[Milvus] 启动失败: {e}")
  40. # 尝试连接到可能已经在运行的实例
  41. try:
  42. connections.connect(
  43. alias="default",
  44. host='127.0.0.1',
  45. port=default_server.listen_port,
  46. timeout=5
  47. )
  48. print(f"[Milvus] 连接到已存在的实例")
  49. except Exception as e2:
  50. raise RuntimeError(f"无法启动或连接到 Milvus: {e}, {e2}")
  51. self._init_collection()
  52. def _init_collection(self):
  53. """初始化 collection"""
  54. collection_name = "knowledge"
  55. if utility.has_collection(collection_name):
  56. self.collection = Collection(collection_name)
  57. else:
  58. # 定义 schema
  59. fields = [
  60. FieldSchema(name="id", dtype=DataType.VARCHAR,
  61. max_length=100, is_primary=True),
  62. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR,
  63. dim=1536),
  64. FieldSchema(name="message_id", dtype=DataType.VARCHAR,
  65. max_length=100),
  66. FieldSchema(name="task", dtype=DataType.VARCHAR,
  67. max_length=2000),
  68. FieldSchema(name="content", dtype=DataType.VARCHAR,
  69. max_length=50000),
  70. FieldSchema(name="types", dtype=DataType.ARRAY,
  71. element_type=DataType.VARCHAR, max_capacity=20, max_length=50),
  72. FieldSchema(name="tags", dtype=DataType.JSON),
  73. FieldSchema(name="scopes", dtype=DataType.ARRAY,
  74. element_type=DataType.VARCHAR, max_capacity=20, max_length=100),
  75. FieldSchema(name="owner", dtype=DataType.VARCHAR,
  76. max_length=200),
  77. FieldSchema(name="resource_ids", dtype=DataType.ARRAY,
  78. element_type=DataType.VARCHAR, max_capacity=50, max_length=200),
  79. FieldSchema(name="source", dtype=DataType.JSON),
  80. FieldSchema(name="eval", dtype=DataType.JSON),
  81. FieldSchema(name="created_at", dtype=DataType.INT64),
  82. FieldSchema(name="updated_at", dtype=DataType.INT64),
  83. ]
  84. schema = CollectionSchema(fields, description="KnowHub Knowledge")
  85. self.collection = Collection(collection_name, schema)
  86. # 创建向量索引
  87. index_params = {
  88. "metric_type": "COSINE",
  89. "index_type": "HNSW",
  90. "params": {"M": 16, "efConstruction": 200}
  91. }
  92. self.collection.create_index("embedding", index_params)
  93. self.collection.load()
  94. def insert(self, knowledge: Dict):
  95. """
  96. 插入单条知识
  97. Args:
  98. knowledge: 知识数据(包含 embedding)
  99. """
  100. self.collection.insert([knowledge])
  101. self.collection.flush()
  102. def insert_batch(self, knowledge_list: List[Dict]):
  103. """
  104. 批量插入知识
  105. Args:
  106. knowledge_list: 知识列表
  107. """
  108. if not knowledge_list:
  109. return
  110. self.collection.insert(knowledge_list)
  111. self.collection.flush()
  112. def search(self,
  113. query_embedding: List[float],
  114. filters: Optional[str] = None,
  115. limit: int = 10) -> List[Dict]:
  116. """
  117. 向量检索 + 标量过滤
  118. Args:
  119. query_embedding: 查询向量
  120. filters: 过滤表达式(如: 'owner == "agent"')
  121. limit: 返回数量
  122. Returns:
  123. 知识列表
  124. """
  125. search_params = {"metric_type": "COSINE", "params": {"ef": 100}}
  126. results = self.collection.search(
  127. data=[query_embedding],
  128. anns_field="embedding",
  129. param=search_params,
  130. limit=limit,
  131. expr=filters,
  132. output_fields=["id", "message_id", "task", "content", "types",
  133. "tags", "scopes", "owner", "resource_ids",
  134. "source", "eval", "created_at", "updated_at"]
  135. )
  136. if not results or not results[0]:
  137. return []
  138. return [hit.entity.to_dict() for hit in results[0]]
  139. def query(self, filters: str, limit: int = 100) -> List[Dict]:
  140. """
  141. 纯标量查询(不使用向量)
  142. Args:
  143. filters: 过滤表达式
  144. limit: 返回数量
  145. Returns:
  146. 知识列表
  147. """
  148. results = self.collection.query(
  149. expr=filters,
  150. output_fields=["id", "message_id", "task", "content", "types",
  151. "tags", "scopes", "owner", "resource_ids",
  152. "source", "eval", "created_at", "updated_at"],
  153. limit=limit
  154. )
  155. return results
  156. def get_by_id(self, knowledge_id: str) -> Optional[Dict]:
  157. """
  158. 根据 ID 获取知识
  159. Args:
  160. knowledge_id: 知识 ID
  161. Returns:
  162. 知识数据,不存在返回 None
  163. """
  164. results = self.collection.query(
  165. expr=f'id == "{knowledge_id}"',
  166. output_fields=["id", "message_id", "task", "content", "types",
  167. "tags", "scopes", "owner", "resource_ids",
  168. "source", "eval", "created_at", "updated_at"]
  169. )
  170. return results[0] if results else None
  171. def update(self, knowledge_id: str, updates: Dict):
  172. """
  173. 更新知识(先删除再插入)
  174. Args:
  175. knowledge_id: 知识 ID
  176. updates: 更新字段
  177. """
  178. # 1. 查询现有数据
  179. existing = self.get_by_id(knowledge_id)
  180. if not existing:
  181. raise ValueError(f"Knowledge not found: {knowledge_id}")
  182. # 2. 合并更新
  183. existing.update(updates)
  184. existing["updated_at"] = int(time.time())
  185. # 3. 删除旧数据
  186. self.delete(knowledge_id)
  187. # 4. 插入新数据
  188. self.insert(existing)
  189. def delete(self, knowledge_id: str):
  190. """
  191. 删除知识
  192. Args:
  193. knowledge_id: 知识 ID
  194. """
  195. self.collection.delete(f'id == "{knowledge_id}"')
  196. self.collection.flush()
  197. def count(self) -> int:
  198. """返回知识总数"""
  199. return self.collection.num_entities
  200. def drop_collection(self):
  201. """删除 collection(危险操作)"""
  202. utility.drop_collection("knowledge")