supeng hace 5 horas
padre
commit
f62bb2a278

+ 4 - 0
supply-demand-engine-core/pom.xml

@@ -40,6 +40,10 @@
             <groupId>org.apache.httpcomponents</groupId>
             <artifactId>httpclient</artifactId>
         </dependency>
+        <dependency>
+            <groupId>com.squareup.okhttp3</groupId>
+            <artifactId>okhttp</artifactId>
+        </dependency>
 
         <!-- Bean 校验(javax.validation.constraints.*) -->
         <dependency>

+ 165 - 0
supply-demand-engine-core/src/main/java/com/tzld/piaoquan/sde/integration/OpenRouterClient.java

@@ -1,12 +1,177 @@
 package com.tzld.piaoquan.sde.integration;
 
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONArray;
+import com.alibaba.fastjson.JSONObject;
 import lombok.extern.slf4j.Slf4j;
+import okhttp3.*;
+import org.springframework.beans.factory.annotation.Value;
 import org.springframework.stereotype.Service;
 
+import javax.annotation.PostConstruct;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
 /**
+ * OpenRouter API Client
+ * 支持通过 OpenRouter 调用 Gemini 和 ChatGPT 模型
+ *
  * @author supeng
  */
 @Slf4j
 @Service
 public class OpenRouterClient {
+
+    private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8");
+
+    /**
+     * 国内代理:<a href="https://openrouter.piaoquantv.com/api/v1/chat/completions">...</a>
+     * 官方链接:<a href="https://openrouter.ai/api/v1/chat/completions">...</a>
+     */
+    @Value("${openrouter.api.url:https://openrouter.piaoquantv.com/api/v1/chat/completions}")
+    private String url;
+
+    @Value("${openrouter.api.key:}")
+    private String apiKey;
+
+    @Value("${openrouter.timeout:60}")
+    private int timeout;
+
+    private OkHttpClient httpClient;
+
+    @PostConstruct
+    public void init() {
+        this.httpClient = new OkHttpClient.Builder()
+                .connectTimeout(timeout, TimeUnit.SECONDS)
+                .readTimeout(timeout, TimeUnit.SECONDS)
+                .writeTimeout(timeout, TimeUnit.SECONDS)
+                .build();
+    }
+
+    /**
+     * 调用 OpenRouter API
+     *
+     * @param model       模型名称,例如: "google/gemini-pro", "openai/gpt-4", "openai/gpt-3.5-turbo"
+     * @param messages    消息列表,每个消息包含 role 和 content
+     * @param temperature 温度参数,控制随机性 (0.0-2.0)
+     * @param maxTokens   最大生成 token 数
+     * @return API 响应内容
+     * @throws IOException 网络请求异常
+     */
+    public String chat(String model, List<Map<String, String>> messages, Double temperature, Integer maxTokens) throws IOException {
+        JSONObject requestBody = buildRequestBody(model, messages, temperature, maxTokens);
+
+        Request request = new Request.Builder()
+                .url(url)
+                .addHeader("Authorization", "Bearer " + apiKey)
+                .addHeader("Content-Type", "application/json")
+                .addHeader("HTTP-Referer", "https://github.com/tzld")
+                .addHeader("X-Title", "Supply Demand Engine")
+                .post(RequestBody.create(JSON_MEDIA_TYPE, requestBody.toJSONString()))
+                .build();
+
+        log.info("Calling OpenRouter API with model: {}", model);
+
+        try (Response response = httpClient.newCall(request).execute()) {
+            if (!response.isSuccessful()) {
+                String errorBody = response.body() != null ? response.body().string() : "No error body";
+                log.error("OpenRouter API call failed: {} - {}", response.code(), errorBody);
+                throw new IOException("OpenRouter API call failed: " + response.code() + " - " + errorBody);
+            }
+
+            String responseBody = response.body().string();
+            log.debug("OpenRouter API response: {}", responseBody);
+
+            return parseResponse(responseBody);
+        }
+    }
+
+    /**
+     * 简化的调用方法,使用默认参数
+     *
+     * @param model    模型名称
+     * @param messages 消息列表
+     * @return API 响应内容
+     * @throws IOException 网络请求异常
+     */
+    public String chat(String model, List<Map<String, String>> messages) throws IOException {
+        return chat(model, messages, 0.7, 2000);
+    }
+
+    /**
+     * 调用 Gemini 模型
+     *
+     * @param messages 消息列表
+     * @return API 响应内容
+     * @throws IOException 网络请求异常
+     */
+    public String chatWithGemini(List<Map<String, String>> messages) throws IOException {
+        return chat("google/gemini-pro", messages);
+    }
+
+    /**
+     * 调用 ChatGPT 模型
+     *
+     * @param messages 消息列表
+     * @param useGpt4  是否使用 GPT-4,false 则使用 GPT-3.5-turbo
+     * @return API 响应内容
+     * @throws IOException 网络请求异常
+     */
+    public String chatWithGPT(List<Map<String, String>> messages, boolean useGpt4) throws IOException {
+        String model = useGpt4 ? "openai/gpt-4" : "openai/gpt-3.5-turbo";
+        return chat(model, messages);
+    }
+
+    /**
+     * 构建请求体
+     */
+    private JSONObject buildRequestBody(String model, List<Map<String, String>> messages, Double temperature, Integer maxTokens) {
+        JSONObject requestBody = new JSONObject();
+        requestBody.put("model", model);
+
+        JSONArray messagesArray = new JSONArray();
+        for (Map<String, String> message : messages) {
+            JSONObject messageNode = new JSONObject();
+            messageNode.put("role", message.get("role"));
+            messageNode.put("content", message.get("content"));
+            messagesArray.add(messageNode);
+        }
+        requestBody.put("messages", messagesArray);
+
+        if (temperature != null) {
+            requestBody.put("temperature", temperature);
+        }
+        if (maxTokens != null) {
+            requestBody.put("max_tokens", maxTokens);
+        }
+
+        return requestBody;
+    }
+
+    /**
+     * 解析响应,提取生成的内容
+     */
+    private String parseResponse(String responseBody) throws IOException {
+        try {
+            JSONObject root = JSON.parseObject(responseBody);
+            JSONArray choices = root.getJSONArray("choices");
+            if (choices != null && !choices.isEmpty()) {
+                JSONObject firstChoice = choices.getJSONObject(0);
+                JSONObject message = firstChoice.getJSONObject("message");
+                if (message != null) {
+                    String content = message.getString("content");
+                    if (content != null) {
+                        return content;
+                    }
+                }
+            }
+            log.warn("Unable to parse content from response, returning full response");
+            return responseBody;
+        } catch (Exception e) {
+            log.error("Error parsing response: {}", e.getMessage());
+            return responseBody;
+        }
+    }
 }

+ 6 - 43
supply-demand-engine-core/src/main/java/com/tzld/piaoquan/sde/service/impl/TaskServiceImpl.java

@@ -15,6 +15,8 @@ import com.tzld.piaoquan.sde.model.request.TaskGetParam;
 import com.tzld.piaoquan.sde.model.request.TaskListParam;
 import com.tzld.piaoquan.sde.model.vo.SdTaskVO;
 import com.tzld.piaoquan.sde.service.TaskService;
+import com.tzld.piaoquan.sde.service.strategy.DeconstructStrategy;
+import com.tzld.piaoquan.sde.service.strategy.DemandExtractionStrategy;
 import com.tzld.piaoquan.sde.util.DateUtil;
 import com.tzld.piaoquan.sde.util.IdGeneratorUtil;
 import lombok.extern.slf4j.Slf4j;
@@ -36,15 +38,12 @@ import java.util.stream.Collectors;
 public class TaskServiceImpl implements TaskService {
 
     private static final String PREFIX_TASK_NAME = "供给需求任务";
-
     private static final String TASK_NAME_DATETIME_PATTEN = "yyyyMMddHHmmss";
 
     @Autowired
     private SdTaskMapper sdTaskMapper;
-
     @Autowired
     private SdSubTaskMapper sdSubTaskMapper;
-
     @Autowired
     private SdStrategyMapper sdStrategyMapper;
     @Autowired
@@ -55,6 +54,8 @@ public class TaskServiceImpl implements TaskService {
     private OpenRouterClient openRouterServiceClient;
     @Autowired
     private OpenRouterClient openRouterClient;
+    @Autowired
+    private Map<Integer, DeconstructStrategy> strategyMap;
 
     @Override
     public void create(CommonRequest<TaskCreateParam> request) {
@@ -280,46 +281,8 @@ public class TaskServiceImpl implements TaskService {
                     log.info("taskExecuteHandler strategy is null");
                     return;
                 }
-                SdPromptTemplate promptTemplate = sdPromptTemplateMapper.selectById(strategy.getPromptTemplateId());
-                if (Objects.isNull(promptTemplate)) {
-                    log.info("taskExecuteHandler promptTemplate is null");
-                    return;
-                }
-
-                LambdaQueryWrapper<SdSubTask> subTaskWrapper = Wrappers.lambdaQuery(SdSubTask.class)
-                        .eq(SdSubTask::getIsDeleted, IsDeleteEnum.NORMAL.getValue())
-                        .eq(SdSubTask::getTaskId, sdTask.getId())
-                        .eq(SdSubTask::getTaskStatus, SubTaskStatusEnum.SUCCESS.getValue())
-                        .eq(SdSubTask::getSubTaskType, sdTask.getTaskType())
-                        .select(SdSubTask::getId);
-                List<SdSubTask> subTasks = sdSubTaskMapper.selectList(subTaskWrapper);
-                if (Objects.isNull(subTasks)) {
-                    log.info("taskExecuteHandler subTasks is null");
-                    return;
-                }
-                List<Long> subTaskIds = subTasks.stream().map(SdSubTask::getId).collect(Collectors.toList());
-                //TODO
-                String itemType = "";
-                LambdaQueryWrapper<SdSubTaskResultItem> resultItemWrapper = Wrappers.lambdaQuery(SdSubTaskResultItem.class)
-                        .eq(SdSubTaskResultItem::getIsDeleted, IsDeleteEnum.NORMAL.getValue())
-                        .in(SdSubTaskResultItem::getSubTaskId, subTaskIds)
-                        .eq(SdSubTaskResultItem::getItemType, itemType)
-                        .select(SdSubTaskResultItem::getItemContent);
-                List<SdSubTaskResultItem> resultItems = sdSubTaskResultItemMapper.selectList(resultItemWrapper);
-                if (Objects.isNull(resultItems) || resultItems.isEmpty()) {
-                    log.info("taskExecuteHandler resultItems is null");
-                    return;
-                }
-                //去掉null,换行拼接
-                String input = resultItems.stream()
-                        .map(SdSubTaskResultItem::getItemContent)
-                        .filter(Objects::nonNull)
-                        .collect(Collectors.joining("\n"));
-
-                String prompt = promptTemplate.getPromptContent();
-                String finalPrompt = prompt.replaceAll("input", input);
-
-
+                DemandExtractionStrategy demandExtractionStrategy = strategyMap.get(sdTask.getTaskType());
+                demandExtractionStrategy.execute(sdTask, strategy);
 
             } catch (Exception e) {
                 log.error("taskExecuteHandler error", e);

+ 31 - 0
supply-demand-engine-core/src/main/java/com/tzld/piaoquan/sde/service/strategy/ClusterStrategy.java

@@ -0,0 +1,31 @@
+package com.tzld.piaoquan.sde.service.strategy;
+
+import com.tzld.piaoquan.sde.common.enums.TaskTypeEnum;
+import com.tzld.piaoquan.sde.model.entity.SdStrategy;
+import com.tzld.piaoquan.sde.model.entity.SdTask;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+
+/**
+ * 聚类任务策略
+ *
+ * @author supeng
+ */
+@Slf4j
+@Component
+public class ClusterStrategy implements DemandExtractionStrategy {
+
+    @Override
+    public Integer getType() {
+        return TaskTypeEnum.CLUSTER.getValue();
+    }
+
+    @Override
+    public void execute(SdTask task, SdStrategy strategy) {
+        log.info("执行聚类任务: taskId={}, taskNo={}", task.getId(), task.getTaskNo());
+
+        // TODO: 实现聚类任务逻辑
+
+        log.info("聚类任务执行完成: taskId={}", task.getId());
+    }
+}

+ 103 - 0
supply-demand-engine-core/src/main/java/com/tzld/piaoquan/sde/service/strategy/DeconstructStrategy.java

@@ -0,0 +1,103 @@
+package com.tzld.piaoquan.sde.service.strategy;
+
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.toolkit.Wrappers;
+import com.tzld.piaoquan.sde.common.enums.IsDeleteEnum;
+import com.tzld.piaoquan.sde.common.enums.SubTaskStatusEnum;
+import com.tzld.piaoquan.sde.common.enums.TaskTypeEnum;
+import com.tzld.piaoquan.sde.integration.OpenRouterClient;
+import com.tzld.piaoquan.sde.mapper.SdPromptTemplateMapper;
+import com.tzld.piaoquan.sde.mapper.SdSubTaskMapper;
+import com.tzld.piaoquan.sde.mapper.SdSubTaskResultItemMapper;
+import com.tzld.piaoquan.sde.model.entity.*;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Component;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+/**
+ * 解构任务策略
+ *
+ * @author supeng
+ */
+@Slf4j
+@Component
+public class DeconstructStrategy implements DemandExtractionStrategy {
+
+    @Autowired
+    private SdSubTaskMapper sdSubTaskMapper;
+
+    @Autowired
+    private SdSubTaskResultItemMapper sdSubTaskResultItemMapper;
+
+    @Autowired
+    private SdPromptTemplateMapper sdPromptTemplateMapper;
+
+    @Autowired
+    private OpenRouterClient openRouterClient;
+
+    @Override
+    public Integer getType() {
+        return TaskTypeEnum.DECONSTRUCT.getValue();
+    }
+
+    @Override
+    public void execute(SdTask task, SdStrategy strategy) {
+        log.info("执行解构任务: taskId={}, taskNo={}", task.getId(), task.getTaskNo());
+        // 1. 获取Prompt模板
+        SdPromptTemplate promptTemplate = sdPromptTemplateMapper.selectById(strategy.getPromptTemplateId());
+        if (Objects.isNull(promptTemplate)) {
+            log.error("Prompt模板不存在: promptTemplateId={}", strategy.getPromptTemplateId());
+            return;
+        }
+
+        // 2. 查询成功的子任务
+        LambdaQueryWrapper<SdSubTask> subTaskWrapper = Wrappers.lambdaQuery(SdSubTask.class)
+                .eq(SdSubTask::getIsDeleted, IsDeleteEnum.NORMAL.getValue())
+                .eq(SdSubTask::getTaskId, task.getId())
+                .eq(SdSubTask::getTaskStatus, SubTaskStatusEnum.SUCCESS.getValue())
+                .eq(SdSubTask::getSubTaskType, task.getTaskType())
+                .select(SdSubTask::getId);
+        List<SdSubTask> subTasks = sdSubTaskMapper.selectList(subTaskWrapper);
+        if (Objects.isNull(subTasks) || subTasks.isEmpty()) {
+            log.warn("没有成功的子任务: taskId={}", task.getId());
+            return;
+        }
+
+        // 3. 查询子任务结果项
+        List<Long> subTaskIds = subTasks.stream().map(SdSubTask::getId).collect(Collectors.toList());
+        // TODO: 根据实际业务确定itemType
+        String itemType = "";
+        LambdaQueryWrapper<SdSubTaskResultItem> resultItemWrapper = Wrappers.lambdaQuery(SdSubTaskResultItem.class)
+                .eq(SdSubTaskResultItem::getIsDeleted, IsDeleteEnum.NORMAL.getValue())
+                .in(SdSubTaskResultItem::getSubTaskId, subTaskIds)
+                .eq(SdSubTaskResultItem::getItemType, itemType)
+                .select(SdSubTaskResultItem::getItemContent);
+        List<SdSubTaskResultItem> resultItems = sdSubTaskResultItemMapper.selectList(resultItemWrapper);
+        if (Objects.isNull(resultItems) || resultItems.isEmpty()) {
+            log.warn("没有子任务结果项: taskId=", task.getId());
+            return;
+        }
+
+        // 4. 构建输入内容
+        String input = resultItems.stream()
+                .map(SdSubTaskResultItem::getItemContent)
+                .filter(Objects::nonNull)
+                .collect(Collectors.joining("\n"));
+
+        // 5. 构建最终Prompt
+        String prompt = promptTemplate.getPromptContent();
+        String finalPrompt = prompt.replace("{{input}}", input);
+
+        // 6. 调用AI服务 (TODO: 实现具体调用逻辑)
+        log.info("准备调用AI服务: taskId={}, promptLength={}", task.getId(), finalPrompt.length());
+        // String result = openRouterClient.chat(finalPrompt, strategy.getConfig());
+
+        // 7. 保存结果 (TODO: 实现结果保存逻辑)
+
+        log.info("解构任务执行完成: taskId={}", task.getId());
+    }
+}

+ 23 - 0
supply-demand-engine-core/src/main/java/com/tzld/piaoquan/sde/service/strategy/DemandExtractionStrategy.java

@@ -0,0 +1,23 @@
+package com.tzld.piaoquan.sde.service.strategy;
+
+import com.tzld.piaoquan.sde.common.enums.TaskTypeEnum;
+import com.tzld.piaoquan.sde.model.entity.SdStrategy;
+import com.tzld.piaoquan.sde.model.entity.SdTask;
+
+/**
+ * 需求提取策略接口
+ *
+ * @author supeng
+ */
+public interface DemandExtractionStrategy {
+
+    /**
+     * 获取支持的任务类型
+     */
+    Integer getType();
+
+    /**
+     * 执行需求提取
+     */
+    void execute(SdTask task, SdStrategy strategy);
+}