Jelajahi Sumber

feat: support gemini output text and inline images. (close #866)

CaIon 10 bulan lalu
induk
melakukan
473e8e0eaf

+ 1 - 1
relay/channel/gemini/adaptor.go

@@ -99,7 +99,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	ai, err := CovertGemini2OpenAI(*request)
+	ai, err := CovertGemini2OpenAI(*request, info)
 	if err != nil {
 		return nil, err
 	}

+ 10 - 9
relay/channel/gemini/dto.go

@@ -71,15 +71,16 @@ type GeminiChatTool struct {
 }
 
 type GeminiChatGenerationConfig struct {
-	Temperature      *float64 `json:"temperature,omitempty"`
-	TopP             float64  `json:"topP,omitempty"`
-	TopK             float64  `json:"topK,omitempty"`
-	MaxOutputTokens  uint     `json:"maxOutputTokens,omitempty"`
-	CandidateCount   int      `json:"candidateCount,omitempty"`
-	StopSequences    []string `json:"stopSequences,omitempty"`
-	ResponseMimeType string   `json:"responseMimeType,omitempty"`
-	ResponseSchema   any      `json:"responseSchema,omitempty"`
-	Seed             int64    `json:"seed,omitempty"`
+	Temperature        *float64 `json:"temperature,omitempty"`
+	TopP               float64  `json:"topP,omitempty"`
+	TopK               float64  `json:"topK,omitempty"`
+	MaxOutputTokens    uint     `json:"maxOutputTokens,omitempty"`
+	CandidateCount     int      `json:"candidateCount,omitempty"`
+	StopSequences      []string `json:"stopSequences,omitempty"`
+	ResponseMimeType   string   `json:"responseMimeType,omitempty"`
+	ResponseSchema     any      `json:"responseSchema,omitempty"`
+	Seed               int64    `json:"seed,omitempty"`
+	ResponseModalities []string `json:"responseModalities,omitempty"`
 }
 
 type GeminiChatCandidate struct {

+ 30 - 7
relay/channel/gemini/relay-gemini.go

@@ -19,7 +19,7 @@ import (
 )
 
 // Setting safety to the lowest possible values since Gemini is already powerless enough
-func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
+func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
 
 	geminiRequest := GeminiChatRequest{
 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
@@ -32,6 +32,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 		},
 	}
 
+	if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
+		geminiRequest.GenerationConfig.ResponseModalities = []string{
+			"TEXT",
+			"IMAGE",
+		}
+	}
+
 	safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
 	for _, category := range SafetySettingList {
 		safetySettings = append(safetySettings, GeminiChatSafetySettings{
@@ -546,9 +553,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 	return &fullTextResponse
 }
 
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
+func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
 	choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
 	isStop := false
+	hasImage := false
 	for _, candidate := range geminiResponse.Candidates {
 		if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
 			isStop = true
@@ -574,7 +582,13 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 			}
 		}
 		for _, part := range candidate.Content.Parts {
-			if part.FunctionCall != nil {
+			if part.InlineData != nil {
+				if strings.HasPrefix(part.InlineData.MimeType, "image") {
+					imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
+					texts = append(texts, imgText)
+					hasImage = true
+				}
+			} else if part.FunctionCall != nil {
 				isTools = true
 				if call := getResponseToolCall(&part); call != nil {
 					call.SetIndex(len(choice.Delta.ToolCalls))
@@ -602,7 +616,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
 	response.Choices = choices
-	return &response, isStop
+	return &response, isStop, hasImage
 }
 
 func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -610,20 +624,23 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 	id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	createAt := common.GetTimestamp()
 	var usage = &dto.Usage{}
+	var imageCount int
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse GeminiChatResponse
-		err := json.Unmarshal([]byte(data), &geminiResponse)
+		err := common.DecodeJsonStr(data, &geminiResponse)
 		if err != nil {
 			common.LogError(c, "error unmarshalling stream response: "+err.Error())
 			return false
 		}
 
-		response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
+		response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
+		if hasImage {
+			imageCount++
+		}
 		response.Id = id
 		response.Created = createAt
 		response.Model = info.UpstreamModelName
-		// responseText += response.Choices[0].Delta.GetContentString()
 		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
 			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
 			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
@@ -641,6 +658,12 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 
 	var response *dto.ChatCompletionsStreamResponse
 
+	if imageCount != 0 {
+		if usage.CompletionTokens == 0 {
+			usage.CompletionTokens = imageCount * 258
+		}
+	}
+
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
 	usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens

+ 1 - 1
relay/channel/vertex/adaptor.go

@@ -143,7 +143,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 		info.UpstreamModelName = claudeReq.Model
 		return vertexClaudeReq, nil
 	} else if a.RequestMode == RequestModeGemini {
-		geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
+		geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
 		if err != nil {
 			return nil, err
 		}

+ 2 - 2
relay/common/relay_info.go

@@ -90,7 +90,7 @@ type RelayInfo struct {
 	RelayFormat          string
 	SendResponseCount    int
 	ThinkingContentInfo
-	ClaudeConvertInfo
+	*ClaudeConvertInfo
 	*RerankerInfo
 }
 
@@ -120,7 +120,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
 	info := GenRelayInfo(c)
 	info.RelayFormat = RelayFormatClaude
 	info.ShouldIncludeUsage = false
-	info.ClaudeConvertInfo = ClaudeConvertInfo{
+	info.ClaudeConvertInfo = &ClaudeConvertInfo{
 		LastMessagesType: LastMessageTypeNone,
 	}
 	return info

+ 16 - 2
setting/model_setting/gemini.go

@@ -6,8 +6,9 @@ import (
 
 // GeminiSettings 定义Gemini模型的配置
 type GeminiSettings struct {
-	SafetySettings  map[string]string `json:"safety_settings"`
-	VersionSettings map[string]string `json:"version_settings"`
+	SafetySettings         map[string]string `json:"safety_settings"`
+	VersionSettings        map[string]string `json:"version_settings"`
+	SupportedImagineModels []string          `json:"supported_imagine_models"`
 }
 
 // 默认配置
@@ -20,6 +21,10 @@ var defaultGeminiSettings = GeminiSettings{
 		"default":        "v1beta",
 		"gemini-1.0-pro": "v1",
 	},
+	SupportedImagineModels: []string{
+		"gemini-2.0-flash-exp-image-generation",
+		"gemini-2.0-flash-exp",
+	},
 }
 
 // 全局实例
@@ -50,3 +55,12 @@ func GetGeminiVersionSetting(key string) string {
 	}
 	return geminiSettings.VersionSettings["default"]
 }
+
+func IsGeminiModelSupportImagine(model string) bool {
+	for _, v := range geminiSettings.SupportedImagineModels {
+		if v == model {
+			return true
+		}
+	}
+	return false
+}

+ 3 - 1
web/src/components/ModelSetting.js

@@ -13,6 +13,7 @@ const ModelSetting = () => {
   let [inputs, setInputs] = useState({
     'gemini.safety_settings': '',
     'gemini.version_settings': '',
+    'gemini.supported_imagine_models': '',
     'claude.model_headers_settings': '',
     'claude.thinking_adapter_enabled': true,
     'claude.default_max_tokens': '',
@@ -34,7 +35,8 @@ const ModelSetting = () => {
           item.key === 'gemini.safety_settings' ||
           item.key === 'gemini.version_settings' ||
           item.key === 'claude.model_headers_settings'||
-          item.key === 'claude.default_max_tokens'
+          item.key === 'claude.default_max_tokens'||
+          item.key === 'gemini.supported_imagine_models'
         ) {
           item.value = JSON.stringify(JSON.parse(item.value), null, 2);
         }

+ 11 - 0
web/src/pages/Setting/Model/SettingGeminiModel.js

@@ -26,6 +26,7 @@ export default function SettingGeminiModel(props) {
   const [inputs, setInputs] = useState({
     'gemini.safety_settings': '',
     'gemini.version_settings': '',
+    'gemini.supported_imagine_models': [],
   });
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
@@ -125,6 +126,16 @@ export default function SettingGeminiModel(props) {
                 />
               </Col>
             </Row>
+            <Row>
+              <Col xs={24} sm={12} md={8} lg={8} xl={8}>
+                <Form.TextArea
+                  field={'gemini.supported_imagine_models'}
+                  label={t('支持的图像模型')}
+                  placeholder={t('例如:') + '\n' + JSON.stringify(['gemini-2.0-flash-exp-image-generation'], null, 2)}
+                  onChange={(value) => setInputs({ ...inputs, 'gemini.supported_imagine_models': value })}
+                />
+              </Col>
+            </Row>
 
             <Row>
               <Button size='default' onClick={onSubmit}>