소스 검색

get neg sample

often 7 달 전
부모
커밋
37afa59359
1개의 변경된 파일257개의 추가작업 그리고 0개의 파일을 삭제
  1. 257 0
      src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

+ 257 - 0
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -0,0 +1,257 @@
+import org.apache.spark.sql.{DataFrame, SparkSession, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
+import scala.util.Random
+import scala.collection.mutable.ArrayBuffer
+import org.apache.log4j.{Logger, Level}
+
+object video_dssm_sampler {
+  private val logger = Logger.getLogger(this.getClass)
+
+  // 配置参数
+  private val CONFIG = Map(
+    "shuffle.partitions" -> "200",
+    "memory.fraction" -> "0.8",
+    "default.parallelism" -> "200"
+  )
+
+  case class ProcessingStats(
+                              startTime: Long = System.currentTimeMillis(),
+                              var positiveSamplesCount: Long = 0,
+                              var negativeSamplesCount: Long = 0
+                            ) {
+    def logStats(): Unit = {
+      val duration = (System.currentTimeMillis() - startTime) / 1000
+      logger.info(s"""
+                     |Processing Statistics:
+                     |Duration: ${duration}s
+                     |Positive Samples: $positiveSamplesCount
+                     |Negative Samples: $negativeSamplesCount
+                     |Total Samples: ${positiveSamplesCount + negativeSamplesCount}
+        """.stripMargin)
+    }
+  }
+
+  def createSparkSession(appName: String): SparkSession = {
+    val spark = SparkSession.builder()
+      .appName(appName)
+      .enableHiveSupport()
+      .config("spark.sql.shuffle.partitions", CONFIG("shuffle.partitions"))
+      .config("spark.memory.fraction", CONFIG("memory.fraction"))
+      .config("spark.default.parallelism", CONFIG("default.parallelism"))
+      .getOrCreate()
+
+    // 设置日志级别
+    spark.sparkContext.setLogLevel("WARN")
+    spark
+  }
+
+  def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = {
+    val stats = ProcessingStats()
+
+    try {
+      logger.info(s"Starting negative sample generation for date: $dt")
+
+      // 1. 获取并缓存正样本对数据
+      val positivePairs = time({
+        val df = spark.sql(s"""
+          SELECT
+            dt,
+            hh,
+            vid_left,
+            vid_right,
+            extend,
+            view_24h,
+            total_return_uv,
+            ts_right,
+            apptype,
+            pagesource,
+            mid
+          FROM loghubods.alg_dssm_sample
+          WHERE dt = '$dt'
+        """).persist(StorageLevel.MEMORY_AND_DISK)
+
+        stats.positiveSamplesCount = df.count()
+        df
+      }, "Loading positive pairs")
+
+      // 2. 获取所有可用的vid列表
+      val allVids = time({
+        spark.sql(s"""
+          SELECT vid
+          FROM loghubods.t_vid_tag_feature
+          WHERE dt = '$dt'
+        """)
+          .select("vid")
+          .collect()
+          .map(_.getString(0))
+          .toSet
+      }, "Loading all vids")
+
+      // 3. 定义UDF函数来生成负样本
+      def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => {
+        val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20)
+        negativeVids
+      })
+
+      // 注册UDF
+      spark.udf.register("generateNegativeVids", generateNegativeVidsUDF)
+
+      // 4. 生成负样本
+      val negativeSamplesDF = time({
+        val df = positivePairs
+          .select("dt", "hh", "vid_left", "vid_right")
+          .withColumn("negative_vids", generateNegativeVidsUDF(col("vid_left"), col("vid_right")))
+          .select(
+            col("dt"),
+            col("hh"),
+            col("vid_left"),
+            explode(col("negative_vids")).as("vid_right"),
+            lit(null).as("extend"),
+            lit(null).as("view_24h"),
+            lit(null).as("total_return_uv"),
+            lit(null).as("ts_right"),
+            lit(null).as("apptype"),
+            lit(null).as("pagesource"),
+            lit(null).as("mid")
+          )
+
+        stats.negativeSamplesCount = df.count()
+        df
+      }, "Generating negative samples")
+
+      // 5. 合并正负样本
+      val allSamples = time({
+        positivePairs
+          .withColumn("label", lit(1))
+          .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
+          .union(
+            negativeSamplesDF
+              .withColumn("label", lit(0))
+              .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
+          )
+          .persist(StorageLevel.MEMORY_AND_DISK)
+      }, "Merging positive and negative samples")
+
+      // 6. 获取特征数据
+      val features = time({
+        spark.sql(s"""
+          SELECT
+            t.vid,
+            t.feature as tag_feature,
+            s.feature as stat_feature,
+            c.cate_l1_feature,
+            c.cate_l2_feature
+          FROM loghubods.t_vid_tag_feature t
+          LEFT JOIN loghubods.t_vid_stat_feature s ON t.vid = s.vid AND s.dt = '$dt'
+          LEFT JOIN (
+            SELECT
+              a.vid,
+              b.feature as cate_l1_feature,
+              c.feature as cate_l2_feature
+            FROM (
+              SELECT
+                vid,
+                get_json_object(feature,'$.category1') as category1,
+                get_json_object(feature,'$.category2_1') as category2
+              FROM loghubods.t_vid_tag_feature
+              WHERE dt = '$dt'
+            ) a
+            LEFT JOIN (
+              SELECT category1, feature
+              FROM t_vid_l1_cat_stat_feature
+              WHERE dt = '$dt'
+            ) b ON a.category1 = b.category1
+            LEFT JOIN (
+              SELECT category2, feature
+              FROM t_vid_l2_cat_stat_feature
+              WHERE dt = '$dt'
+            ) c ON a.category2 = c.category2
+          ) c ON t.vid = c.vid
+          WHERE t.dt = '$dt'
+        """).persist(StorageLevel.MEMORY_AND_DISK)
+      }, "Loading features")
+
+      // 7. 添加特征并保存结果
+      time({
+        val result = allSamples
+          .join(broadcast(features).as("left_features"),
+            col("vid_left") === col("left_features.vid"),
+            "left")
+          .select(
+            allSamples("*"),
+            col("left_features.tag_feature").as("vid_left_tag_feature"),
+            col("left_features.stat_feature").as("vid_left_stat_feature"),
+            col("left_features.cate_l1_feature").as("vid_left_cate_l1_feature"),
+            col("left_features.cate_l2_feature").as("vid_left_cate_l2_feature")
+          )
+          .join(broadcast(features).as("right_features"),
+            col("vid_right") === col("right_features.vid"),
+            "left")
+          .select(
+            col("*"),
+            col("right_features.tag_feature").as("vid_right_tag_feature"),
+            col("right_features.stat_feature").as("vid_right_stat_feature"),
+            col("right_features.cate_l1_feature").as("vid_right_cate_l1_feature"),
+            col("right_features.cate_l2_feature").as("vid_right_cate_l2_feature")
+          )
+
+        // 保存到HDFS
+        result.write
+          .mode("overwrite")
+          .partitionBy("dt")
+          .parquet(s"$outputPath/dt=$dt")
+
+        logger.info(s"Results saved to: $outputPath/dt=$dt")
+      }, "Adding features and saving results")
+
+      // 8. 清理缓存
+      positivePairs.unpersist()
+      allSamples.unpersist()
+      features.unpersist()
+
+      // 输出统计信息
+      stats.logStats()
+
+    } catch {
+      case e: Exception =>
+        logger.error(s"Error processing data for date $dt: ${e.getMessage}")
+        e.printStackTrace()
+        throw e
+    }
+  }
+
+  private def time[R](block: => R, name: String): R = {
+    val t0 = System.nanoTime()
+    val result = block
+    val t1 = System.nanoTime()
+    logger.info(s"$name took: ${(t1 - t0) / 1e9d} seconds")
+    result
+  }
+
+  def main(args: Array[String]): Unit = {
+    if (args.length < 2) {
+      println("Usage: NegativeSampleGenerator <date> <output_path>")
+      System.exit(1)
+    }
+
+    val dt = args(0)
+    val outputPath = args(1)
+
+    val spark = createSparkSession("NegativeSampleGenerator")
+
+    try {
+      generateNegativeSamples(spark, dt, outputPath)
+    } finally {
+      spark.stop()
+    }
+  }
+}
+
+
+// 运行
+// 上传代码
+// git clone
+// mvn clean install
+//