丁云鹏 9 months ago
parent
commit
1efcfdbe07

+ 1 - 1
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -97,7 +97,7 @@ public class XGBoostService {
             Dataset<Row> predictData = dataset(path);
             predictData.show();
             Dataset<Row> predictions = model.transform(predictData);
-            predictions.show(2000);
+            predictions.show(50000);
 
             // 计算AUC
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()

+ 5 - 4
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostPredictLocalTest.java

@@ -2,6 +2,7 @@ package com.tzld.piaoquan.recommend.model.produce.xgboost;
 
 import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
 import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
+import com.tzld.piaoquan.recommend.model.produce.util.JSONUtils;
 import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
@@ -116,11 +117,11 @@ public class XGBoostPredictLocalTest {
         CompressUtil.decompressGzFile(gzPath, modelDir);
 
         XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
-        model.setMissing(0.0f)
-                .setFeaturesCol("features");
+        model.setMissing(0.0f);
+        //Vector p = model.predictProbability(v);
         double score = model.predict(v);
 
-        log.info("model.predict {}", score);
+        log.info("model.score {}", score);
     }
 
     private static void batchTest() {
@@ -150,7 +151,7 @@ public class XGBoostPredictLocalTest {
             Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
             predictData.show();
             Dataset<Row> predictions = model.transform(predictData);
-            predictions.show();
+            predictions.show(2000);
 
         } catch (Throwable e) {
             log.error("", e);