Bläddra i källkod

模型训练不存在的cid不用模型的打分

xueyiming 1 dag sedan
förälder
incheckning
6d4fb03b43

+ 39 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/config/OssConfig.java

@@ -0,0 +1,39 @@
+package com.tzld.piaoquan.ad.engine.commons.config;
+
+import com.aliyun.oss.ClientBuilderConfiguration;
+import com.aliyun.oss.OSS;
+import com.aliyun.oss.OSSClientBuilder;
+import com.aliyun.oss.common.auth.CredentialsProvider;
+import com.aliyun.oss.common.auth.CredentialsProviderFactory;
+import com.aliyun.oss.common.auth.DefaultCredentialProvider;
+import com.aliyun.oss.common.comm.SignVersion;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+@Configuration
+public class OssConfig {
+
+    @Value("${oss.endpoint}")
+    private String endpoint;
+
+    @Value("${oss.adplatform.accessKey}")
+    private String accessKey;
+
+    @Value("${oss.adplatform.secretKey}")
+    private String secretKey;
+
+    @Bean
+    public OSS ossClient() {
+        DefaultCredentialProvider defaultCredentialProvider = CredentialsProviderFactory.newDefaultCredentialProvider(accessKey, secretKey);
+        ClientBuilderConfiguration config = new ClientBuilderConfiguration();
+        config.setSignatureVersion(SignVersion.V4);
+        return OSSClientBuilder.create()
+                .endpoint(endpoint)
+                .credentialsProvider(defaultCredentialProvider)
+                .clientConfiguration(config)
+                .region("cn-hangzhou")
+                .build();
+    }
+
+}

+ 155 - 0
ad-engine-commons/src/main/java/com/tzld/piaoquan/ad/engine/commons/helper/DnnCidDataHelper.java

@@ -0,0 +1,155 @@
+package com.tzld.piaoquan.ad.engine.commons.helper;
+
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONObject;
+import com.aliyun.oss.OSS;
+import com.aliyun.oss.model.ListObjectsRequest;
+import com.aliyun.oss.model.OSSObject;
+import com.aliyun.oss.model.OSSObjectSummary;
+import com.aliyun.oss.model.ObjectListing;
+import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
+import lombok.Getter;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.scheduling.annotation.Scheduled;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.PostConstruct;
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+@Slf4j
+@Component
+public class DnnCidDataHelper {
+
+
+    @Autowired
+    private OSS ossClient;
+
+    private static final String BUCKET_NAME = "art-recommend";
+
+    @ApolloJsonValue("${modelVersionPath:fengzhoutian/pai_model_trained_cids/model_version.json}")
+    private String modelVersionPath;
+
+    @ApolloJsonValue("${cidPath:fengzhoutian/pai_model_trained_cids/}")
+    private String cidPath;
+
+
+    // 提供获取CID集合的方法
+    // 全局变量,存储CSV文件中的数据
+    @Getter
+    private volatile static Set<Long> cidSet = Collections.emptySet();
+
+    // 服务启动时初始化数据
+    @PostConstruct
+    public void init() {
+        log.info("开始初始化CID数据...");
+        updateCidSet();
+        log.info("CID数据初始化完成,共{}条记录", cidSet.size());
+    }
+
+    // 每10分钟更新一次数据
+    @Scheduled(fixedRate = 10 * 60 * 1000)
+    public void scheduledUpdate() {
+        log.info("开始定时更新CID数据...");
+        updateCidSet();
+        log.info("CID数据定时更新完成,共{}条记录", cidSet.size());
+    }
+
+    // 更新CID集合的方法
+    private synchronized void updateCidSet() {
+        try {
+            String modelVersion = readCsvPathFromOss();
+            if (StringUtils.isEmpty(modelVersion)) {
+                //TODO 报警
+                return;
+            }
+            String csvPath = cidPath + modelVersion;
+
+            // 创建临时Set存储新数据
+            Set<Long> newCidSet = new HashSet<>();
+            // 列出指定文件夹下的所有对象
+
+            ListObjectsRequest listObjectsRequest = new ListObjectsRequest(BUCKET_NAME);
+            listObjectsRequest.setPrefix(csvPath);
+            listObjectsRequest.setMaxKeys(10);
+            ObjectListing objectListing = ossClient.listObjects(listObjectsRequest);
+
+            // 查找第一个CSV文件
+            String firstCsvFilePath = null;
+            for (OSSObjectSummary objectSummary : objectListing.getObjectSummaries()) {
+                String filePath = objectSummary.getKey();
+                if (filePath.equals(csvPath)) {
+                    continue;
+                }
+                if (filePath.toLowerCase().endsWith(".csv")) {
+                    firstCsvFilePath = filePath;
+                    log.info("找到CSV文件: {}", firstCsvFilePath);
+                    break;
+                }
+            }
+
+            // 如果找到CSV文件,则读取内容
+            if (firstCsvFilePath != null) {
+                // 读取CSV文件内容到临时Set
+                try (OSSObject ossObject = ossClient.getObject(BUCKET_NAME, firstCsvFilePath);
+                     BufferedReader reader = new BufferedReader(new InputStreamReader(
+                             ossObject.getObjectContent(), StandardCharsets.UTF_8))) {
+                    String line;
+                    while ((line = reader.readLine()) != null) {
+                        if (!line.trim().isEmpty()) {
+                            newCidSet.add(Long.valueOf(line));
+                        }
+                    }
+                }
+                log.info("成功读取CSV文件,共{}行数据", newCidSet.size());
+            } else {
+                log.warn("指定文件夹下没有找到CSV文件!");
+            }
+            // 使用volatile保证可见性,一次性替换整个集合
+            cidSet = Collections.unmodifiableSet(newCidSet);
+        } catch (Exception e) {
+            log.error("更新CID数据失败", e);
+            // 发生异常时保持原数据不变
+        }
+    }
+
+    /**
+     * 从OSS配置文件中读取CSV路径(第一行)
+     */
+    private String readCsvPathFromOss() {
+        try {
+            // 检查配置文件是否存在
+            if (!ossClient.doesObjectExist(BUCKET_NAME, modelVersionPath)) {
+                log.error("OSS配置文件不存在: {}", modelVersionPath);
+                return null;
+            }
+            // 获取配置文件内容
+            OSSObject ossObject = ossClient.getObject(BUCKET_NAME, "fengzhoutian/pai_model_trained_cids/model_version.json");
+            try (BufferedReader reader = new BufferedReader(new InputStreamReader(
+                    ossObject.getObjectContent(), StandardCharsets.UTF_8))) {
+                // 读取整个文件内容
+                StringBuilder content = new StringBuilder();
+                String line;
+                while ((line = reader.readLine()) != null) {
+                    content.append(line);
+                }
+                JSONObject jsonObject = JSON.parseObject(content.toString());
+                String modelName = jsonObject.getString("modelName");
+                String dtVersion = jsonObject.getString("dtVersion");
+                if (StringUtils.isNotEmpty(modelName) && StringUtils.isNotEmpty(dtVersion)) {
+                    return modelName + "/" + dtVersion;
+                }
+            }
+        } catch (Exception e) {
+            log.error("从OSS读取配置文件失败", e);
+        }
+        return null;
+    }
+}

+ 17 - 7
ad-engine-service/src/main/java/com/tzld/piaoquan/ad/engine/service/score/strategy/RankStrategyBy688.java

@@ -1,22 +1,20 @@
 package com.tzld.piaoquan.ad.engine.service.score.strategy;
 
-import com.alibaba.fastjson.JSONObject;
 import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
-import com.tzld.piaoquan.ad.engine.commons.redis.AdRedisHelper;
+import com.tzld.piaoquan.ad.engine.commons.dto.AdPlatformCreativeDTO;
+import com.tzld.piaoquan.ad.engine.commons.helper.DnnCidDataHelper;
+import com.tzld.piaoquan.ad.engine.commons.param.RankRecommendRequestParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScoreParam;
 import com.tzld.piaoquan.ad.engine.commons.score.ScorerUtils;
 import com.tzld.piaoquan.ad.engine.commons.thread.ThreadPoolFactory;
 import com.tzld.piaoquan.ad.engine.commons.util.*;
 import com.tzld.piaoquan.ad.engine.service.entity.GuaranteeView;
 import com.tzld.piaoquan.ad.engine.service.feature.Feature;
-import com.tzld.piaoquan.ad.engine.commons.dto.AdPlatformCreativeDTO;
-import com.tzld.piaoquan.ad.engine.commons.param.RankRecommendRequestParam;
 import com.tzld.piaoquan.recommend.feature.domain.ad.base.AdRankItem;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.collections4.CollectionUtils;
 import org.apache.commons.collections4.MapUtils;
 import org.apache.commons.lang3.StringUtils;
-import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
 import org.springframework.stereotype.Component;
 import org.xm.Similarity;
@@ -262,11 +260,23 @@ public class RankStrategyBy688 extends RankStrategyBasic {
         // getScorerPipeline
         List<AdRankItem> result = ScorerUtils.getScorerPipeline(ScorerUtils.PAI_SCORE_CONF_20250214).scoring(sceneFeatureMap, userFeatureMap, adRankItems);
         long time5 = System.currentTimeMillis();
-
         // calibrate score for negative sampling
         for (AdRankItem item : result) {
             double originalScore = item.getLrScore();
-            double calibratedScore = originalScore / (originalScore + (1 - originalScore) / negSampleRate);
+            double calibratedScore;
+            if (CollectionUtils.isNotEmpty(DnnCidDataHelper.getCidSet()) && !DnnCidDataHelper.getCidSet().contains(item.getAdId())) {
+                Map<String, Map<String, String>> cidFeature = allCidFeature.getOrDefault(String.valueOf(item.getAdId()), new HashMap<>());
+                Map<String, String> b3Feature = cidFeature.getOrDefault("alg_cid_feature_cid_action", new HashMap<>());
+                double view = Double.parseDouble(b3Feature.getOrDefault("ad_view_14d", "0"));
+                double conver = Double.parseDouble(b3Feature.getOrDefault("ad_conversion_14d", "0"));
+                if (view <= 0) {
+                    calibratedScore = 0.0;
+                } else {
+                    calibratedScore = conver / view;
+                }
+            } else {
+                calibratedScore = originalScore / (originalScore + (1 - originalScore) / negSampleRate);
+            }
             item.setLrScore(calibratedScore);
             item.getScoreMap().put("originCtcvrScore", originalScore);
             item.getScoreMap().put("ctcvrScore", calibratedScore);