Pārlūkot izejas kodu

add predict_test, update test - rov recall pool data

liqian 3 gadi atpakaļ
vecāks
revīzija
045f569c16
4 mainītis faili ar 91 papildinājumiem un 9 dzēšanām
  1. 3 3
      config.py
  2. 44 0
      db_helper.py
  3. 4 3
      requirements.txt
  4. 40 3
      rov_train.py

+ 3 - 3
config.py

@@ -1,9 +1,9 @@
 class BaseConfig(object):
     # 产品标识
     APP_TYPE = {
-        'VLOG': 0,
-        'LOVE_LIVE': 4,
-        'LONG_VIDEO': 5,
+        'VLOG': 0,  # vlog
+        'LOVE_LIVE': 4,  # 票圈视频
+        'LONG_VIDEO': 5,  # 内容精选
         'SHORT_VIDEO': 6
     }
     # 数据存放路径

+ 44 - 0
db_helper.py

@@ -1,5 +1,6 @@
 import redis
 import psycopg2
+import pymysql
 from config import set_config
 from log import Log
 
@@ -218,6 +219,49 @@ class HologresHelper(object):
         return data
 
 
+class MysqlHelper(object):
+    def __init__(self, mysql_info):
+        """
+        初始化mysql连接信息
+        :param mysql_info: mysql连接信息, 格式:dict, {'host': '', 'port': '', 'user':'', 'password': '', 'db': ''}
+        """
+        self.host = mysql_info['host']
+        self.port = mysql_info['port']
+        self.user = mysql_info['user']
+        self.password = mysql_info['password']
+        self.db = mysql_info['db']
+
+    def get_data(self, sql):
+        """
+        查询数据
+        :param sql: sql语句
+        :return: data
+        """
+        # 连接数据库
+        conn = pymysql.connect(
+            host=self.host,
+            port=self.port,
+            user=self.user,
+            password=self.password,
+            db=self.db,
+            charset='utf8'
+        )
+        # 创建游标
+        cursor = conn.cursor()
+        try:
+            # 执行SQL语句
+            cursor.execute(sql)
+            # 获取查询的所有记录
+            data = cursor.fetchall()
+        except Exception as e:
+            return None
+        # 关闭游标对象
+        cursor.close()
+        # 关闭数据库连接
+        conn.close()
+        return data
+
+
 if __name__ == '__main__':
     redis_helper = RedisHelper()
     key = 'com.weiqu.video.hot.recommend.item.score.20210901'

+ 4 - 3
requirements.txt

@@ -1,9 +1,10 @@
+psycopg2_binary==2.9.1
 pandas==1.1.3
-pyodps==0.10.7
+PyMySQL==1.0.2
+redis==3.5.3
 lightgbm==3.2.1
+pyodps==0.10.7
 requests==2.24.0
-redis==3.5.3
-psycopg2_binary==2.9.1
 odps==3.5.1
 psycopg2==2.9.1
 scikit_learn==1.0.1

+ 40 - 3
rov_train.py

@@ -1,4 +1,5 @@
 import os
+import random
 import time
 
 import lightgbm as lgb
@@ -8,9 +9,9 @@ from sklearn.model_selection import train_test_split
 from sklearn.metrics import mean_absolute_error, r2_score, mean_absolute_percentage_error
 
 from config import set_config
-from utils import read_from_pickle, write_to_pickle, data_normalization, request_post
+from utils import read_from_pickle, write_to_pickle, data_normalization, request_post, filter_video_status
 from log import Log
-from db_helper import RedisHelper
+from db_helper import RedisHelper, MysqlHelper
 
 config_ = set_config()
 log_ = Log()
@@ -170,6 +171,43 @@ def predict():
         log_.error('notify backend fail!')
 
 
+def predict_test():
+    """测试环境数据生成"""
+    # 获取测试环境中最近发布的40000条视频
+    mysql_info = {
+        'host': 'rm-bp1k5853td1r25g3n690.mysql.rds.aliyuncs.com',
+        'port': 3306,
+        'user': 'wx2016_longvideo',
+        'password': 'wx2016_longvideoP@assword1234',
+        'db': 'longvideo'
+    }
+    sql = "SELECT id FROM wx_video ORDER BY id DESC LIMIT 40000;"
+    mysql_helper = MysqlHelper(mysql_info=mysql_info)
+    data = mysql_helper.get_data(sql=sql)
+    video_ids = [video[0] for video in data]
+    # 视频状态过滤
+    filtered_videos = filter_video_status(video_ids)
+    log_.info('filtered_videos nums={}'.format(len(filtered_videos)))
+    # 随机生成 0-100 数作为分数
+    redis_data = {}
+    json_data = []
+    for video_id in filtered_videos:
+        score = random.uniform(0, 100)
+        redis_data[video_id] = score
+        json_data.append({'videoId': video_id, 'rovScore': score})
+    # 上传Redis
+    redis_helper = RedisHelper()
+    key_name = config_.RECALL_KEY_NAME_PREFIX + time.strftime('%Y%m%d')
+    redis_helper.add_data_with_zset(key_name=key_name, data=redis_data)
+    log_.info('test data to redis finished!')
+    # 通知后端更新数据
+    result = request_post(request_url=config_.NOTIFY_BACKEND_UPDATE_ROV_SCORE_URL, request_data={'videos': json_data})
+    if result['code'] == 0:
+        log_.info('notify backend success!')
+    else:
+        log_.error('notify backend fail!')
+
+
 if __name__ == '__main__':
     log_.info('rov model train start...')
     train_start = time.time()
@@ -185,4 +223,3 @@ if __name__ == '__main__':
     predict()
     predict_end = time.time()
     log_.info('rov model predict end, execute time = {}ms'.format((predict_end - predict_start)*1000))
-