hybrid_search.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from typing import List, Dict, Optional, Any
  2. from .base_search import BaseSearch
  3. from applications.utils.neo4j import AsyncGraphExpansion
  4. from applications.utils.elastic_search import ElasticSearchStrategy
  5. class HybridSearch(BaseSearch):
  6. def __init__(self, milvus_pool, es_pool, graph_pool):
  7. super().__init__(milvus_pool, es_pool)
  8. self.es_strategy = ElasticSearchStrategy(self.es_pool)
  9. self.graph_expansion = AsyncGraphExpansion(driver=graph_pool)
  10. async def hybrid_search(
  11. self,
  12. filters: Dict[str, Any], # 条件过滤
  13. query_vec: List[float], # query 的向量
  14. anns_field: str = "vector_text", # query指定的向量空间
  15. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  16. query_text: str = None, # 是否通过 topic 倒排
  17. _source=False, # 是否返回元数据
  18. es_size: int = 10000, # es 第一层过滤数量
  19. sort_by: str = None, # 排序
  20. milvus_size: int = 10, # milvus粗排返回数量
  21. ):
  22. milvus_ids = await self.es_strategy.base_search(
  23. filters=filters,
  24. text_query=query_text,
  25. _source=_source,
  26. size=es_size,
  27. sort_by=sort_by,
  28. )
  29. if not milvus_ids:
  30. return {"results": []}
  31. milvus_ids_list = ",".join(milvus_ids)
  32. expr = f"id in [{milvus_ids_list}]"
  33. return await self.base_vector_search(
  34. query_vec=query_vec,
  35. anns_field=anns_field,
  36. limit=milvus_size,
  37. expr=expr,
  38. search_params=search_params,
  39. )
  40. async def hybrid_search_with_graph(
  41. self,
  42. filters: Dict[str, Any], # 条件过滤
  43. query_vec: List[float], # query 的向量
  44. anns_field: str = "vector_text", # query指定的向量空间
  45. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  46. query_text: str = None, # 是否通过 topic 倒排
  47. _source=False, # 是否返回元数据
  48. es_size: int = 10000, # es 第一层过滤数量
  49. sort_by: str = None, # 排序
  50. milvus_size: int = 10, # milvus粗排返回数量
  51. co_occurrence_fields: Dict[str, Any] = None, # 共现字段
  52. shortest_path_fields: Dict[str, Any] = None, # 最短之间的 chunks
  53. ):
  54. # step1, use elastic_search to filter chunks
  55. es_milvus_ids = await self.es_strategy.base_search(
  56. filters=filters,
  57. text_query=query_text,
  58. _source=_source,
  59. size=es_size,
  60. sort_by=sort_by,
  61. )
  62. # step2, use graph to get co_occurrence chunks
  63. if not co_occurrence_fields:
  64. co_occurrence_ids = []
  65. else:
  66. # 测试版本先只用实体
  67. seed_label = "Entity"
  68. name = co_occurrence_fields.get(seed_label)
  69. if not name:
  70. co_occurrence_ids = []
  71. else:
  72. co_occurrence_ids = await self.graph_expansion.co_occurrence(
  73. seed_name=name, seed_label=seed_label
  74. )
  75. # step3, 查询两个 chunk 之间的chunks
  76. if not shortest_path_fields:
  77. shortest_path_ids = []
  78. else:
  79. shortest_path_ids = await self.graph_expansion.shortest_path_chunks(
  80. a_name=shortest_path_fields.get("a_name"),
  81. a_label=shortest_path_fields.get("a_label"),
  82. b_name=shortest_path_fields.get("b_name"),
  83. b_label=shortest_path_fields.get("b_label"),
  84. )
  85. print("es:", es_milvus_ids)
  86. print("co:", co_occurrence_ids)
  87. print("shortest:", shortest_path_ids)
  88. # step3, merge 上述 ids
  89. final_milvus_ids = list(
  90. set(shortest_path_ids + co_occurrence_ids + es_milvus_ids)
  91. )
  92. # step4, 通过向量获取候选集
  93. if not final_milvus_ids:
  94. return {"results": []}
  95. milvus_ids_list = ",".join(final_milvus_ids)
  96. expr = f"id in [{milvus_ids_list}]"
  97. return await self.base_vector_search(
  98. query_vec=query_vec,
  99. anns_field=anns_field,
  100. limit=milvus_size,
  101. expr=expr,
  102. search_params=search_params,
  103. )
  104. async def expand_with_graph(
  105. self, milvus_ids: List[int], limit: int
  106. ) -> List[Dict[str, Any]]:
  107. """拓展字段"""
  108. expanded_chunks = await self.graph_expansion.expand_candidates(
  109. seed_ids=milvus_ids, limit=limit
  110. )
  111. return expanded_chunks