vector_store.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. """
  2. Milvus Lite 存储封装
  3. 单一存储架构,存储完整知识数据 + 向量。
  4. """
  5. from milvus import default_server
  6. from pymilvus import (
  7. connections, Collection, FieldSchema,
  8. CollectionSchema, DataType, utility
  9. )
  10. from typing import List, Dict, Optional
  11. import json
  12. import time
  13. class MilvusStore:
  14. def __init__(self, data_dir: str = "./milvus_data"):
  15. """
  16. 初始化 Milvus Lite 存储
  17. Args:
  18. data_dir: 数据存储目录
  19. """
  20. # 启动内嵌服务器
  21. default_server.set_base_dir(data_dir)
  22. # 检查是否已经有 Milvus 实例在运行
  23. try:
  24. # 尝试连接到可能已存在的实例
  25. connections.connect(
  26. alias="default",
  27. host='127.0.0.1',
  28. port=default_server.listen_port,
  29. timeout=5
  30. )
  31. print(f"[Milvus] 连接到已存在的 Milvus 实例 (端口 {default_server.listen_port})")
  32. except Exception:
  33. # 没有运行的实例,启动新的
  34. print(f"[Milvus] 启动新的 Milvus Lite 实例...")
  35. try:
  36. default_server.start()
  37. print(f"[Milvus] Milvus Lite 启动成功 (端口 {default_server.listen_port})")
  38. # 启动后建立连接
  39. connections.connect(
  40. alias="default",
  41. host='127.0.0.1',
  42. port=default_server.listen_port,
  43. timeout=5
  44. )
  45. print(f"[Milvus] 已连接到新启动的实例")
  46. except Exception as e:
  47. print(f"[Milvus] 启动失败: {e}")
  48. # 尝试连接到可能已经在运行的实例
  49. try:
  50. connections.connect(
  51. alias="default",
  52. host='127.0.0.1',
  53. port=default_server.listen_port,
  54. timeout=5
  55. )
  56. print(f"[Milvus] 连接到已存在的实例")
  57. except Exception as e2:
  58. raise RuntimeError(f"无法启动或连接到 Milvus: {e}, {e2}")
  59. self._init_collection()
  60. def _init_collection(self):
  61. """初始化 collection"""
  62. collection_name = "knowledge"
  63. if utility.has_collection(collection_name):
  64. self.collection = Collection(collection_name)
  65. else:
  66. # 定义 schema
  67. fields = [
  68. FieldSchema(name="id", dtype=DataType.VARCHAR,
  69. max_length=100, is_primary=True),
  70. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR,
  71. dim=1536),
  72. FieldSchema(name="message_id", dtype=DataType.VARCHAR,
  73. max_length=100),
  74. FieldSchema(name="task", dtype=DataType.VARCHAR,
  75. max_length=2000),
  76. FieldSchema(name="content", dtype=DataType.VARCHAR,
  77. max_length=50000),
  78. FieldSchema(name="types", dtype=DataType.ARRAY,
  79. element_type=DataType.VARCHAR, max_capacity=20, max_length=50),
  80. FieldSchema(name="tags", dtype=DataType.JSON),
  81. FieldSchema(name="tag_keys", dtype=DataType.ARRAY,
  82. element_type=DataType.VARCHAR, max_capacity=50, max_length=100),
  83. FieldSchema(name="scopes", dtype=DataType.ARRAY,
  84. element_type=DataType.VARCHAR, max_capacity=20, max_length=100),
  85. FieldSchema(name="owner", dtype=DataType.VARCHAR,
  86. max_length=200),
  87. FieldSchema(name="resource_ids", dtype=DataType.ARRAY,
  88. element_type=DataType.VARCHAR, max_capacity=50, max_length=200),
  89. FieldSchema(name="source", dtype=DataType.JSON),
  90. FieldSchema(name="eval", dtype=DataType.JSON),
  91. FieldSchema(name="created_at", dtype=DataType.INT64),
  92. FieldSchema(name="updated_at", dtype=DataType.INT64),
  93. ]
  94. schema = CollectionSchema(fields, description="KnowHub Knowledge")
  95. self.collection = Collection(collection_name, schema)
  96. # 创建向量索引
  97. index_params = {
  98. "metric_type": "COSINE",
  99. "index_type": "HNSW",
  100. "params": {"M": 16, "efConstruction": 200}
  101. }
  102. self.collection.create_index("embedding", index_params)
  103. self.collection.load()
  104. def insert(self, knowledge: Dict):
  105. """
  106. 插入单条知识
  107. Args:
  108. knowledge: 知识数据(包含 embedding)
  109. """
  110. self.collection.insert([knowledge])
  111. self.collection.flush()
  112. def insert_batch(self, knowledge_list: List[Dict]):
  113. """
  114. 批量插入知识
  115. Args:
  116. knowledge_list: 知识列表
  117. """
  118. if not knowledge_list:
  119. return
  120. self.collection.insert(knowledge_list)
  121. self.collection.flush()
  122. def search(self,
  123. query_embedding: List[float],
  124. filters: Optional[str] = None,
  125. limit: int = 10) -> List[Dict]:
  126. """
  127. 向量检索 + 标量过滤
  128. Args:
  129. query_embedding: 查询向量
  130. filters: 过滤表达式(如: 'owner == "agent"')
  131. limit: 返回数量
  132. Returns:
  133. 知识列表
  134. """
  135. search_params = {"metric_type": "COSINE", "params": {"ef": 100}}
  136. results = self.collection.search(
  137. data=[query_embedding],
  138. anns_field="embedding",
  139. param=search_params,
  140. limit=limit,
  141. expr=filters,
  142. output_fields=["id", "message_id", "task", "content", "types",
  143. "tags", "tag_keys", "scopes", "owner", "resource_ids",
  144. "source", "eval", "created_at", "updated_at"]
  145. )
  146. if not results or not results[0]:
  147. return []
  148. # 返回实体字典,包含所有字段
  149. # 注意:时间戳需要转换为毫秒(JavaScript Date 需要)
  150. return [
  151. {
  152. "id": hit.entity.get("id"),
  153. "message_id": hit.entity.get("message_id"),
  154. "task": hit.entity.get("task"),
  155. "content": hit.entity.get("content"),
  156. "types": list(hit.entity.get("types")) if hit.entity.get("types") else [],
  157. "tags": hit.entity.get("tags"),
  158. "tag_keys": list(hit.entity.get("tag_keys")) if hit.entity.get("tag_keys") else [],
  159. "scopes": list(hit.entity.get("scopes")) if hit.entity.get("scopes") else [],
  160. "owner": hit.entity.get("owner"),
  161. "resource_ids": list(hit.entity.get("resource_ids")) if hit.entity.get("resource_ids") else [],
  162. "source": hit.entity.get("source"),
  163. "eval": hit.entity.get("eval"),
  164. "created_at": hit.entity.get("created_at") * 1000 if hit.entity.get("created_at") else None,
  165. "updated_at": hit.entity.get("updated_at") * 1000 if hit.entity.get("updated_at") else None,
  166. }
  167. for hit in results[0]
  168. ]
  169. def query(self, filters: str, limit: int = 100) -> List[Dict]:
  170. """
  171. 纯标量查询(不使用向量)
  172. Args:
  173. filters: 过滤表达式
  174. limit: 返回数量
  175. Returns:
  176. 知识列表
  177. """
  178. results = self.collection.query(
  179. expr=filters,
  180. output_fields=["id", "message_id", "task", "content", "types",
  181. "tags", "tag_keys", "scopes", "owner", "resource_ids",
  182. "source", "eval", "created_at", "updated_at"],
  183. limit=limit
  184. )
  185. # 转换时间戳为毫秒,确保数组字段格式正确
  186. for r in results:
  187. if r.get("created_at"):
  188. r["created_at"] = r["created_at"] * 1000
  189. if r.get("updated_at"):
  190. r["updated_at"] = r["updated_at"] * 1000
  191. # 确保数组字段是列表格式
  192. if r.get("types") and not isinstance(r["types"], list):
  193. r["types"] = list(r["types"])
  194. if r.get("tag_keys") and not isinstance(r["tag_keys"], list):
  195. r["tag_keys"] = list(r["tag_keys"])
  196. if r.get("scopes") and not isinstance(r["scopes"], list):
  197. r["scopes"] = list(r["scopes"])
  198. if r.get("resource_ids") and not isinstance(r["resource_ids"], list):
  199. r["resource_ids"] = list(r["resource_ids"])
  200. return results
  201. def get_by_id(self, knowledge_id: str) -> Optional[Dict]:
  202. """
  203. 根据 ID 获取知识
  204. Args:
  205. knowledge_id: 知识 ID
  206. Returns:
  207. 知识数据,不存在返回 None
  208. """
  209. results = self.collection.query(
  210. expr=f'id == "{knowledge_id}"',
  211. output_fields=["id", "message_id", "task", "content", "types",
  212. "tags", "tag_keys", "scopes", "owner", "resource_ids",
  213. "source", "eval", "created_at", "updated_at"]
  214. )
  215. if not results:
  216. return None
  217. # 转换时间戳和数组字段
  218. r = results[0]
  219. if r.get("created_at"):
  220. r["created_at"] = r["created_at"] * 1000
  221. if r.get("updated_at"):
  222. r["updated_at"] = r["updated_at"] * 1000
  223. if r.get("types") and not isinstance(r["types"], list):
  224. r["types"] = list(r["types"])
  225. if r.get("tag_keys") and not isinstance(r["tag_keys"], list):
  226. r["tag_keys"] = list(r["tag_keys"])
  227. if r.get("scopes") and not isinstance(r["scopes"], list):
  228. r["scopes"] = list(r["scopes"])
  229. if r.get("resource_ids") and not isinstance(r["resource_ids"], list):
  230. r["resource_ids"] = list(r["resource_ids"])
  231. return r
  232. def update(self, knowledge_id: str, updates: Dict):
  233. """
  234. 更新知识(先删除再插入)
  235. Args:
  236. knowledge_id: 知识 ID
  237. updates: 更新字段
  238. """
  239. # 1. 查询现有数据
  240. existing = self.get_by_id(knowledge_id)
  241. if not existing:
  242. raise ValueError(f"Knowledge not found: {knowledge_id}")
  243. # 2. 合并更新
  244. existing.update(updates)
  245. existing["updated_at"] = int(time.time())
  246. # 3. 删除旧数据
  247. self.delete(knowledge_id)
  248. # 4. 插入新数据
  249. self.insert(existing)
  250. def delete(self, knowledge_id: str):
  251. """
  252. 删除知识
  253. Args:
  254. knowledge_id: 知识 ID
  255. """
  256. self.collection.delete(f'id == "{knowledge_id}"')
  257. self.collection.flush()
  258. def count(self) -> int:
  259. """返回知识总数"""
  260. return self.collection.num_entities
  261. def drop_collection(self):
  262. """删除 collection(危险操作)"""
  263. utility.drop_collection("knowledge")