丁云鹏 8 月之前
父节点
当前提交
6cb2105348

+ 7 - 27
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
 import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -58,27 +59,6 @@ public class XGBoostTrain {
             //file = "/Users/dingyunpeng/Desktop/part-00099.gz";
             JavaRDD<String> rdd = jsc.textFile(file);
 
-            // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row>
-//            JavaRDD<Row> rowRDD = rdd.map(s -> {
-//                String[] line = StringUtils.split(s, '\t');
-//                int label = NumberUtils.toInt(line[0]);
-//                // 选特征
-//                Map<String, Double> map = new HashMap<>();
-//                for (int i = 1; i < line.length; i++) {
-//                    String[] fv = StringUtils.split(line[i], ':');
-//                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
-//                }
-//
-//                int[] indices = new int[features.length];
-//                double[] values = new double[features.length];
-//                for (int i = 0; i < features.length; i++) {
-//                    indices[i] = i;
-//                    values[i] = map.getOrDefault(features[i], 0.0);
-//                }
-//                SparseVector vector = new SparseVector(indices.length, indices, values);
-//                return RowFactory.create(label, vector);
-//            });
-
             JavaRDD<Row> rowRDD = rdd.map(s -> {
                 String[] line = StringUtils.split(s, '\t');
                 int label = NumberUtils.toInt(line[0]);
@@ -91,7 +71,7 @@ public class XGBoostTrain {
 
                 Object[] v = new Object[features.length + 1];
                 v[0] = label;
-                //v[0] = RandomUtils.nextInt(0, 2);
+                v[0] = RandomUtils.nextInt(0, 2);
                 for (int i = 0; i < features.length; i++) {
                     v[i + 1] = map.getOrDefault(features[i], 0.0d);
                 }
@@ -118,9 +98,9 @@ public class XGBoostTrain {
             // 划分训练集和测试集
             Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3});
             Dataset<Row> trainData = splits[0];
-            trainData.show();
+            trainData.show(500);
             Dataset<Row> testData = splits[1];
-            testData.show();
+            testData.show(500);
 
             // 参数
 
@@ -139,11 +119,11 @@ public class XGBoostTrain {
 
 
             // 训练模型
-            XGBoostClassificationModel model = xgbClassifier.fit(trainData);
+            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
 
             // 显示预测结果
-            Dataset<Row> predictions = model.transform(testData);
-            predictions.show(100);
+            Dataset<Row> predictions = model.transform(assembledData);
+            predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
 
 
         } catch (Throwable e) {

+ 154 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -0,0 +1,154 @@
+package com.tzld.piaoquan.recommend.model.produce.xgboost;
+
+import lombok.extern.slf4j.Slf4j;
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.RandomUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.feature.VectorAssembler;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class XGBoostTrainLocalTest {
+
+    public static void main(String[] args) {
+        try {
+
+            String[] features = {"cpa",
+                    "b2_12h_ctr",
+                    "b2_12h_ctcvr",
+                    "b2_12h_cvr",
+                    "b2_12h_conver",
+                    "b2_12h_click",
+                    "b2_12h_conver*log(view)",
+                    "b2_12h_conver*ctcvr",
+                    "b2_7d_ctr",
+                    "b2_7d_ctcvr",
+                    "b2_7d_cvr",
+                    "b2_7d_conver",
+                    "b2_7d_click",
+                    "b2_7d_conver*log(view)",
+                    "b2_7d_conver*ctcvr"
+            };
+
+
+            SparkSession spark = SparkSession.builder()
+                    .appName("XGBoostTrain")
+                    .master("local")
+                    .getOrCreate();
+
+            JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+            String file = "/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz";
+            file = "/Users/dingyunpeng/Desktop/part-00099.gz";
+            JavaRDD<String> rdd = jsc.textFile(file);
+
+            // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row>
+//            JavaRDD<Row> rowRDD = rdd.map(s -> {
+//                String[] line = StringUtils.split(s, '\t');
+//                int label = NumberUtils.toInt(line[0]);
+//                // 选特征
+//                Map<String, Double> map = new HashMap<>();
+//                for (int i = 1; i < line.length; i++) {
+//                    String[] fv = StringUtils.split(line[i], ':');
+//                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+//                }
+//
+//                int[] indices = new int[features.length];
+//                double[] values = new double[features.length];
+//                for (int i = 0; i < features.length; i++) {
+//                    indices[i] = i;
+//                    values[i] = map.getOrDefault(features[i], 0.0);
+//                }
+//                SparseVector vector = new SparseVector(indices.length, indices, values);
+//                return RowFactory.create(label, vector);
+//            });
+
+            JavaRDD<Row> rowRDD = rdd.map(s -> {
+                String[] line = StringUtils.split(s, '\t');
+                int label = NumberUtils.toInt(line[0]);
+                // 选特征
+                Map<String, Double> map = new HashMap<>();
+                for (int i = 1; i < line.length; i++) {
+                    String[] fv = StringUtils.split(line[i], ':');
+                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+                }
+
+                Object[] v = new Object[features.length + 1];
+                v[0] = label;
+                v[0] = RandomUtils.nextInt(0, 2);
+                for (int i = 0; i < features.length; i++) {
+                    v[i + 1] = map.getOrDefault(features[i], 0.0d);
+                }
+
+                return RowFactory.create(v);
+            });
+
+            log.info("rowRDD count {}", rowRDD.count());
+            // 将 JavaRDD<Row> 转换为 Dataset<Row>
+            List<StructField> fields = new ArrayList<>();
+            fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
+            for (String f : features) {
+                fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
+            }
+            StructType schema = DataTypes.createStructType(fields);
+            Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
+
+            VectorAssembler assembler = new VectorAssembler()
+                    .setInputCols(features)
+                    .setOutputCol("features");
+
+            Dataset<Row> assembledData = assembler.transform(dataset);
+            assembledData.show();
+            // 划分训练集和测试集
+            Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3});
+            Dataset<Row> trainData = splits[0];
+            trainData.show(500);
+            Dataset<Row> testData = splits[1];
+            testData.show(500);
+
+            // 参数
+
+
+            // 创建 XGBoostClassifier 对象
+            XGBoostClassifier xgbClassifier = new XGBoostClassifier()
+                    .setEta(0.1f)
+                    .setMissing(0.0f)
+                    .setFeaturesCol("features")
+                    .setLabelCol("label")
+                    .setMaxDepth(5)
+                    .setObjective("binary:logistic")
+                    .setNthread(1)
+                    .setNumRound(5)
+                    .setNumWorkers(1);
+
+
+            // 训练模型
+            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
+
+            // 显示预测结果
+            Dataset<Row> predictions = model.transform(assembledData);
+            predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
+
+
+        } catch (Throwable e) {
+            log.error("", e);
+        }
+    }
+}