| 
					
				 | 
			
			
				@@ -5,9 +5,11 @@ import lombok.extern.slf4j.Slf4j; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.commons.lang.math.NumberUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.commons.lang3.RandomUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.commons.lang3.StringUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.api.java.JavaRDD; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.api.java.JavaSparkContext; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.ml.feature.VectorAssembler; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.ml.linalg.SparseVector; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.ml.linalg.VectorUDT; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.sql.Dataset; 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -18,10 +20,8 @@ import org.apache.spark.sql.types.DataTypes; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.sql.types.StructField; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.sql.types.StructType; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import java.util.ArrayList; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import java.util.HashMap; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import java.util.List; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import java.util.Map; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.*; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.stream.Collectors; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 /** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  * @author dyp 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -32,30 +32,22 @@ public class XGBoostTrain { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     public static void main(String[] args) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            List<String> features = Lists.newArrayList("cpa", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_1h_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_1h_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_1h_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_1h_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_1h_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_1h_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_1h_conver*ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_2h_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_2h_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_2h_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_2h_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_2h_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_2h_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_2h_conver*ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_3h_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_3h_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_3h_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_3h_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_3h_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_3h_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_3h_conver*ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_6h_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    "b2_6h_ctcvr"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String[] features = {"cpa", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_12h_conver*ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_ctr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_ctcvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_cvr", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_conver", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_click", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_conver*log(view)", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    "b2_7d_conver*ctcvr" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             SparkSession spark = SparkSession.builder() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -69,6 +61,26 @@ public class XGBoostTrain { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             JavaRDD<String> rdd = jsc.textFile(file); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            JavaRDD<Row> rowRDD = rdd.map(s -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                String[] line = StringUtils.split(s, '\t'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                int label = NumberUtils.toInt(line[0]); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                // 选特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                Map<String, Double> map = new HashMap<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                for (int i = 1; i < line.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                    String[] fv = StringUtils.split(line[i], ':'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+// 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                int[] indices = new int[features.length]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                double[] values = new double[features.length]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                for (int i = 0; i < features.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                    indices[i] = i; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                    values[i] = map.getOrDefault(features[i], 0.0); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                SparseVector vector = new SparseVector(indices.length, indices, values); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//                return RowFactory.create(label, vector); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+//            }); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             JavaRDD<Row> rowRDD = rdd.map(s -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 String[] line = StringUtils.split(s, '\t'); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 int label = NumberUtils.toInt(line[0]); 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -79,29 +91,40 @@ public class XGBoostTrain { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                int[] indices = new int[features.size()]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                double[] values = new double[features.size()]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for (int i = 0; i < features.size(); i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    indices[i] = i; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    values[i] = map.getOrDefault(features.get(i), 0.0); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Object[] v = new Object[features.length + 1]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                v[0] = label; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                v[0] = RandomUtils.nextInt(0, 2); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                double[] values = new double[features.length]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for (int i = 0; i < features.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    values[i] = map.getOrDefault(features[i], 0.0d); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    v[i + 1] = map.getOrDefault(features[i], 0.0d); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                SparseVector vector = new SparseVector(indices.length, indices, values); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                return RowFactory.create(label, vector); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return RowFactory.create(v); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             }); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             log.info("rowRDD count {}", rowRDD.count()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             // 将 JavaRDD<Row> 转换为 Dataset<Row> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             List<StructField> fields = new ArrayList<>(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            fields.add(DataTypes.createStructField("features", new VectorUDT(), true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for (String f : features) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             StructType schema = DataTypes.createStructType(fields); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            VectorAssembler assembler = new VectorAssembler() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setInputCols(features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    .setOutputCol("features"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row> assembledData = assembler.transform(dataset); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            assembledData.show(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             // 划分训练集和测试集 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3}); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3}); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             Dataset<Row> trainData = splits[0]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            trainData.show(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             Dataset<Row> testData = splits[1]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            testData.show(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             // 参数 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -124,7 +147,7 @@ public class XGBoostTrain { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             // 显示预测结果 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             Dataset<Row> predictions = model.transform(testData); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            predictions.select("label", "prediction").show(30000); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            predictions.show(100); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         } catch (Throwable e) { 
			 |