Parcourir la source

上传读取标题功能

罗俊辉 il y a 1 an
Parent
commit
8eec036ad6
3 fichiers modifiés avec 143 ajouts et 54 suppressions
  1. 2 1
      functions/__init__.py
  2. 52 2
      functions/mysql.py
  3. 89 51
      process_data.py

+ 2 - 1
functions/__init__.py

@@ -2,4 +2,5 @@
 init file for functions
 """
 from .date import *
-from .odps_function import PyODPS
+from .odps_function import PyODPS
+from .mysql import MysqlClient

+ 52 - 2
functions/mysql.py

@@ -1,9 +1,59 @@
 """
 Mysql Functions
 """
+import pymysql
 
 
-class MySQLClient(object):
+class MysqlClient(object):
     """
-    MySQL Client
+    MySQL工具, env默认prod版本
     """
+
+    def __init__(self):
+        mysql_config = {
+            "host": "rm-bp1159bu17li9hi94.mysql.rds.aliyuncs.com",
+            "port": 3306,  # 端口号
+            "user": "crawler",  # mysql用户名
+            "passwd": "crawler123456@",  # mysql用户登录密码
+            "db": "piaoquan-crawler",  # 数据库名
+            "charset": "utf8mb4"  # 如果数据库里面的文本是utf8编码的,charset指定是utf8
+        }
+        self.connection = pymysql.connect(
+            host=mysql_config['host'],  # 数据库IP地址,内网地址
+            port=mysql_config['port'],  # 端口号
+            user=mysql_config['user'],  # mysql用户名
+            passwd=mysql_config['passwd'],  # mysql用户登录密码
+            db=mysql_config['db'],  # 数据库名
+            charset=mysql_config['charset']  # 如果数据库里面的文本是utf8编码的,charset指定是utf8
+        )
+
+    def select(self, sql):
+        """
+        查询
+        :param sql:
+        :return:
+        """
+        cursor = self.connection.cursor()
+        cursor.execute(sql)
+        data = cursor.fetchall()
+        return data
+
+    def update(self, sql):
+        """
+        插入
+        :param sql:
+        :return:
+        """
+        cursor = self.connection.cursor()
+        try:
+            res = cursor.execute(sql)
+            self.connection.commit()
+            return res
+        except Exception as e:
+            self.connection.rollback()
+
+    def close(self):
+        """
+        关闭连接
+        """
+        self.connection.close()

+ 89 - 51
process_data.py

@@ -6,67 +6,105 @@ import sys
 import os
 import json
 from tqdm import tqdm
+import jieba.analyse
 
 sys.path.append(os.getcwd())
 
-from functions import generate_label_date
+from functions import generate_label_date, MysqlClient
 
 
-def generate_train_label(item, y_ori_data, cate):
+class DataProcessor(object):
     """
-    生成训练数据,用 np.array矩阵的方式返回,
-    :return: x_train, 训练数据, y_train, 训练 label
+    Process the data to satisfy the lightGBM
     """
-    video_id = item["video_id"]
-    dt = item["dt"]
-    userful_features = [
-        "uid",
-        "type",
-        "channel",
-        "fans",
-        "view_count_user_30days",
-        "share_count_user_30days",
-        "return_count_user_30days",
-        "rov_user",
-        "str_user",
-        "out_user_id",
-        "mode",
-        "out_play_cnt",
-        "out_like_cnt",
-        "out_share_cnt",
-        "out_collection_cnt",
-    ]
-    item_features = [item[i] for i in userful_features]
-    label_dt = generate_label_date(dt)
-    label_obj = y_ori_data.get(label_dt, {}).get(video_id)
-    if label_obj:
-        label = int(label_obj[cate]) if label_obj[cate] else 0
-    else:
-        label = 0
-    return label, item_features
+
+    def __init__(self, flag):
+        self.client = MysqlClient()
+        self.flag = flag
+
+    def generate_train_label(self,item, y_ori_data, cate):
+        """
+        生成训练数据,用 np.array矩阵的方式返回,
+        :return: x_train, 训练数据, y_train, 训练 label
+        """
+        video_id = item["video_id"]
+        dt = item["dt"]
+        userful_features = [
+            "uid",
+            "type",
+            "channel",
+            "fans",
+            "view_count_user_30days",
+            "share_count_user_30days",
+            "return_count_user_30days",
+            "rov_user",
+            "str_user",
+            "out_user_id",
+            "mode",
+            "out_play_cnt",
+            "out_like_cnt",
+            "out_share_cnt",
+            "out_collection_cnt",
+        ]
+        item_features = [item[i] for i in userful_features]
+        keywords_textrank, keywords_tf = self.title_processor(video_id)
+        item_features.append(",".join(keywords_textrank))
+        item_features.append(",".join(keywords_tf))
+        label_dt = generate_label_date(dt)
+        label_obj = y_ori_data.get(label_dt, {}).get(video_id)
+        if label_obj:
+            label = int(label_obj[cate]) if label_obj[cate] else 0
+        else:
+            label = 0
+        return label, item_features
+
+    def title_processor(self, video_id):
+        """
+        通过 video_id 去获取title, 然后通过 title 再分词,把关键词作为 feature
+        :param video_id: the video id
+        :return: tag_list [tag, tag, tag, tag......]
+        """
+        sql = f"""SELECT video_title from crawler_video where video_id = {video_id};"""
+        title = self.client.select(sql)[0][0]
+        keywords_textrank = jieba.analyse.textrank(title, topK=3)
+        keywords_tfidf = jieba.analyse.extract_tags(title, topK=3)
+        return list(keywords_textrank), list(keywords_tfidf)
+
+    def producer(self):
+        """
+        生成数据
+        :return:none
+        """
+        if self.flag == "train":
+            x_path = "data/hour_train.json"
+            y_path = "data/daily-label-20240101-20240320.json"
+        elif self.flag == "predict":
+            x_path = "prid_data/train_0314_0317.json"
+            y_path = "data/daily-label-20240315-20240321.json"
+        else:
+            return
+        with open(x_path) as f:
+            x_data = json.loads(f.read())
+        with open(y_path) as f:
+            y_data = json.loads(f.read())
+        cate_list = ["total_return"]
+        for c in cate_list:
+            x_list = []
+            y_list = []
+            for video_obj in tqdm(x_data):
+                our_label, features = self.generate_train_label(video_obj, y_data, c)
+                x_list.append(features)
+                y_list.append(our_label)
+            with open("produce_data/x_data_{}_{}.json".format(c, self.flag), "w") as f1:
+                f1.write(json.dumps(x_list, ensure_ascii=False))
+
+            with open("produce_data/y_data_{}_{}.json".format(c, self.flag), "w") as f2:
+                f2.write(json.dumps(y_list, ensure_ascii=False))
 
 
 if __name__ == "__main__":
-    x_path = "prid_data/train_0314_0317.json"
-    y_path = "data/daily-label-20240315-20240321.json"
+    D = DataProcessor(flag="train")
+    D.producer()
 
-    with open(x_path) as f:
-        x_data = json.loads(f.read())
 
-    with open(y_path) as f:
-        y_data = json.loads(f.read())
-    cate_list = ["total_return"]
-    for c in cate_list:
-        x_list = []
-        y_list = []
-        for video_obj in tqdm(x_data):
-            print(video_obj)
-            our_label, features = generate_train_label(video_obj, y_data, c)
-            x_list.append(features)
-            y_list.append(our_label)
-        # print(len(y_list))
-        with open("whole_data/x_data_{}_prid.json".format(c), "w") as f1:
-            f1.write(json.dumps(x_list, ensure_ascii=False))
 
-        with open("whole_data/y_data_{}_prid.json".format(c), "w") as f2:
-            f2.write(json.dumps(y_list, ensure_ascii=False))