|  | @@ -75,8 +75,9 @@ object recsys_01_ros_reg_xgb_train {
 | 
											
												
													
														|  |      fields = fields ++ Array(
 |  |      fields = fields ++ Array(
 | 
											
												
													
														|  |        DataTypes.createStructField("logKey", DataTypes.StringType, true)
 |  |        DataTypes.createStructField("logKey", DataTypes.StringType, true)
 | 
											
												
													
														|  |      )
 |  |      )
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  |      val schema = DataTypes.createStructType(fields)
 |  |      val schema = DataTypes.createStructType(fields)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
 |  |      val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
 | 
											
												
													
														|  |      val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
 |  |      val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
 | 
											
												
													
														|  |      val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label").persist()
 |  |      val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label").persist()
 | 
											
										
											
												
													
														|  | @@ -108,7 +109,14 @@ object recsys_01_ros_reg_xgb_train {
 | 
											
												
													
														|  |        features
 |  |        features
 | 
											
												
													
														|  |      )
 |  |      )
 | 
											
												
													
														|  |      val testDataSet = spark.createDataFrame(testData, schema)
 |  |      val testDataSet = spark.createDataFrame(testData, schema)
 | 
											
												
													
														|  | 
 |  | +    println("recsys ros testDataSet schema")
 | 
											
												
													
														|  | 
 |  | +    testDataSet.printSchema()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
 |  |      val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
 | 
											
												
													
														|  | 
 |  | +    println("recsys ros testDataSetTrans schema")
 | 
											
												
													
														|  | 
 |  | +    testDataSetTrans.printSchema()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      val predictions = model.transform(testDataSetTrans)
 |  |      val predictions = model.transform(testDataSetTrans)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      // 保存评估结果
 |  |      // 保存评估结果
 | 
											
										
											
												
													
														|  | @@ -162,7 +170,9 @@ object recsys_01_ros_reg_xgb_train {
 | 
											
												
													
														|  |      data.map(r => {
 |  |      data.map(r => {
 | 
											
												
													
														|  |        val line: Array[String] = StringUtils.split(r, '\t')
 |  |        val line: Array[String] = StringUtils.split(r, '\t')
 | 
											
												
													
														|  |        val logKey = line(0)
 |  |        val logKey = line(0)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |        val label: Double = NumberUtils.toDouble(line(1))
 |  |        val label: Double = NumberUtils.toDouble(line(1))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |        val map: util.Map[String, Double] = new util.HashMap[String, Double]
 |  |        val map: util.Map[String, Double] = new util.HashMap[String, Double]
 | 
											
												
													
														|  |        for (i <- 2 until line.length) {
 |  |        for (i <- 2 until line.length) {
 | 
											
												
													
														|  |          val fv: Array[String] = StringUtils.split(line(i), ':')
 |  |          val fv: Array[String] = StringUtils.split(line(i), ':')
 |