Browse Source

修改获取特征为批量

xueyiming 1 week ago
parent
commit
499b24ee49

+ 8 - 0
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/feature/Feature.java

@@ -23,4 +23,12 @@ public class Feature {
     // k1:skuid、k2:表、v:特征值
     private Map<String, Map<String, Map<String, String>>> skuFeature = new HashMap<>();
 
+    // 合并方法,使用 putAll() 合并,且如果目标为空则取源
+    public void merge(Feature other) {
+        this.cidFeature.putAll(other.cidFeature);
+        this.videoFeature.putAll(other.videoFeature);
+        this.adVerFeature.putAll(other.adVerFeature);
+        this.userFeature.putAll(other.userFeature);
+        this.skuFeature.putAll(other.skuFeature);
+    }
 }

+ 86 - 15
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/strategy/RankStrategyBasic.java

@@ -13,6 +13,7 @@ import com.tzld.piaoquan.ad.engine.commons.param.RankRecommendRequestParam;
 import com.tzld.piaoquan.ad.engine.commons.redis.AdRedisHelper;
 import com.tzld.piaoquan.ad.engine.commons.redis.AlgorithmRedisHelper;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
+import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.ad.engine.commons.util.DateUtils;
 import com.tzld.piaoquan.ad.engine.commons.util.ObjUtil;
 import com.tzld.piaoquan.ad.engine.service.entity.*;
@@ -27,6 +28,10 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
 
 import java.util.*;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -113,6 +118,9 @@ public abstract class RankStrategyBasic implements RankStrategy {
     @Value("${filter.config.value:[]}")
     protected String filterConfigValue;
 
+    @Value("${feature.branch.config.value:500}")
+    protected Integer featureBranchConfigValue;
+
     @Autowired
     private FeatureService featureService;
     @Autowired
@@ -153,25 +161,88 @@ public abstract class RankStrategyBasic implements RankStrategy {
     }};
 
 
+//    protected Feature getFeature(ScoreParam param, RankRecommendRequestParam request) {
+//        List<AdPlatformCreativeDTO> adIdList = request.getAdIdList();
+//        List<String> cidList = adIdList.stream()
+//                .map(AdPlatformCreativeDTO::getCreativeId)
+//                .map(Object::toString)
+//                .collect(Collectors.toList());
+//
+//        List<String> adVerIdList = adIdList.stream()
+//                .map(AdPlatformCreativeDTO::getAdVerId)
+//                .filter(StringUtils::isNotBlank)
+//                .distinct()
+//                .collect(Collectors.toList());
+//
+//        List<Long> skuIdList = adIdList.stream()
+//                .map(AdPlatformCreativeDTO::getSkuId)
+//                .filter(Objects::nonNull)
+//                .distinct()
+//                .collect(Collectors.toList());
+//        return featureService.getFeature(cidList, adVerIdList, skuIdList, param);
+//    }
+
     protected Feature getFeature(ScoreParam param, RankRecommendRequestParam request) {
         List<AdPlatformCreativeDTO> adIdList = request.getAdIdList();
-        List<String> cidList = adIdList.stream()
-                .map(AdPlatformCreativeDTO::getCreativeId)
-                .map(Object::toString)
-                .collect(Collectors.toList());
+        Feature finalFeature = null;
 
-        List<String> adVerIdList = adIdList.stream()
-                .map(AdPlatformCreativeDTO::getAdVerId)
-                .filter(StringUtils::isNotBlank)
-                .distinct()
-                .collect(Collectors.toList());
+        // 分批处理 AdPlatformCreativeDTO 列表
+        List<List<AdPlatformCreativeDTO>> adIdBatches = partitionList(adIdList, featureBranchConfigValue);
 
-        List<Long> skuIdList = adIdList.stream()
-                .map(AdPlatformCreativeDTO::getSkuId)
-                .filter(Objects::nonNull)
-                .distinct()
-                .collect(Collectors.toList());
-        return featureService.getFeature(cidList, adVerIdList, skuIdList, param);
+        // 使用 ThreadPoolFactory 获取 DEFAULT 线程池
+        ExecutorService executorService = ThreadPoolFactory.defaultPool();  // 使用 DEFAULT 线程池
+        List<Callable<Feature>> tasks = new ArrayList<>();
+
+        for (List<AdPlatformCreativeDTO> batch : adIdBatches) {
+            // 提取批次中的 cidList、adVerIdList 和 skuIdList
+            List<String> cidList = batch.stream()
+                    .map(AdPlatformCreativeDTO::getCreativeId)
+                    .map(Object::toString)
+                    .collect(Collectors.toList());
+
+            List<String> adVerIdList = batch.stream()
+                    .map(AdPlatformCreativeDTO::getAdVerId)
+                    .filter(StringUtils::isNotBlank)
+                    .distinct()
+                    .collect(Collectors.toList());
+
+            List<Long> skuIdList = batch.stream()
+                    .map(AdPlatformCreativeDTO::getSkuId)
+                    .filter(Objects::nonNull)
+                    .distinct()
+                    .collect(Collectors.toList());
+
+            // 将每个批次的请求任务封装为 Callable
+            tasks.add(() -> featureService.getFeature(cidList, adVerIdList, skuIdList, param));
+        }
+
+        try {
+            // 执行所有的任务并等待所有任务完成
+            List<Future<Feature>> futures = executorService.invokeAll(tasks);
+
+            // 等待所有任务完成并合并结果
+            for (Future<Feature> future : futures) {
+                Feature batchFeature = future.get();  // 获取每个任务的结果
+                if (finalFeature == null) {
+                    finalFeature = batchFeature;
+                } else {
+                    // 合并特征
+                    finalFeature.merge(batchFeature);
+                }
+            }
+        } catch (InterruptedException | ExecutionException e) {
+            log.error("getFeature error", e);
+        }
+        return finalFeature;
+    }
+
+    // 辅助方法:将列表分成指定大小的批次
+    private <T> List<List<T>> partitionList(List<T> originalList, Integer batchSize) {
+        List<List<T>> batches = new ArrayList<>();
+        for (int i = 0; i < originalList.size(); i += batchSize) {
+            batches.add(originalList.subList(i, Math.min(i + batchSize, originalList.size())));
+        }
+        return batches;
     }
 
     protected Set<String> getNoApiAdVerIds() {