liqian 1 年之前
父节点
当前提交
621824bc45
共有 1 个文件被更改,包括 35 次插入10 次删除
  1. 35 10
      ad_xgboost_predict_data_generate.py

+ 35 - 10
ad_xgboost_predict_data_generate.py

@@ -31,9 +31,13 @@ if __name__ == '__main__':
     # 1. 获取用户特征数据
     user_filepath = f"{predict_data_dir}/{user_filename}"
     user_df = read_csv_data(filepath=user_filepath)
+    user_df = user_df[user_df['mid'] != '-1']
+    print(f"user_df shape: {user_df.shape}")
     # 2. 获取视频特征数据
     video_filepath = f"{predict_data_dir}/{video_filename}"
     video_df = read_csv_data(filepath=video_filepath)
+    video_df = video_df[video_df['videoid' != '-1']]
+    print(f"video_df shape: {video_df.shape}")
     # 3. 用户特征和视频特征进行拼接
     video_features = [
         'videoid',
@@ -52,21 +56,42 @@ if __name__ == '__main__':
         'video_share_rate_pv_30day',
         'video_return_rate_30day',
     ]
+    predict_data_dir = './data/predict_data'
+    file_list = [f"{predict_data_dir}/predict_data_0.csv", f"{predict_data_dir}/predict_data_1.csv"]
+    for file in file_list:
+        try:
+            os.remove(file)
+        except:
+            continue
     merge_df_list = []
     for ind, row in video_df.iterrows():
         merge_df_temp = user_df.copy()
         for feature in video_features:
             merge_df_temp[feature] = row[feature]
         merge_df_list.append(merge_df_temp)
-    merge_df = pd.concat(merge_df_list, ignore_index=True)
-    # 4. 拼接广告特征ad_status
-    for ad_status in [0, 1]:
-        res_df = merge_df.copy()
-        res_df['ad_status'] = ad_status
-        # 写入csv
-        predict_data_dir = './data/predict_data'
-        if not os.path.exists(predict_data_dir):
-            os.makedirs(predict_data_dir)
-        res_df.to_csv(f"{predict_data_dir}/predict_data_{ad_status}.csv", index=False)
+        if ind % 100 == 0:
+            merge_df = pd.concat(merge_df_list, ignore_index=True)
+            print(f"ind: {ind}, merge_df shape: {merge_df.shape}")
+            # 4. 拼接广告特征ad_status
+            for ad_status in [0, 1]:
+                res_df = merge_df.copy()
+                res_df['ad_status'] = ad_status
+                # 写入csv
+                if not os.path.exists(predict_data_dir):
+                    os.makedirs(predict_data_dir)
+                res_df.to_csv(f"{predict_data_dir}/predict_data_{ad_status}.csv", index=False, mode='a')
+            merge_df_list = []
+
+    if len(merge_df_list) > 0:
+        merge_df = pd.concat(merge_df_list, ignore_index=True)
+        print(f"merge_df shape: {merge_df.shape}")
+        # 4. 拼接广告特征ad_status
+        for ad_status in [0, 1]:
+            res_df = merge_df.copy()
+            res_df['ad_status'] = ad_status
+            # 写入csv
+            if not os.path.exists(predict_data_dir):
+                os.makedirs(predict_data_dir)
+            res_df.to_csv(f"{predict_data_dir}/predict_data_{ad_status}.csv", index=False, mode='a')
 
     print(f"{time.time() - st_time}s")