|
@@ -0,0 +1,135 @@
|
|
|
+import asyncio
|
|
|
+from typing import List, Optional, Dict, Any, Union
|
|
|
+
|
|
|
+
|
|
|
+class MilvusSearcher:
|
|
|
+
|
|
|
+ output_fields = [
|
|
|
+ "doc_id",
|
|
|
+ "chunk_id",
|
|
|
+ "summary",
|
|
|
+ "topic",
|
|
|
+ "domain",
|
|
|
+ "task_type",
|
|
|
+ "keywords",
|
|
|
+ "concepts",
|
|
|
+ "questions",
|
|
|
+ "entities",
|
|
|
+ "tokens",
|
|
|
+ "topic_purity",
|
|
|
+ ]
|
|
|
+
|
|
|
+ def __init__(self, milvus_pool):
|
|
|
+ self.milvus_pool = milvus_pool
|
|
|
+
|
|
|
+ # 通过向量匹配
|
|
|
+ async def vector_search(
|
|
|
+ self,
|
|
|
+ query_vec: List[float],
|
|
|
+ anns_field: str = "vector_text",
|
|
|
+ limit: int = 5,
|
|
|
+ expr: Optional[str] = None,
|
|
|
+ search_params: Optional[Dict[str, Any]] = None,
|
|
|
+ ):
|
|
|
+ """向量搜索,可选过滤"""
|
|
|
+ if search_params is None:
|
|
|
+ search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ self.milvus_pool.search,
|
|
|
+ data=[query_vec],
|
|
|
+ anns_field=anns_field,
|
|
|
+ param=search_params,
|
|
|
+ limit=limit,
|
|
|
+ expr=expr,
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 通过doc_id + chunk_id 获取数据
|
|
|
+ async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
|
|
|
+ expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ self.milvus_pool.query,
|
|
|
+ expr=expr,
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 只按 metadata 条件查询
|
|
|
+ async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
|
|
|
+ exprs = []
|
|
|
+ for k, v in filters.items():
|
|
|
+ if isinstance(v, str):
|
|
|
+ exprs.append(f'{k} == "{v}"')
|
|
|
+ else:
|
|
|
+ exprs.append(f"{k} == {v}")
|
|
|
+ expr = " and ".join(exprs)
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ self.milvus_pool.query,
|
|
|
+ expr=expr,
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 混合搜索(向量 + metadata)
|
|
|
+ async def hybrid_search(
|
|
|
+ self,
|
|
|
+ query_vec: List[float],
|
|
|
+ anns_field: str = "vector_text",
|
|
|
+ limit: int = 5,
|
|
|
+ filters: Optional[Dict[str, Union[str, int, float]]] = None,
|
|
|
+ ):
|
|
|
+ expr = None
|
|
|
+ if filters:
|
|
|
+ parts = []
|
|
|
+ for k, v in filters.items():
|
|
|
+ if isinstance(v, str):
|
|
|
+ parts.append(f'{k} == "{v}"')
|
|
|
+ else:
|
|
|
+ parts.append(f"{k} == {v}")
|
|
|
+ expr = " and ".join(parts)
|
|
|
+
|
|
|
+ return await self.vector_search(
|
|
|
+ query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
|
|
|
+ )
|
|
|
+
|
|
|
+ # 通过主键获取milvus数据
|
|
|
+ async def get_by_id(self, pk: int):
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ self.milvus_pool.query,
|
|
|
+ expr=f"id == {pk}",
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+
|
|
|
+ async def search_by_strategy(
|
|
|
+ self,
|
|
|
+ query_vec: List[float],
|
|
|
+ weight_map: Dict,
|
|
|
+ limit: int = 5,
|
|
|
+ expr: Optional[str] = None,
|
|
|
+ search_params: Optional[Dict[str, Any]] = None,
|
|
|
+ ):
|
|
|
+ async def _sub_search(vec, field):
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ self.milvus_pool.search,
|
|
|
+ data=[vec],
|
|
|
+ anns_field=field,
|
|
|
+ param={"metric_type": "COSINE", "params": {"ef": 64}},
|
|
|
+ limit=limit,
|
|
|
+ expr=expr,
|
|
|
+ output_fields=self.output_fields,
|
|
|
+ )
|
|
|
+
|
|
|
+ tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
|
|
|
+ results = await asyncio.gather(*tasks.values())
|
|
|
+
|
|
|
+ scores = {}
|
|
|
+ for (field, weight), res in zip(weight_map.items(), results):
|
|
|
+ for hit in res[0]:
|
|
|
+ key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
|
|
|
+ sim_score = 1 - hit.distance
|
|
|
+ scores[key] = scores.get(key, 0) + weight * sim_score
|
|
|
+
|
|
|
+ ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
|
|
|
+
|
|
|
+ return [
|
|
|
+ {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
|
|
|
+ for k, v in ranked
|
|
|
+ ]
|