Ver Fonte

feat:添加apptype和abcode特征

zhaohaipeng há 8 meses atrás
pai
commit
9b7c0994e9

+ 3 - 1
ad/24_ad_model_batch_calc_cid_score_avg.sh

@@ -24,4 +24,6 @@ for cid in "${cid_array[@]}"; do
     score_avg=`awk '{ sum += $2; count++ } END { if (count > 0) print sum / count }' ${PREDICT_PATH}/${model}_${cid}.txt`
 
     echo -e "CID- ${cid} -平均分计算结果: ${score_avg} \n\t模型: ${MODEL_PATH}/${model} \n\tHDFS数据路径: ${hdfs_path} \n\t"
-done
+done
+
+# nohup ./ad/24_ad_model_batch_calc_cid_score_avg.sh 3024,2966,2670,3163,3595,3594,3364,3365,3593,3363,3180,1910,2660,3478,3431,3772,3060,3178,3056,3771,3208,3041,2910,3690,1626,3318,3357,3628,3766,3770,3763,3769,3768,3541,3534,2806,3755,3760,3319,3758,3746,3759,3747,3754,3767,3745,3756,3437,3608,3527,3691,3197,3361,3362,3212,3344,3343,3346,3345,3612,3540,3526,3611,3761,3617,3762,3618,3616,3623,3765,3624,3764,3198,3542,3353,2374,3200 model_bkb8_v55_20240804 /dw/recommend/model/33_ad_train_data_v4/20240806 8 > logs/model_bkb8_v55_20240804_cid_06_12.log 2>&1 &

+ 52 - 57
src/main/scala/com/aliyun/odps/spark/ad/xgboost/v20240808/XGBoostTrain.scala

@@ -2,13 +2,16 @@ package com.aliyun.odps.spark.ad.xgboost.v20240808
 
 import com.aliyun.odps.spark.examples.myUtils.ParamUtils
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
+import org.apache.commons.lang.StringUtils
 import org.apache.commons.lang3.math.NumberUtils
-import org.apache.spark.SparkConf
 import org.apache.spark.ml.feature.VectorAssembler
-import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.{DataTypes, StructField}
+import org.apache.spark.sql.{Dataset, Row, SparkSession}
 
 import java.net.URL
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
 import scala.io.Source
 
 object XGBoostTrain {
@@ -17,91 +20,64 @@ object XGBoostTrain {
 
       val param = ParamUtils.parseArgs(args)
 
-      val conf = new SparkConf()
-        .set("spark.yarn.appMasterEnv.PYSPARK_PYTHON", "/usr/bin/python2.7")
-        .set("spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON", "/usr/bin/python2.7")
+      val dt = LocalDateTime.now.format(DateTimeFormatter.ofPattern("yyyyMMddHHmmSS"))
 
       val spark = SparkSession.builder()
-        .config(conf)
-        .appName("XGBoostTrain")
+        .appName("XGBoostTrain:" + dt)
         .getOrCreate()
       val sc = spark.sparkContext
 
       val loader = getClass.getClassLoader
 
-      val readPath = param.getOrElse("readPath", "")
+      val readPath = param.getOrElse("trainReadPath", "")
+      val predictReadPath = param.getOrElse("predictReadPath", "")
       val filterNameSet = param.getOrElse("filterNames", "").split(",").filter(_.nonEmpty).toSet
       val featureNameFile = param.getOrElse("featureNameFile", "20240718_ad_feature_name.txt")
 
       val featureNameContent = readFile(loader.getResource(featureNameFile))
 
-      val featureNameList = featureNameContent.split("\n")
+      val featureNameList: List[String] = featureNameContent.split("\n")
         .map(r => r.replace(" ", "").replaceAll("\n", ""))
         .filter(r => r.nonEmpty)
         .filter(r => !containsAny(filterNameSet, r))
         .toList
 
-      val rowRDD = sc.textFile(readPath).map(r => {
-        val line = r.split("\t")
+      val rowRDD = dataMap(sc.textFile(readPath), featureNameList)
 
-        val label = NumberUtils.toInt(line(0))
-
-        val map = line.drop(1).map { entry =>
-          val Array(key, value) = entry.split(":")
-          key -> NumberUtils.toDouble(value, 0.0)
-        }.toMap
-
-        val v = Array.ofDim[Any](featureNameList.length + 1)
-        v(0) = label
-
-        for (index <- featureNameList.indices) {
-          v(index + 1) = map.getOrElse(featureNameList(index), 0.0)
-        }
-
-        Row.fromSeq(v)
-      })
       println(s"rowRDD count ${rowRDD.count()}")
 
-      val fields = Seq(
-        StructField("label", DataTypes.IntegerType, true)
-      ) ++ featureNameList.map(f => StructField(f, DataTypes.DoubleType, true))
-
-      val dataset = spark.createDataFrame(rowRDD, StructType(fields))
+      val fields: Array[StructField] = Array(
+        DataTypes.createStructField("label", DataTypes.IntegerType, true)
+      ) ++ featureNameList.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
 
-      val assembler = new VectorAssembler()
-        .setInputCols(featureNameList.toArray)
-        .setOutputCol("features")
+      val trainDataSet: Dataset[Row] = spark.createDataFrame(rowRDD, DataTypes.createStructType(fields))
 
-      val assembledData = assembler.transform(dataset)
-      assembledData.show()
+      val vectorAssembler = new VectorAssembler().setInputCols(featureNameList.toArray).setOutputCol("features")
 
-      // 划分训练集和测试集
-      val Array(trainData, testData) = assembledData.randomSplit(Array(0.7, 0.3))
-      trainData.show()
-      testData.show()
+      val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
+      xgbInput.show()
 
       // 创建 XGBoostClassifier 对象
       val xgbClassifier = new XGBoostClassifier()
         .setEta(0.01f)
+        .setMissing(0.0f)
+        .setMaxDepth(5)
+        .setNumRound(1000)
         .setSubsample(0.8)
         .setColsampleBytree(0.8)
         .setScalePosWeight(1)
-        .setSeed(2024)
-        .setMissing(0.0f)
+        .setObjective("binary:logistic")
+        .setEvalMetric("auc")
         .setFeaturesCol("features")
         .setLabelCol("label")
-        .setMaxDepth(5)
-        .setObjective("binary:logistic")
         .setNthread(1)
-        .setNumWorkers(1)
-        .setNumRound(100)
+        .setNumWorkers(22)
 
       // 训练模型
-      val model = xgbClassifier.fit(trainData)
+      val model = xgbClassifier.fit(xgbInput)
+
+
 
-      // 显示预测结果
-      val predictions = model.transform(testData)
-      predictions.show(100)
     }
     catch {
       case e: Throwable => e.printStackTrace()
@@ -112,17 +88,15 @@ object XGBoostTrain {
     var source: Option[Source] = None
     try {
       source = Some(Source.fromURL(filePath))
-      source.get.getLines().mkString("\n")
+      return source.get.getLines().mkString("\n")
     }
     catch {
-      case e: Exception => {
-        println("文件读取异常: " + e.toString)
-        ""
-      }
+      case e: Exception => println("文件读取异常: " + e.toString)
     }
     finally {
       source.foreach(_.close())
     }
+    ""
   }
 
   private def containsAny(list: Iterable[String], s: String): Boolean = {
@@ -133,4 +107,25 @@ object XGBoostTrain {
     }
     false
   }
+
+  private def dataMap(data: RDD[String], featureNameList: List[String]): RDD[Row] = {
+    data.map(r => {
+      val line: Array[String] = StringUtils.split(r, "\t")
+      val label: Int = NumberUtils.toInt(line(0))
+
+      val map: Map[String, Double] = line.drop(1).map { entry =>
+        val Array(key, value) = entry.split(":")
+        key -> NumberUtils.toDouble(value, 0.0)
+      }.toMap
+
+      val v: Array[Any] = Array.ofDim[Any](featureNameList.length + 1)
+      v(0) = label
+
+      for (index <- featureNameList.indices) {
+        v(index + 1) = map.getOrElse(featureNameList(index), 0.0)
+      }
+
+      Row.fromSeq(v)
+    })
+  }
 }

+ 12 - 3
src/main/scala/com/aliyun/odps/spark/examples/makedata_ad/v20240718/makedata_ad_31_originData_20240718.scala

@@ -58,6 +58,9 @@ object makedata_ad_31_originData_20240718 {
 
             val ts = record.getString("ts").toInt
             val cid = record.getString("cid")
+            val apptype = record.getString("apptype")
+            val extend: JSONObject = if (record.isNull("extend")) new JSONObject() else
+              JSON.parseObject(record.getString("extend"))
 
 
             val featureMap = new JSONObject()
@@ -94,10 +97,17 @@ object makedata_ad_31_originData_20240718 {
             }
 
             val hour = DateTimeUtil.getHourByTimestamp(ts)
-            featureMap.put("hour_" + hour, 0.1)
+            featureMap.put("hour_" + hour, idDefaultValue)
 
             val dayOfWeek = DateTimeUtil.getDayOrWeekByTimestamp(ts)
-            featureMap.put("dayofweek_" + dayOfWeek, 0.1);
+            featureMap.put("dayofweek_" + dayOfWeek, idDefaultValue);
+
+            featureMap.put("apptype_" + apptype, idDefaultValue);
+
+            if (extend.containsKey("abcode") && extend.getString("abcode").nonEmpty) {
+              featureMap.put("abcode_" + extend.getString("abcode"), idDefaultValue)
+            }
+
 
             if (b1.containsKey("cpa")) {
               featureMap.put("cpa", b1.getString("cpa").toDouble)
@@ -366,7 +376,6 @@ object makedata_ad_31_originData_20240718 {
               }
             }
             //5 处理log key表头。
-            val apptype = record.getString("apptype")
             val mid = record.getString("mid")
             val headvideoid = record.getString("headvideoid")
             val logKey = (apptype, mid, cid, ts, headvideoid).productIterator.mkString(",")

+ 1 - 1
zhangbo/03_predict.sh

@@ -12,7 +12,7 @@ $HADOOP fs -text ${train_path}/${day}/* | /root/sunmingze/alphaFM/bin/fm_predict
 cat predict/${output_file}_$day.txt | /root/sunmingze/AUC/AUC
 
 
-# nohup sh zhangbo/03_predict.sh 20240804 /dw/recommend/model/43_recsys_train_data_274/ recommend/model_nba8_v3_20240802.txt recommen/model_nba8_v3_20240802_04 0 >logs/predict_model_nba8_v3_20240802_04.log 2>&1 &
+# nohup sh zhangbo/03_predict.sh 20240805 /dw/recommend/model/33_ad_train_data_v4/ ad/model_bkb8_v55_20240804.txt ad/model_bkb8_v55_20240804 8 >logs/predict_model_bkb8_v55_20240804_05.log 2>&1 &
 # nohup sh 03_predict.sh 20240611 /dw/recommend/model/16_train_data/ model_aka4_20240610.txt model_aka4_20240610 4 >p3_model_aka4.log 2>&1 &
 # nohup sh 03_predict.sh 20240613 /dw/recommend/model/16_train_data/ model_aka8_20240612.txt model_aka8_20240612 8 >p3_model_aka8_12.log 2>&1 &