Jelajahi Sumber

Merge pull request #674 from Yan-Zero/main

fix: Gemini 函数调用的文本转义,以及其他文件类型的 Base64 支持
Calcium-Ion 1 tahun lalu
induk
melakukan
a1b864bc5e
2 mengubah file dengan 107 tambahan dan 40 penghapusan
  1. 80 39
      relay/channel/gemini/relay-gemini.go
  2. 27 1
      service/image.go

+ 80 - 39
relay/channel/gemini/relay-gemini.go

@@ -12,6 +12,7 @@ import (
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
+	"unicode/utf8"
 
 	"github.com/gin-gonic/gin"
 )
@@ -203,13 +204,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 						},
 					})
 				} else {
-					_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
+					format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
 					if err != nil {
 						return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
 					}
 					parts = append(parts, GeminiPart{
 						InlineData: &GeminiInlineData{
-							MimeType: "image/" + format,
+							MimeType: format,
 							Data:     base64String,
 						},
 					})
@@ -279,57 +280,97 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
 	return v
 }
 
-// func (g *GeminiChatResponse) GetResponseText() string {
-// 	if g == nil {
-// 		return ""
-// 	}
-// 	if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
-// 		return g.Candidates[0].Content.Parts[0].Text
-// 	}
-// 	return ""
-// }
+func unescapeString(s string) (string, error) {
+	var result []rune
+	escaped := false
+	i := 0
+
+	for i < len(s) {
+		r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
+		if r == utf8.RuneError {
+			return "", fmt.Errorf("invalid UTF-8 encoding")
+		}
+
+		if escaped {
+			// 如果是转义符后的字符,检查其类型
+			switch r {
+			case '"':
+				result = append(result, '"')
+			case '\\':
+				result = append(result, '\\')
+			case '/':
+				result = append(result, '/')
+			case 'b':
+				result = append(result, '\b')
+			case 'f':
+				result = append(result, '\f')
+			case 'n':
+				result = append(result, '\n')
+			case 'r':
+				result = append(result, '\r')
+			case 't':
+				result = append(result, '\t')
+			case '\'':
+				result = append(result, '\'')
+			default:
+				// 如果遇到一个非法的转义字符,直接按原样输出
+				result = append(result, '\\', r)
+			}
+			escaped = false
+		} else {
+			if r == '\\' {
+				escaped = true // 记录反斜杠作为转义符
+			} else {
+				result = append(result, r)
+			}
+		}
+		i += size // 移动到下一个字符
+	}
+
+	return string(result), nil
+}
+func unescapeMapOrSlice(data interface{}) interface{} {
+	switch v := data.(type) {
+	case map[string]interface{}:
+		for k, val := range v {
+			v[k] = unescapeMapOrSlice(val)
+		}
+	case []interface{}:
+		for i, val := range v {
+			v[i] = unescapeMapOrSlice(val)
+		}
+	case string:
+		if unescaped, err := unescapeString(v); err != nil {
+			return v
+		} else {
+			return unescaped
+		}
+	}
+	return data
+}
 
 func getToolCall(item *GeminiPart) *dto.ToolCall {
-	argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
+	var argsBytes []byte
+	var err error
+	if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
+		argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
+	} else {
+		argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
+	}
+
 	if err != nil {
-		//common.SysError("getToolCall failed: " + err.Error())
 		return nil
 	}
 	return &dto.ToolCall{
 		ID:   fmt.Sprintf("call_%s", common.GetUUID()),
 		Type: "function",
 		Function: dto.FunctionCall{
-			// 不好评价,得去转义一下反斜杠,Gemini 的特性好像是,Google 返回的时候本身就会转义“\”
-			Arguments: strings.ReplaceAll(string(argsBytes), "\\\\", "\\"),
+			Arguments: string(argsBytes),
 			Name:      item.FunctionCall.FunctionName,
 		},
 	}
 }
 
-// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
-// 	var toolCalls []dto.ToolCall
-
-// 	item := candidate.Content.Parts[index]
-// 	if item.FunctionCall == nil {
-// 		return toolCalls
-// 	}
-// 	argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
-// 	if err != nil {
-// 		//common.SysError("getToolCalls failed: " + err.Error())
-// 		return toolCalls
-// 	}
-// 	toolCall := dto.ToolCall{
-// 		ID:   fmt.Sprintf("call_%s", common.GetUUID()),
-// 		Type: "function",
-// 		Function: dto.FunctionCall{
-// 			Arguments: string(argsBytes),
-// 			Name:      item.FunctionCall.FunctionName,
-// 		},
-// 	}
-// 	toolCalls = append(toolCalls, toolCall)
-// 	return toolCalls
-// }
-
 func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
 	fullTextResponse := dto.OpenAITextResponse{
 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),

+ 27 - 1
service/image.go

@@ -5,11 +5,12 @@ import (
 	"encoding/base64"
 	"errors"
 	"fmt"
-	"golang.org/x/image/webp"
 	"image"
 	"io"
 	"one-api/common"
 	"strings"
+
+	"golang.org/x/image/webp"
 )
 
 func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
@@ -31,6 +32,31 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
 	return config, format, base64String, err
 }
 
+func DecodeBase64FileData(base64String string) (string, string, error) {
+	var mimeType string
+	var idx int
+	idx = strings.Index(base64String, ",")
+	if idx == -1 {
+		_, file_type, base64, err := DecodeBase64ImageData(base64String)
+		return "image/" + file_type, base64, err
+	}
+	mimeType = base64String[:idx]
+	base64String = base64String[idx+1:]
+	idx = strings.Index(mimeType, ";")
+	if idx == -1 {
+		_, file_type, base64, err := DecodeBase64ImageData(base64String)
+		return "image/" + file_type, base64, err
+	}
+	mimeType = mimeType[:idx]
+	idx = strings.Index(mimeType, ":")
+	if idx == -1 {
+		_, file_type, base64, err := DecodeBase64ImageData(base64String)
+		return "image/" + file_type, base64, err
+	}
+	mimeType = mimeType[idx+1:]
+	return mimeType, base64String, nil
+}
+
 // GetImageFromUrl 获取图片的类型和base64编码的数据
 func GetImageFromUrl(url string) (mimeType string, data string, err error) {
 	resp, err := DoDownloadRequest(url)