| 
					
				 | 
			
			
				@@ -129,66 +129,72 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         partition = s"dt=20241124,hh=08", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         transfer = funcPositive, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         numPartition = CONFIG("shuffle.partitions").toInt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      ).sample(false, 0.001) // 随机抽样千分之一的数据 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-       .persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      ).persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       println("开始执行partiton:" + partition) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       stats.positiveSamplesCount = positivePairs.count() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       logger.info(s"start read vid list for date $dt") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      val numberedPos = positivePairs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .withColumn("pos_id", row_number().over(Window.orderBy("vid_left", "vid_right"))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      // 2. 给候选视频表按100条分组编号 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      val vidsRDD = odpsOps.readTable( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        project = "loghubods", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        table = "t_vid_tag_feature", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        partition = s"dt='$dt'", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        transfer = funcVids, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        numPartition = CONFIG("shuffle.partitions").toInt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      // 2. 获取所有可用的vid列表 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val allVids = time({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        odpsOps.readTable( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          project = "loghubods", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          table = "t_vid_tag_feature", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          partition = s"dt='$dt'", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          transfer = funcVids, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          numPartition = CONFIG("shuffle.partitions").toInt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ).collect().toSet 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      }, "Loading all vids") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val allVidsBroadcast = spark.sparkContext.broadcast(allVids) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      // 4. 生成负样本 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val negativeSamplesRDD = time({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        positivePairs.rdd 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          .mapPartitions { iter => 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            // 每次处理一批数据,而不是单条处理 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            val batchSize = 100 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            val localVids = allVidsBroadcast.value.toArray // 只转换一次为数组 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            val random = new Random() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            iter.grouped(batchSize).flatMap { batch => 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              // 为整批数据一次性生成负样本池 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              val excludeVids = batch.flatMap(row => 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                Seq(row.getAs[String]("vid_left"), row.getAs[String]("vid_right")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              ).toSet 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              // 预先过滤一次负样本池 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              val candidatePool = localVids.filter(vid => !excludeVids.contains(vid)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              batch.flatMap { row => 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                val vid_left = row.getAs[String]("vid_left") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                // 从已过滤的池中快速采样 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                random.shuffle(candidatePool.take(1000)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                  .take(20) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                  .map(negative_vid => Row(vid_left, negative_vid, null, null, null, null, null, null, null)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+              } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      }, "Generating negative samples") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      // 转换回DataFrame 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val negativeSamplesDF = spark.createDataFrame(negativeSamplesRDD, schema) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      stats.negativeSamplesCount = negativeSamplesRDD.count() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      // 将RDD转换为DataFrame 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      val numberedVid = spark.createDataFrame(vidsRDD.map(vid => Row(vid)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructType(Seq(StructField("vid", StringType, true)))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .withColumn("group_id", floor((row_number().over(Window.orderBy(rand())) - 1) / 100)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      val negativeSamples = numberedPos 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .join(numberedVid, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          floor((col("pos_id") - 1) / 100) === col("group_id")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .where(col("vid") =!= col("vid_left") && col("vid") =!= col("vid_right")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .select( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          lit(null).as("extend"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          lit(null).as("view_24h"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          lit(null).as("total_return_uv"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          lit(null).as("ts_right"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          lit(null).as("apptype"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          lit(null).as("pagesource"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          lit(null).as("mid"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          col("vid_left"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          col("vid").as("vid_right") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .withColumn("rn", row_number().over(Window.partitionBy("vid_left").orderBy(rand()))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      // 4. 筛选每个vid_left对应的2个负样本 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      val filteredNegative = negativeSamples 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .where(col("rn") <= 2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .drop("rn") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 5. 合并正负样本 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       val allSamples = positivePairs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         .withColumn("label", lit(1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         .union( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          filteredNegative 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          negativeSamplesDF 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             .withColumn("label", lit(0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        .persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ).persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 6. 获取左侧特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 读取L1类别统计特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      /* 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       val l1CatStatFeatures = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         val rdd = odpsOps.readTable( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				           project = "loghubods", 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -365,12 +371,11 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				           col("vid_right_cate_l2_feature") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         .persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 保存结果到HDFS 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      //val resultWithDt = result.withColumn("dt", lit(s"$dt")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      val resultWithDt = allSamples.withColumn("dt", lit(s"$dt")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val resultWithDt = result.withColumn("dt", lit(s"$dt")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       resultWithDt.write 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         .mode("overwrite") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         .partitionBy("dt") 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -380,9 +385,8 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 8. 清理缓存 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       positivePairs.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      negativeSamples.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       allSamples.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      /* 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       l1CatStatFeatures.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       l2CatStatFeatures.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       tagFeatures.unpersist() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -394,8 +398,6 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       statRightFeatures.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       categoryRightWithStats.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       result.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 输出统计信息 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       stats.logStats() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |