| 
														
															@@ -84,6 +84,13 @@ object video_dssm_sampler { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     (record.getString("vid"), category1, category2, feature) 
														 | 
														
														 | 
														
															     (record.getString("vid"), category1, category2, feature) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															   } 
														 | 
														
														 | 
														
															   } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+  // 3. 定义UDF函数来生成负样本 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+  def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    negativeVids 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+  }) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															   def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = { 
														 | 
														
														 | 
														
															   def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     val stats = ProcessingStats() 
														 | 
														
														 | 
														
															     val stats = ProcessingStats() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -94,32 +101,31 @@ object video_dssm_sampler { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															       // 2 读取odps+表信息 
														 | 
														
														 | 
														
															       // 2 读取odps+表信息 
														 | 
													
												
											
												
													
														| 
														 | 
														
															       val odpsOps = env.getODPS(sc) 
														 | 
														
														 | 
														
															       val odpsOps = env.getODPS(sc) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															       // 1. 获取正样本对数据 
														 | 
														
														 | 
														
															       // 1. 获取正样本对数据 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-      val positivePairs = time({ 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        val schema = StructType(Array( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("dt", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("hh", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("vid_left", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("vid_right", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("extend", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("view_24h", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("total_return_uv", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("ts_right", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("apptype", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("pagesource", StringType, true), 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          StructField("mid", StringType, true) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        )) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        val rdd = odpsOps.readTable( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          project = "loghubods", 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          table = "alg_dssm_sample", 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          partition = s"dt='$dt'", 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          transfer = funcPositive, 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-          numPartition = CONFIG("shuffle.partitions").toInt 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        ).persist(StorageLevel.MEMORY_AND_DISK) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        val df = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        stats.positiveSamplesCount = df.count() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        df 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-      }, "Loading positive pairs") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+      val schema = StructType(Array( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("dt", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("hh", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("vid_left", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("vid_right", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("extend", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("view_24h", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("total_return_uv", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("ts_right", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("apptype", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("pagesource", StringType, true), 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        StructField("mid", StringType, true) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+      )) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+      logger.info(s"start read positivePairs for date $dt") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+      val rdd = odpsOps.readTable( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        project = "loghubods", 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        table = "alg_dssm_sample", 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        partition = s"dt='$dt'", 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        transfer = funcPositive, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        numPartition = CONFIG("shuffle.partitions").toInt 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+      ).persist(StorageLevel.MEMORY_AND_DISK) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+      val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+      stats.positiveSamplesCount = positivePairs.count() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+      logger.info(s"start read vid list for date $dt") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															       // 2. 获取所有可用的vid列表 
														 | 
														
														 | 
														
															       // 2. 获取所有可用的vid列表 
														 | 
													
												
											
												
													
														| 
														 | 
														
															       val allVids = time({ 
														 | 
														
														 | 
														
															       val allVids = time({ 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -132,11 +138,7 @@ object video_dssm_sampler { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         ).collect().toSet 
														 | 
														
														 | 
														
															         ).collect().toSet 
														 | 
													
												
											
												
													
														| 
														 | 
														
															       }, "Loading all vids") 
														 | 
														
														 | 
														
															       }, "Loading all vids") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-      // 3. 定义UDF函数来生成负样本 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-      def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => { 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        negativeVids 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-      }) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															       // 注册UDF 
														 | 
														
														 | 
														
															       // 注册UDF 
														 | 
													
												
											
												
													
														| 
														 | 
														
															       spark.udf.register("generateNegativeVids", generateNegativeVidsUDF) 
														 | 
														
														 | 
														
															       spark.udf.register("generateNegativeVids", generateNegativeVidsUDF) 
														 |