浏览代码

仅通过标题tag 分析全部数据

罗俊辉 1 年之前
父节点
当前提交
f80e2eddb6
共有 1 个文件被更改,包括 7 次插入6 次删除
  1. 7 6
      main.py

+ 7 - 6
main.py

@@ -44,9 +44,9 @@ class LightGBM(object):
         """
         df = pd.read_json(path)
         df = df.dropna(subset=['label'])  # 把 label 为空的删掉
-        df = df.dropna(subset=['tag1', 'tag2', 'tag3', 'tag4'], how="all")  # 把 tag 为空的数据也删掉
+        df = df.dropna(subset=['tag1', 'tag2'], how="all")  # 把 tag 为空的数据也删掉
         labels = df['label']
-        features = df.drop('label', axis=1)
+        features = df.drop(['label', 'tag3', 'tag4'], axis=1)
         for key in self.str_columns:
             features[key] = self.label_encoder.fit_transform(features[key])
         return features, labels, df
@@ -108,7 +108,7 @@ class LightGBM(object):
         train_data = lgb.Dataset(
             X_train,
             label=Y_train,
-            categorical_feature=["tag1", "tag2", "tag3", "tag4"],
+            categorical_feature=["tag1", "tag2"],
         )
         test_data = lgb.Dataset(X_test, label=Y_test, reference=train_data)
         params = {
@@ -149,8 +149,6 @@ class LightGBM(object):
             real_label = y[index]
             score = item
             prid_label = y_pred_binary[index]
-            if score < 0.169541:
-                print(real_label, "\t", prid_label, "\t", score)
             fw.write("{}\t{}\t{}\n".format(real_label, prid_label, score))
             score_list.append(score)
         print("预测样本总量: {}".format(len(score_list)))
@@ -181,7 +179,10 @@ class LightGBM(object):
         for name, imp in feature_importance:
             print(name, imp)
 
-
+# "cat summary_20240328.txt | awk -F "\t" '{print $1" "$3}'| /root/AUC/AUC/AUC"
+"""
+ ossutil64 cp /root/luojunhui/alg/data/predict_data/spider_predict_result_20240330.xlsx oss://art-pubbucket/0temp/
+"""
 if __name__ == "__main__":
     i = int(input("输入 1 训练, 输入 2 预测:\n"))
     if i == 1: