|
@@ -7,7 +7,7 @@ import org.apache.spark.sql.functions._
|
|
|
import org.apache.spark.sql.types._
|
|
|
import org.apache.spark.storage.StorageLevel
|
|
|
import com.jayway.jsonpath.JsonPath
|
|
|
-
|
|
|
+import org.apache.spark.sql.expressions.Window
|
|
|
import scala.util.Random
|
|
|
import scala.collection.mutable.ArrayBuffer
|
|
|
import org.apache.log4j.{Level, Logger}
|
|
@@ -136,6 +136,58 @@ object video_dssm_sampler {
|
|
|
stats.positiveSamplesCount = positivePairs.count()
|
|
|
logger.info(s"start read vid list for date $dt")
|
|
|
|
|
|
+ val numberedPos = positivePairs
|
|
|
+ .withColumn("pos_id", row_number().over(Window.orderBy("vid_left", "vid_right")))
|
|
|
+ // 2. 给候选视频表按100条分组编号
|
|
|
+ val vidsRDD = odpsOps.readTable(
|
|
|
+ project = "loghubods",
|
|
|
+ table = "t_vid_tag_feature",
|
|
|
+ partition = s"dt='$dt'",
|
|
|
+ transfer = funcVids,
|
|
|
+ numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
+ )
|
|
|
+
|
|
|
+ // 将RDD转换为DataFrame
|
|
|
+ val numberedVid = spark.createDataFrame(vidsRDD.map(vid => Row(vid)),
|
|
|
+ StructType(Seq(StructField("vid", StringType, true))))
|
|
|
+ .withColumn("group_id", floor((row_number().over(Window.orderBy(rand())) - 1) / 100))
|
|
|
+ .persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
|
+
|
|
|
+ val negativeSamples = numberedPos
|
|
|
+ .join(numberedVid,
|
|
|
+ floor((col("pos_id") - 1) / 100) === col("group_id"))
|
|
|
+ .where(col("vid") =!= col("vid_left") && col("vid") =!= col("vid_right"))
|
|
|
+ .select(
|
|
|
+ col("dt"),
|
|
|
+ col("hh"),
|
|
|
+ lit(null).as("extend"),
|
|
|
+ lit(null).as("view_24h"),
|
|
|
+ lit(null).as("total_return_uv"),
|
|
|
+ lit(null).as("ts_right"),
|
|
|
+ lit(null).as("apptype"),
|
|
|
+ lit(null).as("pagesource"),
|
|
|
+ lit(null).as("mid"),
|
|
|
+ col("vid_left"),
|
|
|
+ col("vid").as("vid_right")
|
|
|
+ )
|
|
|
+ .withColumn("rn", row_number().over(Window.partitionBy("vid_left").orderBy(rand())))
|
|
|
+ // 4. 筛选每个vid_left对应的2个负样本
|
|
|
+ val filteredNegative = negativeSamples
|
|
|
+ .where(col("rn") <= 2)
|
|
|
+ .drop("rn")
|
|
|
+
|
|
|
+ // 5. 合并正负样本
|
|
|
+ val allSamples = positivePairs
|
|
|
+ .withColumn("label", lit(1))
|
|
|
+ .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
|
|
|
+ .union(
|
|
|
+ filteredNegative
|
|
|
+ .withColumn("label", lit(0))
|
|
|
+ .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
|
|
|
+ )
|
|
|
+ .persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
|
+
|
|
|
+ /*
|
|
|
// 2. 获取所有可用的vid列表
|
|
|
val allVids = time({
|
|
|
odpsOps.readTable(
|
|
@@ -216,7 +268,7 @@ object video_dssm_sampler {
|
|
|
.withColumn("label", lit(0))
|
|
|
.withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
|
|
|
).persist(StorageLevel.MEMORY_AND_DISK_SER)
|
|
|
-
|
|
|
+*/
|
|
|
// 6. 获取左侧特征
|
|
|
// 读取L1类别统计特征
|
|
|
val l1CatStatFeatures = {
|
|
@@ -409,7 +461,7 @@ object video_dssm_sampler {
|
|
|
|
|
|
// 8. 清理缓存
|
|
|
positivePairs.unpersist()
|
|
|
- negativeSamplesDF.unpersist()
|
|
|
+ negativeSamples.unpersist()
|
|
|
allSamples.unpersist()
|
|
|
l1CatStatFeatures.unpersist()
|
|
|
l2CatStatFeatures.unpersist()
|