import asyncio from typing import List, Optional, Dict, Any, Union class MilvusBase: output_fields = [ "id", "doc_id", "chunk_id", ] def __init__(self, milvus_pool): self.milvus_pool = milvus_pool @staticmethod def hits_to_json(hits): if not hits: return [] special_keys = {"entities", "concepts", "questions", "keywords"} return [ { "pk": hit.id, "score": hit.distance, **{ key: list(value) if key in special_keys else value for key, value in (hit.get("entity", {}) or {}).items() }, } for hit in hits[0] ] class MilvusSearch(MilvusBase): # 通过向量粗搜索 async def base_vector_search( self, query_vec: List[float], anns_field: str = "vector_text", limit: int = 5, expr: Optional[str] = None, search_params: Optional[Dict[str, Any]] = None, ): """向量搜索,可选过滤""" if search_params is None: search_params = {"metric_type": "COSINE", "params": {"ef": 64}} response = await asyncio.to_thread( self.milvus_pool.search, data=[query_vec], anns_field=anns_field, param=search_params, limit=limit, expr=expr, output_fields=self.output_fields, ) return {"results": self.hits_to_json(response)} async def search_by_strategy( self, query_vec: List[float], weight_map: Dict, limit: int = 5, expr: Optional[str] = None, search_params: Optional[Dict[str, Any]] = None, ): async def _sub_search(vec, field): return await asyncio.to_thread( self.milvus_pool.search, data=[vec], anns_field=field, param={"metric_type": "COSINE", "params": {"ef": 64}}, limit=limit, expr=expr, output_fields=self.output_fields, ) tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()} results = await asyncio.gather(*tasks.values()) scores = {} for (field, weight), res in zip(weight_map.items(), results): for hit in res[0]: key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id")) sim_score = 1 - hit.distance scores[key] = scores.get(key, 0) + weight * sim_score ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit] return [ {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v} for k, v in ranked ]