瀏覽代碼

i2i样本制作,第一步。

zhangbo 5 月之前
父節點
當前提交
c2d6a3e77a

+ 93 - 0
src/main/scala/com/aliyun/odps/spark/examples/makedata_dssm/makedata_i2i_01_originData_20241127.scala

@@ -0,0 +1,93 @@
+package com.aliyun.odps.spark.examples.makedata_dssm
+
+import com.alibaba.fastjson.{JSON, JSONObject}
+import com.aliyun.odps.TableSchema
+import com.aliyun.odps.data.Record
+import com.aliyun.odps.spark.examples.myUtils.{MyDateUtils, MyHdfsUtils, ParamUtils, env}
+import org.apache.hadoop.io.compress.GzipCodec
+import org.apache.spark.sql.SparkSession
+import scala.util.Random
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+object makedata_i2i_01_originData_20241127 {
+  def func(record: Record, schema: TableSchema): Record = {
+    record
+  }
+  def main(args: Array[String]): Unit = {
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName)
+      .getOrCreate()
+    val sc = spark.sparkContext
+
+    // 1 读取参数
+    val param = ParamUtils.parseArgs(args)
+    val tablePart = param.getOrElse("tablePart", "64").toInt
+    val beginStr = param.getOrElse("beginStr", "2024062008")
+    val endStr = param.getOrElse("endStr", "2024062023")
+    val savePath = param.getOrElse("savePath", "/dw/recommend/model/41_dssm_i2i_sample/")
+    val project = param.getOrElse("project", "loghubods")
+    val repartition = param.getOrElse("repartition", "100").toInt
+    val filterHours = param.getOrElse("filterHours", "25").split(",").toSet
+    // 2 读取odps+表信息
+    val odpsOps = env.getODPS(sc)
+    // 3 循环执行数据生产
+    val timeRange = MyDateUtils.getDateHourRange(beginStr, endStr)
+    for (dt_hh <- timeRange) {
+      val dt = dt_hh.substring(0, 8)
+      val hh = dt_hh.substring(8, 10)
+      val partition = s"dt=$dt,hh=$hh"
+      val vidsArr = odpsOps.readTable(project = project,
+          table = "t_vid_tag_feature",
+          partition = s"dt=$dt",
+          transfer = func,
+          numPartition = tablePart)
+        .map(r => {
+          r.getString("vid")
+        }).collect().toList
+      val vids_br = sc.broadcast(vidsArr)
+      if (filterHours.nonEmpty && filterHours.contains(hh)) {
+        println("不执行partiton:" + partition)
+      } else {
+        println("开始执行partiton:" + partition)
+        val odpsData = odpsOps.readTable(project = project,
+          table = "alg_dssm_sample",
+          partition = partition,
+          transfer = func,
+          numPartition = tablePart)
+          .map(record =>{
+            val apptype = record.getString("apptype")
+            val pagesource = record.getString("pagesource")
+            val mid = record.getString("mid")
+            val vid_right = record.getString("vid_right")
+            val vid_left = record.getString("vid_left")
+            val total_return_uv = record.getString("total_return_uv")
+            val view_24h = record.getString("view_24h")
+            val logKey = (apptype, pagesource, mid, vid_right, vid_left, total_return_uv, view_24h).productIterator.mkString(",")
+            (logKey, vid_left, vid_right)
+          }).mapPartitions(row =>{
+            val result = new ArrayBuffer[String]()
+            val vids = vids_br.value
+            row.foreach {
+              case (logKey, vid_left, vid_right) =>
+                val negs = Random.shuffle(vids).take(20).filter(r => !r.equals(vid_left) && !r.equals(vid_right))
+                negs.foreach(negVid =>{
+                  result.add((logKey, "0", vid_left, negVid).productIterator.mkString("\t"))
+                })
+                result.add((logKey, "1", vid_left, vid_right).productIterator.mkString("\t"))
+            }
+            result.iterator
+          })
+        val savePartition = dt + hh
+        val hdfsPath = savePath + "/" + savePartition
+        if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
+          println("删除路径并开始数据写入:" + hdfsPath)
+          MyHdfsUtils.delete_hdfs_path(hdfsPath)
+          odpsData.coalesce(repartition).saveAsTextFile(hdfsPath, classOf[GzipCodec])
+        } else {
+          println("路径不合法,无法写入:" + hdfsPath)
+        }
+      }
+    }
+  }
+}