Parcourir la source

update cluster strategy

supeng il y a 2 semaines
Parent
commit
ecaf39ae54

+ 98 - 4
supply-demand-engine-core/src/main/java/com/tzld/piaoquan/sde/service/strategy/ClusterStrategy.java

@@ -1,12 +1,27 @@
 package com.tzld.piaoquan.sde.service.strategy;
 
+import com.alibaba.fastjson.JSON;
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.toolkit.Wrappers;
+import com.tzld.piaoquan.sde.common.enums.IsDeleteEnum;
+import com.tzld.piaoquan.sde.common.enums.SubTaskStatusEnum;
 import com.tzld.piaoquan.sde.common.enums.TaskTypeEnum;
+import com.tzld.piaoquan.sde.integration.OpenRouterClient;
+import com.tzld.piaoquan.sde.mapper.SdPromptTemplateMapper;
+import com.tzld.piaoquan.sde.mapper.SdSubTaskMapper;
+import com.tzld.piaoquan.sde.mapper.SdSubTaskResultItemMapper;
+import com.tzld.piaoquan.sde.model.dto.StrategyConfigDTO;
 import com.tzld.piaoquan.sde.model.dto.StrategyResultDTO;
-import com.tzld.piaoquan.sde.model.entity.SdStrategy;
-import com.tzld.piaoquan.sde.model.entity.SdTask;
+import com.tzld.piaoquan.sde.model.entity.*;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Component;
 
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
 /**
  * 聚类任务策略
  *
@@ -16,6 +31,18 @@ import org.springframework.stereotype.Component;
 @Component
 public class ClusterStrategy implements DemandExtractionStrategy {
 
+    @Autowired
+    private SdSubTaskMapper sdSubTaskMapper;
+
+    @Autowired
+    private SdSubTaskResultItemMapper sdSubTaskResultItemMapper;
+
+    @Autowired
+    private SdPromptTemplateMapper sdPromptTemplateMapper;
+
+    @Autowired
+    private OpenRouterClient openRouterClient;
+
     @Override
     public Integer getType() {
         return TaskTypeEnum.CLUSTER.getValue();
@@ -25,9 +52,76 @@ public class ClusterStrategy implements DemandExtractionStrategy {
     public StrategyResultDTO execute(SdTask task, SdStrategy strategy) {
         log.info("Executing cluster task: taskId={}, taskNo={}", task.getId(), task.getTaskNo());
 
-        // TODO: 实现聚类任务逻辑
+        StrategyConfigDTO strategyConfigDTO = JSON.parseObject(strategy.getConfig(), StrategyConfigDTO.class);
+        // 1. 获取Prompt模板
+        SdPromptTemplate promptTemplate = sdPromptTemplateMapper.selectById(strategyConfigDTO.getPromptTemplateId());
+        if (Objects.isNull(promptTemplate)) {
+            log.error("Prompt template not found: promptTemplateId={}", strategyConfigDTO.getPromptTemplateId());
+            return StrategyResultDTO.builder()
+                    .success(false)
+                    .msg("Prompt template not found")
+                    .build();
+        }
+
+        // 2. 查询成功的子任务
+        LambdaQueryWrapper<SdSubTask> subTaskWrapper = Wrappers.lambdaQuery(SdSubTask.class)
+                .eq(SdSubTask::getIsDeleted, IsDeleteEnum.NORMAL.getValue())
+                .eq(SdSubTask::getTaskId, task.getId())
+                .eq(SdSubTask::getTaskStatus, SubTaskStatusEnum.SUCCESS.getValue())
+                .eq(SdSubTask::getSubTaskType, task.getTaskType())
+                .select(SdSubTask::getId);
+        List<SdSubTask> subTasks = sdSubTaskMapper.selectList(subTaskWrapper);
+        if (Objects.isNull(subTasks) || subTasks.isEmpty()) {
+            log.warn("No successful subtasks found: taskId={}", task.getId());
+            return StrategyResultDTO.builder()
+                    .success(false)
+                    .msg("No successful subtasks found")
+                    .build();
+        }
+
+        // 3. 查询子任务结果项
+        List<Long> subTaskIds = subTasks.stream().map(SdSubTask::getId).collect(Collectors.toList());
+        // TODO: 根据实际业务确定itemType
+        String itemType = "";
+        LambdaQueryWrapper<SdSubTaskResultItem> resultItemWrapper = Wrappers.lambdaQuery(SdSubTaskResultItem.class)
+                .eq(SdSubTaskResultItem::getIsDeleted, IsDeleteEnum.NORMAL.getValue())
+                .in(SdSubTaskResultItem::getSubTaskId, subTaskIds)
+                .eq(SdSubTaskResultItem::getItemType, itemType)
+                .select(SdSubTaskResultItem::getItemContent);
+        List<SdSubTaskResultItem> resultItems = sdSubTaskResultItemMapper.selectList(resultItemWrapper);
+        if (Objects.isNull(resultItems) || resultItems.isEmpty()) {
+            log.warn("No subtask result items found: taskId={}", task.getId());
+            return StrategyResultDTO.builder()
+                    .success(false)
+                    .msg("No subtask result items found")
+                    .build();
+        }
+
+        // 4. 构建输入内容
+        String input = resultItems.stream()
+                .map(SdSubTaskResultItem::getItemContent)
+                .filter(Objects::nonNull)
+                .collect(Collectors.joining("\n"));
+
+        // 5. 构建最终Prompt
+        String prompt = promptTemplate.getPromptContent();
+        String finalPrompt = prompt.replace("{{input}}", input);
 
+        // 6. 调用AI服务 (TODO: 实现具体调用逻辑)
+        log.info("Preparing to call AI service: taskId={}, promptLength={}", task.getId(), finalPrompt.length());
+        try {
+            String result = openRouterClient.chat(strategyConfigDTO.getModel(), finalPrompt);
+            return StrategyResultDTO.builder()
+                    .success(true)
+                    .result(result)
+                    .build();
+        } catch (IOException e) {
+            log.error("open router client error", e);
+        }
         log.info("Cluster task execution completed: taskId={}", task.getId());
-        return StrategyResultDTO.builder().build();
+        return StrategyResultDTO.builder()
+                .success(false)
+                .msg("execute Cluster strategy failure")
+                .build();
     }
 }