| 
					
				 | 
			
			
				@@ -4,10 +4,10 @@ import lombok.extern.slf4j.Slf4j; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.commons.lang.math.NumberUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import org.apache.commons.lang3.RandomUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.commons.lang3.StringUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.api.java.JavaRDD; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.api.java.JavaSparkContext; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.ml.feature.VectorAssembler; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.sql.Dataset; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.sql.Row; 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -31,111 +31,22 @@ public class XGBoostTrainLocalTest { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     public static void main(String[] args) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            String[] features = {"cpa", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_12h_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_12h_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_12h_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_12h_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_12h_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_12h_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_12h_conver*ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_7d_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_7d_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_7d_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_7d_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_7d_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_7d_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_7d_conver*ctcvr" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            SparkSession spark = SparkSession.builder() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .appName("XGBoostTrain") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .master("local") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .getOrCreate(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            String file = "/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            file = "/Users/dingyunpeng/Desktop/part-00099.gz"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            JavaRDD<String> rdd = jsc.textFile(file); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//            JavaRDD<Row> rowRDD = rdd.map(s -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                String[] line = StringUtils.split(s, '\t'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                int label = NumberUtils.toInt(line[0]); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                // 选特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                Map<String, Double> map = new HashMap<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                for (int i = 1; i < line.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                    String[] fv = StringUtils.split(line[i], ':'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-// 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                int[] indices = new int[features.length]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                double[] values = new double[features.length]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                for (int i = 0; i < features.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                    indices[i] = i; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                    values[i] = map.getOrDefault(features[i], 0.0); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                SparseVector vector = new SparseVector(indices.length, indices, values); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//                return RowFactory.create(label, vector); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-//            }); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            JavaRDD<Row> rowRDD = rdd.map(s -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                String[] line = StringUtils.split(s, '\t'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                int label = NumberUtils.toInt(line[0]); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                // 选特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                Map<String, Double> map = new HashMap<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for (int i = 1; i < line.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    String[] fv = StringUtils.split(line[i], ':'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                Object[] v = new Object[features.length + 1]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                v[0] = label; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                v[0] = RandomUtils.nextInt(0, 2); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for (int i = 0; i < features.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    v[i + 1] = map.getOrDefault(features[i], 0.0d); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                return RowFactory.create(v); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            }); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            log.info("rowRDD count {}", rowRDD.count()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            // 将 JavaRDD<Row> 转换为 Dataset<Row> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            List<StructField> fields = new ArrayList<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for (String f : features) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            StructType schema = DataTypes.createStructType(fields); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            VectorAssembler assembler = new VectorAssembler() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .setInputCols(features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .setOutputCol("features"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            Dataset<Row> assembledData = assembler.transform(dataset); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            assembledData.show(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            // 划分训练集和测试集 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3}); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            Dataset<Row> trainData = splits[0]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            trainData.show(500); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            Dataset<Row> testData = splits[1]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            testData.show(500); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            // 参数 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> assembledData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             // 创建 XGBoostClassifier 对象 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             XGBoostClassifier xgbClassifier = new XGBoostClassifier() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .setEta(0.1f) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setEta(0.01f) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setSubsample(0.8) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setColsampleBytree(0.8) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setScalePosWeight(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setSeed(2024) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     .setMissing(0.0f) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     .setFeaturesCol("features") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     .setLabelCol("label") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     .setMaxDepth(5) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     .setObjective("binary:logistic") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     .setNthread(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    .setNumRound(5) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setNumRound(100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     .setNumWorkers(1); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -143,12 +54,88 @@ public class XGBoostTrainLocalTest { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             XGBoostClassificationModel model = xgbClassifier.fit(assembledData); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             // 显示预测结果 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            Dataset<Row> predictions = model.transform(assembledData); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> predictions = model.transform(predictData); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 计算AUC 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> selected = predictions.select("label", "rawPrediction"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setLabelCol("label") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setRawPredictionCol("rawPrediction") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setMetricName("areaUnderROC"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            double auc = evaluator.evaluate(selected); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            log.info("AUC: {}", auc); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         } catch (Throwable e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             log.error("", e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private static Dataset<Row> dataset(String path) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String[] features = {"cpa", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_12h_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_12h_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_12h_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_12h_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_12h_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_12h_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_12h_conver*ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_7d_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_7d_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_7d_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_7d_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_7d_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_7d_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "b2_7d_conver*ctcvr" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        SparkSession spark = SparkSession.builder() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .appName("XGBoostTrain") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .master("local") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .getOrCreate(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String file = path; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        JavaRDD<String> rdd = jsc.textFile(file); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        JavaRDD<Row> rowRDD = rdd.map(s -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String[] line = StringUtils.split(s, '\t'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            int label = NumberUtils.toInt(line[0]); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 选特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Map<String, Double> map = new HashMap<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for (int i = 1; i < line.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                String[] fv = StringUtils.split(line[i], ':'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Object[] v = new Object[features.length + 1]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            v[0] = label; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for (int i = 0; i < features.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                v[i + 1] = map.getOrDefault(features[i], 0.0d); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return RowFactory.create(v); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        log.info("rowRDD count {}", rowRDD.count()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 将 JavaRDD<Row> 转换为 Dataset<Row> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        List<StructField> fields = new ArrayList<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for (String f : features) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructType schema = DataTypes.createStructType(fields); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        VectorAssembler assembler = new VectorAssembler() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .setInputCols(features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .setOutputCol("features"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Dataset<Row> assembledData = assembler.transform(dataset); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assembledData.show(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return assembledData; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 } 
			 |