Explorar o código

Merge branch 'main'

zhangbo hai 8 meses
pai
achega
affed96ee0

+ 4 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/CMDService.java

@@ -1,5 +1,6 @@
 package com.tzld.piaoquan.recommend.model.produce.service;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -11,6 +12,9 @@ import java.util.Map;
 public class CMDService {
 
     public Map<String, String> parse(String[] args) {
+        if (args == null) {
+            return Collections.emptyMap();
+        }
         Map<String, String> map = new HashMap<>();
         for (int i = 0; i < args.length - 1; i++) {
             map.put(args[i].substring(1), args[i + 1]);

+ 38 - 33
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -33,10 +33,12 @@ public class XGBoostService {
 
     public void train(String[] args) {
         try {
-            Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz");
-            log.info("训练样本 show");
-            assembledData.show();
-            // 创建 XGBoostClassifier 对象
+            CMDService cmd = new CMDService();
+            Map<String, String> argMap = cmd.parse(args);
+            String path = argMap.get("path");
+            // 训练
+            Dataset<Row> trainData = dataset(path);
+            trainData.show();
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
                     .setEta(0.01f)
                     .setSubsample(0.8)
@@ -51,20 +53,17 @@ public class XGBoostService {
                     .setNthread(1)
                     .setNumRound(100)
                     .setNumWorkers(1);
-
-
-            // 训练模型
-            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
+            XGBoostClassificationModel model = xgbClassifier.fit(trainData);
 
             // 保存模型
-            String path = "/root/recommend-model/modeltrain";
-            model.write().overwrite().save("file://" + path);
-            String outputPath = "/root/recommend-model/model.tar.gz";
-            CompressUtil.compressDirectoryToGzip(path, outputPath);
+            String modelPath = "/root/recommend-model/modeltrain";
+            model.write().overwrite().save("file://" + modelPath);
+            String gzPath = "/root/recommend-model/model.tar.gz";
+            CompressUtil.compressDirectoryToGzip(modelPath, gzPath);
             String bucketName = "art-test-video";
             String ossPath = "test/model.tar.gz";
             OSSService ossService = new OSSService();
-            ossService.upload(bucketName, outputPath, ossPath);
+            ossService.upload(bucketName, gzPath, ossPath);
 
         } catch (Throwable e) {
             log.error("", e);
@@ -74,32 +73,37 @@ public class XGBoostService {
     public void predict(String[] args) {
         try {
 
-            Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
-            log.info("测试样本 show");
-            assembledData.show();
+            CMDService cmd = new CMDService();
+            Map<String, String> argMap = cmd.parse(args);
+            String path = argMap.get("path");
 
-            // 保存模型
+
+            // 加载模型
             String bucketName = "art-test-video";
             String objectName = "test/model.tar.gz";
             OSSService ossService = new OSSService();
 
-            String destPath = "/root/recommend-model/model2.tar.gz";
-            ossService.download(bucketName, destPath, objectName);
-            String destDir = "/root/recommend-model/modelpredict";
-            CompressUtil.decompressGzFile(destPath, destDir);
+            String gzPath = "/root/recommend-model/model2.tar.gz";
+            ossService.download(bucketName, gzPath, objectName);
+            String modelDir = "/root/recommend-model/modelpredict";
+            CompressUtil.decompressGzFile(gzPath, modelDir);
+
+            XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
+            model.setMissing(0.0f)
+                    .setFeaturesCol("features");
+
 
-            // 显示预测结果
-            XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
-            Dataset<Row> predictions = model2.transform(assembledData);
-            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
+            // 预测
+            Dataset<Row> predictData = dataset(path);
+            predictData.show();
+            Dataset<Row> predictions = model.transform(predictData);
+            predictions.show();
 
             // 计算AUC
-            Dataset<Row> selected = predictions.select("label", "rawPrediction");
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
                     .setLabelCol("label")
-                    .setRawPredictionCol("rawPrediction")
-                    .setMetricName("areaUnderROC");
-            double auc = evaluator.evaluate(selected);
+                    .setRawPredictionCol("rawPrediction");
+            double auc = evaluator.evaluate(predictions);
             log.info("AUC: {}", auc);
 
         } catch (Throwable e) {
@@ -108,7 +112,8 @@ public class XGBoostService {
     }
 
     private static Dataset<Row> dataset(String path) {
-        String[] features = {"cpa",
+        String[] features = {
+                "cpa",
                 "b2_1h_ctr",
                 "b2_1h_ctcvr",
                 "b2_1h_cvr",
@@ -146,7 +151,7 @@ public class XGBoostService {
 
         JavaRDD<Row> rowRDD = rdd.map(s -> {
             String[] line = StringUtils.split(s, '\t');
-            int label = NumberUtils.toInt(line[0]);
+            double label = NumberUtils.toDouble(line[0]);
             // 选特征
             Map<String, Double> map = new HashMap<>();
             for (int i = 1; i < line.length; i++) {
@@ -166,7 +171,7 @@ public class XGBoostService {
         log.info("rowRDD count {}", rowRDD.count());
         // 将 JavaRDD<Row> 转换为 Dataset<Row>
         List<StructField> fields = new ArrayList<>();
-        fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
+        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
         for (String f : features) {
             fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
         }
@@ -177,7 +182,7 @@ public class XGBoostService {
                 .setInputCols(features)
                 .setOutputCol("features");
 
-        Dataset<Row> assembledData = assembler.transform(dataset);
+        Dataset<Row> assembledData = assembler.transform(dataset).select("features", "label");
         assembledData.show();
         return assembledData;
     }

+ 38 - 30
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
 import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -30,10 +31,7 @@ public class XGBoostTrainLocalTest {
 
     public static void main(String[] args) {
         try {
-
-            Dataset<Row> assembledData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
-
-            // 创建 XGBoostClassifier 对象
+            // 训练
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
                     .setEta(0.01f)
                     .setSubsample(0.8)
@@ -48,22 +46,23 @@ public class XGBoostTrainLocalTest {
                     .setNthread(1)
                     .setNumRound(100)
                     .setNumWorkers(1);
+            Dataset<Row> trainData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
+            trainData.show();
+            XGBoostClassificationModel model = xgbClassifier.fit(trainData);
 
 
-            // 训练模型
-            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
-
-            // 显示预测结果
+            // 预测
             Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
+            model.setFeaturesCol("features").setMissing(0.0f);
             Dataset<Row> predictions = model.transform(predictData);
-            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show();
+            predictions.show();
+
 
             // 计算AUC
             Dataset<Row> selected = predictions.select("label", "rawPrediction");
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
                     .setLabelCol("label")
-                    .setRawPredictionCol("rawPrediction")
-                    .setMetricName("areaUnderROC");
+                    .setRawPredictionCol("rawPrediction");
             double auc = evaluator.evaluate(selected);
 
             log.info("AUC: {}", auc);
@@ -74,21 +73,31 @@ public class XGBoostTrainLocalTest {
     }
 
     private static Dataset<Row> dataset(String path) {
-        String[] features = {"cpa",
-                "b2_12h_ctr",
-                "b2_12h_ctcvr",
-                "b2_12h_cvr",
-                "b2_12h_conver",
-                "b2_12h_click",
-                "b2_12h_conver*log(view)",
-                "b2_12h_conver*ctcvr",
-                "b2_7d_ctr",
-                "b2_7d_ctcvr",
-                "b2_7d_cvr",
-                "b2_7d_conver",
-                "b2_7d_click",
-                "b2_7d_conver*log(view)",
-                "b2_7d_conver*ctcvr"
+        String[] features = {
+                "cpa",
+                "b2_1h_ctr",
+                "b2_1h_ctcvr",
+                "b2_1h_cvr",
+                "b2_1h_conver",
+                "b2_1h_click",
+                "b2_1h_conver*log(view)",
+                "b2_1h_conver*ctcvr",
+                "b2_2h_ctr",
+                "b2_2h_ctcvr",
+                "b2_2h_cvr",
+                "b2_2h_conver",
+                "b2_2h_click",
+                "b2_2h_conver*log(view)",
+                "b2_2h_conver*ctcvr",
+                "b2_3h_ctr",
+                "b2_3h_ctcvr",
+                "b2_3h_cvr",
+                "b2_3h_conver",
+                "b2_3h_click",
+                "b2_3h_conver*log(view)",
+                "b2_3h_conver*ctcvr",
+                "b2_6h_ctr",
+                "b2_6h_ctcvr"
         };
 
 
@@ -103,7 +112,7 @@ public class XGBoostTrainLocalTest {
 
         JavaRDD<Row> rowRDD = rdd.map(s -> {
             String[] line = StringUtils.split(s, '\t');
-            int label = NumberUtils.toInt(line[0]);
+            double label = NumberUtils.toDouble(line[0]);
             // 选特征
             Map<String, Double> map = new HashMap<>();
             for (int i = 1; i < line.length; i++) {
@@ -116,14 +125,14 @@ public class XGBoostTrainLocalTest {
             for (int i = 0; i < features.length; i++) {
                 v[i + 1] = map.getOrDefault(features[i], 0.0d);
             }
-
+            //v[0] = (double) v[1] > 0.05 ? 1.0 : 0.0;
             return RowFactory.create(v);
         });
 
         log.info("rowRDD count {}", rowRDD.count());
         // 将 JavaRDD<Row> 转换为 Dataset<Row>
         List<StructField> fields = new ArrayList<>();
-        fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
+        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
         for (String f : features) {
             fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
         }
@@ -135,7 +144,6 @@ public class XGBoostTrainLocalTest {
                 .setOutputCol("features");
 
         Dataset<Row> assembledData = assembler.transform(dataset);
-        assembledData.show();
         return assembledData;
     }
 }