|
@@ -84,6 +84,13 @@ object video_dssm_sampler {
|
|
|
(record.getString("vid"), category1, category2, feature)
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+ // 3. 定义UDF函数来生成负样本
|
|
|
+ def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => {
|
|
|
+ val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20)
|
|
|
+ negativeVids
|
|
|
+ })
|
|
|
+
|
|
|
def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = {
|
|
|
val stats = ProcessingStats()
|
|
|
|
|
@@ -94,32 +101,31 @@ object video_dssm_sampler {
|
|
|
// 2 读取odps+表信息
|
|
|
val odpsOps = env.getODPS(sc)
|
|
|
// 1. 获取正样本对数据
|
|
|
- val positivePairs = time({
|
|
|
- val schema = StructType(Array(
|
|
|
- StructField("dt", StringType, true),
|
|
|
- StructField("hh", StringType, true),
|
|
|
- StructField("vid_left", StringType, true),
|
|
|
- StructField("vid_right", StringType, true),
|
|
|
- StructField("extend", StringType, true),
|
|
|
- StructField("view_24h", StringType, true),
|
|
|
- StructField("total_return_uv", StringType, true),
|
|
|
- StructField("ts_right", StringType, true),
|
|
|
- StructField("apptype", StringType, true),
|
|
|
- StructField("pagesource", StringType, true),
|
|
|
- StructField("mid", StringType, true)
|
|
|
- ))
|
|
|
|
|
|
- val rdd = odpsOps.readTable(
|
|
|
- project = "loghubods",
|
|
|
- table = "alg_dssm_sample",
|
|
|
- partition = s"dt='$dt'",
|
|
|
- transfer = funcPositive,
|
|
|
- numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
- ).persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
- val df = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
- stats.positiveSamplesCount = df.count()
|
|
|
- df
|
|
|
- }, "Loading positive pairs")
|
|
|
+ val schema = StructType(Array(
|
|
|
+ StructField("dt", StringType, true),
|
|
|
+ StructField("hh", StringType, true),
|
|
|
+ StructField("vid_left", StringType, true),
|
|
|
+ StructField("vid_right", StringType, true),
|
|
|
+ StructField("extend", StringType, true),
|
|
|
+ StructField("view_24h", StringType, true),
|
|
|
+ StructField("total_return_uv", StringType, true),
|
|
|
+ StructField("ts_right", StringType, true),
|
|
|
+ StructField("apptype", StringType, true),
|
|
|
+ StructField("pagesource", StringType, true),
|
|
|
+ StructField("mid", StringType, true)
|
|
|
+ ))
|
|
|
+ logger.info(s"start read positivePairs for date $dt")
|
|
|
+ val rdd = odpsOps.readTable(
|
|
|
+ project = "loghubods",
|
|
|
+ table = "alg_dssm_sample",
|
|
|
+ partition = s"dt='$dt'",
|
|
|
+ transfer = funcPositive,
|
|
|
+ numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
+ ).persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
+ val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
+ stats.positiveSamplesCount = positivePairs.count()
|
|
|
+ logger.info(s"start read vid list for date $dt")
|
|
|
|
|
|
// 2. 获取所有可用的vid列表
|
|
|
val allVids = time({
|
|
@@ -132,11 +138,7 @@ object video_dssm_sampler {
|
|
|
).collect().toSet
|
|
|
}, "Loading all vids")
|
|
|
|
|
|
- // 3. 定义UDF函数来生成负样本
|
|
|
- def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => {
|
|
|
- val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20)
|
|
|
- negativeVids
|
|
|
- })
|
|
|
+
|
|
|
|
|
|
// 注册UDF
|
|
|
spark.udf.register("generateNegativeVids", generateNegativeVidsUDF)
|