Bladeren bron

support gemini-pro-vision

CaIon 2 jaren geleden
bovenliggende
commit
14592f9758
6 gewijzigde bestanden met toevoegingen van 141 en 15 verwijderingen
  1. 44 12
      common/image.go
  2. 1 0
      common/model-ratio.go
  3. 9 0
      controller/model.go
  4. 29 0
      controller/relay-gemini.go
  5. 4 3
      controller/relay-utils.go
  6. 54 0
      controller/relay.go

+ 44 - 12
common/image.go

@@ -12,7 +12,7 @@ import (
 	"strings"
 )
 
-func DecodeBase64ImageData(base64String string) (image.Config, error) {
+func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
 	// 去除base64数据的URL前缀(如果有)
 	if idx := strings.Index(base64String, ","); idx != -1 {
 		base64String = base64String[idx+1:]
@@ -22,20 +22,51 @@ func DecodeBase64ImageData(base64String string) (image.Config, error) {
 	decodedData, err := base64.StdEncoding.DecodeString(base64String)
 	if err != nil {
 		fmt.Println("Error: Failed to decode base64 string")
-		return image.Config{}, err
+		return image.Config{}, "", err
 	}
 
 	// 创建一个bytes.Buffer用于存储解码后的数据
 	reader := bytes.NewReader(decodedData)
-	config, err := getImageConfig(reader)
-	return config, err
+	config, format, err := getImageConfig(reader)
+	return config, format, err
 }
 
-func DecodeUrlImageData(imageUrl string) (image.Config, error) {
+func IsImageUrl(url string) (bool, error) {
+	resp, err := http.Head(url)
+	if err != nil {
+		return false, err
+	}
+	if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
+		return false, nil
+	}
+	return true, nil
+}
+
+func GetImageFromUrl(url string) (mimeType string, data string, err error) {
+	isImage, err := IsImageUrl(url)
+	if !isImage {
+		return
+	}
+	resp, err := http.Get(url)
+	if err != nil {
+		return
+	}
+	defer resp.Body.Close()
+	buffer := bytes.NewBuffer(nil)
+	_, err = buffer.ReadFrom(resp.Body)
+	if err != nil {
+		return
+	}
+	mimeType = resp.Header.Get("Content-Type")
+	data = base64.StdEncoding.EncodeToString(buffer.Bytes())
+	return
+}
+
+func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
 	response, err := http.Get(imageUrl)
 	if err != nil {
 		SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
-		return image.Config{}, err
+		return image.Config{}, "", err
 	}
 
 	// 限制读取的字节数,防止下载整个图片
@@ -45,14 +76,14 @@ func DecodeUrlImageData(imageUrl string) (image.Config, error) {
 	//	log.Fatal(err)
 	//}
 	//log.Printf("%x", data)
-	config, err := getImageConfig(limitReader)
+	config, format, err := getImageConfig(limitReader)
 	response.Body.Close()
-	return config, err
+	return config, format, err
 }
 
-func getImageConfig(reader io.Reader) (image.Config, error) {
+func getImageConfig(reader io.Reader) (image.Config, string, error) {
 	// 读取图片的头部信息来获取图片尺寸
-	config, _, err := image.DecodeConfig(reader)
+	config, format, err := image.DecodeConfig(reader)
 	if err != nil {
 		err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
 		SysLog(err.Error())
@@ -61,9 +92,10 @@ func getImageConfig(reader io.Reader) (image.Config, error) {
 			err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
 			SysLog(err.Error())
 		}
+		format = "webp"
 	}
 	if err != nil {
-		return image.Config{}, err
+		return image.Config{}, "", err
 	}
-	return config, nil
+	return config, format, nil
 }

+ 1 - 0
common/model-ratio.go

@@ -62,6 +62,7 @@ var ModelRatio = map[string]float64{
 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens
 	"PaLM-2":                    1,
 	"gemini-pro":                1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens
+	"gemini-pro-vision":         1,      // $0.00025 / 1k characters -> $0.001 / 1k tokens
 	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens
 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens

+ 9 - 0
controller/model.go

@@ -432,6 +432,15 @@ func init() {
 			Root:       "gemini-pro",
 			Parent:     nil,
 		},
+		{
+			Id:         "gemini-pro-vision",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "google",
+			Permission: permission,
+			Root:       "gemini-pro-vision",
+			Parent:     nil,
+		},
 		{
 			Id:         "chatglm_turbo",
 			Object:     "model",

+ 29 - 0
controller/relay-gemini.go

@@ -12,6 +12,10 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+const (
+	GeminiVisionMaxImageNum = 16
+)
+
 type GeminiChatRequest struct {
 	Contents         []GeminiChatContent        `json:"contents"`
 	SafetySettings   []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
@@ -97,6 +101,31 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
 				},
 			},
 		}
+		openaiContent := message.ParseContent()
+		var parts []GeminiPart
+		imageNum := 0
+		for _, part := range openaiContent {
+
+			if part.Type == ContentTypeText {
+				parts = append(parts, GeminiPart{
+					Text: part.Text,
+				})
+			} else if part.Type == ContentTypeImageURL {
+				imageNum += 1
+				if imageNum > GeminiVisionMaxImageNum {
+					continue
+				}
+				mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
+				parts = append(parts, GeminiPart{
+					InlineData: &GeminiInlineData{
+						MimeType: mimeType,
+						Data:     data,
+					},
+				})
+			}
+		}
+		content.Parts = parts
+
 		// there's no assistant role in gemini and API shall vomit if Role is not user or model
 		if content.Role == "assistant" {
 			content.Role = "model"

+ 4 - 3
controller/relay-utils.go

@@ -76,12 +76,13 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
 	}
 	var config image.Config
 	var err error
+	var format string
 	if strings.HasPrefix(imageUrl.Url, "http") {
 		common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
-		config, err = common.DecodeUrlImageData(imageUrl.Url)
+		config, format, err = common.DecodeUrlImageData(imageUrl.Url)
 	} else {
 		common.SysLog(fmt.Sprintf("decoding image"))
-		config, err = common.DecodeBase64ImageData(imageUrl.Url)
+		config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
 	}
 	if err != nil {
 		return 0, err
@@ -101,7 +102,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
 
 	shortSide := config.Width
 	otherSide := config.Height
-	log.Printf("width: %d, height: %d", config.Width, config.Height)
+	log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
 	// 缩放倍数
 	scale := 1.0
 	if config.Height < shortSide {

+ 54 - 0
controller/relay.go

@@ -29,6 +29,60 @@ type MessageImageUrl struct {
 	Detail string `json:"detail"`
 }
 
+const (
+	ContentTypeText     = "text"
+	ContentTypeImageURL = "image_url"
+)
+
+func (m Message) ParseContent() []MediaMessage {
+	var contentList []MediaMessage
+	var stringContent string
+	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+		contentList = append(contentList, MediaMessage{
+			Type: ContentTypeText,
+			Text: stringContent,
+		})
+		return contentList
+	}
+	var arrayContent []json.RawMessage
+	if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
+		for _, contentItem := range arrayContent {
+			var contentMap map[string]any
+			if err := json.Unmarshal(contentItem, &contentMap); err != nil {
+				continue
+			}
+			switch contentMap["type"] {
+			case ContentTypeText:
+				if subStr, ok := contentMap["text"].(string); ok {
+					contentList = append(contentList, MediaMessage{
+						Type: ContentTypeText,
+						Text: subStr,
+					})
+				}
+			case ContentTypeImageURL:
+				if subObj, ok := contentMap["image_url"].(map[string]any); ok {
+					detail, ok := subObj["detail"]
+					if ok {
+						subObj["detail"] = detail.(string)
+					} else {
+						subObj["detail"] = "auto"
+					}
+					contentList = append(contentList, MediaMessage{
+						Type: ContentTypeImageURL,
+						ImageUrl: MessageImageUrl{
+							Url:    subObj["url"].(string),
+							Detail: subObj["detail"].(string),
+						},
+					})
+				}
+			}
+		}
+		return contentList
+	}
+
+	return nil
+}
+
 const (
 	RelayModeUnknown = iota
 	RelayModeChatCompletions