Explorar o código

增加embedding查询

xueyiming hai 2 meses
pai
achega
035ea9008d
Modificáronse 3 ficheiros con 44 adicións e 2 borrados
  1. 4 1
      api/search.py
  2. 1 1
      utils/deepseek_utils.py
  3. 39 0
      utils/embedding_utils.py

+ 4 - 1
api/search.py

@@ -10,6 +10,7 @@ from schemas.schemas import Query, ContentData
 from tools_v1 import query_keyword_summary_results, query_keyword_content_results
 from utils.data_utils import add_data
 from utils.deepseek_utils import get_keywords
+from utils.embedding_utils import get_embedding_content_data
 
 router = APIRouter()
 
@@ -19,11 +20,13 @@ executor = ThreadPoolExecutor(max_workers=10)
 
 @router.post("/query", response_model=ResponseWrapper)
 async def query_keyword(query: Query):
+    print(query.text)
     keywords = get_keywords(query.text)['keywords']
     print(keywords)
     summary_res = query_keyword_summary_results(keywords)
     content_res = query_keyword_content_results(keywords)
-    res = {'summary_results': summary_res, 'content_results': content_res}
+    embedding_res = get_embedding_content_data(query.text)
+    res = {'summary_results': summary_res, 'content_results': content_res, 'embedding_results': embedding_res}
     return ResponseWrapper(
         status_code=200,
         detail="success",

+ 1 - 1
utils/deepseek_utils.py

@@ -169,7 +169,7 @@ def create_keyword_prompt(text):
         str: 格式化后的 prompt
     """
     prompt = f"""
-提取最能代表当前分析范围(整体或段落)核心内容的关键词或短语。避免使用过于通用和宽泛的词汇
+提取最能代表当前分析范围(整体或段落)核心内容的关键词或短语,如果本身就是一个词,直接返回这个词。避免使用过于通用和宽泛的词汇,
 ## 描述内容:
 {text}
 

+ 39 - 0
utils/embedding_utils.py

@@ -0,0 +1,39 @@
+import json
+
+import requests
+from core.config import logger
+from core.database import DBHelper
+from data_models.content_chunks import ContentChunks
+
+
+def get_embedding_data(query):
+    try:
+        response = requests.post(
+            url='http://192.168.100.31:8001/api/search',
+            json={
+                "query": query,
+                "search_type": "by_vector",
+                "limit": 5},
+            headers={"Content-Type": "application/json"},
+        )
+        return response.json()['results']
+    except Exception as e:
+        logger.error(e)
+    return []
+
+
+def get_embedding_content_data(query):
+    res = []
+    db_helper = DBHelper()
+    results = get_embedding_data(query)
+    if results:
+        for result in results:
+            content_chunk = db_helper.get(ContentChunks, doc_id=result['doc_id'], chunk_id=result['chunk_id'])
+            res.append(
+                {'content': content_chunk.text, 'content_summary': content_chunk.summary, 'score': result['score']})
+    return res
+
+
+if __name__ == '__main__':
+    results = get_embedding_content_data("AI绘图工具")
+    print(json.dumps(results, ensure_ascii=False))