|
@@ -1,27 +1,47 @@
|
|
import asyncio
|
|
import asyncio
|
|
from typing import List, Optional, Dict, Any, Union
|
|
from typing import List, Optional, Dict, Any, Union
|
|
|
|
|
|
-
|
|
|
|
-class MilvusSearcher:
|
|
|
|
|
|
+class MilvusBase:
|
|
|
|
|
|
output_fields = [
|
|
output_fields = [
|
|
"doc_id",
|
|
"doc_id",
|
|
"chunk_id",
|
|
"chunk_id",
|
|
- "summary",
|
|
|
|
- "topic",
|
|
|
|
- "domain",
|
|
|
|
- "task_type",
|
|
|
|
- "keywords",
|
|
|
|
- "concepts",
|
|
|
|
- "questions",
|
|
|
|
- "entities",
|
|
|
|
- "tokens",
|
|
|
|
- "topic_purity",
|
|
|
|
|
|
+ # "summary",
|
|
|
|
+ # "topic",
|
|
|
|
+ # "domain",
|
|
|
|
+ # "task_type",
|
|
|
|
+ # "keywords",
|
|
|
|
+ # "concepts",
|
|
|
|
+ # "questions",
|
|
|
|
+ # "entities",
|
|
|
|
+ # "tokens",
|
|
|
|
+ # "topic_purity",
|
|
]
|
|
]
|
|
|
|
|
|
def __init__(self, milvus_pool):
|
|
def __init__(self, milvus_pool):
|
|
self.milvus_pool = milvus_pool
|
|
self.milvus_pool = milvus_pool
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
+ def hits_to_json(hits):
|
|
|
|
+ if not hits:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ special_keys = {"entities", "concepts", "questions", "keywords"}
|
|
|
|
+ return [
|
|
|
|
+ {
|
|
|
|
+ "pk": hit.id,
|
|
|
|
+ "score": 1 - hit.distance,
|
|
|
|
+ **{
|
|
|
|
+ key: list(value) if key in special_keys else value
|
|
|
|
+ for key, value in (hit.get("entity", {}) or {}).items()
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ for hit in hits[0]
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class MilvusSearch(MilvusBase):
|
|
|
|
+
|
|
# 通过向量匹配
|
|
# 通过向量匹配
|
|
async def vector_search(
|
|
async def vector_search(
|
|
self,
|
|
self,
|
|
@@ -34,7 +54,8 @@ class MilvusSearcher:
|
|
"""向量搜索,可选过滤"""
|
|
"""向量搜索,可选过滤"""
|
|
if search_params is None:
|
|
if search_params is None:
|
|
search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
|
|
search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
|
|
- return await asyncio.to_thread(
|
|
|
|
|
|
+
|
|
|
|
+ response = await asyncio.to_thread(
|
|
self.milvus_pool.search,
|
|
self.milvus_pool.search,
|
|
data=[query_vec],
|
|
data=[query_vec],
|
|
anns_field=anns_field,
|
|
anns_field=anns_field,
|
|
@@ -43,30 +64,9 @@ class MilvusSearcher:
|
|
expr=expr,
|
|
expr=expr,
|
|
output_fields=self.output_fields,
|
|
output_fields=self.output_fields,
|
|
)
|
|
)
|
|
|
|
+ res = self.hits_to_json(response)
|
|
|
|
+ return res
|
|
|
|
|
|
- # 通过doc_id + chunk_id 获取数据
|
|
|
|
- async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
|
|
|
|
- expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
|
|
|
|
- return await asyncio.to_thread(
|
|
|
|
- self.milvus_pool.query,
|
|
|
|
- expr=expr,
|
|
|
|
- output_fields=self.output_fields,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- # 只按 metadata 条件查询
|
|
|
|
- async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
|
|
|
|
- exprs = []
|
|
|
|
- for k, v in filters.items():
|
|
|
|
- if isinstance(v, str):
|
|
|
|
- exprs.append(f'{k} == "{v}"')
|
|
|
|
- else:
|
|
|
|
- exprs.append(f"{k} == {v}")
|
|
|
|
- expr = " and ".join(exprs)
|
|
|
|
- return await asyncio.to_thread(
|
|
|
|
- self.milvus_pool.query,
|
|
|
|
- expr=expr,
|
|
|
|
- output_fields=self.output_fields,
|
|
|
|
- )
|
|
|
|
|
|
|
|
# 混合搜索(向量 + metadata)
|
|
# 混合搜索(向量 + metadata)
|
|
async def hybrid_search(
|
|
async def hybrid_search(
|
|
@@ -86,17 +86,10 @@ class MilvusSearcher:
|
|
parts.append(f"{k} == {v}")
|
|
parts.append(f"{k} == {v}")
|
|
expr = " and ".join(parts)
|
|
expr = " and ".join(parts)
|
|
|
|
|
|
- return await self.vector_search(
|
|
|
|
|
|
+ response = await self.vector_search(
|
|
query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
|
|
query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
|
|
)
|
|
)
|
|
-
|
|
|
|
- # 通过主键获取milvus数据
|
|
|
|
- async def get_by_id(self, pk: int):
|
|
|
|
- return await asyncio.to_thread(
|
|
|
|
- self.milvus_pool.query,
|
|
|
|
- expr=f"id == {pk}",
|
|
|
|
- output_fields=self.output_fields,
|
|
|
|
- )
|
|
|
|
|
|
+ return self.hits_to_json(response)
|
|
|
|
|
|
async def search_by_strategy(
|
|
async def search_by_strategy(
|
|
self,
|
|
self,
|
|
@@ -133,3 +126,41 @@ class MilvusSearcher:
|
|
{"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
|
|
{"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
|
|
for k, v in ranked
|
|
for k, v in ranked
|
|
]
|
|
]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class MilvusQuery(MilvusBase):
|
|
|
|
+ # 通过doc_id + chunk_id 获取数据
|
|
|
|
+ async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
|
|
|
|
+ expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
|
|
|
|
+ response = await asyncio.to_thread(
|
|
|
|
+ self.milvus_pool.query,
|
|
|
|
+ expr=expr,
|
|
|
|
+ output_fields=self.output_fields,
|
|
|
|
+ )
|
|
|
|
+ return self.hits_to_json(response)
|
|
|
|
+
|
|
|
|
+ # 只按 metadata 条件查询
|
|
|
|
+ async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
|
|
|
|
+ exprs = []
|
|
|
|
+ for k, v in filters.items():
|
|
|
|
+ if isinstance(v, str):
|
|
|
|
+ exprs.append(f'{k} == "{v}"')
|
|
|
|
+ else:
|
|
|
|
+ exprs.append(f"{k} == {v}")
|
|
|
|
+ expr = " and ".join(exprs)
|
|
|
|
+ response = await asyncio.to_thread(
|
|
|
|
+ self.milvus_pool.query,
|
|
|
|
+ expr=expr,
|
|
|
|
+ output_fields=self.output_fields,
|
|
|
|
+ )
|
|
|
|
+ print(response)
|
|
|
|
+ return self.hits_to_json(response)
|
|
|
|
+
|
|
|
|
+ # 通过主键获取milvus数据
|
|
|
|
+ async def get_by_id(self, pk: int):
|
|
|
|
+ response = await asyncio.to_thread(
|
|
|
|
+ self.milvus_pool.query,
|
|
|
|
+ expr=f"id == {pk}",
|
|
|
|
+ output_fields=self.output_fields,
|
|
|
|
+ )
|
|
|
|
+ return self.hits_to_json(response)
|