丁云鹏 9 months ago
parent
commit
da7830ed51

+ 1 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -97,7 +97,7 @@ public class XGBoostService {
             Dataset<Row> predictData = dataset(path);
             predictData.show();
             Dataset<Row> predictions = model.transform(predictData);
-            predictions.show();
+            predictions.show(2000);
 
             // 计算AUC
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
@@ -183,7 +183,6 @@ public class XGBoostService {
                 .setOutputCol("features");
 
         Dataset<Row> assembledData = assembler.transform(dataset);
-        assembledData.show();
         return assembledData;
     }
 }

+ 93 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostPredictLocalTest.java

@@ -12,6 +12,8 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
 import org.apache.spark.ml.feature.VectorAssembler;
+import org.apache.spark.ml.linalg.Vector;
+import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
@@ -32,7 +34,98 @@ import java.util.Map;
 public class XGBoostPredictLocalTest {
 
     public static void main(String[] args) {
+        //batchTest();
+        singleTest();
+    }
+
+    private static void singleTest() {
+        String[] features = {
+                "cpa",
+                "b2_1h_ctr",
+                "b2_1h_ctcvr",
+                "b2_1h_cvr",
+                "b2_1h_conver",
+                "b2_1h_click",
+                "b2_1h_conver*log(view)",
+                "b2_1h_conver*ctcvr",
+                "b2_2h_ctr",
+                "b2_2h_ctcvr",
+                "b2_2h_cvr",
+                "b2_2h_conver",
+                "b2_2h_click",
+                "b2_2h_conver*log(view)",
+                "b2_2h_conver*ctcvr",
+                "b2_3h_ctr",
+                "b2_3h_ctcvr",
+                "b2_3h_cvr",
+                "b2_3h_conver",
+                "b2_3h_click",
+                "b2_3h_conver*log(view)",
+                "b2_3h_conver*ctcvr",
+                "b2_6h_ctr",
+                "b2_6h_ctcvr"
+        };
+
+        Map<String, String> featureMap = new HashMap<>();
+        featureMap.put("cpa", "0.1");
+        featureMap.put("b2_1h_ctr", "0");
+        featureMap.put("b2_1h_ctcvr", "0");
+        featureMap.put("b2_1h_cvr", "0");
+        featureMap.put("b2_1h_conver", "0");
+        featureMap.put("b2_1h_click", "0");
+        featureMap.put("b2_1h_conver*log(view)", "0");
+        featureMap.put("b2_1h_conver*ctcvr", "0");
+        featureMap.put("b2_2h_ctr", "0");
+        featureMap.put("b2_2h_ctcvr", "0");
+        featureMap.put("b2_2h_cvr", "0");
+        featureMap.put("b2_2h_conver", "0");
+        featureMap.put("b2_2h_click", "0");
+        featureMap.put("b2_2h_conver*log(view)", "0");
+        featureMap.put("b2_2h_conver*ctcvr", "0");
+        featureMap.put("b2_3h_ctr", "0.89");
+        featureMap.put("b2_3h_ctcvr", "0");
+        featureMap.put("b2_3h_cvr", "0");
+        featureMap.put("b2_3h_conver", "0");
+        featureMap.put("b2_3h_click", "0.01");
+        featureMap.put("b2_3h_conver*log(view)", "0");
+        featureMap.put("b2_3h_conver*ctcvr", "0");
+        featureMap.put("b2_6h_ctr", "0.88");
+        featureMap.put("b2_6h_ctcvr", "0");
+
+        double[] values = new double[features.length];
+        for (int i = 0; i < features.length; i++) {
+            double v = NumberUtils.toDouble(featureMap.getOrDefault(features[i], "0.0"), 0.0);
+            values[i] = v;
+        }
+        Vector v = Vectors.dense(values);
+
+
+        SparkConf sparkConf = new SparkConf()
+                .setMaster("local")
+                .setAppName("XGBoostPredict");
+        JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+
+        String bucketName = "art-test-video";
+        String objectName = "test/model.tar.gz";
+        OSSService ossService = new OSSService();
+
+        String gzPath = "/Users/dingyunpeng/Desktop/model2.tar.gz";
+        ossService.download(bucketName, gzPath, objectName);
+        String modelDir = "/Users/dingyunpeng/Desktop/modelpredict";
+        CompressUtil.decompressGzFile(gzPath, modelDir);
+
+        XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
+        model.setMissing(0.0f)
+                .setFeaturesCol("features");
+        double score = model.predict(v);
+
+        log.info("model.predict {}", score);
+    }
+
+    private static void batchTest() {
         try {
+
             SparkConf sparkConf = new SparkConf()
                     .setMaster("local")
                     .setAppName("XGBoostPredict");