罗俊辉 1 年之前
父节点
当前提交
bb8c76caa8
共有 1 个文件被更改,包括 4 次插入2 次删除
  1. 4 2
      main.py

+ 4 - 2
main.py

@@ -31,10 +31,12 @@ with open("whole_data/x_data.json") as f1:
     X_train['uid'] = X_train['uid'].astype(str)
     X_train['type'] = X_train['type'].astype(str)
     X_train['channel'] = X_train['channel'].astype(str)
+    X_train['mode'] = X_train['mode'].astype(str)
     X_test = pd.DataFrame(x_list[10000:], columns=my_c)
     X_test['uid'] = X_test['uid'].astype(str)
     X_test['type'] = X_test['type'].astype(str)
     X_test['channel'] = X_test['channel'].astype(str)
+    X_test['mode'] = X_test['mode'].astype(str)
 
 with open("whole_data/y_data.json") as f2:
     y_list = json.loads(f2.read())
@@ -42,7 +44,7 @@ with open("whole_data/y_data.json") as f2:
     y_test = np.array(y_list[10000:])
 
 # 创建LightGBM数据集
-train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=['uid', 'type', 'channel'])
+train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=['uid', 'type', 'channel', 'mode'])
 test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
 
 # 设置模型的参数
@@ -58,7 +60,7 @@ params = {
 
 # 训练模型
 num_round = 100
-bst = lgb.train(params, train_data, num_round)
+bst = lgb.train(params, train_data, num_round, valid_sets=[test_data])
 
 # 预测
 y_pred = bst.predict(X_test, num_iteration=bst.best_iteration)