search.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import asyncio
  2. from typing import List, Optional, Dict, Any, Union
  3. class MilvusBase:
  4. output_fields = [
  5. "doc_id",
  6. "chunk_id",
  7. # "summary",
  8. # "topic",
  9. # "domain",
  10. # "task_type",
  11. # "keywords",
  12. # "concepts",
  13. # "questions",
  14. # "entities",
  15. # "tokens",
  16. # "topic_purity",
  17. ]
  18. def __init__(self, milvus_pool):
  19. self.milvus_pool = milvus_pool
  20. @staticmethod
  21. def hits_to_json(hits):
  22. if not hits:
  23. return []
  24. special_keys = {"entities", "concepts", "questions", "keywords"}
  25. return [
  26. {
  27. "pk": hit.id,
  28. "score": hit.distance,
  29. **{
  30. key: list(value) if key in special_keys else value
  31. for key, value in (hit.get("entity", {}) or {}).items()
  32. }
  33. }
  34. for hit in hits[0]
  35. ]
  36. class MilvusSearch(MilvusBase):
  37. # 通过向量匹配
  38. async def vector_search(
  39. self,
  40. query_vec: List[float],
  41. anns_field: str = "vector_text",
  42. limit: int = 5,
  43. expr: Optional[str] = None,
  44. search_params: Optional[Dict[str, Any]] = None,
  45. ):
  46. """向量搜索,可选过滤"""
  47. if search_params is None:
  48. search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
  49. response = await asyncio.to_thread(
  50. self.milvus_pool.search,
  51. data=[query_vec],
  52. anns_field=anns_field,
  53. param=search_params,
  54. limit=limit,
  55. expr=expr,
  56. output_fields=self.output_fields,
  57. )
  58. res = self.hits_to_json(response)
  59. return res
  60. # 混合搜索(向量 + metadata)
  61. async def hybrid_search(
  62. self,
  63. query_vec: List[float],
  64. anns_field: str = "vector_text",
  65. limit: int = 5,
  66. filters: Optional[Dict[str, Union[str, int, float]]] = None,
  67. ):
  68. expr = None
  69. if filters:
  70. parts = []
  71. for k, v in filters.items():
  72. if isinstance(v, str):
  73. parts.append(f'{k} == "{v}"')
  74. else:
  75. parts.append(f"{k} == {v}")
  76. expr = " and ".join(parts)
  77. response = await self.vector_search(
  78. query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
  79. )
  80. return self.hits_to_json(response)
  81. async def search_by_strategy(
  82. self,
  83. query_vec: List[float],
  84. weight_map: Dict,
  85. limit: int = 5,
  86. expr: Optional[str] = None,
  87. search_params: Optional[Dict[str, Any]] = None,
  88. ):
  89. async def _sub_search(vec, field):
  90. return await asyncio.to_thread(
  91. self.milvus_pool.search,
  92. data=[vec],
  93. anns_field=field,
  94. param={"metric_type": "COSINE", "params": {"ef": 64}},
  95. limit=limit,
  96. expr=expr,
  97. output_fields=self.output_fields,
  98. )
  99. tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
  100. results = await asyncio.gather(*tasks.values())
  101. scores = {}
  102. for (field, weight), res in zip(weight_map.items(), results):
  103. for hit in res[0]:
  104. key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
  105. sim_score = 1 - hit.distance
  106. scores[key] = scores.get(key, 0) + weight * sim_score
  107. ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
  108. return [
  109. {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
  110. for k, v in ranked
  111. ]
  112. class MilvusQuery(MilvusBase):
  113. # 通过doc_id + chunk_id 获取数据
  114. async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
  115. expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
  116. response = await asyncio.to_thread(
  117. self.milvus_pool.query,
  118. expr=expr,
  119. output_fields=self.output_fields,
  120. )
  121. return self.hits_to_json(response)
  122. # 只按 metadata 条件查询
  123. async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
  124. exprs = []
  125. for k, v in filters.items():
  126. if isinstance(v, str):
  127. exprs.append(f'{k} == "{v}"')
  128. else:
  129. exprs.append(f"{k} == {v}")
  130. expr = " and ".join(exprs)
  131. response = await asyncio.to_thread(
  132. self.milvus_pool.query,
  133. expr=expr,
  134. output_fields=self.output_fields,
  135. )
  136. print(response)
  137. return self.hits_to_json(response)
  138. # 通过主键获取milvus数据
  139. async def get_by_id(self, pk: int):
  140. response = await asyncio.to_thread(
  141. self.milvus_pool.query,
  142. expr=f"id == {pk}",
  143. output_fields=self.output_fields,
  144. )
  145. return self.hits_to_json(response)