|
@@ -85,11 +85,7 @@ object video_dssm_sampler {
|
|
|
}
|
|
|
|
|
|
|
|
|
- // 3. 定义UDF函数来生成负样本
|
|
|
- def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => {
|
|
|
- val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20)
|
|
|
- negativeVids
|
|
|
- })
|
|
|
+
|
|
|
|
|
|
def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = {
|
|
|
val stats = ProcessingStats()
|
|
@@ -142,7 +138,11 @@ object video_dssm_sampler {
|
|
|
|
|
|
// 注册UDF
|
|
|
spark.udf.register("generateNegativeVids", generateNegativeVidsUDF)
|
|
|
-
|
|
|
+ // 3. 定义UDF函数来生成负样本
|
|
|
+ def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => {
|
|
|
+ val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20)
|
|
|
+ negativeVids
|
|
|
+ })
|
|
|
// 4. 生成负样本
|
|
|
val negativeSamplesDF = time({
|
|
|
val df = positivePairs
|