罗俊辉 9 ماه پیش
والد
کامیت
60b662b2a2
6فایلهای تغییر یافته به همراه107 افزوده شده و 72 حذف شده
  1. 1 1
      alg.toml
  2. 6 5
      alg_app.py
  3. 52 26
      applications/textSimilarity.py
  4. 7 6
      routes/__init__.py
  5. 17 34
      routes/nlpServer.py
  6. 24 0
      test/nlp_dev.py

+ 1 - 1
alg.toml

@@ -1,6 +1,6 @@
 reload = true
 bind = "0.0.0.0:6060"
-workers = 4
+workers = 1
 keep_alive_timeout = 120  # 保持连接的最大秒数,根据需要调整
 graceful_timeout = 30    # 重启或停止之前等待当前工作完成的时间
 loglevel = "debug"  # 日志级别

+ 6 - 5
alg_app.py

@@ -2,22 +2,23 @@
 @author: luojunhui
 """
 from quart import Quart
+from similarities import BertSimilarity
 from routes import AlgRoutes
 from applications import AsyncMySQLClient
 
 app = Quart(__name__)
 AsyncMySQL = AsyncMySQLClient(app)
-app_routes = AlgRoutes(AsyncMySQL)
-app.register_blueprint(app_routes)
 
 
 @app.before_serving
-async def init_db():
+async def init():
     """
-    初始化
-    :return:
+    初始化模型
     """
     await AsyncMySQL.init_pool()
+    model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
+    app_routes = AlgRoutes(AsyncMySQL, model)
+    app.register_blueprint(app_routes)
 
 
 @app.after_serving

+ 52 - 26
applications/textSimilarity.py

@@ -1,15 +1,21 @@
 """
 @author: luojunhui
 """
-import time
 import torch
 import numpy as np
-from similarities import BertSimilarity
 
 
-# bge_large_zh_v1_5 = 'bge_large_zh_v1_5'
-# text2vec_base_chinese = "text2vec_base_chinese"
-# text2vec_bge_large_chinese = "text2vec_bge_large_chinese"
+def score_to_attention(score, symbol=1):
+    """
+
+    :param score:
+    :param symbol:
+    :return:
+    """
+    score_pred = torch.FloatTensor(score).unsqueeze(0)
+    score_norm = symbol * torch.nn.functional.normalize(score_pred, p=2)
+    score_attn = torch.nn.functional.softmax(score_norm, dim=1)
+    return score_attn, score_norm, score_pred
 
 
 class NLPFunction(object):
@@ -35,7 +41,7 @@ class NLPFunction(object):
     def base_list_similarity(self, pair_list_dict):
         """
         计算两个list的相似度
-        :return: "score_list_b": [100, 1000, 500, 40],
+        :return:
         """
         score_tensor = self.model.similarity(
             pair_list_dict['text_list_a'],
@@ -43,24 +49,44 @@ class NLPFunction(object):
         )
         return score_tensor.tolist()
 
+    def max_cross_similarity(self, data):
+        """
+        max
+        :param data:
+        :return:
+        """
+        score_list_max = []
+        text_list_max = []
+        score_array = self.base_list_similarity(data)
+        text_list_a, text_list_b = data['text_list_a'], data['text_list_b']
+        for i, row in enumerate(score_array):
+            max_index = np.argmax(row)
+            max_value = row[max_index]
+            score_list_max.append(max_value)
+            text_list_max.append(text_list_b[max_index])
+        return score_list_max, text_list_max, score_array
+
+    def mean_cross_similarity(self, data):
+        """
+        :param data:
+        :return:
+        """
+        score_list_max, text_list_max, score_array = self.max_cross_similarity(data)
+        score_tensor = torch.tensor(score_array)
+        score_res = torch.mean(score_tensor, dim=1)
+        score_list = score_res.tolist()
+        return score_list, text_list_max, score_array
 
-if __name__ == '__main__':
-    a = time.time()
-    m = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
-    b = time.time()
-    print("模型加载时间:\t", b - a)
-    NF = NLPFunction(m)
-    td = {
-        "text_a": "王者荣耀",
-        "text_b": "斗罗大陆"
-    }
-    tld = {
-        "text_list_a": ["凯旋", "圣洁", "篮球"],
-        "text_list_b": ["胜利", "纯洁", "足球"]
-    }
-    # res = NF.base_string_similarity(text_dict=td)
-    res = NF.base_list_similarity(pair_list_dict=tld)
-    c = time.time()
-    print("计算时间:\t", c - b)
-    for i in res:
-        print(i)
+    def avg_cross_similarity(self, data):
+        """
+        :param data:
+        :return:
+        """
+        score_list_b = data['score_list_b']
+        symbol = data['symbol']
+        score_list_max, text_list_max, score_array = self.max_cross_similarity(data)
+        score_attn, score_norm, score_pred = score_to_attention(score_list_b, symbol=symbol)
+        score_tensor = torch.tensor(score_array)
+        score_res = torch.matmul(score_tensor, score_attn.transpose(0, 1))
+        score_list = score_res.squeeze(-1).tolist()
+        return score_list, text_list_max, score_array

+ 7 - 6
routes/__init__.py

@@ -5,15 +5,15 @@
 from quart import Blueprint, jsonify, request
 
 from .AccountArticleRank import AccountArticleRank
+from .nlpServer import NLPServer
 
-blueprint = Blueprint("LongArticlesAlgServer", __name__)
 
-
-def AlgRoutes(mysql_client):
+def AlgRoutes(mysql_client, model):
     """
     ALG ROUTES
     :return:
     """
+    blueprint = Blueprint("LongArticlesAlgServer", __name__)
 
     @blueprint.route("/healthCheck")
     def helloFuture():
@@ -42,9 +42,10 @@ def AlgRoutes(mysql_client):
         nlper ma
         :return:
         """
-        response = {
-            "msg": "this function is developing"
-        }
+        params = await request.get_json()
+        nlpS = NLPServer(params=params, model=model)
+        result = nlpS.deal()
+        response = {"result": result}
         return jsonify(response)
 
     return blueprint

+ 17 - 34
routes/nlpServer.py

@@ -1,27 +1,21 @@
 """
 @author: luojunhui
 """
-from typing import List
-from pydantic import BaseModel
-from similarities import BertSimilarity
-import numpy as np
-import torch
-import logging
+from applications.textSimilarity import NLPFunction
 
 
 class NLPServer(object):
     """
     nlp_server
     """
-    def __init__(self, params):
+    def __init__(self, params, model):
         """
         :param params:
         """
-        self.model = None
+        self.data = None
         self.function = None
-        self.text_02 = None
-        self.text_01 = None
         self.params = params
+        self.nlp = NLPFunction(model=model)
 
     def check_params(self):
         """
@@ -29,54 +23,43 @@ class NLPServer(object):
         :return:
         """
         try:
-            self.text_01 = self.params['text_01']
-            self.text_02 = self.params['text_02']
+            self.data = self.params['data']
             self.function = self.params['function']
-            self.model = self.params['model']
+            print("参数校验成功")
             return None
         except Exception as e:
             error_info = {
                 "error": "params error",
                 "detail": str(e)
             }
+            print("参数校验失败")
             return error_info
 
-    def choose_function(self):
+    def schedule_function(self):
         """
         :return:
         """
         match self.function:
             case "similarities":
-                return
+                return self.nlp.base_string_similarity(text_dict=self.data)
             case "similarities_cross":
-                return
+                return self.nlp.base_list_similarity(pair_list_dict=self.data)
             case "similarities_cross_max":
-                return
+                return self.nlp.max_cross_similarity(data=self.data)
             case "similarities_cross_avg":
-                return
+                return self.nlp.avg_cross_similarity(data=self.data)
             case "similarities_cross_mean":
-                return
-
-    def base_similarity(self):
-        """
-        base similarity
-        :return:
-        """
-        try:
-
-            res = {
-                'score_list': []
-            }
-            return res
-        except Exception as e:
-            return {"error": str(e)}
+                return self.nlp.mean_cross_similarity(data=self.data)
 
     def deal(self):
         """
         deal function
         :return:
         """
-        return self.check_params if self.check_params else self.choose_function
+        if self.check_params():
+            return self.check_params()
+        else:
+            return self.schedule_function()
 
 
 

+ 24 - 0
test/nlp_dev.py

@@ -0,0 +1,24 @@
+"""
+@author: luojunhui
+"""
+import json
+import requests
+import time
+
+url = "http://localhost:6060/nlp"
+
+body = {
+    "data": {
+        "text_a": "毛主席",
+        "text_b": "毛泽东"
+    },
+    "function": "similarities"
+}
+
+headers = {"Content-Type": "application/json"}
+
+a = time.time()
+response = requests.post(url=url, headers=headers, json=body)
+b = time.time()
+print(json.dumps(response.json(), ensure_ascii=False, indent=4))
+print(b - a)