Parcourir la source

测试GPU性能

罗俊辉 il y a 9 mois
Parent
commit
e22ba2a571
2 fichiers modifiés avec 26 ajouts et 32 suppressions
  1. 1 1
      alg_app.py
  2. 25 31
      test/score_list_dev.py

+ 1 - 1
alg_app.py

@@ -16,7 +16,7 @@ async def init():
     初始化模型
     """
     await AsyncMySQL.init_pool()
-    model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5", device="cuda")
+    model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
     print("模型加载成功")
     app_routes = AlgRoutes(AsyncMySQL, model)
     app.register_blueprint(app_routes)

+ 25 - 31
test/score_list_dev.py

@@ -2,45 +2,39 @@
 @author: luojunhui
 """
 import json
+import time
 
 import requests
+from concurrent.futures.thread import ThreadPoolExecutor
 
 
-class ArticleRank(object):
-    """
-    账号排序
+def score_list(account):
     """
     url = "http://192.168.100.31:8179/score_list"
     url1 = "http://47.98.154.124:6060/score_list"
     # url1 = "http://localhost:6060/score_list"
-    url2 = "http://192.168.100.31:8179/score_list"
-
-    @classmethod
-    def rank(cls, account_list, text_list):
-        """
-        Rank
-        :param account_list:
-        :param text_list:
-        :return:
-        """
-        body = {
-            "account_nickname_list": account_list,
-            "text_list": text_list,
-            "max_time": None,
-            "min_time": None,
-            "interest_type": "avg",
-            "sim_type": "mean",
-            "rate": 0.1
-        }
-        response = requests.post(url=cls.url, headers={}, json=body).json()
-        return response
+    url2 = "http://192.168.100.31:6062/score_list"
+    :param account:
+    :return:
+    """
+    url2 = "http://192.168.100.31:6062/score_list"
+    body = {
+        "account_nickname_list": [account],
+        "text_list": ['保姆为300万拆迁款,嫁给大24岁老头,丈夫去世后,她发现房产证没有丈夫名字'] * 50,
+        "max_time": None,
+        "min_time": None,
+        "interest_type": "avg",
+        "sim_type": "mean",
+        "rate": 0.1
+    }
+    response = requests.post(url=url2, headers={}, json=body).json()
+    print(json.dumps(response, ensure_ascii=False, indent=4))
+    return response
 
 
 if __name__ == '__main__':
-    AR = ArticleRank()
-    response = AR.rank(
-        account_list=['生活良读'],
-        text_list=['保姆为300万拆迁款,嫁给大24岁老头,丈夫去世后,她发现房产证没有丈夫名字'] * 10,
-
-    )
-    print(json.dumps(response, ensure_ascii=False, indent=4))
+    a = time.time()
+    with ThreadPoolExecutor(max_workers=100) as pool:
+        pool.map(score_list, ["生活良读"] * 1000)
+    b = time.time()
+    print(b - a)