123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- import asyncio
- from typing import List, Optional, Dict, Any, Union
- class MilvusBase:
- output_fields = [
- "id",
- "doc_id",
- "chunk_id",
- ]
- def __init__(self, milvus_pool):
- self.milvus_pool = milvus_pool
- @staticmethod
- def hits_to_json(hits):
- if not hits:
- return []
- special_keys = {"entities", "concepts", "questions", "keywords"}
- return [
- {
- "pk": hit.id,
- "score": hit.distance,
- **{
- key: list(value) if key in special_keys else value
- for key, value in (hit.get("entity", {}) or {}).items()
- },
- }
- for hit in hits[0]
- ]
- class MilvusSearch(MilvusBase):
- # 通过向量粗搜索
- async def base_vector_search(
- self,
- query_vec: List[float],
- anns_field: str = "vector_text",
- limit: int = 5,
- expr: Optional[str] = None,
- search_params: Optional[Dict[str, Any]] = None,
- ):
- """向量搜索,可选过滤"""
- if search_params is None:
- search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
- response = await asyncio.to_thread(
- self.milvus_pool.search,
- data=[query_vec],
- anns_field=anns_field,
- param=search_params,
- limit=limit,
- expr=expr,
- output_fields=self.output_fields,
- )
- return {"results": self.hits_to_json(response)[:10]}
- async def search_by_strategy(
- self,
- query_vec: List[float],
- weight_map: Dict,
- limit: int = 5,
- expr: Optional[str] = None,
- search_params: Optional[Dict[str, Any]] = None,
- ):
- async def _sub_search(vec, field):
- return await asyncio.to_thread(
- self.milvus_pool.search,
- data=[vec],
- anns_field=field,
- param={"metric_type": "COSINE", "params": {"ef": 64}},
- limit=limit,
- expr=expr,
- output_fields=self.output_fields,
- )
- tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
- results = await asyncio.gather(*tasks.values())
- scores = {}
- for (field, weight), res in zip(weight_map.items(), results):
- for hit in res[0]:
- key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
- sim_score = 1 - hit.distance
- scores[key] = scores.get(key, 0) + weight * sim_score
- ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
- return [
- {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
- for k, v in ranked
- ]
|