@@ -52,7 +52,7 @@ public class XGBoostService {
.setObjective("binary:logistic")
.setNthread(1)
.setNumRound(100)
- .setNumWorkers(1);
+ .setNumWorkers(4);
XGBoostClassificationModel model = xgbClassifier.fit(trainData);
// 保存模型
@@ -211,6 +211,11 @@ public class XGBoostPredictLocalTest {
for (int i = 0; i < features.length; i++) {
v[i + 1] = map.getOrDefault(features[i], 0.0d);
}
+ if ((double) v[1] > 0.02) {
+ v[0] = 1;
+ } else {
+ v[0] = 0;
+ }
return RowFactory.create(v);
});