Browse Source

use local nlp function

luojunhui 4 months ago
parent
commit
cb8df1f417
7 changed files with 84 additions and 14 deletions
  1. 2 0
      .gitignore
  2. 4 1
      alg_app.py
  3. 28 0
      alive.sh
  4. 21 0
      applications/nlp_task.py
  5. 1 1
      routes/__init__.py
  6. 20 5
      routes/accountServer.py
  7. 8 7
      test/score_list_dev.py

+ 2 - 0
.gitignore

@@ -58,3 +58,5 @@ docs/_build/
 # PyBuilder
 target/
 
+/test/
+/test/

+ 4 - 1
alg_app.py

@@ -1,6 +1,7 @@
 """
 @author: luojunhui
 """
+import os
 from quart import Quart
 from similarities import BertSimilarity
 from routes import AlgRoutes
@@ -10,13 +11,15 @@ from applications.embedding_manager import EmbeddingManager
 app = Quart(__name__)
 AsyncMySQL = AsyncMySQLClient(app)
 
+
 @app.before_serving
 async def init():
     """
     初始化模型
     """
     await AsyncMySQL.init_pool()
-    model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
+
+    model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5", device="cuda:1")
     embedding_manager = EmbeddingManager(model)
     print("模型加载成功")
     app_routes = AlgRoutes(AsyncMySQL, model, embedding_manager)

+ 28 - 0
alive.sh

@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# 获取当前日期,格式为 YYYY-MM-DD
+CURRENT_DATE=$(date +%F)
+
+# 日志文件路径,含日期
+LOG_FILE="/home/ubuntu/luojunhui/logs/alg_server_log_$CURRENT_DATE.txt"
+
+export TRANSFORMERS_OFFLINE=1
+
+# 重定向整个脚本的输出到带日期的日志文件
+exec >> "$LOG_FILE" 2>&1
+if pgrep -f "/home/ubuntu/anaconda3/envs/alg/bin/python" > /dev/null
+then
+    echo "$(date '+%Y-%m-%d %H:%M:%S') - match_alg_server is running"
+else
+    echo "$(date '+%Y-%m-%d %H:%M:%S') - trying to restart match_alg_server"
+    # 切换到指定目录
+    cd /home/ubuntu/luojunhui/LongArticleAlgServer
+
+    # 激活 Conda 环境
+    source /home/ubuntu/miniconda3/etc/profile.d/conda.sh
+    conda activate alg
+
+    # 在后台运行 Python 脚本并重定向日志输出
+    nohup hypercorn alg_app:app --config alg.toml >> "${LOG_FILE}" 2>&1 &
+    echo "$(date '+%Y-%m-%d %H:%M:%S') - successfully restarted alg_server"
+fi

+ 21 - 0
applications/nlp_task.py

@@ -0,0 +1,21 @@
+"""
+@author: luojunhui
+"""
+
+
+async def get_nlp_similarity_score(nlp, function, data, use_cache):
+    """
+    获取nlp的相似度分数
+    """
+
+    match function:
+        case "similarities":
+            return nlp.base_string_similarity(text_dict=data, use_cache=use_cache)
+        case "similarities_cross":
+            return nlp.base_list_similarity(pair_list_dict=data, use_cache=use_cache)
+        case "similarities_cross_max":
+            return nlp.max_cross_similarity(data=data)
+        case "similarities_cross_avg":
+            return nlp.avg_cross_similarity(data=data)
+        case "similarities_cross_mean":
+            return nlp.mean_cross_similarity(data=data)

+ 1 - 1
routes/__init__.py

@@ -56,7 +56,7 @@ def AlgRoutes(mysql_client, model, embedding_manager):
         :return:
         """
         params = await request.get_json()
-        AS = AccountServer(mysql_client=mysql_client, params=params)
+        AS = AccountServer(mysql_client=mysql_client, params=params, model=model, embedding_manager=embedding_manager)
         response = await AS.deal()
         return jsonify(response)
 

+ 20 - 5
routes/accountServer.py

@@ -5,7 +5,9 @@ import json
 
 import aiohttp
 from applications.articleTools import ArticleDBTools
+from applications.nlp_task import get_nlp_similarity_score
 from applications.config import port
+from applications.textSimilarity import NLPFunction
 
 
 class AccountServer(object):
@@ -13,7 +15,7 @@ class AccountServer(object):
     获取标题和公众号文章的相关性
     """
 
-    def __init__(self, mysql_client, params):
+    def __init__(self, mysql_client, params, model, embedding_manager):
         self.account_name_list = None
         self.gh_id_list = None
         self.sim_type = None
@@ -23,8 +25,10 @@ class AccountServer(object):
         self.rate = None
         self.title_list = None
         self.view_count_filter = None
+        self.use_cache = True
         self.params = params
         self.AT = ArticleDBTools(mysql_client)
+        self.nlp = NLPFunction(model=model, embedding_manager=embedding_manager)
 
     async def request_for_nlp(self, title_list, account_interest, interest_weight):
         """
@@ -66,6 +70,7 @@ class AccountServer(object):
             self.interest_type = self.params.get("interest_type", "top")
             self.sim_type = self.params.get("sim_type", "mean")
             self.view_count_filter = self.params.get("view_count_filter", None)
+            self.use_cache = self.params.get("use_cache", True)
             return None
         except Exception as e:
             response = {"error": "Params error", "detail": str(e)}
@@ -131,11 +136,21 @@ class AccountServer(object):
             interest_weight = extend_dicts['view_count']
             if self.sim_type == "weighted_by_view_count_rate":
                 interest_weight = extend_dicts['view_count_rate']
-            response = await self.request_for_nlp(
-                title_list=self.title_list,
-                account_interest=account_interest,
-                interest_weight=interest_weight
+
+            data = {
+                "text_list_a": [i.replace("'", "") for i in self.title_list],
+                "text_list_b": [i.replace("'", "") for i in account_interest],
+                "score_list_b": interest_weight,
+                "symbol": 1,
+            },
+            function = "similarities_cross_mean" if self.sim_type == "mean" else "similarities_cross_avg"
+            response = await get_nlp_similarity_score(
+                nlp=self.nlp,
+                function=function,
+                data=data,
+                use_cache=self.use_cache
             )
+
             score_list_key = "score_list_mean" if self.sim_type == "mean" else "score_list_avg"
             return {
                 "score_list": response[score_list_key],

+ 8 - 7
test/score_list_dev.py

@@ -17,20 +17,21 @@ def score_list(account):
     :param account:
     :return:
     """
-    url2 = "http://47.98.136.48:6060/score_list"
+    url2 = "http://61.48.133.26:6061/score_list"
     body = {
-        "account_nickname_list": [account],
+        "gh_id_list": [
+            "gh_02f5bca5b5d9"
+        ],
         "text_list": [
             "在俄罗斯买好地了,却发现没有公路、码头、仓储、燃气管道……”",
             "被霸占15年后成功收回,岛礁资源超100万吨,曾遭到美菲联手抢夺",
             "感人!河南姐弟被父母遗弃,7岁弟弟带着姐姐看病:别怕,以后我养",
             "山东26岁女子产下罕见“4胞胎”,丈夫却突然消失,婆婆:养不起"
         ],
-        "max_time": None,
-        "min_time": None,
-        "interest_type": "avg",
-        "sim_type": "mean",
-        "rate": 0.1
+        "interest_type": "account_avg",
+        "sim_type": "avg",
+        "rate": 0.1,
+        "view_count_filter": 1000
     }
     response = requests.post(url=url2, headers={}, json=body).json()
     print(json.dumps(response, ensure_ascii=False, indent=4))