Browse Source

初始化

罗俊辉 1 year ago
parent
commit
9bc4bbb819
3 changed files with 57 additions and 22 deletions
  1. 48 19
      applications/functions.py
  2. 6 1
      applications/model_init.py
  3. 3 2
      applications/routes.py

+ 48 - 19
applications/functions.py

@@ -3,29 +3,58 @@
 """
 import asyncio
 
-
 from .model_init import models
 
 
-async def process_data(params):
+class ParamProcess(models):
     """
-    执行结果
-    :param params:
-    :return:
+    处理 params, 继承 models
     """
-    print("正在处理")
-    flag = params['version']
-    if flag == "v1":
-        model = models.model_v1
-    elif flag == "v2":
-        model = models.model_v2
-    else:
-        return
-    features = params['features']
-    prediction = model.predict([features])
-    print("处理完成")
-    return {'prediction': prediction.tolist()}
-
-
+    def __init__(self):
+        self.model_v1 = models.model_v1
+        self.model_v2 = models.model_v2
+        self.layer_encoder = models.label_encoder
+
+    async def predict_score(self, version, features):
+        """
+        预测
+        :param version: 模型版本
+        :param features: 视频被 label_encoder 之后的features
+        :return: score: 返回的分数
+        """
+        match version:
+            case "v1":
+                return await self.model_v1.predict(x)
+            case "v2":
+                return await self.model_v2.predict(x)
+
+    async def process_label(self, params):
+        """
+        处理类别 features 和 float features
+        :param params: 接收到的参数
+        :return:
+        """
+        version = params['version']
+        features = params['features']
+        match version:
+            case "v1":
+                # 全部转化为类别
+                print("all to string cate")
+                # features = []
+                return version, features
+            case "v2":
+                print("all to float cate")
+                # features = []
+                return version, features
+
+    async def process(self, params):
+        """
+        处理
+        :param params:
+        :return:
+        """
+        version, features = await self.process_label(params)
+        print(version, features)
+        return await self.process_score(version, features)
 
 

+ 6 - 1
applications/model_init.py

@@ -5,6 +5,8 @@ import json
 import asyncio
 import lightgbm as lgb
 
+from sklearn.preprocessing import LabelEncoder
+
 
 class Models(object):
     """
@@ -16,7 +18,10 @@ class Models(object):
         在项目启动的时候加载好所有的模型
         :return:
         """
-        print("正在加载模型")
+
+        self.label_encoder = LabelEncoder()
+        print("标签分类器加载完成")
+        print("开始加载模型")
         self.model_v1 = lgb.Booster(model_file="/root/lightgbm_score/models/lightgbm_0409_all_tags.bin")
         print("模型 1 加载完成......")
         self.model_v2 = lgb.Booster(model_file="/root/lightgbm_score/models/lightgbm_0409_spider.bin")

+ 3 - 2
applications/routes.py

@@ -2,7 +2,7 @@
 @author: luojunhui
 """
 from quart import Blueprint, jsonify, request
-from applications.functions import process_data
+from applications.functions import ParamProcess
 
 
 video_score_blueprint = Blueprint('light_gbm_score', __name__)
@@ -23,6 +23,7 @@ async def post_data():
     请求接口代码
     :return:
     """
+    p = ParamProcess()
     data = await request.get_json()
-    processed_data = await process_data(data)
+    processed_data = await p.process(data)
     return jsonify(processed_data)