|
@@ -0,0 +1,165 @@
|
|
|
+import asyncio
|
|
|
+from typing import List, Optional, Dict, Any, Union
|
|
|
+
|
|
|
+
|
|
|
+class MilvusBase:
|
|
|
+
|
|
|
+ output_fields = [
|
|
|
+ "doc_id",
|
|
|
+ "chunk_id",
|
|
|
+ # "summary",
|
|
|
+ # "topic",
|
|
|
+ # "domain",
|
|
|
+ # "task_type",
|
|
|
+ # "keywords",
|
|
|
+ # "concepts",
|
|
|
+ # "questions",
|
|
|
+ # "entities",
|
|
|
+ # "tokens",
|
|
|
+ # "topic_purity",
|
|
|
+ ]
|
|
|
+
|
|
|
+ def __init__(self, milvus_pool):
|
|
|
+ self.milvus_pool = milvus_pool
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def hits_to_json(hits):
|
|
|
+ if not hits:
|
|
|
+ return []
|
|
|
+
|
|
|
+ special_keys = {"entities", "concepts", "questions", "keywords"}
|
|
|
+ return [
|
|
|
+ {
|
|
|
+ "pk": hit.id,
|
|
|
+ "score": hit.distance,
|
|
|
+ **{
|
|
|
+ key: list(value) if key in special_keys else value
|
|
|
+ for key, value in (hit.get("entity", {}) or {}).items()
|
|
|
+ },
|
|
|
+ }
|
|
|
+ for hit in hits[0]
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+class MilvusSearch(MilvusBase):
|
|
|
+
|
|
|
+ # 通过向量匹配
|
|
|
+ async def vector_search(
|
|
|
+ self,
|
|
|
+ query_vec: List[float],
|
|
|
+ anns_field: str = "vector_text",
|
|
|
+ limit: int = 5,
|
|
|
+ expr: Optional[str] = None,
|
|
|
+ search_params: Optional[Dict[str, Any]] = None,
|
|
|
+ ):
|
|
|
+ """向量搜索,可选过滤"""
|
|
|
+ if search_params is None:
|
|
|
+ search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
|
|
|
+
|
|
|
+ response = await asyncio.to_thread(
|
|
|
+ self.milvus_pool.search,
|
|
|
+ data=[query_vec],
|
|
|
+ anns_field=anns_field,
|
|
|
+ param=search_params,
|
|
|
+ limit=limit,
|
|
|
+ expr=expr,
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+ return {"results": self.hits_to_json(response)}
|
|
|
+
|
|
|
+ # 混合搜索(向量 + metadata)
|
|
|
+ async def hybrid_search(
|
|
|
+ self,
|
|
|
+ query_vec: List[float],
|
|
|
+ anns_field: str = "vector_text",
|
|
|
+ limit: int = 5,
|
|
|
+ filters: Optional[Dict[str, Union[str, int, float]]] = None,
|
|
|
+ ):
|
|
|
+ expr = None
|
|
|
+ if filters:
|
|
|
+ parts = []
|
|
|
+ for k, v in filters.items():
|
|
|
+ if isinstance(v, str):
|
|
|
+ parts.append(f'{k} == "{v}"')
|
|
|
+ else:
|
|
|
+ parts.append(f"{k} == {v}")
|
|
|
+ expr = " and ".join(parts)
|
|
|
+
|
|
|
+ response = await self.vector_search(
|
|
|
+ query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
|
|
|
+ )
|
|
|
+ return self.hits_to_json(response)
|
|
|
+
|
|
|
+ async def search_by_strategy(
|
|
|
+ self,
|
|
|
+ query_vec: List[float],
|
|
|
+ weight_map: Dict,
|
|
|
+ limit: int = 5,
|
|
|
+ expr: Optional[str] = None,
|
|
|
+ search_params: Optional[Dict[str, Any]] = None,
|
|
|
+ ):
|
|
|
+ async def _sub_search(vec, field):
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ self.milvus_pool.search,
|
|
|
+ data=[vec],
|
|
|
+ anns_field=field,
|
|
|
+ param={"metric_type": "COSINE", "params": {"ef": 64}},
|
|
|
+ limit=limit,
|
|
|
+ expr=expr,
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+
|
|
|
+ tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
|
|
|
+ results = await asyncio.gather(*tasks.values())
|
|
|
+
|
|
|
+ scores = {}
|
|
|
+ for (field, weight), res in zip(weight_map.items(), results):
|
|
|
+ for hit in res[0]:
|
|
|
+ key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
|
|
|
+ sim_score = 1 - hit.distance
|
|
|
+ scores[key] = scores.get(key, 0) + weight * sim_score
|
|
|
+
|
|
|
+ ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
|
|
|
+
|
|
|
+ return [
|
|
|
+ {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
|
|
|
+ for k, v in ranked
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+class MilvusQuery(MilvusBase):
|
|
|
+ # 通过doc_id + chunk_id 获取数据
|
|
|
+ async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
|
|
|
+ expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
|
|
|
+ response = await asyncio.to_thread(
|
|
|
+ self.milvus_pool.query,
|
|
|
+ expr=expr,
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+ return self.hits_to_json(response)
|
|
|
+
|
|
|
+ # 只按 metadata 条件查询
|
|
|
+ async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
|
|
|
+ exprs = []
|
|
|
+ for k, v in filters.items():
|
|
|
+ if isinstance(v, str):
|
|
|
+ exprs.append(f'{k} == "{v}"')
|
|
|
+ else:
|
|
|
+ exprs.append(f"{k} == {v}")
|
|
|
+ expr = " and ".join(exprs)
|
|
|
+ response = await asyncio.to_thread(
|
|
|
+ self.milvus_pool.query,
|
|
|
+ expr=expr,
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+ print(response)
|
|
|
+ return self.hits_to_json(response)
|
|
|
+
|
|
|
+ # 通过主键获取milvus数据
|
|
|
+ async def get_by_id(self, pk: int):
|
|
|
+ response = await asyncio.to_thread(
|
|
|
+ self.milvus_pool.query,
|
|
|
+ expr=f"id == {pk}",
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+ return self.hits_to_json(response)
|