|
@@ -129,66 +129,72 @@ object video_dssm_sampler {
|
|
partition = s"dt=20241124,hh=08",
|
|
partition = s"dt=20241124,hh=08",
|
|
transfer = funcPositive,
|
|
transfer = funcPositive,
|
|
numPartition = CONFIG("shuffle.partitions").toInt
|
|
numPartition = CONFIG("shuffle.partitions").toInt
|
|
- ).sample(false, 0.001) // 随机抽样千分之一的数据
|
|
|
|
- .persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
|
|
|
|
+ ).persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
println("开始执行partiton:" + partition)
|
|
println("开始执行partiton:" + partition)
|
|
|
|
|
|
val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
stats.positiveSamplesCount = positivePairs.count()
|
|
stats.positiveSamplesCount = positivePairs.count()
|
|
logger.info(s"start read vid list for date $dt")
|
|
logger.info(s"start read vid list for date $dt")
|
|
|
|
|
|
- val numberedPos = positivePairs
|
|
|
|
- .withColumn("pos_id", row_number().over(Window.orderBy("vid_left", "vid_right")))
|
|
|
|
- // 2. 给候选视频表按100条分组编号
|
|
|
|
- val vidsRDD = odpsOps.readTable(
|
|
|
|
- project = "loghubods",
|
|
|
|
- table = "t_vid_tag_feature",
|
|
|
|
- partition = s"dt='$dt'",
|
|
|
|
- transfer = funcVids,
|
|
|
|
- numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
|
- )
|
|
|
|
|
|
+ // 2. 获取所有可用的vid列表
|
|
|
|
+ val allVids = time({
|
|
|
|
+ odpsOps.readTable(
|
|
|
|
+ project = "loghubods",
|
|
|
|
+ table = "t_vid_tag_feature",
|
|
|
|
+ partition = s"dt='$dt'",
|
|
|
|
+ transfer = funcVids,
|
|
|
|
+ numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
|
+ ).collect().toSet
|
|
|
|
+ }, "Loading all vids")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ val allVidsBroadcast = spark.sparkContext.broadcast(allVids)
|
|
|
|
+
|
|
|
|
+ // 4. 生成负样本
|
|
|
|
+ val negativeSamplesRDD = time({
|
|
|
|
+ positivePairs.rdd
|
|
|
|
+ .mapPartitions { iter =>
|
|
|
|
+ // 每次处理一批数据,而不是单条处理
|
|
|
|
+ val batchSize = 100
|
|
|
|
+ val localVids = allVidsBroadcast.value.toArray // 只转换一次为数组
|
|
|
|
+ val random = new Random()
|
|
|
|
+
|
|
|
|
+ iter.grouped(batchSize).flatMap { batch =>
|
|
|
|
+ // 为整批数据一次性生成负样本池
|
|
|
|
+ val excludeVids = batch.flatMap(row =>
|
|
|
|
+ Seq(row.getAs[String]("vid_left"), row.getAs[String]("vid_right"))
|
|
|
|
+ ).toSet
|
|
|
|
+
|
|
|
|
+ // 预先过滤一次负样本池
|
|
|
|
+ val candidatePool = localVids.filter(vid => !excludeVids.contains(vid))
|
|
|
|
+
|
|
|
|
+ batch.flatMap { row =>
|
|
|
|
+ val vid_left = row.getAs[String]("vid_left")
|
|
|
|
+ // 从已过滤的池中快速采样
|
|
|
|
+ random.shuffle(candidatePool.take(1000))
|
|
|
|
+ .take(20)
|
|
|
|
+ .map(negative_vid => Row(vid_left, negative_vid, null, null, null, null, null, null, null))
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }, "Generating negative samples")
|
|
|
|
+ // 转换回DataFrame
|
|
|
|
+ val negativeSamplesDF = spark.createDataFrame(negativeSamplesRDD, schema)
|
|
|
|
+ stats.negativeSamplesCount = negativeSamplesRDD.count()
|
|
|
|
|
|
- // 将RDD转换为DataFrame
|
|
|
|
- val numberedVid = spark.createDataFrame(vidsRDD.map(vid => Row(vid)),
|
|
|
|
- StructType(Seq(StructField("vid", StringType, true))))
|
|
|
|
- .withColumn("group_id", floor((row_number().over(Window.orderBy(rand())) - 1) / 100))
|
|
|
|
- .persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
|
|
-
|
|
|
|
- val negativeSamples = numberedPos
|
|
|
|
- .join(numberedVid,
|
|
|
|
- floor((col("pos_id") - 1) / 100) === col("group_id"))
|
|
|
|
- .where(col("vid") =!= col("vid_left") && col("vid") =!= col("vid_right"))
|
|
|
|
- .select(
|
|
|
|
- lit(null).as("extend"),
|
|
|
|
- lit(null).as("view_24h"),
|
|
|
|
- lit(null).as("total_return_uv"),
|
|
|
|
- lit(null).as("ts_right"),
|
|
|
|
- lit(null).as("apptype"),
|
|
|
|
- lit(null).as("pagesource"),
|
|
|
|
- lit(null).as("mid"),
|
|
|
|
- col("vid_left"),
|
|
|
|
- col("vid").as("vid_right")
|
|
|
|
- )
|
|
|
|
- .withColumn("rn", row_number().over(Window.partitionBy("vid_left").orderBy(rand())))
|
|
|
|
- // 4. 筛选每个vid_left对应的2个负样本
|
|
|
|
- val filteredNegative = negativeSamples
|
|
|
|
- .where(col("rn") <= 2)
|
|
|
|
- .drop("rn")
|
|
|
|
|
|
|
|
// 5. 合并正负样本
|
|
// 5. 合并正负样本
|
|
val allSamples = positivePairs
|
|
val allSamples = positivePairs
|
|
.withColumn("label", lit(1))
|
|
.withColumn("label", lit(1))
|
|
.withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
|
|
.withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
|
|
.union(
|
|
.union(
|
|
- filteredNegative
|
|
|
|
|
|
+ negativeSamplesDF
|
|
.withColumn("label", lit(0))
|
|
.withColumn("label", lit(0))
|
|
.withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
|
|
.withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
|
|
- )
|
|
|
|
- .persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
|
|
|
|
+ ).persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
|
|
|
|
// 6. 获取左侧特征
|
|
// 6. 获取左侧特征
|
|
// 读取L1类别统计特征
|
|
// 读取L1类别统计特征
|
|
- /*
|
|
|
|
val l1CatStatFeatures = {
|
|
val l1CatStatFeatures = {
|
|
val rdd = odpsOps.readTable(
|
|
val rdd = odpsOps.readTable(
|
|
project = "loghubods",
|
|
project = "loghubods",
|
|
@@ -365,12 +371,11 @@ object video_dssm_sampler {
|
|
col("vid_right_cate_l2_feature")
|
|
col("vid_right_cate_l2_feature")
|
|
)
|
|
)
|
|
.persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
.persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
- */
|
|
|
|
|
|
+
|
|
|
|
|
|
|
|
|
|
// 保存结果到HDFS
|
|
// 保存结果到HDFS
|
|
- //val resultWithDt = result.withColumn("dt", lit(s"$dt"))
|
|
|
|
- val resultWithDt = allSamples.withColumn("dt", lit(s"$dt"))
|
|
|
|
|
|
+ val resultWithDt = result.withColumn("dt", lit(s"$dt"))
|
|
resultWithDt.write
|
|
resultWithDt.write
|
|
.mode("overwrite")
|
|
.mode("overwrite")
|
|
.partitionBy("dt")
|
|
.partitionBy("dt")
|
|
@@ -380,9 +385,8 @@ object video_dssm_sampler {
|
|
|
|
|
|
// 8. 清理缓存
|
|
// 8. 清理缓存
|
|
positivePairs.unpersist()
|
|
positivePairs.unpersist()
|
|
- negativeSamples.unpersist()
|
|
|
|
|
|
+
|
|
allSamples.unpersist()
|
|
allSamples.unpersist()
|
|
- /*
|
|
|
|
l1CatStatFeatures.unpersist()
|
|
l1CatStatFeatures.unpersist()
|
|
l2CatStatFeatures.unpersist()
|
|
l2CatStatFeatures.unpersist()
|
|
tagFeatures.unpersist()
|
|
tagFeatures.unpersist()
|
|
@@ -394,8 +398,6 @@ object video_dssm_sampler {
|
|
statRightFeatures.unpersist()
|
|
statRightFeatures.unpersist()
|
|
categoryRightWithStats.unpersist()
|
|
categoryRightWithStats.unpersist()
|
|
result.unpersist()
|
|
result.unpersist()
|
|
- */
|
|
|
|
-
|
|
|
|
// 输出统计信息
|
|
// 输出统计信息
|
|
stats.logStats()
|
|
stats.logStats()
|
|
|
|
|