zhangbo 1 năm trước cách đây
mục cha
commit
9d870a998f

+ 55 - 39
src/main/scala/com/aliyun/odps/spark/examples/makedata/makedata_02_writeredis.scala

@@ -2,6 +2,7 @@ package com.aliyun.odps.spark.examples.makedata
 
 import com.aliyun.odps.TableSchema
 import com.aliyun.odps.data.Record
+import com.aliyun.odps.spark.examples.myUtils.ParamUtils
 import examples.dataloader.RecommRedisFeatureConstructor
 import org.apache.spark.aliyun.odps.OdpsOps
 import org.apache.spark.sql.SparkSession
@@ -9,6 +10,7 @@ import org.springframework.data.redis.connection.RedisStandaloneConfiguration
 import org.springframework.data.redis.connection.jedis.JedisConnectionFactory
 import org.springframework.data.redis.core.RedisTemplate
 import org.springframework.data.redis.serializer.StringRedisSerializer
+
 import java.util.concurrent.TimeUnit
 import java.util
 import scala.collection.JavaConversions._
@@ -22,6 +24,11 @@ object makedata_02_writeredis {
       .getOrCreate()
     val sc = spark.sparkContext
 
+    // 读取参数
+    val param = ParamUtils.parseArgs(args)
+    val ifUser = param.getOrDefault("ifUser", "True").toBoolean
+    val ifVideo = param.getOrDefault("ifVideo", "False").toBoolean
+
 
     // 读取数据库odps
     val accessKeyId = "LTAIWYUujJAm7CbH"
@@ -37,46 +44,55 @@ object makedata_02_writeredis {
     val odpsOps = OdpsOps(sc, accessKeyId, accessKeySecret, odpsUrl, tunnelUrl)
 
     //用户测特征处理
-    val userData = odpsOps.readTable(project = project, table = tableUser, partition = partition, transfer = handleUser, numPartition = 100)
-    val userDataTake = userData.take(10)
-    userDataTake.foreach(r=>{
-      println(r.get(0) + "\t" + r.get(1))
-    })
-
-    val userDataTakeRddRun = userData.sample(false, 0.1).mapPartitions(row=>{
-      val redisTemplate = this.getRedisTemplate()
-      val redisFormat = new util.HashMap[String, String]
-      row.foreach(r =>{
-        val key = r.get(0)
-        val value = r.get(1)
-        redisFormat.put(key, value)
+    if (ifUser){
+      val userData = odpsOps.readTable(project = project, table = tableUser, partition = partition, transfer = handleUser, numPartition = 100)
+      val userDataTake = userData.take(10)
+      userDataTake.foreach(r => {
+        println(r.get(0) + "\t" + r.get(1))
+      })
+
+      val userDataTakeRddRun = userData.sample(false, 0.05).mapPartitions(row => {
+        val redisTemplate = this.getRedisTemplate()
+        val redisFormat = new util.HashMap[String, String]
+        row.foreach(r => {
+          val key = r.get(0)
+          val value = r.get(1)
+          redisFormat.put(key, value)
+        })
+        redisTemplate.opsForValue.multiSet(redisFormat)
+        redisFormat.keySet.foreach(key => redisTemplate.expire(key, 24 * 7, TimeUnit.HOURS))
+        redisFormat.iterator
       })
-      redisTemplate.opsForValue.multiSet(redisFormat)
-      redisFormat.keySet.foreach(key => redisTemplate.expire(key, 24*7, TimeUnit.HOURS))
-      redisFormat.iterator
-    })
-    println("user.action.count="+userDataTakeRddRun.count())
-
-    //video测特征处理
-    println("video测特征处理")
-    val itemData = odpsOps.readTable(project = project, table = tableItem, partition = partition, transfer = handleItem, numPartition = 100)
-    val itemDataTake = itemData.take(10)
-    itemDataTake.foreach(r => {
-      println(r.get(0) + "\t" + r.get(1))
-    })
-    val itemDataTakeRddRun = itemData.mapPartitions(row => {
-      val redisTemplate = this.getRedisTemplate()
-      val redisFormat = new util.HashMap[String, String]
-      for (r <- row) {
-        val key = r.get(0)
-        val value = r.get(1)
-        redisFormat.put(key, value)
-      }
-      redisTemplate.opsForValue.multiSet(redisFormat)
-      redisFormat.keySet.foreach(key => redisTemplate.expire(key, 24*7, TimeUnit.HOURS))
-      redisFormat.iterator
-    })
-    println("item.action.count="+itemDataTakeRddRun.count())
+      println("user.action.count=" + userDataTakeRddRun.count())
+    }else{
+      println("不处理user")
+    }
+    if (ifVideo){
+      //video测特征处理
+      println("video测特征处理")
+      val itemData = odpsOps.readTable(project = project, table = tableItem, partition = partition, transfer = handleItem, numPartition = 100)
+      val itemDataTake = itemData.take(10)
+      itemDataTake.foreach(r => {
+        println(r.get(0) + "\t" + r.get(1))
+      })
+      val itemDataTakeRddRun = itemData.mapPartitions(row => {
+        val redisTemplate = this.getRedisTemplate()
+        val redisFormat = new util.HashMap[String, String]
+        for (r <- row) {
+          val key = r.get(0)
+          val value = r.get(1)
+          redisFormat.put(key, value)
+        }
+        redisTemplate.opsForValue.multiSet(redisFormat)
+        redisFormat.keySet.foreach(key => redisTemplate.expire(key, 24 * 7, TimeUnit.HOURS))
+        redisFormat.iterator
+      })
+      println("item.action.count=" + itemDataTakeRddRun.count())
+    }else{
+      println("不处理video")
+    }
+
+
   }
 
   def handleUser(record: Record, schema: TableSchema): util.ArrayList[String] = {

+ 28 - 0
src/main/scala/com/aliyun/odps/spark/examples/myUtils/ParamUtils.scala

@@ -0,0 +1,28 @@
+package com.aliyun.odps.spark.examples.myUtils
+
+import scala.collection.mutable
+object ParamUtils {
+  def parseArgs(args: Array[String]): mutable.HashMap[String, String] = {
+    println("args size:" + args.size)
+
+    val rst = new mutable.HashMap[String, String]() {
+      override def default(key: String) = "无参数传入"
+    }
+    for (a <- args) {
+      val key_val = a.split(":")
+      if (key_val.length >= 2) {
+        // 为了解决hdfs正则化路径时Array变多个的问题
+        if (rst.contains(key_val(0))) {
+          val value = rst.get(key_val(0)).get
+          val newValue = value + "," + key_val.splitAt(1)._2.mkString(":")
+          rst += (key_val(0) -> newValue)
+          println(key_val(0) + ":" + newValue)
+        } else {
+          rst += (key_val(0) -> key_val.splitAt(1)._2.mkString(":"))
+          println(key_val(0) + ":" + key_val.splitAt(1)._2.mkString(":"))
+        }
+      }
+    }
+    rst
+  }
+}