浏览代码

debug code

often 5 月之前
父节点
当前提交
c28367f7b1
共有 1 个文件被更改,包括 7 次插入84 次删除
  1. 7 84
      src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

+ 7 - 84
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -186,90 +186,9 @@ object video_dssm_sampler {
         )
         .persist(StorageLevel.MEMORY_AND_DISK_SER)
 
-      /*
-      // 2. 获取所有可用的vid列表
-      val allVids = time({
-        odpsOps.readTable(
-          project = "loghubods",
-          table = "t_vid_tag_feature",
-          partition = s"dt='$dt'",
-          transfer = funcVids,
-          numPartition = CONFIG("shuffle.partitions").toInt
-        ).collect().toSet
-      }, "Loading all vids")
-
-
-      val allVidsBroadcast = spark.sparkContext.broadcast(allVids)
-
-      // 注册UDF
-      spark.udf.register("generateNegativeVids", generateNegativeVidsUDF)
-      // 3. 定义UDF函数来生成负样本
-      def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => {
-        val localAllVids = allVidsBroadcast.value
-        val negativeVids = Random.shuffle(localAllVids - vid_left - vid_right).take(2) //20
-        negativeVids.toArray
-      })
-
-      // 4. 生成负样本
-      val negativeSamplesRDD = time({
-        positivePairs.rdd
-          .map(row => (row.getAs[String]("vid_left"), row.getAs[String]("vid_right")))
-          .flatMap { case (vid_left, vid_right) =>
-            val localAllVids = allVidsBroadcast.value
-            // 生成负样本对
-            Random.shuffle(localAllVids - vid_left - vid_right)
-              .take(2)
-              .map(negative_vid => Row(
-                vid_left,          // vid_left
-                negative_vid,      // vid_right
-                null,             // extend
-                null,             // view_24h
-                null,             // total_return_uv
-                null,             // ts_right
-                null,             // apptype
-                null,             // pagesource
-                null              // mid
-              ))
-          }
-      }, "Generating negative samples")
-      // 转换回DataFrame
-      val negativeSamplesDF = spark.createDataFrame(negativeSamplesRDD, schema)
-      stats.negativeSamplesCount = negativeSamplesRDD.count()
-
-
-      /*
-      val negativeSamplesDF = time({
-        val df = positivePairs
-          .select("vid_left", "vid_right")
-          .withColumn("negative_vids", generateNegativeVidsUDF(col("vid_left"), col("vid_right")))
-          .select(
-            col("vid_left"),
-            explode(col("negative_vids")).as("vid_right"),
-            lit(null).as("extend"),
-            lit(null).as("view_24h"),
-            lit(null).as("total_return_uv"),
-            lit(null).as("ts_right"),
-            lit(null).as("apptype"),
-            lit(null).as("pagesource"),
-            lit(null).as("mid")
-          )
-
-        stats.negativeSamplesCount = df.count()
-        df
-      }, "Generating negative samples")
-      */
-      // 5. 合并正负样本
-      val allSamples = positivePairs
-        .withColumn("label", lit(1))
-        .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
-        .union(
-          negativeSamplesDF
-            .withColumn("label", lit(0))
-            .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
-        ).persist(StorageLevel.MEMORY_AND_DISK_SER)
-*/
       // 6. 获取左侧特征
       // 读取L1类别统计特征
+      /*
       val l1CatStatFeatures = {
         val rdd = odpsOps.readTable(
           project = "loghubods",
@@ -446,11 +365,12 @@ object video_dssm_sampler {
           col("vid_right_cate_l2_feature")
         )
         .persist(StorageLevel.MEMORY_AND_DISK_SER)
-
+      */
 
 
       // 保存结果到HDFS
-      val resultWithDt = result.withColumn("dt", lit(s"$dt"))
+      //val resultWithDt = result.withColumn("dt", lit(s"$dt"))
+      val resultWithDt = allSamples.withColumn("dt", lit(s"$dt"))
       resultWithDt.write
         .mode("overwrite")
         .partitionBy("dt")
@@ -462,6 +382,7 @@ object video_dssm_sampler {
       positivePairs.unpersist()
       negativeSamples.unpersist()
       allSamples.unpersist()
+      /*
       l1CatStatFeatures.unpersist()
       l2CatStatFeatures.unpersist()
       tagFeatures.unpersist()
@@ -473,6 +394,8 @@ object video_dssm_sampler {
       statRightFeatures.unpersist()
       categoryRightWithStats.unpersist()
       result.unpersist()
+      */
+
       // 输出统计信息
       stats.logStats()