search.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import asyncio
  2. from typing import List, Optional, Dict, Any, Union
  3. class MilvusBase:
  4. output_fields = [
  5. "id",
  6. "doc_id",
  7. "chunk_id",
  8. ]
  9. def __init__(self, milvus_pool):
  10. self.milvus_pool = milvus_pool
  11. @staticmethod
  12. def hits_to_json(hits):
  13. if not hits:
  14. return []
  15. special_keys = {"entities", "concepts", "questions", "keywords"}
  16. return [
  17. {
  18. "pk": hit.id,
  19. "score": hit.distance,
  20. **{
  21. key: list(value) if key in special_keys else value
  22. for key, value in (hit.get("entity", {}) or {}).items()
  23. },
  24. }
  25. for hit in hits[0]
  26. ]
  27. class MilvusSearch(MilvusBase):
  28. # 通过向量粗搜索
  29. async def base_vector_search(
  30. self,
  31. query_vec: List[float],
  32. anns_field: str = "vector_text",
  33. limit: int = 5,
  34. expr: Optional[str] = None,
  35. search_params: Optional[Dict[str, Any]] = None,
  36. ):
  37. """向量搜索,可选过滤"""
  38. if search_params is None:
  39. search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
  40. response = await asyncio.to_thread(
  41. self.milvus_pool.search,
  42. data=[query_vec],
  43. anns_field=anns_field,
  44. param=search_params,
  45. limit=limit,
  46. expr=expr,
  47. output_fields=self.output_fields,
  48. )
  49. return {"results": self.hits_to_json(response)[:10]}
  50. async def search_by_strategy(
  51. self,
  52. query_vec: List[float],
  53. weight_map: Dict,
  54. limit: int = 5,
  55. expr: Optional[str] = None,
  56. search_params: Optional[Dict[str, Any]] = None,
  57. ):
  58. async def _sub_search(vec, field):
  59. return await asyncio.to_thread(
  60. self.milvus_pool.search,
  61. data=[vec],
  62. anns_field=field,
  63. param={"metric_type": "COSINE", "params": {"ef": 64}},
  64. limit=limit,
  65. expr=expr,
  66. output_fields=self.output_fields,
  67. )
  68. tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
  69. results = await asyncio.gather(*tasks.values())
  70. scores = {}
  71. for (field, weight), res in zip(weight_map.items(), results):
  72. for hit in res[0]:
  73. key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
  74. sim_score = 1 - hit.distance
  75. scores[key] = scores.get(key, 0) + weight * sim_score
  76. ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
  77. return [
  78. {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
  79. for k, v in ranked
  80. ]