Browse Source

支持向量化 API 模式及远程调用

wangyunpeng 2 days ago
parent
commit
ebbb9259b9

+ 142 - 0
core/src/main/java/com/tzld/videoVector/api/EmbeddingApiService.java

@@ -0,0 +1,142 @@
+package com.tzld.videoVector.api;
+
+import com.alibaba.fastjson.JSONArray;
+import com.alibaba.fastjson.JSONObject;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.*;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Service;
+
+import javax.annotation.PostConstruct;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * 向量化 API 服务
+ * 调用外部 Embedding 服务获取文本向量
+ */
+@Slf4j
+@Service
+public class EmbeddingApiService {
+
+    private OkHttpClient client;
+
+    @Value("${embedding.api.url:http://192.168.100.31:8000/v1/embeddings}")
+    private String apiUrl;
+
+    @Value("${embedding.api.model:/models/Qwen3-Embedding-0.6B}")
+    private String model;
+
+    @Value("${embedding.api.timeout:60}")
+    private int timeout;
+
+    private static final MediaType JSON_MEDIA_TYPE = MediaType.get("application/json; charset=utf-8");
+
+    @PostConstruct
+    public void init() {
+        client = new OkHttpClient.Builder()
+                .connectTimeout(timeout, TimeUnit.SECONDS)
+                .readTimeout(timeout, TimeUnit.SECONDS)
+                .writeTimeout(timeout, TimeUnit.SECONDS)
+                .build();
+        log.info("向量化 API 服务初始化完成,API地址: {}, 模型: {}", apiUrl, model);
+    }
+
+    /**
+     * 单条文本向量化
+     *
+     * @param text 输入文本
+     * @return 向量列表,失败返回空列表
+     */
+    public List<Float> embed(String text) {
+        if (text == null || text.trim().isEmpty()) {
+            log.warn("输入文本为空,返回空向量");
+            return Collections.emptyList();
+        }
+
+        try {
+            JSONObject requestBody = new JSONObject();
+            requestBody.put("input", text);
+            requestBody.put("model", model);
+
+            RequestBody body = RequestBody.create(JSON_MEDIA_TYPE, requestBody.toJSONString());
+            Request request = new Request.Builder()
+                    .url(apiUrl)
+                    .post(body)
+                    .addHeader("Content-Type", "application/json")
+                    .build();
+
+            try (Response response = client.newCall(request).execute()) {
+                if (!response.isSuccessful()) {
+                    String errorBody = response.body() != null ? response.body().string() : "无";
+                    log.error("向量化 API 请求失败,HTTP状态码: {}, 错误信息: {}", response.code(), errorBody);
+                    return Collections.emptyList();
+                }
+
+                String responseBody = response.body().string();
+                return parseEmbedding(responseBody);
+            }
+        } catch (IOException e) {
+            log.error("调用向量化 API 失败: {}", e.getMessage(), e);
+            return Collections.emptyList();
+        }
+    }
+
+    /**
+     * 批量文本向量化
+     *
+     * @param texts 输入文本列表
+     * @return 向量列表的列表,失败返回空列表
+     */
+    public List<List<Float>> batchEmbed(List<String> texts) {
+        if (texts == null || texts.isEmpty()) {
+            return Collections.emptyList();
+        }
+
+        List<List<Float>> results = new ArrayList<>(texts.size());
+        for (String text : texts) {
+            results.add(embed(text));
+        }
+        return results;
+    }
+
+    /**
+     * 解析 API 响应,提取向量数据
+     *
+     * @param responseBody API 响应体
+     * @return 向量列表
+     */
+    private List<Float> parseEmbedding(String responseBody) {
+        try {
+            JSONObject jsonResponse = JSONObject.parseObject(responseBody);
+            JSONArray dataArray = jsonResponse.getJSONArray("data");
+
+            if (dataArray == null || dataArray.isEmpty()) {
+                log.error("API 响应中未找到 data 数组");
+                return Collections.emptyList();
+            }
+
+            JSONObject firstData = dataArray.getJSONObject(0);
+            JSONArray embeddingArray = firstData.getJSONArray("embedding");
+
+            if (embeddingArray == null || embeddingArray.isEmpty()) {
+                log.error("API 响应中未找到 embedding 数组");
+                return Collections.emptyList();
+            }
+
+            List<Float> vector = new ArrayList<>(embeddingArray.size());
+            for (int i = 0; i < embeddingArray.size(); i++) {
+                vector.add(embeddingArray.getFloat(i));
+            }
+
+            log.debug("成功解析向量,维度: {}", vector.size());
+            return vector;
+        } catch (Exception e) {
+            log.error("解析向量化 API 响应失败: {}", e.getMessage(), e);
+            return Collections.emptyList();
+        }
+    }
+}

+ 54 - 20
core/src/main/java/com/tzld/videoVector/service/impl/EmbeddingServiceImpl.java

@@ -6,6 +6,7 @@ import org.springframework.beans.factory.annotation.Value;
 import org.springframework.stereotype.Service;
 
 import javax.annotation.PostConstruct;
+import javax.annotation.Resource;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -31,13 +32,22 @@ public class EmbeddingServiceImpl implements EmbeddingService {
     @Value("${embedding.ngram:2}")
     private int ngram;
 
+    /**
+     * 向量化模式:local(本地) 或 api(远程API)
+     */
+    @Value("${embedding.mode:local}")
+    private String embeddingMode;
+
+    @Resource
+    private com.tzld.videoVector.api.EmbeddingApiService embeddingApiService;
+
     private int[] hashSeeds;
 
     @PostConstruct
     public void init() {
         // 初始化哈希种子
         hashSeeds = new int[]{31, 37, 41, 43, 47, 53, 59, 61};
-        log.info("本地向量化服务初始化完成,向量维度: {}, N-gram: {}", dimension, ngram);
+        log.info("向量化服务初始化完成,模式: {}, 向量维度: {}, N-gram: {}", embeddingMode, dimension, ngram);
     }
 
     @Override
@@ -47,6 +57,48 @@ public class EmbeddingServiceImpl implements EmbeddingService {
             return Collections.emptyList();
         }
 
+        // 根据配置选择向量化方式
+        if ("api".equalsIgnoreCase(embeddingMode)) {
+            return embeddingApiService.embed(text);
+        } else {
+            return localEmbed(text);
+        }
+    }
+
+    @Override
+    public List<List<Float>> batchEmbed(List<String> texts) {
+        if (texts == null || texts.isEmpty()) {
+            return Collections.emptyList();
+        }
+
+        // 根据配置选择向量化方式
+        if ("api".equalsIgnoreCase(embeddingMode)) {
+            return embeddingApiService.batchEmbed(texts);
+        } else {
+            List<List<Float>> results = new ArrayList<>(texts.size());
+            for (String text : texts) {
+                results.add(localEmbed(text));
+            }
+            return results;
+        }
+    }
+
+    @Override
+    public int getDimension() {
+        // 如果使用 API 模式,尝试从 API 获取维度
+        if ("api".equalsIgnoreCase(embeddingMode)) {
+            List<Float> sample = embeddingApiService.embed("测试");
+            if (!sample.isEmpty()) {
+                return sample.size();
+            }
+        }
+        return dimension;
+    }
+
+    /**
+     * 本地向量化实现
+     */
+    private List<Float> localEmbed(String text) {
         try {
             // 1. 文本预处理
             String processedText = preprocess(text);
@@ -73,27 +125,9 @@ public class EmbeddingServiceImpl implements EmbeddingService {
             
             return toFloatList(vector);
         } catch (Exception e) {
-            log.error("文本向量化失败: {}", e.getMessage(), e);
-            return Collections.emptyList();
-        }
-    }
-
-    @Override
-    public List<List<Float>> batchEmbed(List<String> texts) {
-        if (texts == null || texts.isEmpty()) {
+            log.error("本地文本向量化失败: {}", e.getMessage(), e);
             return Collections.emptyList();
         }
-
-        List<List<Float>> results = new ArrayList<>(texts.size());
-        for (String text : texts) {
-            results.add(embed(text));
-        }
-        return results;
-    }
-
-    @Override
-    public int getDimension() {
-        return dimension;
     }
 
     /**

+ 6 - 1
server/src/main/resources/application.yml

@@ -101,4 +101,9 @@ cdn:
 
 embedding:
   dimension: 1024    # 向量维度
-  ngram: 2          # N-gram 大小
+  ngram: 2          # N-gram 大小
+  mode: api         # 向量化模式:local(本地N-gram哈希) 或 api(远程API)
+  api:
+    url: http://192.168.100.31:8000/v1/embeddings  # 向量化API地址
+    model: /models/Qwen3-Embedding-0.6B            # 模型路径
+    timeout: 60                                    # 超时时间(秒)