|
|
@@ -54,6 +54,25 @@ type BaiduChatStreamResponse struct {
|
|
|
IsEnd bool `json:"is_end"`
|
|
|
}
|
|
|
|
|
|
+type BaiduEmbeddingRequest struct {
|
|
|
+ Input []string `json:"input"`
|
|
|
+}
|
|
|
+
|
|
|
+type BaiduEmbeddingData struct {
|
|
|
+ Object string `json:"object"`
|
|
|
+ Embedding []float64 `json:"embedding"`
|
|
|
+ Index int `json:"index"`
|
|
|
+}
|
|
|
+
|
|
|
+type BaiduEmbeddingResponse struct {
|
|
|
+ Id string `json:"id"`
|
|
|
+ Object string `json:"object"`
|
|
|
+ Created int64 `json:"created"`
|
|
|
+ Data []BaiduEmbeddingData `json:"data"`
|
|
|
+ Usage Usage `json:"usage"`
|
|
|
+ BaiduError
|
|
|
+}
|
|
|
+
|
|
|
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
|
|
messages := make([]BaiduMessage, 0, len(request.Messages))
|
|
|
for _, message := range request.Messages {
|
|
|
@@ -112,6 +131,36 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
|
|
|
return &response
|
|
|
}
|
|
|
|
|
|
+func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
|
|
+ baiduEmbeddingRequest := BaiduEmbeddingRequest{
|
|
|
+ Input: nil,
|
|
|
+ }
|
|
|
+ switch request.Input.(type) {
|
|
|
+ case string:
|
|
|
+ baiduEmbeddingRequest.Input = []string{request.Input.(string)}
|
|
|
+ case []string:
|
|
|
+ baiduEmbeddingRequest.Input = request.Input.([]string)
|
|
|
+ }
|
|
|
+ return &baiduEmbeddingRequest
|
|
|
+}
|
|
|
+
|
|
|
+func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
|
|
+ openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
|
|
+ Object: "list",
|
|
|
+ Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
|
|
+ Model: "baidu-embedding",
|
|
|
+ Usage: response.Usage,
|
|
|
+ }
|
|
|
+ for _, item := range response.Data {
|
|
|
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
|
|
+ Object: item.Object,
|
|
|
+ Index: item.Index,
|
|
|
+ Embedding: item.Embedding,
|
|
|
+ })
|
|
|
+ }
|
|
|
+ return &openAIEmbeddingResponse
|
|
|
+}
|
|
|
+
|
|
|
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
|
var usage Usage
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
@@ -212,3 +261,39 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
|
return nil, &fullTextResponse.Usage
|
|
|
}
|
|
|
+
|
|
|
+func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
|
|
+ var baiduResponse BaiduEmbeddingResponse
|
|
|
+ responseBody, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ err = resp.Body.Close()
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ err = json.Unmarshal(responseBody, &baiduResponse)
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ if baiduResponse.ErrorMsg != "" {
|
|
|
+ return &OpenAIErrorWithStatusCode{
|
|
|
+ OpenAIError: OpenAIError{
|
|
|
+ Message: baiduResponse.ErrorMsg,
|
|
|
+ Type: "baidu_error",
|
|
|
+ Param: "",
|
|
|
+ Code: baiduResponse.ErrorCode,
|
|
|
+ },
|
|
|
+ StatusCode: resp.StatusCode,
|
|
|
+ }, nil
|
|
|
+ }
|
|
|
+ fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
|
|
+ jsonResponse, err := json.Marshal(fullTextResponse)
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ c.Writer.Header().Set("Content-Type", "application/json")
|
|
|
+ c.Writer.WriteHeader(resp.StatusCode)
|
|
|
+ _, err = c.Writer.Write(jsonResponse)
|
|
|
+ return nil, &fullTextResponse.Usage
|
|
|
+}
|