丁云鹏 hace 9 meses
padre
commit
a6b14a7d15

+ 140 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostPredictLocalTest.java

@@ -0,0 +1,140 @@
+package com.tzld.piaoquan.recommend.model.produce.xgboost;
+
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
+import lombok.extern.slf4j.Slf4j;
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.feature.VectorAssembler;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class XGBoostPredictLocalTest {
+
+    public static void main(String[] args) {
+        try {
+            SparkConf sparkConf = new SparkConf()
+                    .setMaster("local")
+                    .setAppName("XGBoostPredict");
+            JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+
+
+            String bucketName = "art-test-video";
+            String objectName = "test/model.tar.gz";
+            OSSService ossService = new OSSService();
+
+            String gzPath = "/Users/dingyunpeng/Desktop/model2.tar.gz";
+            ossService.download(bucketName, gzPath, objectName);
+            String modelDir = "/Users/dingyunpeng/Desktop/modelpredict";
+            CompressUtil.decompressGzFile(gzPath, modelDir);
+
+            XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
+            model.setMissing(0.0f)
+                    .setFeaturesCol("features");
+
+
+            // 预测
+            Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
+            predictData.show();
+            Dataset<Row> predictions = model.transform(predictData);
+            predictions.show();
+
+        } catch (Throwable e) {
+            log.error("", e);
+        }
+    }
+
+    private static Dataset<Row> dataset(String path) {
+        String[] features = {
+                "cpa",
+                "b2_1h_ctr",
+                "b2_1h_ctcvr",
+                "b2_1h_cvr",
+                "b2_1h_conver",
+                "b2_1h_click",
+                "b2_1h_conver*log(view)",
+                "b2_1h_conver*ctcvr",
+                "b2_2h_ctr",
+                "b2_2h_ctcvr",
+                "b2_2h_cvr",
+                "b2_2h_conver",
+                "b2_2h_click",
+                "b2_2h_conver*log(view)",
+                "b2_2h_conver*ctcvr",
+                "b2_3h_ctr",
+                "b2_3h_ctcvr",
+                "b2_3h_cvr",
+                "b2_3h_conver",
+                "b2_3h_click",
+                "b2_3h_conver*log(view)",
+                "b2_3h_conver*ctcvr",
+                "b2_6h_ctr",
+                "b2_6h_ctcvr"
+        };
+
+
+        SparkSession spark = SparkSession.builder()
+                .appName("XGBoostTrain")
+                .master("local")
+                .getOrCreate();
+
+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+        String file = path;
+        JavaRDD<String> rdd = jsc.textFile(file);
+
+        JavaRDD<Row> rowRDD = rdd.map(s -> {
+            String[] line = StringUtils.split(s, '\t');
+            double label = NumberUtils.toDouble(line[0]);
+            // 选特征
+            Map<String, Double> map = new HashMap<>();
+            for (int i = 1; i < line.length; i++) {
+                String[] fv = StringUtils.split(line[i], ':');
+                map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+            }
+
+            Object[] v = new Object[features.length + 1];
+            v[0] = label;
+            for (int i = 0; i < features.length; i++) {
+                v[i + 1] = map.getOrDefault(features[i], 0.0d);
+            }
+            return RowFactory.create(v);
+        });
+
+        log.info("rowRDD count {}", rowRDD.count());
+        // 将 JavaRDD<Row> 转换为 Dataset<Row>
+        List<StructField> fields = new ArrayList<>();
+        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
+        for (String f : features) {
+            fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
+        }
+        StructType schema = DataTypes.createStructType(fields);
+        Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
+
+        VectorAssembler assembler = new VectorAssembler()
+                .setInputCols(features)
+                .setOutputCol("features");
+
+        Dataset<Row> assembledData = assembler.transform(dataset);
+        return assembledData;
+    }
+}