search.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import asyncio
  2. from typing import List, Optional, Dict, Any, Union
  3. class MilvusSearcher:
  4. output_fields = [
  5. "doc_id",
  6. "chunk_id",
  7. "summary",
  8. "topic",
  9. "domain",
  10. "task_type",
  11. "keywords",
  12. "concepts",
  13. "questions",
  14. "entities",
  15. "tokens",
  16. "topic_purity",
  17. ]
  18. def __init__(self, milvus_pool):
  19. self.milvus_pool = milvus_pool
  20. # 通过向量匹配
  21. async def vector_search(
  22. self,
  23. query_vec: List[float],
  24. anns_field: str = "vector_text",
  25. limit: int = 5,
  26. expr: Optional[str] = None,
  27. search_params: Optional[Dict[str, Any]] = None,
  28. ):
  29. """向量搜索,可选过滤"""
  30. if search_params is None:
  31. search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
  32. return await asyncio.to_thread(
  33. self.milvus_pool.search,
  34. data=[query_vec],
  35. anns_field=anns_field,
  36. param=search_params,
  37. limit=limit,
  38. expr=expr,
  39. output_fields=self.output_fields,
  40. )
  41. # 通过doc_id + chunk_id 获取数据
  42. async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
  43. expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
  44. return await asyncio.to_thread(
  45. self.milvus_pool.query,
  46. expr=expr,
  47. output_fields=self.output_fields,
  48. )
  49. # 只按 metadata 条件查询
  50. async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
  51. exprs = []
  52. for k, v in filters.items():
  53. if isinstance(v, str):
  54. exprs.append(f'{k} == "{v}"')
  55. else:
  56. exprs.append(f"{k} == {v}")
  57. expr = " and ".join(exprs)
  58. return await asyncio.to_thread(
  59. self.milvus_pool.query,
  60. expr=expr,
  61. output_fields=self.output_fields,
  62. )
  63. # 混合搜索(向量 + metadata)
  64. async def hybrid_search(
  65. self,
  66. query_vec: List[float],
  67. anns_field: str = "vector_text",
  68. limit: int = 5,
  69. filters: Optional[Dict[str, Union[str, int, float]]] = None,
  70. ):
  71. expr = None
  72. if filters:
  73. parts = []
  74. for k, v in filters.items():
  75. if isinstance(v, str):
  76. parts.append(f'{k} == "{v}"')
  77. else:
  78. parts.append(f"{k} == {v}")
  79. expr = " and ".join(parts)
  80. return await self.vector_search(
  81. query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
  82. )
  83. # 通过主键获取milvus数据
  84. async def get_by_id(self, pk: int):
  85. return await asyncio.to_thread(
  86. self.milvus_pool.query,
  87. expr=f"id == {pk}",
  88. output_fields=self.output_fields,
  89. )
  90. async def search_by_strategy(
  91. self,
  92. query_vec: List[float],
  93. weight_map: Dict,
  94. limit: int = 5,
  95. expr: Optional[str] = None,
  96. search_params: Optional[Dict[str, Any]] = None,
  97. ):
  98. async def _sub_search(vec, field):
  99. return await asyncio.to_thread(
  100. self.milvus_pool.search,
  101. data=[vec],
  102. anns_field=field,
  103. param={"metric_type": "COSINE", "params": {"ef": 64}},
  104. limit=limit,
  105. expr=expr,
  106. output_fields=self.output_fields,
  107. )
  108. tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
  109. results = await asyncio.gather(*tasks.values())
  110. scores = {}
  111. for (field, weight), res in zip(weight_map.items(), results):
  112. for hit in res[0]:
  113. key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
  114. sim_score = 1 - hit.distance
  115. scores[key] = scores.get(key, 0) + weight * sim_score
  116. ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
  117. return [
  118. {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
  119. for k, v in ranked
  120. ]