| 
					
				 | 
			
			
				@@ -84,6 +84,13 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     (record.getString("vid"), category1, category2, feature) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				   } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+  // 3. 定义UDF函数来生成负样本 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+  def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    negativeVids 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+  }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				   def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     val stats = ProcessingStats() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -94,32 +101,31 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 2 读取odps+表信息 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       val odpsOps = env.getODPS(sc) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 1. 获取正样本对数据 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      val positivePairs = time({ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        val schema = StructType(Array( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("dt", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("hh", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("vid_left", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("vid_right", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("extend", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("view_24h", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("total_return_uv", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("ts_right", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("apptype", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("pagesource", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          StructField("mid", StringType, true) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        )) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        val rdd = odpsOps.readTable( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          project = "loghubods", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          table = "alg_dssm_sample", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          partition = s"dt='$dt'", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          transfer = funcPositive, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-          numPartition = CONFIG("shuffle.partitions").toInt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ).persist(StorageLevel.MEMORY_AND_DISK) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        val df = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        stats.positiveSamplesCount = df.count() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        df 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      }, "Loading positive pairs") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val schema = StructType(Array( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("dt", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("hh", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("vid_left", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("vid_right", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("extend", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("view_24h", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("total_return_uv", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("ts_right", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("apptype", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("pagesource", StringType, true), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        StructField("mid", StringType, true) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      )) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      logger.info(s"start read positivePairs for date $dt") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val rdd = odpsOps.readTable( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        project = "loghubods", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        table = "alg_dssm_sample", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        partition = s"dt='$dt'", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        transfer = funcPositive, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        numPartition = CONFIG("shuffle.partitions").toInt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      ).persist(StorageLevel.MEMORY_AND_DISK) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      stats.positiveSamplesCount = positivePairs.count() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+      logger.info(s"start read vid list for date $dt") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 2. 获取所有可用的vid列表 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       val allVids = time({ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -132,11 +138,7 @@ object video_dssm_sampler { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ).collect().toSet 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       }, "Loading all vids") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      // 3. 定义UDF函数来生成负样本 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        negativeVids 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      }) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       // 注册UDF 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				       spark.udf.register("generateNegativeVids", generateNegativeVidsUDF) 
			 |