sunmingze пре 1 година
родитељ
комит
5d11f2f972

+ 1 - 14
src/main/java/examples/sparksql/SparkAdCTRSampleLoader.java

@@ -84,17 +84,6 @@ public class SparkAdCTRSampleLoader {
         return parseSamplesToString2(label, lrSamples);
     }
 
-    // 构建样本的字符串
-    public static String parseSamplesToString(String label, ListMultimap<FeatureGroup, BaseFeature> featureMap) {
-        ArrayList<String> featureList = new ArrayList<String>();
-        for (Map.Entry<FeatureGroup, BaseFeature> entry : featureMap.entries()) {
-            FeatureGroup groupedFeature = entry.getKey();
-            BaseFeature baseFeature = entry.getValue();
-            Long featureIdentifier = baseFeature.getIdentifier();
-            featureList.add(String.valueOf(featureIdentifier) + ":1");
-        }
-        return label + "\t" + String.join("\t", featureList);
-    }
 
 
     // 构建样本的字符串
@@ -106,14 +95,12 @@ public class SparkAdCTRSampleLoader {
                 for (int j = 0; j < groupedFeature.getFeaturesCount(); j++) {
                     BaseFeature baseFeature = groupedFeature.getFeatures(j);
                     if (baseFeature != null) {
-                        featureList.add(baseFeature.getFea());
+                        featureList.add(String.valueOf(baseFeature.getIdentifier()));
                     }
                 }
             }
         }
-
         return label + "\t" + String.join("\t", featureList);
-
     }
 
 

+ 17 - 12
src/main/java/examples/sparksql/SparkAdCVRSampleLoader.java

@@ -7,6 +7,8 @@ import com.tzld.piaoquan.recommend.feature.domain.ad.base.*;
 import com.tzld.piaoquan.recommend.feature.domain.ad.feature.VlogAdCtrLRFeatureExtractor;
 import com.tzld.piaoquan.recommend.feature.model.sample.BaseFeature;
 import com.tzld.piaoquan.recommend.feature.model.sample.FeatureGroup;
+import com.tzld.piaoquan.recommend.feature.model.sample.GroupedFeature;
+import com.tzld.piaoquan.recommend.feature.model.sample.LRSamples;
 import examples.dataloader.AdSampleConstructor;
 import org.apache.spark.SparkConf;
 import org.apache.spark.aliyun.odps.OdpsOps;
@@ -59,6 +61,7 @@ public class SparkAdCVRSampleLoader {
             label = "1";
         }
 
+
         // 从sql的 record中 初始化对象内容
         AdRequestContext requestContext = AdSampleConstructor.constructRequestContext(record);
         UserAdFeature userFeature = AdSampleConstructor.constructUserFeature(record);
@@ -73,23 +76,25 @@ public class SparkAdCVRSampleLoader {
         VlogAdCtrLRFeatureExtractor bytesFeatureExtractor;
         bytesFeatureExtractor = new VlogAdCtrLRFeatureExtractor();
 
-        bytesFeatureExtractor.getUserFeatures(userBytesFeature);
-        bytesFeatureExtractor.getItemFeature(adItemBytesFeature);
-        bytesFeatureExtractor.getContextFeatures(adRequestContextBytesFeature);
-        bytesFeatureExtractor.getCrossFeature(adItemBytesFeature, adRequestContextBytesFeature, userBytesFeature);
+        LRSamples lrSamples = bytesFeatureExtractor.single(userBytesFeature, adItemBytesFeature, adRequestContextBytesFeature);
 
-        ListMultimap<FeatureGroup, BaseFeature> featureMap = bytesFeatureExtractor.getFeatures();
-        return parseSamplesToString(label, featureMap);
+        return parseSamplesToString2(label, lrSamples);
     }
 
+
     // 构建样本的字符串
-    public static String parseSamplesToString(String label, ListMultimap<FeatureGroup, BaseFeature> featureMap) {
+    public static String parseSamplesToString2(String label, LRSamples lrSamples) {
         ArrayList<String> featureList = new ArrayList<String>();
-        for (Map.Entry<FeatureGroup, BaseFeature> entry : featureMap.entries()) {
-            FeatureGroup groupedFeature = entry.getKey();
-            BaseFeature baseFeature = entry.getValue();
-            Long featureIdentifier = baseFeature.getIdentifier();
-            featureList.add(String.valueOf(featureIdentifier) + ":1");
+        for (int i = 0; i < lrSamples.getFeaturesCount(); i++) {
+            GroupedFeature groupedFeature = lrSamples.getFeatures(i);
+            if (groupedFeature != null && groupedFeature.getFeaturesCount() != 0) {
+                for (int j = 0; j < groupedFeature.getFeaturesCount(); j++) {
+                    BaseFeature baseFeature = groupedFeature.getFeatures(j);
+                    if (baseFeature != null) {
+                        featureList.add(String.valueOf(baseFeature.getIdentifier()));
+                    }
+                }
+            }
         }
         return label + "\t" + String.join("\t", featureList);
     }