123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- import asyncio
- from typing import List, Optional, Dict, Any, Union
- class MilvusBase:
- output_fields = [
- "doc_id",
- "chunk_id",
- # "summary",
- # "topic",
- # "domain",
- # "task_type",
- # "keywords",
- # "concepts",
- # "questions",
- # "entities",
- # "tokens",
- # "topic_purity",
- ]
- def __init__(self, milvus_pool):
- self.milvus_pool = milvus_pool
- @staticmethod
- def hits_to_json(hits):
- if not hits:
- return []
- special_keys = {"entities", "concepts", "questions", "keywords"}
- return [
- {
- "pk": hit.id,
- "score": hit.distance,
- **{
- key: list(value) if key in special_keys else value
- for key, value in (hit.get("entity", {}) or {}).items()
- }
- }
- for hit in hits[0]
- ]
- class MilvusSearch(MilvusBase):
- # 通过向量匹配
- async def vector_search(
- self,
- query_vec: List[float],
- anns_field: str = "vector_text",
- limit: int = 5,
- expr: Optional[str] = None,
- search_params: Optional[Dict[str, Any]] = None,
- ):
- """向量搜索,可选过滤"""
- if search_params is None:
- search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
- response = await asyncio.to_thread(
- self.milvus_pool.search,
- data=[query_vec],
- anns_field=anns_field,
- param=search_params,
- limit=limit,
- expr=expr,
- output_fields=self.output_fields,
- )
- res = self.hits_to_json(response)
- return res
- # 混合搜索(向量 + metadata)
- async def hybrid_search(
- self,
- query_vec: List[float],
- anns_field: str = "vector_text",
- limit: int = 5,
- filters: Optional[Dict[str, Union[str, int, float]]] = None,
- ):
- expr = None
- if filters:
- parts = []
- for k, v in filters.items():
- if isinstance(v, str):
- parts.append(f'{k} == "{v}"')
- else:
- parts.append(f"{k} == {v}")
- expr = " and ".join(parts)
- response = await self.vector_search(
- query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
- )
- return self.hits_to_json(response)
- async def search_by_strategy(
- self,
- query_vec: List[float],
- weight_map: Dict,
- limit: int = 5,
- expr: Optional[str] = None,
- search_params: Optional[Dict[str, Any]] = None,
- ):
- async def _sub_search(vec, field):
- return await asyncio.to_thread(
- self.milvus_pool.search,
- data=[vec],
- anns_field=field,
- param={"metric_type": "COSINE", "params": {"ef": 64}},
- limit=limit,
- expr=expr,
- output_fields=self.output_fields,
- )
- tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
- results = await asyncio.gather(*tasks.values())
- scores = {}
- for (field, weight), res in zip(weight_map.items(), results):
- for hit in res[0]:
- key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
- sim_score = 1 - hit.distance
- scores[key] = scores.get(key, 0) + weight * sim_score
- ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
- return [
- {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
- for k, v in ranked
- ]
- class MilvusQuery(MilvusBase):
- # 通过doc_id + chunk_id 获取数据
- async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
- expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
- response = await asyncio.to_thread(
- self.milvus_pool.query,
- expr=expr,
- output_fields=self.output_fields,
- )
- return self.hits_to_json(response)
- # 只按 metadata 条件查询
- async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
- exprs = []
- for k, v in filters.items():
- if isinstance(v, str):
- exprs.append(f'{k} == "{v}"')
- else:
- exprs.append(f"{k} == {v}")
- expr = " and ".join(exprs)
- response = await asyncio.to_thread(
- self.milvus_pool.query,
- expr=expr,
- output_fields=self.output_fields,
- )
- print(response)
- return self.hits_to_json(response)
- # 通过主键获取milvus数据
- async def get_by_id(self, pk: int):
- response = await asyncio.to_thread(
- self.milvus_pool.query,
- expr=f"id == {pk}",
- output_fields=self.output_fields,
- )
- return self.hits_to_json(response)
|