import asyncio from typing import List, Optional, Dict, Any, Union class MilvusBase: output_fields = [ "doc_id", "chunk_id", # "summary", # "topic", # "domain", # "task_type", # "keywords", # "concepts", # "questions", # "entities", # "tokens", # "topic_purity", ] def __init__(self, milvus_pool): self.milvus_pool = milvus_pool @staticmethod def hits_to_json(hits): if not hits: return [] special_keys = {"entities", "concepts", "questions", "keywords"} return [ { "pk": hit.id, "score": hit.distance, **{ key: list(value) if key in special_keys else value for key, value in (hit.get("entity", {}) or {}).items() }, } for hit in hits[0] ] class MilvusSearch(MilvusBase): # 通过向量匹配 async def vector_search( self, query_vec: List[float], anns_field: str = "vector_text", limit: int = 5, expr: Optional[str] = None, search_params: Optional[Dict[str, Any]] = None, ): """向量搜索,可选过滤""" if search_params is None: search_params = {"metric_type": "COSINE", "params": {"ef": 64}} response = await asyncio.to_thread( self.milvus_pool.search, data=[query_vec], anns_field=anns_field, param=search_params, limit=limit, expr=expr, output_fields=self.output_fields, ) return {"results": self.hits_to_json(response)} # 混合搜索(向量 + metadata) async def hybrid_search( self, query_vec: List[float], anns_field: str = "vector_text", limit: int = 5, filters: Optional[Dict[str, Union[str, int, float]]] = None, ): expr = None if filters: parts = [] for k, v in filters.items(): if isinstance(v, str): parts.append(f'{k} == "{v}"') else: parts.append(f"{k} == {v}") expr = " and ".join(parts) response = await self.vector_search( query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr ) return self.hits_to_json(response) async def search_by_strategy( self, query_vec: List[float], weight_map: Dict, limit: int = 5, expr: Optional[str] = None, search_params: Optional[Dict[str, Any]] = None, ): async def _sub_search(vec, field): return await asyncio.to_thread( self.milvus_pool.search, data=[vec], anns_field=field, param={"metric_type": "COSINE", "params": {"ef": 64}}, limit=limit, expr=expr, output_fields=self.output_fields, ) tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()} results = await asyncio.gather(*tasks.values()) scores = {} for (field, weight), res in zip(weight_map.items(), results): for hit in res[0]: key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id")) sim_score = 1 - hit.distance scores[key] = scores.get(key, 0) + weight * sim_score ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit] return [ {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v} for k, v in ranked ] class MilvusQuery(MilvusBase): # 通过doc_id + chunk_id 获取数据 async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int): expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}' response = await asyncio.to_thread( self.milvus_pool.query, expr=expr, output_fields=self.output_fields, ) return self.hits_to_json(response) # 只按 metadata 条件查询 async def filter_search(self, filters: Dict[str, Union[str, int, float]]): exprs = [] for k, v in filters.items(): if isinstance(v, str): exprs.append(f'{k} == "{v}"') else: exprs.append(f"{k} == {v}") expr = " and ".join(exprs) response = await asyncio.to_thread( self.milvus_pool.query, expr=expr, output_fields=self.output_fields, ) print(response) return self.hits_to_json(response) # 通过主键获取milvus数据 async def get_by_id(self, pk: int): response = await asyncio.to_thread( self.milvus_pool.query, expr=f"id == {pk}", output_fields=self.output_fields, ) return self.hits_to_json(response)