丁云鹏 8 månader sedan
förälder
incheckning
f4b9e487de

+ 4 - 0
recommend-model-produce/pom.xml

@@ -54,6 +54,10 @@
                     <artifactId>scala-library</artifactId>
                     <groupId>org.scala-lang</groupId>
                 </exclusion>
+                <exclusion>
+                    <artifactId>hadoop-mapreduce-client-core</artifactId>
+                    <groupId>org.apache.hadoop</groupId>
+                </exclusion>
             </exclusions>
         </dependency>
         <dependency>

+ 12 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -7,6 +7,9 @@ import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
 import org.apache.commons.lang.math.NumberUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.feature.VectorAssembler;
+import org.apache.spark.mllib.linalg.SparseVector;
+import org.apache.spark.mllib.linalg.VectorUDT;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
@@ -39,21 +42,28 @@ public class XGBoostTrain {
             JavaRDD<Row> rowRDD = rdd.map(s -> {
                 String[] line = StringUtils.split("\t");
                 int label = NumberUtils.toInt(line[0]);
+                int[] indices = new int[line.length - 1];
                 double[] values = new double[line.length - 1];
                 for (int i = 1; i < line.length; i++) {
                     String[] fv = StringUtils.split(":");
+                    indices[i - 1] = i - 1;
                     values[i - 1] = NumberUtils.toDouble(fv[1], 0.0);
                 }
-                return RowFactory.create(label, values);
+                SparseVector vector = new SparseVector(indices.length, indices, values);
+                return RowFactory.create(label, vector);
             });
+
             log.info("rowRDD count {}", rowRDD.count());
             // 将 JavaRDD<Row> 转换为 Dataset<Row>
             List<StructField> fields = new ArrayList<>();
             fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
-            fields.add(DataTypes.createStructField("features", new ArrayType(DataTypes.DoubleType, true), true));
+            fields.add(DataTypes.createStructField("features", new VectorUDT(), true));
             StructType schema = DataTypes.createStructType(fields);
             Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
 
+
+
+            // 使用 VectorAssembler 转换数据
             // 划分训练集和测试集
             Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3});
             Dataset<Row> trainData = splits[0];