| 
					
				 | 
			
			
				@@ -0,0 +1,154 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+package com.tzld.piaoquan.recommend.model.produce.xgboost; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import lombok.extern.slf4j.Slf4j; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.commons.lang.math.NumberUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.commons.lang3.RandomUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.commons.lang3.StringUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.api.java.JavaRDD; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.api.java.JavaSparkContext; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.ml.feature.VectorAssembler; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.Dataset; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.Row; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.RowFactory; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.SparkSession; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.types.DataTypes; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.types.StructField; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.types.StructType; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.ArrayList; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.HashMap; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.List; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.Map; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+/** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ * @author dyp 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+@Slf4j 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+public class XGBoostTrainLocalTest { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public static void main(String[] args) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String[] features = {"cpa", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_conver*ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_conver*ctcvr" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            SparkSession spark = SparkSession.builder() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .appName("XGBoostTrain") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .master("local") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .getOrCreate(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String file = "/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            file = "/Users/dingyunpeng/Desktop/part-00099.gz"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            JavaRDD<String> rdd = jsc.textFile(file); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            JavaRDD<Row> rowRDD = rdd.map(s -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                String[] line = StringUtils.split(s, '\t'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                int label = NumberUtils.toInt(line[0]); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                // 选特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                Map<String, Double> map = new HashMap<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                for (int i = 1; i < line.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                    String[] fv = StringUtils.split(line[i], ':'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+// 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                int[] indices = new int[features.length]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                double[] values = new double[features.length]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                for (int i = 0; i < features.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                    indices[i] = i; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                    values[i] = map.getOrDefault(features[i], 0.0); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                SparseVector vector = new SparseVector(indices.length, indices, values); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                return RowFactory.create(label, vector); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            }); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            JavaRDD<Row> rowRDD = rdd.map(s -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                String[] line = StringUtils.split(s, '\t'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                int label = NumberUtils.toInt(line[0]); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                // 选特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Map<String, Double> map = new HashMap<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for (int i = 1; i < line.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    String[] fv = StringUtils.split(line[i], ':'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Object[] v = new Object[features.length + 1]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                v[0] = label; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                v[0] = RandomUtils.nextInt(0, 2); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for (int i = 0; i < features.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    v[i + 1] = map.getOrDefault(features[i], 0.0d); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return RowFactory.create(v); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            }); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            log.info("rowRDD count {}", rowRDD.count()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 将 JavaRDD<Row> 转换为 Dataset<Row> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            List<StructField> fields = new ArrayList<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for (String f : features) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            StructType schema = DataTypes.createStructType(fields); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            VectorAssembler assembler = new VectorAssembler() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setInputCols(features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setOutputCol("features"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> assembledData = assembler.transform(dataset); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            assembledData.show(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 划分训练集和测试集 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3}); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> trainData = splits[0]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            trainData.show(500); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> testData = splits[1]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            testData.show(500); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 参数 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 创建 XGBoostClassifier 对象 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            XGBoostClassifier xgbClassifier = new XGBoostClassifier() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setEta(0.1f) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setMissing(0.0f) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setFeaturesCol("features") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setLabelCol("label") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setMaxDepth(5) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setObjective("binary:logistic") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setNthread(1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setNumRound(5) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setNumWorkers(1); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 训练模型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            XGBoostClassificationModel model = xgbClassifier.fit(assembledData); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 显示预测结果 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> predictions = model.transform(assembledData); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Throwable e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            log.error("", e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+} 
			 |