|
@@ -1,6 +1,7 @@
|
|
|
package com.tzld.piaoquan.recommend.model.produce.xgboost;
|
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
+import ml.dmlc.xgboost4j.scala.DMatrix;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
@@ -47,6 +48,7 @@ public class XGBoostTrainLocalTest {
|
|
|
.setObjective("binary:logistic")
|
|
|
.setNthread(1)
|
|
|
.setNumRound(100)
|
|
|
+
|
|
|
.setNumWorkers(1);
|
|
|
Dataset<Row> trainData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
|
|
|
trainData.show();
|
|
@@ -124,13 +126,16 @@ public class XGBoostTrainLocalTest {
|
|
|
featureMap.put("b2_6h_ctr", "0.88");
|
|
|
featureMap.put("b2_6h_ctcvr", "0");
|
|
|
|
|
|
- double[] values = new double[features.length];
|
|
|
+ float[] values = new float[features.length];
|
|
|
for (int i = 0; i < features.length; i++) {
|
|
|
- double v = NumberUtils.toDouble(featureMap.getOrDefault(features[i], "0.0"), 0.0);
|
|
|
+ float v = NumberUtils.toFloat(featureMap.getOrDefault(features[i], "0.0"), 0.0f);
|
|
|
values[i] = v;
|
|
|
}
|
|
|
- Vector v = Vectors.dense(values);
|
|
|
- double result = model.predict(v);
|
|
|
+
|
|
|
+
|
|
|
+ DMatrix dm = new DMatrix(values, 1, features.length, 0.0f);
|
|
|
+ float[][] result = model._booster().predict(dm,false,100);
|
|
|
+
|
|
|
log.info("model.predict {}", result);
|
|
|
|
|
|
} catch (Throwable e) {
|