from typing import List, Dict, Optional, Any from .base_search import BaseSearch from applications.utils.neo4j import AsyncGraphExpansion from applications.utils.elastic_search import ElasticSearchStrategy class HybridSearch(BaseSearch): def __init__(self, milvus_pool, es_pool, graph_pool): super().__init__(milvus_pool, es_pool) self.es_strategy = ElasticSearchStrategy(self.es_pool) self.graph_expansion = AsyncGraphExpansion(driver=graph_pool) async def hybrid_search( self, filters: Dict[str, Any], # 条件过滤 query_vec: List[float], # query 的向量 anns_field: str = "vector_text", # query指定的向量空间 search_params: Optional[Dict[str, Any]] = None, # 向量距离方式 query_text: str = None, # 是否通过 topic 倒排 _source=False, # 是否返回元数据 es_size: int = 10000, # es 第一层过滤数量 sort_by: str = None, # 排序 milvus_size: int = 10, # milvus粗排返回数量 ): milvus_ids = await self.es_strategy.base_search( filters=filters, text_query=query_text, _source=_source, size=es_size, sort_by=sort_by, ) if not milvus_ids: return {"results": []} milvus_ids_list = ",".join(milvus_ids) expr = f"id in [{milvus_ids_list}]" return await self.base_vector_search( query_vec=query_vec, anns_field=anns_field, limit=milvus_size, expr=expr, search_params=search_params, ) async def hybrid_search_with_graph( self, filters: Dict[str, Any], # 条件过滤 query_vec: List[float], # query 的向量 anns_field: str = "vector_text", # query指定的向量空间 search_params: Optional[Dict[str, Any]] = None, # 向量距离方式 query_text: str = None, # 是否通过 topic 倒排 _source=False, # 是否返回元数据 es_size: int = 10000, # es 第一层过滤数量 sort_by: str = None, # 排序 milvus_size: int = 10, # milvus粗排返回数量 co_occurrence_fields: Dict[str, Any] = None, # 共现字段 shortest_path_fields: Dict[str, Any] = None, # 最短之间的 chunks ): # step1, use elastic_search to filter chunks es_milvus_ids = await self.es_strategy.base_search( filters=filters, text_query=query_text, _source=_source, size=es_size, sort_by=sort_by, ) # step2, use graph to get co_occurrence chunks if not co_occurrence_fields: co_occurrence_ids = [] else: # 测试版本先只用实体 seed_label = "Entity" name = co_occurrence_fields.get(seed_label) if not name: co_occurrence_ids = [] else: co_occurrence_ids = await self.graph_expansion.co_occurrence( seed_name=name, seed_label=seed_label ) # step3, 查询两个 chunk 之间的chunks if not shortest_path_fields: shortest_path_ids = [] else: shortest_path_ids = await self.graph_expansion.shortest_path_chunks( a_name=shortest_path_fields.get("a_name"), a_label=shortest_path_fields.get("a_label"), b_name=shortest_path_fields.get("b_name"), b_label=shortest_path_fields.get("b_label"), ) print("es:", es_milvus_ids) print("co:", co_occurrence_ids) print("shortest:", shortest_path_ids) # step3, merge 上述 ids final_milvus_ids = list( set(shortest_path_ids + co_occurrence_ids + es_milvus_ids) ) # step4, 通过向量获取候选集 if not final_milvus_ids: return {"results": []} milvus_ids_list = ",".join(final_milvus_ids) expr = f"id in [{milvus_ids_list}]" return await self.base_vector_search( query_vec=query_vec, anns_field=anns_field, limit=milvus_size, expr=expr, search_params=search_params, ) async def expand_with_graph( self, milvus_ids: List[int], limit: int ) -> List[Dict[str, Any]]: """拓展字段""" expanded_chunks = await self.graph_expansion.expand_candidates( seed_ids=milvus_ids, limit=limit ) return expanded_chunks