|
@@ -1,11 +1,16 @@
|
|
|
package com.aliyun.odps.spark.examples.makedata_recsys
|
|
|
-import org.apache.spark.sql.{DataFrame, SparkSession, Row}
|
|
|
+import com.aliyun.odps.TableSchema
|
|
|
+import com.aliyun.odps.data.Record
|
|
|
+import com.aliyun.odps.spark.examples.myUtils.{ParamUtils, env}
|
|
|
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
|
|
|
import org.apache.spark.sql.functions._
|
|
|
import org.apache.spark.sql.types._
|
|
|
import org.apache.spark.storage.StorageLevel
|
|
|
+import com.jayway.jsonpath.JsonPath
|
|
|
+
|
|
|
import scala.util.Random
|
|
|
import scala.collection.mutable.ArrayBuffer
|
|
|
-import org.apache.log4j.{Logger, Level}
|
|
|
+import org.apache.log4j.{Level, Logger}
|
|
|
|
|
|
object video_dssm_sampler {
|
|
|
private val logger = Logger.getLogger(this.getClass)
|
|
@@ -48,46 +53,83 @@ object video_dssm_sampler {
|
|
|
spark
|
|
|
}
|
|
|
|
|
|
+ def funcPositive(record: Record, schema: TableSchema): Row = {
|
|
|
+ Row(
|
|
|
+ record.getString("dt"),
|
|
|
+ record.getString("hh"),
|
|
|
+ record.getString("vid_left"),
|
|
|
+ record.getString("vid_right"),
|
|
|
+ record.getString("extend"),
|
|
|
+ record.getString("view_24h"),
|
|
|
+ record.getString("total_return_uv"),
|
|
|
+ record.getString("ts_right"),
|
|
|
+ record.getString("apptype"),
|
|
|
+ record.getString("pagesource"),
|
|
|
+ record.getString("mid"))
|
|
|
+ }
|
|
|
+ def funcVids(record: Record, schema: TableSchema): String = {
|
|
|
+ record.getString("vid")
|
|
|
+ }
|
|
|
+ def funcTagFeatures(record: Record, schema: TableSchema): (String, String) = {
|
|
|
+ (record.getString("vid"), record.getString("feature"))
|
|
|
+ }
|
|
|
+ def funcStatFeatures(record: Record, schema: TableSchema): (String, String) = {
|
|
|
+ (record.getString("vid"), record.getString("feature"))
|
|
|
+ }
|
|
|
+ // 获取类别特征
|
|
|
+ def funcCategoryFeatures(record: Record, schema: TableSchema): (String, String, String, String) = {
|
|
|
+ val feature = record.getString("feature")
|
|
|
+ val category1 = JsonPath.read[String](feature, "$.category1")
|
|
|
+ val category2 = JsonPath.read[String](feature, "$.category2_1")
|
|
|
+ (record.getString("vid"), category1, category2, feature)
|
|
|
+ }
|
|
|
+
|
|
|
def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = {
|
|
|
val stats = ProcessingStats()
|
|
|
|
|
|
try {
|
|
|
logger.info(s"Starting negative sample generation for date: $dt")
|
|
|
+ val sc = spark.sparkContext
|
|
|
|
|
|
- // 1. 获取并缓存正样本对数据
|
|
|
+ // 2 读取odps+表信息
|
|
|
+ val odpsOps = env.getODPS(sc)
|
|
|
+ // 1. 获取正样本对数据
|
|
|
val positivePairs = time({
|
|
|
- val df = spark.sql(s"""
|
|
|
- SELECT
|
|
|
- dt,
|
|
|
- hh,
|
|
|
- vid_left,
|
|
|
- vid_right,
|
|
|
- extend,
|
|
|
- view_24h,
|
|
|
- total_return_uv,
|
|
|
- ts_right,
|
|
|
- apptype,
|
|
|
- pagesource,
|
|
|
- mid
|
|
|
- FROM loghubods.alg_dssm_sample
|
|
|
- WHERE dt = '$dt'
|
|
|
- """).persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
-
|
|
|
+ val schema = StructType(Array(
|
|
|
+ StructField("dt", StringType, true),
|
|
|
+ StructField("hh", StringType, true),
|
|
|
+ StructField("vid_left", StringType, true),
|
|
|
+ StructField("vid_right", StringType, true),
|
|
|
+ StructField("extend", StringType, true),
|
|
|
+ StructField("view_24h", StringType, true),
|
|
|
+ StructField("total_return_uv", StringType, true),
|
|
|
+ StructField("ts_right", StringType, true),
|
|
|
+ StructField("apptype", StringType, true),
|
|
|
+ StructField("pagesource", StringType, true),
|
|
|
+ StructField("mid", StringType, true)
|
|
|
+ ))
|
|
|
+
|
|
|
+ val rdd = odpsOps.readTable(
|
|
|
+ project = "loghubods",
|
|
|
+ table = "alg_dssm_sample",
|
|
|
+ partition = s"dt='$dt'",
|
|
|
+ transfer = funcPositive,
|
|
|
+ numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
+ ).persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
+ val df = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
stats.positiveSamplesCount = df.count()
|
|
|
df
|
|
|
}, "Loading positive pairs")
|
|
|
|
|
|
// 2. 获取所有可用的vid列表
|
|
|
val allVids = time({
|
|
|
- spark.sql(s"""
|
|
|
- SELECT vid
|
|
|
- FROM loghubods.t_vid_tag_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- """)
|
|
|
- .select("vid")
|
|
|
- .collect()
|
|
|
- .map(_.getString(0))
|
|
|
- .toSet
|
|
|
+ odpsOps.readTable(
|
|
|
+ project = "loghubods",
|
|
|
+ table = "t_vid_tag_feature",
|
|
|
+ partition = s"dt='$dt'",
|
|
|
+ transfer = funcVids,
|
|
|
+ numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
+ ).collect().toSet
|
|
|
}, "Loading all vids")
|
|
|
|
|
|
// 3. 定义UDF函数来生成负样本
|
|
@@ -134,98 +176,74 @@ object video_dssm_sampler {
|
|
|
|
|
|
// 6. 获取左侧特征
|
|
|
// 获取tag特征
|
|
|
- val tagFeatures = spark.sql(s"""
|
|
|
- SELECT vid, feature as vid_left_tag_feature
|
|
|
- FROM loghubods.t_vid_tag_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- """)
|
|
|
|
|
|
- // 获取统计特征
|
|
|
- val statFeatures = spark.sql(s"""
|
|
|
- SELECT vid, feature as vid_left_stat_feature
|
|
|
- FROM loghubods.t_vid_stat_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- """)
|
|
|
-
|
|
|
- // 获取类别特征
|
|
|
- val categoryFeatures = spark.sql(s"""
|
|
|
- SELECT
|
|
|
- a.vid,
|
|
|
- b.feature as vid_left_cate_l1_feature,
|
|
|
- c.feature as vid_left_cate_l2_feature
|
|
|
- FROM (
|
|
|
- SELECT
|
|
|
- vid,
|
|
|
- get_json_object(feature,"$$.category1") as category1,
|
|
|
- get_json_object(feature,"$$.category2_1") as category2
|
|
|
- FROM loghubods.t_vid_tag_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- ) a
|
|
|
- LEFT JOIN (
|
|
|
- SELECT category1, feature
|
|
|
- FROM t_vid_l1_cat_stat_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- ) b ON a.category1 = b.category1
|
|
|
- LEFT JOIN (
|
|
|
- SELECT category2, feature
|
|
|
- FROM t_vid_l2_cat_stat_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- ) c ON a.category2 = c.category2
|
|
|
- """)
|
|
|
+ val tagFeatures = {
|
|
|
+ val rdd = odpsOps.readTable(
|
|
|
+ project = "loghubods",
|
|
|
+ table = "t_vid_tag_feature",
|
|
|
+ partition = s"dt='$dt'",
|
|
|
+ transfer = funcTagFeatures,
|
|
|
+ numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
+ )
|
|
|
+ val schema = StructType(Array(
|
|
|
+ StructField("vid", StringType, true),
|
|
|
+ StructField("vid_left_tag_feature", StringType, true)
|
|
|
+ ))
|
|
|
+ spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
|
|
|
+ }
|
|
|
|
|
|
+
|
|
|
+ // 获取统计特征
|
|
|
+ val statFeatures = {
|
|
|
+ val rdd = odpsOps.readTable(
|
|
|
+ project = "loghubods",
|
|
|
+ table = "t_vid_stat_feature",
|
|
|
+ partition = s"dt='$dt'",
|
|
|
+ transfer = funcStatFeatures,
|
|
|
+ numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
+ )
|
|
|
+ val schema = StructType(Array(
|
|
|
+ StructField("vid", StringType, true),
|
|
|
+ StructField("vid_left_stat_feature", StringType, true)
|
|
|
+ ))
|
|
|
+ spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ // 读取并处理类别特征
|
|
|
+ val categoryData = {
|
|
|
+ val rdd = odpsOps.readTable(
|
|
|
+ project = "loghubods",
|
|
|
+ table = "t_vid_tag_feature",
|
|
|
+ partition = s"dt='$dt'",
|
|
|
+ transfer = funcCategoryFeatures,
|
|
|
+ numPartition = CONFIG("shuffle.partitions").toInt
|
|
|
+ )
|
|
|
+
|
|
|
+ // 定义 schema
|
|
|
+ val schema = StructType(Array(
|
|
|
+ StructField("vid", StringType, true),
|
|
|
+ StructField("category1", StringType, true),
|
|
|
+ StructField("category2", StringType, true),
|
|
|
+ StructField("feature", StringType, true)
|
|
|
+ ))
|
|
|
+
|
|
|
+ // 将 RDD 转换为 DataFrame
|
|
|
+ spark.createDataFrame(rdd.map(t => Row(t._1, t._2, t._3, t._4)), schema)
|
|
|
+ }
|
|
|
// 合并所有左侧特征
|
|
|
val vidLeftFeatures = allSamples
|
|
|
.join(broadcast(tagFeatures), col("vid_left") === tagFeatures("vid"), "left")
|
|
|
.join(broadcast(statFeatures), col("vid_left") === statFeatures("vid"), "left")
|
|
|
- .join(broadcast(categoryFeatures), col("vid_left") === categoryFeatures("vid"), "left")
|
|
|
+ .join(broadcast(categoryData), col("vid_left") === categoryData("vid"), "left")
|
|
|
.persist(StorageLevel.MEMORY_AND_DISK)
|
|
|
|
|
|
- // 7. 获取右侧特征并生成最终结果
|
|
|
- // 获取tag特征
|
|
|
- val rightTagFeatures = spark.sql(s"""
|
|
|
- SELECT vid, feature as vid_right_tag_feature
|
|
|
- FROM loghubods.t_vid_tag_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- """)
|
|
|
-
|
|
|
- // 获取统计特征
|
|
|
- val rightStatFeatures = spark.sql(s"""
|
|
|
- SELECT vid, feature as vid_right_stat_feature
|
|
|
- FROM loghubods.t_vid_stat_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- """)
|
|
|
-
|
|
|
- // 获取类别特征
|
|
|
- val rightCategoryFeatures = spark.sql(s"""
|
|
|
- SELECT
|
|
|
- a.vid,
|
|
|
- b.feature as vid_right_cate_l1_feature,
|
|
|
- c.feature as vid_right_cate_l2_feature
|
|
|
- FROM (
|
|
|
- SELECT
|
|
|
- vid,
|
|
|
- get_json_object(feature,"$$.category1") as category1,
|
|
|
- get_json_object(feature,"$$.category2_1") as category2
|
|
|
- FROM loghubods.t_vid_tag_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- ) a
|
|
|
- LEFT JOIN (
|
|
|
- SELECT category1, feature
|
|
|
- FROM t_vid_l1_cat_stat_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- ) b ON a.category1 = b.category1
|
|
|
- LEFT JOIN (
|
|
|
- SELECT category2, feature
|
|
|
- FROM t_vid_l2_cat_stat_feature
|
|
|
- WHERE dt = '$dt'
|
|
|
- ) c ON a.category2 = c.category2
|
|
|
- """)
|
|
|
|
|
|
// 合并所有右侧特征并生成最终结果
|
|
|
val result = vidLeftFeatures
|
|
|
- .join(broadcast(rightTagFeatures), col("vid_right") === rightTagFeatures("vid"), "left")
|
|
|
- .join(broadcast(rightStatFeatures), col("vid_right") === rightStatFeatures("vid"), "left")
|
|
|
- .join(broadcast(rightCategoryFeatures), col("vid_right") === rightCategoryFeatures("vid"), "left")
|
|
|
+ .join(broadcast(tagFeatures), col("vid_right") === tagFeatures("vid"), "left")
|
|
|
+ .join(broadcast(statFeatures), col("vid_right") === statFeatures("vid"), "left")
|
|
|
+ .join(broadcast(categoryData), col("vid_right") === categoryData("vid"), "left")
|
|
|
|
|
|
// 保存结果到HDFS
|
|
|
result.write
|
|
@@ -239,7 +257,7 @@ object video_dssm_sampler {
|
|
|
positivePairs.unpersist()
|
|
|
allSamples.unpersist()
|
|
|
vidLeftFeatures.unpersist()
|
|
|
- rightTagFeatures.unpersist()
|
|
|
+
|
|
|
|
|
|
// 输出统计信息
|
|
|
stats.logStats()
|