|
@@ -2,6 +2,7 @@ package com.tzld.piaoquan.recommend.model.produce.xgboost;
|
|
|
|
|
|
import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
|
|
|
import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.util.JSONUtils;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
@@ -116,11 +117,11 @@ public class XGBoostPredictLocalTest {
|
|
|
CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
|
|
|
XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
|
|
|
- model.setMissing(0.0f)
|
|
|
- .setFeaturesCol("features");
|
|
|
+ model.setMissing(0.0f);
|
|
|
+ //Vector p = model.predictProbability(v);
|
|
|
double score = model.predict(v);
|
|
|
|
|
|
- log.info("model.predict {}", score);
|
|
|
+ log.info("model.score {}", score);
|
|
|
}
|
|
|
|
|
|
private static void batchTest() {
|
|
@@ -150,7 +151,7 @@ public class XGBoostPredictLocalTest {
|
|
|
Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
|
|
|
predictData.show();
|
|
|
Dataset<Row> predictions = model.transform(predictData);
|
|
|
- predictions.show();
|
|
|
+ predictions.show(2000);
|
|
|
|
|
|
} catch (Throwable e) {
|
|
|
log.error("", e);
|