|
@@ -165,6 +165,53 @@ class UserProcess(object):
|
|
|
df.to_json(des_path, orient='records')
|
|
|
|
|
|
|
|
|
+class AllProcess(object):
|
|
|
+ """
|
|
|
+ 全部数据
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.client_spider = MySQLClientSpider()
|
|
|
+ self.all_features = ["label", "tag1", "tag2", "tag3"]
|
|
|
+
|
|
|
+ def read_all_data(self, flag, dt_time):
|
|
|
+ """
|
|
|
+ 生成用户数据
|
|
|
+ :param flag: predict/train
|
|
|
+ :param dt_time: 时间
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ dt_time = datetime.datetime.strptime(dt_time, "%Y%m%d")
|
|
|
+ three_date_before = dt_time + datetime.timedelta(days=4)
|
|
|
+ temp_time = three_date_before.strftime("%Y%m%d")
|
|
|
+ if flag == "train":
|
|
|
+ sql = f"""select video_title, label from lightgbm_data where daily_dt_str <= '{temp_time}';"""
|
|
|
+ des_path = "/root/luojunhui/alg/data/train_data/all_train_{}.json".format(
|
|
|
+ datetime.datetime.today().strftime("%Y%m%d"))
|
|
|
+ elif flag == "predict":
|
|
|
+ sql = f"""select video_title, label from lightgbm_data where daily_dt_str = '{temp_time}';"""
|
|
|
+ des_path = "/root/luojunhui/alg/data/predict_data/all_predict_{}.json".format(dt_time.strftime("%Y%m%d"))
|
|
|
+ else:
|
|
|
+ return
|
|
|
+ dt_list = self.client_spider.select(sql)
|
|
|
+ df = []
|
|
|
+ for line in tqdm(dt_list):
|
|
|
+ title = line[0]
|
|
|
+ title_tags = list(jieba.analyse.textrank(title, topK=4))
|
|
|
+ temp = list(line)
|
|
|
+ if title_tags:
|
|
|
+ for i in range(4):
|
|
|
+ try:
|
|
|
+ temp.append(title_tags[i])
|
|
|
+ except:
|
|
|
+ temp.append(None)
|
|
|
+ df.append(temp[1:])
|
|
|
+ else:
|
|
|
+ continue
|
|
|
+ df = pd.DataFrame(df, columns=self.all_features)
|
|
|
+ df.to_json(des_path, orient='records')
|
|
|
+
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser() # 新建参数解释器对象
|
|
|
parser.add_argument("--cate")
|