Yan пре 1 година
родитељ
комит
d40e6ec25d
4 измењених фајлова са 234 додато и 111 уклоњено
  1. 1 3
      common/str.go
  2. 5 2
      constant/finish_reason.go
  3. 37 11
      relay/channel/gemini/dto.go
  4. 191 95
      relay/channel/gemini/relay-gemini.go

+ 1 - 3
common/str.go

@@ -35,9 +35,7 @@ func StrToMap(str string) map[string]interface{} {
 	m := make(map[string]interface{})
 	err := json.Unmarshal([]byte(str), &m)
 	if err != nil {
-		return map[string]interface{}{
-			"result": str,
-		}
+		return nil
 	}
 	return m
 }

+ 5 - 2
constant/finish_reason.go

@@ -1,6 +1,9 @@
 package constant
 
 var (
-	FinishReasonStop      = "stop"
-	FinishReasonToolCalls = "tool_calls"
+	FinishReasonStop          = "stop"
+	FinishReasonToolCalls     = "tool_calls"
+	FinishReasonLength        = "length"
+	FinishReasonFunctionCall  = "function_call"
+	FinishReasonContentFilter = "content_filter"
 )

+ 37 - 11
relay/channel/gemini/dto.go

@@ -4,7 +4,7 @@ type GeminiChatRequest struct {
 	Contents           []GeminiChatContent        `json:"contents"`
 	SafetySettings     []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
 	GenerationConfig   GeminiChatGenerationConfig `json:"generation_config,omitempty"`
-	Tools              []GeminiChatTools          `json:"tools,omitempty"`
+	Tools              []GeminiChatTool           `json:"tools,omitempty"`
 	SystemInstructions *GeminiChatContent         `json:"system_instruction,omitempty"`
 }
 
@@ -18,16 +18,39 @@ type FunctionCall struct {
 	Arguments    any    `json:"args"`
 }
 
+type GeminiFunctionResponseContent struct {
+	Name    string `json:"name"`
+	Content any    `json:"content"`
+}
+
 type FunctionResponse struct {
-	Name     string `json:"name"`
-	Response any    `json:"response"`
+	Name     string                        `json:"name"`
+	Response GeminiFunctionResponseContent `json:"response"`
+}
+
+type GeminiPartExecutableCode struct {
+	Language string `json:"language,omitempty"`
+	Code     string `json:"code,omitempty"`
+}
+
+type GeminiPartCodeExecutionResult struct {
+	Outcome string `json:"outcome,omitempty"`
+	Output  string `json:"output,omitempty"`
+}
+
+type GeminiFileData struct {
+	MimeType string `json:"mimeType,omitempty"`
+	FileUri  string `json:"fileUri,omitempty"`
 }
 
 type GeminiPart struct {
-	Text             string            `json:"text,omitempty"`
-	InlineData       *GeminiInlineData `json:"inlineData,omitempty"`
-	FunctionCall     *FunctionCall     `json:"functionCall,omitempty"`
-	FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
+	Text                string                         `json:"text,omitempty"`
+	InlineData          *GeminiInlineData              `json:"inlineData,omitempty"`
+	FunctionCall        *FunctionCall                  `json:"functionCall,omitempty"`
+	FunctionResponse    *FunctionResponse              `json:"functionResponse,omitempty"`
+	FileData            *GeminiFileData                `json:"fileData,omitempty"`
+	ExecutableCode      *GeminiPartExecutableCode      `json:"executableCode,omitempty"`
+	CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
 }
 
 type GeminiChatContent struct {
@@ -40,9 +63,11 @@ type GeminiChatSafetySettings struct {
 	Threshold string `json:"threshold"`
 }
 
-type GeminiChatTools struct {
-	GoogleSearch         any `json:"googleSearch,omitempty"`
-	FunctionDeclarations any `json:"functionDeclarations,omitempty"`
+type GeminiChatTool struct {
+	GoogleSearch          any `json:"googleSearch,omitempty"`
+	GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
+	CodeExecution         any `json:"codeExecution,omitempty"`
+	FunctionDeclarations  any `json:"functionDeclarations,omitempty"`
 }
 
 type GeminiChatGenerationConfig struct {
@@ -54,11 +79,12 @@ type GeminiChatGenerationConfig struct {
 	StopSequences    []string `json:"stopSequences,omitempty"`
 	ResponseMimeType string   `json:"responseMimeType,omitempty"`
 	ResponseSchema   any      `json:"responseSchema,omitempty"`
+	Seed             int64    `json:"seed,omitempty"`
 }
 
 type GeminiChatCandidate struct {
 	Content       GeminiChatContent        `json:"content"`
-	FinishReason  string                   `json:"finishReason"`
+	FinishReason  *string                  `json:"finishReason"`
 	Index         int64                    `json:"index"`
 	SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
 }

+ 191 - 95
relay/channel/gemini/relay-gemini.go

@@ -18,6 +18,7 @@ import (
 
 // Setting safety to the lowest possible values since Gemini is already powerless enough
 func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
+
 	geminiRequest := GeminiChatRequest{
 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
 		SafetySettings: []GeminiChatSafetySettings{
@@ -46,16 +47,24 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 			Temperature:     textRequest.Temperature,
 			TopP:            textRequest.TopP,
 			MaxOutputTokens: textRequest.MaxTokens,
+			Seed:            int64(textRequest.Seed),
 		},
 	}
+
+	// openaiContent.FuncToToolCalls()
 	if textRequest.Tools != nil {
 		functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
 		googleSearch := false
+		codeExecution := false
 		for _, tool := range textRequest.Tools {
 			if tool.Function.Name == "googleSearch" {
 				googleSearch = true
 				continue
 			}
+			if tool.Function.Name == "codeExecution" {
+				codeExecution = true
+				continue
+			}
 			if tool.Function.Parameters != nil {
 				params, ok := tool.Function.Parameters.(map[string]interface{})
 				if ok {
@@ -68,25 +77,32 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 			}
 			functions = append(functions, tool.Function)
 		}
-		if len(functions) > 0 {
-			geminiRequest.Tools = []GeminiChatTools{
-				{
-					FunctionDeclarations: functions,
-				},
-			}
+		if codeExecution {
+			geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
+				CodeExecution: make(map[string]string),
+			})
 		}
 		if googleSearch {
-			geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTools{
+			geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
 				GoogleSearch: make(map[string]string),
 			})
 		}
+		if len(functions) > 0 {
+			geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
+				FunctionDeclarations: functions,
+			})
+		}
+		// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
+		// json_data, _ := json.Marshal(geminiRequest.Tools)
+		// common.SysLog("tools_json: " + string(json_data))
 	} else if textRequest.Functions != nil {
-		geminiRequest.Tools = []GeminiChatTools{
+		geminiRequest.Tools = []GeminiChatTool{
 			{
 				FunctionDeclarations: textRequest.Functions,
 			},
 		}
 	}
+
 	if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
 		geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
 
@@ -96,20 +112,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 		}
 	}
 	tool_call_ids := make(map[string]string)
+	var system_content []string
 	//shouldAddDummyModelMessage := false
 	for _, message := range textRequest.Messages {
-
 		if message.Role == "system" {
-			geminiRequest.SystemInstructions = &GeminiChatContent{
-				Parts: []GeminiPart{
-					{
-						Text: message.StringContent(),
-					},
-				},
-			}
+			system_content = append(system_content, message.StringContent())
 			continue
-		} else if message.Role == "tool" {
-			if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
+		} else if message.Role == "tool" || message.Role == "function" {
+			if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
 				geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
 					Role: "user",
 				})
@@ -121,9 +131,16 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 			} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
 				name = val
 			}
+			content := common.StrToMap(message.StringContent())
 			functionResp := &FunctionResponse{
-				Name:     name,
-				Response: common.StrToMap(message.StringContent()),
+				Name: name,
+				Response: GeminiFunctionResponseContent{
+					Name:    name,
+					Content: content,
+				},
+			}
+			if content == nil {
+				functionResp.Response.Content = message.StringContent()
 			}
 			*parts = append(*parts, GeminiPart{
 				FunctionResponse: functionResp,
@@ -134,57 +151,65 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 		content := GeminiChatContent{
 			Role: message.Role,
 		}
-		isToolCall := false
+		// isToolCall := false
 		if message.ToolCalls != nil {
-			message.Role = "model"
-			isToolCall = true
+			// message.Role = "model"
+			// isToolCall = true
 			for _, call := range message.ParseToolCalls() {
+				args := map[string]interface{}{}
+				if call.Function.Arguments != "" {
+					if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
+						return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
+					}
+				}
 				toolCall := GeminiPart{
 					FunctionCall: &FunctionCall{
 						FunctionName: call.Function.Name,
-						Arguments:    call.Function.Parameters,
+						Arguments:    args,
 					},
 				}
 				parts = append(parts, toolCall)
 				tool_call_ids[call.ID] = call.Function.Name
 			}
 		}
-		if !isToolCall {
-			openaiContent := message.ParseContent()
-			imageNum := 0
-			for _, part := range openaiContent {
-				if part.Type == dto.ContentTypeText {
+
+		openaiContent := message.ParseContent()
+		imageNum := 0
+		for _, part := range openaiContent {
+			if part.Type == dto.ContentTypeText {
+				if part.Text == "" {
+					continue
+				}
+				parts = append(parts, GeminiPart{
+					Text: part.Text,
+				})
+			} else if part.Type == dto.ContentTypeImageURL {
+				imageNum += 1
+
+				if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
+					return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
+				}
+				// 判断是否是url
+				if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
+					// 是url,获取图片的类型和base64编码的数据
+					mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
 					parts = append(parts, GeminiPart{
-						Text: part.Text,
+						InlineData: &GeminiInlineData{
+							MimeType: mimeType,
+							Data:     data,
+						},
 					})
-				} else if part.Type == dto.ContentTypeImageURL {
-					imageNum += 1
-
-					if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
-						return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
-					}
-					// 判断是否是url
-					if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
-						// 是url,获取图片的类型和base64编码的数据
-						mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
-						parts = append(parts, GeminiPart{
-							InlineData: &GeminiInlineData{
-								MimeType: mimeType,
-								Data:     data,
-							},
-						})
-					} else {
-						_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
-						if err != nil {
-							return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
-						}
-						parts = append(parts, GeminiPart{
-							InlineData: &GeminiInlineData{
-								MimeType: "image/" + format,
-								Data:     base64String,
-							},
-						})
+				} else {
+					_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
+					if err != nil {
+						return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
 					}
+					parts = append(parts, GeminiPart{
+						InlineData: &GeminiInlineData{
+							MimeType: "image/" + format,
+							Data:     base64String,
+						},
+					})
 				}
 			}
 		}
@@ -197,6 +222,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 		}
 		geminiRequest.Contents = append(geminiRequest.Contents, content)
 	}
+
+	if len(system_content) > 0 {
+		geminiRequest.SystemInstructions = &GeminiChatContent{
+			Parts: []GeminiPart{
+				{
+					Text: strings.Join(system_content, "\n"),
+				},
+			},
+		}
+	}
+
 	return &geminiRequest, nil
 }
 
@@ -240,15 +276,15 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
 	return v
 }
 
-func (g *GeminiChatResponse) GetResponseText() string {
-	if g == nil {
-		return ""
-	}
-	if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
-		return g.Candidates[0].Content.Parts[0].Text
-	}
-	return ""
-}
+// func (g *GeminiChatResponse) GetResponseText() string {
+// 	if g == nil {
+// 		return ""
+// 	}
+// 	if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
+// 		return g.Candidates[0].Content.Parts[0].Text
+// 	}
+// 	return ""
+// }
 
 func getToolCall(item *GeminiPart) *dto.ToolCall {
 	argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
@@ -298,11 +334,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
 	}
 	content, _ := json.Marshal("")
-	for i, candidate := range response.Candidates {
-		// jsonData, _ := json.MarshalIndent(candidate, "", "  ")
-		// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
+	is_tool_call := false
+	for _, candidate := range response.Candidates {
 		choice := dto.OpenAITextResponseChoice{
-			Index: i,
+			Index: int(candidate.Index),
 			Message: dto.Message{
 				Role:    "assistant",
 				Content: content,
@@ -319,48 +354,107 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 						tool_calls = append(tool_calls, *call)
 					}
 				} else {
-					texts = append(texts, part.Text)
+					if part.ExecutableCode != nil {
+						texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
+					} else if part.CodeExecutionResult != nil {
+						texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
+					} else {
+						// 过滤掉空行
+						if part.Text != "\n" {
+							texts = append(texts, part.Text)
+						}
+					}
 				}
 			}
+			if len(tool_calls) > 0 {
+				choice.Message.SetToolCalls(tool_calls)
+				is_tool_call = true
+			}
+			// 过滤掉空行
+
 			choice.Message.SetStringContent(strings.Join(texts, "\n"))
-			choice.Message.SetToolCalls(tool_calls)
+
 		}
+		if candidate.FinishReason != nil {
+			switch *candidate.FinishReason {
+			case "STOP":
+				choice.FinishReason = constant.FinishReasonStop
+			case "MAX_TOKENS":
+				choice.FinishReason = constant.FinishReasonLength
+			default:
+				choice.FinishReason = constant.FinishReasonContentFilter
+			}
+		}
+		if is_tool_call {
+			choice.FinishReason = constant.FinishReasonToolCalls
+		}
+
 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
 	}
 	return &fullTextResponse
 }
 
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
-	var choice dto.ChatCompletionsStreamResponseChoice
-	//choice.Delta.SetContentString(geminiResponse.GetResponseText())
-	if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
+func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
+	choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
+	is_stop := false
+	for _, candidate := range geminiResponse.Candidates {
+		if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
+			is_stop = true
+			candidate.FinishReason = nil
+		}
+		choice := dto.ChatCompletionsStreamResponseChoice{
+			Index: int(candidate.Index),
+			Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
+				Role: "assistant",
+			},
+		}
 		var texts []string
-		var tool_calls []dto.ToolCall
-		for _, part := range geminiResponse.Candidates[0].Content.Parts {
+		isTools := false
+		if candidate.FinishReason != nil {
+			// p := GeminiConvertFinishReason(*candidate.FinishReason)
+			switch *candidate.FinishReason {
+			case "STOP":
+				choice.FinishReason = &constant.FinishReasonStop
+			case "MAX_TOKENS":
+				choice.FinishReason = &constant.FinishReasonLength
+			default:
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			}
+		}
+		for _, part := range candidate.Content.Parts {
 			if part.FunctionCall != nil {
+				isTools = true
 				if call := getToolCall(&part); call != nil {
-					tool_calls = append(tool_calls, *call)
+					choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
 				}
 			} else {
-				texts = append(texts, part.Text)
+				if part.ExecutableCode != nil {
+					texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
+				} else if part.CodeExecutionResult != nil {
+					texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
+				} else {
+					if part.Text != "\n" {
+						texts = append(texts, part.Text)
+					}
+				}
 			}
 		}
-		if len(texts) > 0 {
-			choice.Delta.SetContentString(strings.Join(texts, "\n"))
-		}
-		if len(tool_calls) > 0 {
-			choice.Delta.ToolCalls = tool_calls
+		choice.Delta.SetContentString(strings.Join(texts, "\n"))
+		if isTools {
+			choice.FinishReason = &constant.FinishReasonToolCalls
 		}
+		choices = append(choices, choice)
 	}
+
 	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
 	response.Model = "gemini"
-	response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
-	return &response
+	response.Choices = choices
+	return &response, is_stop
 }
 
 func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	responseText := ""
+	// responseText := ""
 	id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	createAt := common.GetTimestamp()
 	var usage = &dto.Usage{}
@@ -384,14 +478,11 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 			continue
 		}
 
-		response := streamResponseGeminiChat2OpenAI(&geminiResponse)
-		if response == nil {
-			continue
-		}
+		response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
 		response.Id = id
 		response.Created = createAt
 		response.Model = info.UpstreamModelName
-		responseText += response.Choices[0].Delta.GetContentString()
+		// responseText += response.Choices[0].Delta.GetContentString()
 		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
 			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
 			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
@@ -400,12 +491,17 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 		if err != nil {
 			common.LogError(c, err.Error())
 		}
+		if is_stop {
+			response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
+			service.ObjectData(c, response)
+		}
 	}
 
-	response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
-	service.ObjectData(c, response)
+	var response *dto.ChatCompletionsStreamResponse
 
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+	usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
 
 	if info.ShouldIncludeUsage {
 		response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)