瀏覽代碼

generate label for mysql

罗俊辉 1 年之前
父節點
當前提交
0a74af847d
共有 2 個文件被更改,包括 2 次插入2 次删除
  1. 1 2
      main_spider.py
  2. 1 0
      process_data.py

+ 1 - 2
main_spider.py

@@ -154,7 +154,6 @@ class LightGBM(object):
         fw = open("result/summary_{}.txt".format(dt), "a+", encoding="utf-8")
         path = 'data/predict_data/predict_{}.json'.format(dt)
         x, y = self.read_data(path)
-        Y_test = [0 if i < 6 else 1 for i in y]
         bst = lgb.Booster(model_file=self.model)
         y_pred = bst.predict(x, num_iteration=bst.best_iteration)
         temp = sorted(list(y_pred))
@@ -163,7 +162,7 @@ class LightGBM(object):
         # 转换为二进制输出
         score_list = []
         for index, item in enumerate(list(y_pred)):
-            real_label = Y_test[index]
+            real_label = y[index]
             score = item
             prid_label = y_pred_binary[index]
             print(real_label, "\t", prid_label, "\t", score)

+ 1 - 0
process_data.py

@@ -137,6 +137,7 @@ class SpiderProcess(object):
             three_date_before = dt_time + datetime.timedelta(days=4)
             temp_time = three_date_before.strftime("%Y%m%d")
             select_sql = f"""SELECT video_id, video_title, label, channel, out_user_id, spider_mode, out_play_cnt, out_like_cnt, out_share_cnt FROM lightgbm_data WHERE type = 'spider' and daily_dt_str = '{temp_time}';"""
+            print(select_sql)
             des_path = "data/predict_data/predict_{}.json".format(dt_time.strftime("%Y%m%d"))
         else:
             return