Browse Source

add cvr model

sunmingze 1 năm trước cách đây
mục cha
commit
ed3bfe0be8

+ 1 - 1
src/main/java/examples/sparksql/SparkAdCTRSampleLoader.java

@@ -56,7 +56,7 @@ public class SparkAdCTRSampleLoader {
     public static String singleParse(Record record, String labelName) {
         // 数据解析
         String label = record.getString(labelName);
-        if (label == null || label.equals("0")) {
+        if (label == null || label.equals("1")) {
             label = "0";
         } else {
             label = "1";

+ 99 - 0
src/main/java/examples/sparksql/SparkAdCVRSampleLoader.java

@@ -0,0 +1,99 @@
+package examples.sparksql;
+
+import com.aliyun.odps.TableSchema;
+import com.aliyun.odps.data.Record;
+import com.google.common.collect.ListMultimap;
+import com.tzld.piaoquan.recommend.feature.domain.ad.base.*;
+import com.tzld.piaoquan.recommend.feature.domain.ad.feature.VlogAdCtrLRFeatureExtractor;
+import com.tzld.piaoquan.recommend.feature.model.sample.BaseFeature;
+import com.tzld.piaoquan.recommend.feature.model.sample.FeatureGroup;
+import examples.dataloader.AdSampleConstructor;
+import org.apache.spark.SparkConf;
+import org.apache.spark.aliyun.odps.OdpsOps;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function2;
+
+import java.util.ArrayList;
+import java.util.Map;
+
+
+public class SparkAdCVRSampleLoader {
+
+    public static void main(String[] args) {
+
+        String partition = args[0];
+        String accessId = "LTAIWYUujJAm7CbH";
+        String accessKey = "RfSjdiWwED1sGFlsjXv0DlfTnZTG1P";
+        String odpsUrl = "http://service.odps.aliyun.com/api";
+        String tunnelUrl = "http://dt.cn-hangzhou.maxcompute.aliyun-inc.com";
+        String project = "loghubods";
+        String table = "alg_ad_view_sample";
+        String hdfsPath = "/dw/recommend/model/ad_cvr_samples/" + partition;
+
+        SparkConf sparkConf = new SparkConf().setAppName("E-MapReduce Demo 3-2: Spark MaxCompute Demo (Java)");
+        JavaSparkContext jsc = new JavaSparkContext(sparkConf);
+        OdpsOps odpsOps = new OdpsOps(jsc.sc(), accessId, accessKey, odpsUrl, tunnelUrl);
+        System.out.println("Read odps table...");
+
+        JavaRDD<String> readData = odpsOps.readTableWithJava(project, table, partition, new RecordsToSamples(), Integer.valueOf(30));
+        readData.saveAsTextFile(hdfsPath);
+    }
+
+
+    static class RecordsToSamples implements Function2<Record, TableSchema, String> {
+        @Override
+        public String call(Record record, TableSchema schema) throws Exception {
+            String labelName = "adinvert_ornot";
+            String ret = singleParse(record, labelName);
+            return ret;
+        }
+    }
+
+
+    // 单条日志处理逻辑
+    public static String singleParse(Record record, String labelName) {
+        // 数据解析
+        String label = record.getString(labelName);
+        if (label == null || label.equals("1")) {
+            label = "0";
+        } else {
+            label = "1";
+        }
+
+        // 从sql的 record中 初始化对象内容
+        AdRequestContext requestContext = AdSampleConstructor.constructRequestContext(record);
+        UserAdFeature userFeature = AdSampleConstructor.constructUserFeature(record);
+        AdItemFeature itemFeature = AdSampleConstructor.constructItemFeature(record);
+
+        // 转化成bytes
+        AdRequestContextBytesFeature adRequestContextBytesFeature = new AdRequestContextBytesFeature(requestContext);
+        UserAdBytesFeature userBytesFeature = new UserAdBytesFeature(userFeature);
+        AdItemBytesFeature adItemBytesFeature = new AdItemBytesFeature(itemFeature);
+
+        // 特征抽取
+        VlogAdCtrLRFeatureExtractor bytesFeatureExtractor;
+        bytesFeatureExtractor = new VlogAdCtrLRFeatureExtractor();
+
+        bytesFeatureExtractor.getUserFeatures(userBytesFeature);
+        bytesFeatureExtractor.getItemFeature(adItemBytesFeature);
+        bytesFeatureExtractor.getContextFeatures(adRequestContextBytesFeature);
+        bytesFeatureExtractor.getCrossFeature(adItemBytesFeature, adRequestContextBytesFeature, userBytesFeature);
+
+        ListMultimap<FeatureGroup, BaseFeature> featureMap = bytesFeatureExtractor.getFeatures();
+        return parseSamplesToString(label, featureMap);
+    }
+
+    // 构建样本的字符串
+    public static String parseSamplesToString(String label, ListMultimap<FeatureGroup, BaseFeature> featureMap) {
+        ArrayList<String> featureList = new ArrayList<String>();
+        for (Map.Entry<FeatureGroup, BaseFeature> entry : featureMap.entries()) {
+            FeatureGroup groupedFeature = entry.getKey();
+            BaseFeature baseFeature = entry.getValue();
+            Long featureIdentifier = baseFeature.getIdentifier();
+            featureList.add(String.valueOf(featureIdentifier) + ":1");
+        }
+        return label + "\t" + String.join("\t", featureList);
+    }
+
+}

+ 5 - 0
src/main/scala/com/tzld/recommend/recall/algo/CollaborativeFilteringAlgo.scala

@@ -0,0 +1,5 @@
+package com.tzld.recommend.recall.algo
+
+class CollaborativeFilteringAlgo {
+
+}