Selaa lähdekoodia

feat: support jina embedding

CalciumIon 1 vuosi sitten
vanhempi
commit
0830ef3305
3 muutettua tiedostoa jossa 29 lisäystä ja 2 poistoa
  1. 3 1
      relay/channel/jina/adaptor.go
  2. 25 0
      relay/channel/jina/relay-jina.go
  3. 1 1
      relay/relay-text.go

+ 3 - 1
relay/channel/jina/adaptor.go

@@ -32,7 +32,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if info.RelayMode == constant.RelayModeRerank {
 		return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
 	} else if info.RelayMode == constant.RelayModeEmbeddings {
-		return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
+		return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
 	}
 	return "", errors.New("invalid relay mode")
 }
@@ -58,6 +58,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
 		err, usage = jinaRerankHandler(c, resp)
+	} else if info.RelayMode == constant.RelayModeEmbeddings {
+		err, usage = jinaEmbeddingHandler(c, resp)
 	}
 	return
 }

+ 25 - 0
relay/channel/jina/relay-jina.go

@@ -33,3 +33,28 @@ func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
 	_, err = c.Writer.Write(jsonResponse)
 	return nil, &jinaResp.Usage
 }
+
+func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	var jinaResp dto.OpenAIEmbeddingResponse
+	err = json.Unmarshal(responseBody, &jinaResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	jsonResponse, err := json.Marshal(jinaResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &jinaResp.Usage
+}

+ 1 - 1
relay/relay-text.go

@@ -52,7 +52,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
 		}
 	case relayconstant.RelayModeEmbeddings:
 	case relayconstant.RelayModeModerations:
-		if textRequest.Input == "" {
+		if textRequest.Input == "" || textRequest.Input == nil {
 			return nil, errors.New("field input is required")
 		}
 	case relayconstant.RelayModeEdits: