|
@@ -75,8 +75,9 @@ object recsys_01_ros_reg_xgb_train {
|
|
|
fields = fields ++ Array(
|
|
|
DataTypes.createStructField("logKey", DataTypes.StringType, true)
|
|
|
)
|
|
|
-
|
|
|
val schema = DataTypes.createStructType(fields)
|
|
|
+
|
|
|
+
|
|
|
val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
|
|
|
val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
|
|
|
val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label").persist()
|
|
@@ -108,7 +109,14 @@ object recsys_01_ros_reg_xgb_train {
|
|
|
features
|
|
|
)
|
|
|
val testDataSet = spark.createDataFrame(testData, schema)
|
|
|
+ println("recsys ros testDataSet schema")
|
|
|
+ testDataSet.printSchema()
|
|
|
+
|
|
|
val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
|
|
|
+ println("recsys ros testDataSetTrans schema")
|
|
|
+ testDataSetTrans.printSchema()
|
|
|
+
|
|
|
+
|
|
|
val predictions = model.transform(testDataSetTrans)
|
|
|
|
|
|
// 保存评估结果
|
|
@@ -162,7 +170,9 @@ object recsys_01_ros_reg_xgb_train {
|
|
|
data.map(r => {
|
|
|
val line: Array[String] = StringUtils.split(r, '\t')
|
|
|
val logKey = line(0)
|
|
|
+
|
|
|
val label: Double = NumberUtils.toDouble(line(1))
|
|
|
+
|
|
|
val map: util.Map[String, Double] = new util.HashMap[String, Double]
|
|
|
for (i <- 2 until line.length) {
|
|
|
val fv: Array[String] = StringUtils.split(line(i), ':')
|