|
@@ -0,0 +1,155 @@
|
|
|
+package com.tzld.piaoquan.ad.engine.commons.helper;
|
|
|
+
|
|
|
+import com.alibaba.fastjson.JSON;
|
|
|
+import com.alibaba.fastjson.JSONObject;
|
|
|
+import com.aliyun.oss.OSS;
|
|
|
+import com.aliyun.oss.model.ListObjectsRequest;
|
|
|
+import com.aliyun.oss.model.OSSObject;
|
|
|
+import com.aliyun.oss.model.OSSObjectSummary;
|
|
|
+import com.aliyun.oss.model.ObjectListing;
|
|
|
+import com.ctrip.framework.apollo.spring.annotation.ApolloJsonValue;
|
|
|
+import lombok.Getter;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.apache.commons.lang3.StringUtils;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.scheduling.annotation.Scheduled;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+
|
|
|
+import javax.annotation.PostConstruct;
|
|
|
+import java.io.BufferedReader;
|
|
|
+import java.io.InputStreamReader;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.HashSet;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.Set;
|
|
|
+
|
|
|
+@Slf4j
|
|
|
+@Component
|
|
|
+public class DnnCidDataHelper {
|
|
|
+
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private OSS ossClient;
|
|
|
+
|
|
|
+ private static final String BUCKET_NAME = "art-recommend";
|
|
|
+
|
|
|
+ @ApolloJsonValue("${modelVersionPath:fengzhoutian/pai_model_trained_cids/model_version.json}")
|
|
|
+ private String modelVersionPath;
|
|
|
+
|
|
|
+ @ApolloJsonValue("${cidPath:fengzhoutian/pai_model_trained_cids/}")
|
|
|
+ private String cidPath;
|
|
|
+
|
|
|
+
|
|
|
+ // 提供获取CID集合的方法
|
|
|
+ // 全局变量,存储CSV文件中的数据
|
|
|
+ @Getter
|
|
|
+ private volatile static Set<Long> cidSet = Collections.emptySet();
|
|
|
+
|
|
|
+ // 服务启动时初始化数据
|
|
|
+ @PostConstruct
|
|
|
+ public void init() {
|
|
|
+ log.info("开始初始化CID数据...");
|
|
|
+ updateCidSet();
|
|
|
+ log.info("CID数据初始化完成,共{}条记录", cidSet.size());
|
|
|
+ }
|
|
|
+
|
|
|
+ // 每10分钟更新一次数据
|
|
|
+ @Scheduled(fixedRate = 10 * 60 * 1000)
|
|
|
+ public void scheduledUpdate() {
|
|
|
+ log.info("开始定时更新CID数据...");
|
|
|
+ updateCidSet();
|
|
|
+ log.info("CID数据定时更新完成,共{}条记录", cidSet.size());
|
|
|
+ }
|
|
|
+
|
|
|
+ // 更新CID集合的方法
|
|
|
+ private synchronized void updateCidSet() {
|
|
|
+ try {
|
|
|
+ String modelVersion = readCsvPathFromOss();
|
|
|
+ if (StringUtils.isEmpty(modelVersion)) {
|
|
|
+ //TODO 报警
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ String csvPath = cidPath + modelVersion;
|
|
|
+
|
|
|
+ // 创建临时Set存储新数据
|
|
|
+ Set<Long> newCidSet = new HashSet<>();
|
|
|
+ // 列出指定文件夹下的所有对象
|
|
|
+
|
|
|
+ ListObjectsRequest listObjectsRequest = new ListObjectsRequest(BUCKET_NAME);
|
|
|
+ listObjectsRequest.setPrefix(csvPath);
|
|
|
+ listObjectsRequest.setMaxKeys(10);
|
|
|
+ ObjectListing objectListing = ossClient.listObjects(listObjectsRequest);
|
|
|
+
|
|
|
+ // 查找第一个CSV文件
|
|
|
+ String firstCsvFilePath = null;
|
|
|
+ for (OSSObjectSummary objectSummary : objectListing.getObjectSummaries()) {
|
|
|
+ String filePath = objectSummary.getKey();
|
|
|
+ if (filePath.equals(csvPath)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (filePath.toLowerCase().endsWith(".csv")) {
|
|
|
+ firstCsvFilePath = filePath;
|
|
|
+ log.info("找到CSV文件: {}", firstCsvFilePath);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 如果找到CSV文件,则读取内容
|
|
|
+ if (firstCsvFilePath != null) {
|
|
|
+ // 读取CSV文件内容到临时Set
|
|
|
+ try (OSSObject ossObject = ossClient.getObject(BUCKET_NAME, firstCsvFilePath);
|
|
|
+ BufferedReader reader = new BufferedReader(new InputStreamReader(
|
|
|
+ ossObject.getObjectContent(), StandardCharsets.UTF_8))) {
|
|
|
+ String line;
|
|
|
+ while ((line = reader.readLine()) != null) {
|
|
|
+ if (!line.trim().isEmpty()) {
|
|
|
+ newCidSet.add(Long.valueOf(line));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ log.info("成功读取CSV文件,共{}行数据", newCidSet.size());
|
|
|
+ } else {
|
|
|
+ log.warn("指定文件夹下没有找到CSV文件!");
|
|
|
+ }
|
|
|
+ // 使用volatile保证可见性,一次性替换整个集合
|
|
|
+ cidSet = Collections.unmodifiableSet(newCidSet);
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("更新CID数据失败", e);
|
|
|
+ // 发生异常时保持原数据不变
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 从OSS配置文件中读取CSV路径(第一行)
|
|
|
+ */
|
|
|
+ private String readCsvPathFromOss() {
|
|
|
+ try {
|
|
|
+ // 检查配置文件是否存在
|
|
|
+ if (!ossClient.doesObjectExist(BUCKET_NAME, modelVersionPath)) {
|
|
|
+ log.error("OSS配置文件不存在: {}", modelVersionPath);
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ // 获取配置文件内容
|
|
|
+ OSSObject ossObject = ossClient.getObject(BUCKET_NAME, "fengzhoutian/pai_model_trained_cids/model_version.json");
|
|
|
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(
|
|
|
+ ossObject.getObjectContent(), StandardCharsets.UTF_8))) {
|
|
|
+ // 读取整个文件内容
|
|
|
+ StringBuilder content = new StringBuilder();
|
|
|
+ String line;
|
|
|
+ while ((line = reader.readLine()) != null) {
|
|
|
+ content.append(line);
|
|
|
+ }
|
|
|
+ JSONObject jsonObject = JSON.parseObject(content.toString());
|
|
|
+ String modelName = jsonObject.getString("modelName");
|
|
|
+ String dtVersion = jsonObject.getString("dtVersion");
|
|
|
+ if (StringUtils.isNotEmpty(modelName) && StringUtils.isNotEmpty(dtVersion)) {
|
|
|
+ return modelName + "/" + dtVersion;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("从OSS读取配置文件失败", e);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+}
|