| 
					
				 | 
			
			
				@@ -7,7 +7,7 @@ import org.apache.spark.sql.functions._ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.sql.types._ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.spark.storage.StorageLevel 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import com.jayway.jsonpath.JsonPath 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.expressions.Window 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import scala.util.Random 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import scala.collection.mutable.ArrayBuffer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.apache.log4j.{Level, Logger} 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -136,6 +136,58 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       stats.positiveSamplesCount = positivePairs.count() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       logger.info(s"start read vid list for date $dt") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val numberedPos = positivePairs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .withColumn("pos_id", row_number().over(Window.orderBy("vid_left", "vid_right"))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      // 2. 给候选视频表按100条分组编号 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val vidsRDD = odpsOps.readTable( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        project = "loghubods", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        table = "t_vid_tag_feature", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        partition = s"dt='$dt'", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        transfer = funcVids, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        numPartition = CONFIG("shuffle.partitions").toInt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      // 将RDD转换为DataFrame 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val numberedVid = spark.createDataFrame(vidsRDD.map(vid => Row(vid)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          StructType(Seq(StructField("vid", StringType, true)))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .withColumn("group_id", floor((row_number().over(Window.orderBy(rand())) - 1) / 100)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val negativeSamples = numberedPos 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .join(numberedVid, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          floor((col("pos_id") - 1) / 100) === col("group_id")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .where(col("vid") =!= col("vid_left") && col("vid") =!= col("vid_right")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .select( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          col("dt"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          col("hh"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          lit(null).as("extend"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          lit(null).as("view_24h"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          lit(null).as("total_return_uv"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          lit(null).as("ts_right"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          lit(null).as("apptype"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          lit(null).as("pagesource"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          lit(null).as("mid"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          col("vid_left"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          col("vid").as("vid_right") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .withColumn("rn", row_number().over(Window.partitionBy("vid_left").orderBy(rand()))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      // 4. 筛选每个vid_left对应的2个负样本 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val filteredNegative = negativeSamples 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .where(col("rn") <= 2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .drop("rn") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      // 5. 合并正负样本 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val allSamples = positivePairs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .withColumn("label", lit(1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .union( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          filteredNegative 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            .withColumn("label", lit(0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        .persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      /* 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 2. 获取所有可用的vid列表 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       val allVids = time({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         odpsOps.readTable( 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -216,7 +268,7 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             .withColumn("label", lit(0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ).persist(StorageLevel.MEMORY_AND_DISK_SER) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+*/ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 6. 获取左侧特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 读取L1类别统计特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       val l1CatStatFeatures = { 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -409,7 +461,7 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 8. 清理缓存 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       positivePairs.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      negativeSamplesDF.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      negativeSamples.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       allSamples.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       l1CatStatFeatures.unpersist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       l2CatStatFeatures.unpersist() 
			 |