Browse Source

Update textSimilarity: add EmbeddingManager for test

StrayWarrior 5 months ago
parent
commit
6a5c2d4314
4 changed files with 22 additions and 5 deletions
  1. 3 1
      alg_app.py
  2. 15 1
      applications/textSimilarity.py
  3. 1 1
      routes/__init__.py
  4. 3 2
      routes/nlpServer.py

+ 3 - 1
alg_app.py

@@ -5,6 +5,7 @@ from quart import Quart
 from similarities import BertSimilarity
 from routes import AlgRoutes
 from applications import AsyncMySQLClient
+from applications.embedding_manager import EmbeddingManager
 
 app = Quart(__name__)
 AsyncMySQL = AsyncMySQLClient(app)
@@ -17,8 +18,9 @@ async def init():
     """
     await AsyncMySQL.init_pool()
     model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
+    embedding_manager = EmbeddingManager(model)
     print("模型加载成功")
-    app_routes = AlgRoutes(AsyncMySQL, model)
+    app_routes = AlgRoutes(AsyncMySQL, model, embedding_manager)
     app.register_blueprint(app_routes)
 
 

+ 15 - 1
applications/textSimilarity.py

@@ -17,14 +17,21 @@ def score_to_attention(score, symbol=1):
     score_attn = torch.nn.functional.softmax(score_norm, dim=1)
     return score_attn, score_norm, score_pred
 
+def compare_tensor(tensor1, tensor2):
+    if tensor1.shape != tensor2.shape:
+        print(f"[EmbeddingManager]shape error: {tensor1.shape} vs {tensor2.shape}")
+        return
+    if not torch.allclose(tensor1, tensor2):
+        print("[EmbeddingManager]value error: tensor1 not close to tensor2")
 
 class NLPFunction(object):
     """
     NLP Task
     """
 
-    def __init__(self, model):
+    def __init__(self, model, embedding_manager):
         self.model = model
+        self.embedding_manager = embedding_manager
 
     def base_string_similarity(self, text_dict):
         """
@@ -36,6 +43,13 @@ class NLPFunction(object):
             text_dict['text_a'],
             text_dict['text_b']
         )
+        # test embedding manager functions
+        text_emb1 = self.embedding_manager.get_embedding(text_dict['text_a'])
+        text_emb2 = self.embedding_manager.get_embedding(text_dict['text_b'])
+        score_function = self.model.score_functions['cos_sim']
+        score_tensor_new = score_function(text_emb1, text_emb2)
+        compare_tensor(score_tensor, score_tensor_new)
+
         response = {
             "score": score_tensor.squeeze().tolist()
         }

+ 1 - 1
routes/__init__.py

@@ -11,7 +11,7 @@ from .accountServer import AccountServer
 from applications.articleTools import ArticleDBTools
 
 
-def AlgRoutes(mysql_client, model):
+def AlgRoutes(mysql_client, model, embedding_manager):
     """
     ALG ROUTES
     :return:

+ 3 - 2
routes/nlpServer.py

@@ -8,14 +8,15 @@ class NLPServer(object):
     """
     nlp_server
     """
-    def __init__(self, params, model):
+    def __init__(self, params, model, embedding_manager):
         """
         :param params:
         """
         self.data = None
         self.function = None
         self.params = params
-        self.nlp = NLPFunction(model=model)
+        self.embedding_manager = embedding_manager
+        self.nlp = NLPFunction(model=model, embedding_manager=embedding_manager)
 
     def check_params(self):
         """