丁云鹏 8 months ago
parent
commit
9b4a9e846f

+ 18 - 21
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -33,10 +33,10 @@ public class XGBoostService {
 
     public void train(String[] args) {
         try {
-            Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz");
-            log.info("训练样本 show");
-            assembledData.show();
-            // 创建 XGBoostClassifier 对象
+
+            // 训练
+            Dataset<Row> trainData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz");
+            trainData.show();
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
                     .setEta(0.01f)
                     .setSubsample(0.8)
@@ -51,10 +51,7 @@ public class XGBoostService {
                     .setNthread(1)
                     .setNumRound(100)
                     .setNumWorkers(1);
-
-
-            // 训练模型
-            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
+            XGBoostClassificationModel model = xgbClassifier.fit(trainData);
 
             // 保存模型
             String path = "/root/recommend-model/modeltrain";
@@ -74,12 +71,8 @@ public class XGBoostService {
     public void predict(String[] args) {
         try {
 
-            Dataset<Row> assembledData =
-                    dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
-            log.info("测试样本 show");
-            assembledData.show();
 
-            // 保存模型
+            // 加载模型
             String bucketName = "art-test-video";
             String objectName = "test/model.tar.gz";
             OSSService ossService = new OSSService();
@@ -89,20 +82,23 @@ public class XGBoostService {
             String destDir = "/root/recommend-model/modelpredict";
             CompressUtil.decompressGzFile(destPath, destDir);
 
-            // 显示预测结果
-            XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
-            model2.setMissing(0.0f)
+            XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + destDir);
+            model.setMissing(0.0f)
                     .setFeaturesCol("features");
 
-            Dataset<Row> predictions = model2.transform(assembledData);
-            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show();
+
+            // 预测
+            Dataset<Row> predictData =
+                    dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
+            predictData.show();
+            Dataset<Row> predictions = model.transform(predictData);
+            predictions.show();
 
             // 计算AUC
-            Dataset<Row> selected = predictions.select("label", "rawPrediction");
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
                     .setLabelCol("label")
                     .setRawPredictionCol("rawPrediction");
-            double auc = evaluator.evaluate(selected);
+            double auc = evaluator.evaluate(predictions);
             log.info("AUC: {}", auc);
 
         } catch (Throwable e) {
@@ -111,7 +107,8 @@ public class XGBoostService {
     }
 
     private static Dataset<Row> dataset(String path) {
-        String[] features = {"cpa",
+        String[] features = {
+                "cpa",
                 "b2_1h_ctr",
                 "b2_1h_ctcvr",
                 "b2_1h_cvr",

+ 37 - 28
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
 import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -30,9 +31,7 @@ public class XGBoostTrainLocalTest {
 
     public static void main(String[] args) {
         try {
-            Dataset<Row> assembledData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
-
-            // 创建 XGBoostClassifier 对象
+            // 训练
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
                     .setEta(0.01f)
                     .setSubsample(0.8)
@@ -47,22 +46,23 @@ public class XGBoostTrainLocalTest {
                     .setNthread(1)
                     .setNumRound(100)
                     .setNumWorkers(1);
+            Dataset<Row> trainData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
+            trainData.show();
+            XGBoostClassificationModel model = xgbClassifier.fit(trainData);
 
 
-            // 显示预测结果
+            // 预测
             Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
-
-            // 训练模型
-            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
+            model.setFeaturesCol("features").setMissing(0.0f);
             Dataset<Row> predictions = model.transform(predictData);
-            predictions.select("prediction", "rawPrediction", "probability", "features").show();
+            predictions.show();
+
 
             // 计算AUC
             Dataset<Row> selected = predictions.select("label", "rawPrediction");
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
                     .setLabelCol("label")
-                    .setRawPredictionCol("rawPrediction")
-                    .setMetricName("areaUnderROC");
+                    .setRawPredictionCol("rawPrediction");
             double auc = evaluator.evaluate(selected);
 
             log.info("AUC: {}", auc);
@@ -73,21 +73,31 @@ public class XGBoostTrainLocalTest {
     }
 
     private static Dataset<Row> dataset(String path) {
-        String[] features = {"cpa",
-                "b2_12h_ctr",
-                "b2_12h_ctcvr",
-                "b2_12h_cvr",
-                "b2_12h_conver",
-                "b2_12h_click",
-                "b2_12h_conver*log(view)",
-                "b2_12h_conver*ctcvr",
-                "b2_7d_ctr",
-                "b2_7d_ctcvr",
-                "b2_7d_cvr",
-                "b2_7d_conver",
-                "b2_7d_click",
-                "b2_7d_conver*log(view)",
-                "b2_7d_conver*ctcvr"
+        String[] features = {
+                "cpa",
+                "b2_1h_ctr",
+                "b2_1h_ctcvr",
+                "b2_1h_cvr",
+                "b2_1h_conver",
+                "b2_1h_click",
+                "b2_1h_conver*log(view)",
+                "b2_1h_conver*ctcvr",
+                "b2_2h_ctr",
+                "b2_2h_ctcvr",
+                "b2_2h_cvr",
+                "b2_2h_conver",
+                "b2_2h_click",
+                "b2_2h_conver*log(view)",
+                "b2_2h_conver*ctcvr",
+                "b2_3h_ctr",
+                "b2_3h_ctcvr",
+                "b2_3h_cvr",
+                "b2_3h_conver",
+                "b2_3h_click",
+                "b2_3h_conver*log(view)",
+                "b2_3h_conver*ctcvr",
+                "b2_6h_ctr",
+                "b2_6h_ctcvr"
         };
 
 
@@ -115,7 +125,7 @@ public class XGBoostTrainLocalTest {
             for (int i = 0; i < features.length; i++) {
                 v[i + 1] = map.getOrDefault(features[i], 0.0d);
             }
-
+            //v[0] = (double) v[1] > 0.05 ? 1.0 : 0.0;
             return RowFactory.create(v);
         });
 
@@ -133,8 +143,7 @@ public class XGBoostTrainLocalTest {
                 .setInputCols(features)
                 .setOutputCol("features");
 
-        Dataset<Row> assembledData = assembler.transform(dataset).select("features", "label");
-        assembledData.show();
+        Dataset<Row> assembledData = assembler.transform(dataset);
         return assembledData;
     }
 }