search.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import asyncio
  2. from typing import List, Optional, Dict, Any, Union
  3. class MilvusBase:
  4. output_fields = [
  5. "doc_id",
  6. "chunk_id",
  7. # "summary",
  8. # "topic",
  9. # "domain",
  10. # "task_type",
  11. # "keywords",
  12. # "concepts",
  13. # "questions",
  14. # "entities",
  15. # "tokens",
  16. # "topic_purity",
  17. ]
  18. def __init__(self, milvus_pool):
  19. self.milvus_pool = milvus_pool
  20. @staticmethod
  21. def hits_to_json(hits):
  22. if not hits:
  23. return []
  24. special_keys = {"entities", "concepts", "questions", "keywords"}
  25. return [
  26. {
  27. "pk": hit.id,
  28. "score": hit.distance,
  29. **{
  30. key: list(value) if key in special_keys else value
  31. for key, value in (hit.get("entity", {}) or {}).items()
  32. }
  33. }
  34. for hit in hits[0]
  35. ]
  36. class MilvusSearch(MilvusBase):
  37. # 通过向量匹配
  38. async def vector_search(
  39. self,
  40. query_vec: List[float],
  41. anns_field: str = "vector_text",
  42. limit: int = 5,
  43. expr: Optional[str] = None,
  44. search_params: Optional[Dict[str, Any]] = None,
  45. ):
  46. """向量搜索,可选过滤"""
  47. if search_params is None:
  48. search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
  49. response = await asyncio.to_thread(
  50. self.milvus_pool.search,
  51. data=[query_vec],
  52. anns_field=anns_field,
  53. param=search_params,
  54. limit=limit,
  55. expr=expr,
  56. output_fields=self.output_fields,
  57. )
  58. return {"results": self.hits_to_json(response)}
  59. # 混合搜索(向量 + metadata)
  60. async def hybrid_search(
  61. self,
  62. query_vec: List[float],
  63. anns_field: str = "vector_text",
  64. limit: int = 5,
  65. filters: Optional[Dict[str, Union[str, int, float]]] = None,
  66. ):
  67. expr = None
  68. if filters:
  69. parts = []
  70. for k, v in filters.items():
  71. if isinstance(v, str):
  72. parts.append(f'{k} == "{v}"')
  73. else:
  74. parts.append(f"{k} == {v}")
  75. expr = " and ".join(parts)
  76. response = await self.vector_search(
  77. query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
  78. )
  79. return self.hits_to_json(response)
  80. async def search_by_strategy(
  81. self,
  82. query_vec: List[float],
  83. weight_map: Dict,
  84. limit: int = 5,
  85. expr: Optional[str] = None,
  86. search_params: Optional[Dict[str, Any]] = None,
  87. ):
  88. async def _sub_search(vec, field):
  89. return await asyncio.to_thread(
  90. self.milvus_pool.search,
  91. data=[vec],
  92. anns_field=field,
  93. param={"metric_type": "COSINE", "params": {"ef": 64}},
  94. limit=limit,
  95. expr=expr,
  96. output_fields=self.output_fields,
  97. )
  98. tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
  99. results = await asyncio.gather(*tasks.values())
  100. scores = {}
  101. for (field, weight), res in zip(weight_map.items(), results):
  102. for hit in res[0]:
  103. key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
  104. sim_score = 1 - hit.distance
  105. scores[key] = scores.get(key, 0) + weight * sim_score
  106. ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
  107. return [
  108. {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
  109. for k, v in ranked
  110. ]
  111. class MilvusQuery(MilvusBase):
  112. # 通过doc_id + chunk_id 获取数据
  113. async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
  114. expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
  115. response = await asyncio.to_thread(
  116. self.milvus_pool.query,
  117. expr=expr,
  118. output_fields=self.output_fields,
  119. )
  120. return self.hits_to_json(response)
  121. # 只按 metadata 条件查询
  122. async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
  123. exprs = []
  124. for k, v in filters.items():
  125. if isinstance(v, str):
  126. exprs.append(f'{k} == "{v}"')
  127. else:
  128. exprs.append(f"{k} == {v}")
  129. expr = " and ".join(exprs)
  130. response = await asyncio.to_thread(
  131. self.milvus_pool.query,
  132. expr=expr,
  133. output_fields=self.output_fields,
  134. )
  135. print(response)
  136. return self.hits_to_json(response)
  137. # 通过主键获取milvus数据
  138. async def get_by_id(self, pk: int):
  139. response = await asyncio.to_thread(
  140. self.milvus_pool.query,
  141. expr=f"id == {pk}",
  142. output_fields=self.output_fields,
  143. )
  144. return self.hits_to_json(response)