Просмотр исходного кода

refactor: simplify streamResponseGeminiChat2OpenAI by removing hasImage return value and optimizing response text handling

CaIon 7 месяцев назад
Родитель
Сommit
f0945da4fb
1 измененных файлов с 25 добавлено и 7 удалено
  1. 25 7
      relay/channel/gemini/relay-gemini.go

+ 25 - 7
relay/channel/gemini/relay-gemini.go

@@ -725,10 +725,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
 	return &fullTextResponse
 	return &fullTextResponse
 }
 }
 
 
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
+func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
 	choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
 	choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
 	isStop := false
 	isStop := false
-	hasImage := false
 	for _, candidate := range geminiResponse.Candidates {
 	for _, candidate := range geminiResponse.Candidates {
 		if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
 		if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
 			isStop = true
 			isStop = true
@@ -759,7 +758,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 				if strings.HasPrefix(part.InlineData.MimeType, "image") {
 				if strings.HasPrefix(part.InlineData.MimeType, "image") {
 					imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
 					imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
 					texts = append(texts, imgText)
 					texts = append(texts, imgText)
-					hasImage = true
 				}
 				}
 			} else if part.FunctionCall != nil {
 			} else if part.FunctionCall != nil {
 				isTools = true
 				isTools = true
@@ -796,7 +794,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 	var response dto.ChatCompletionsStreamResponse
 	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"
 	response.Object = "chat.completion.chunk"
 	response.Choices = choices
 	response.Choices = choices
-	return &response, isStop, hasImage
+	return &response, isStop
 }
 }
 
 
 func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
 func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
@@ -824,6 +822,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 	// responseText := ""
 	// responseText := ""
 	id := helper.GetResponseID(c)
 	id := helper.GetResponseID(c)
 	createAt := common.GetTimestamp()
 	createAt := common.GetTimestamp()
+	responseText := strings.Builder{}
 	var usage = &dto.Usage{}
 	var usage = &dto.Usage{}
 	var imageCount int
 	var imageCount int
 
 
@@ -835,10 +834,19 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 			return false
 			return false
 		}
 		}
 
 
-		response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
-		if hasImage {
-			imageCount++
+		for _, candidate := range geminiResponse.Candidates {
+			for _, part := range candidate.Content.Parts {
+				if part.InlineData != nil && part.InlineData.MimeType != "" {
+					imageCount++
+				}
+				if part.Text != "" {
+					responseText.WriteString(part.Text)
+				}
+			}
 		}
 		}
+
+		response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
+
 		response.Id = id
 		response.Id = id
 		response.Created = createAt
 		response.Created = createAt
 		response.Model = info.UpstreamModelName
 		response.Model = info.UpstreamModelName
@@ -889,6 +897,16 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
 	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
 	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
 	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
 
 
+	if usage.CompletionTokens == 0 {
+		str := responseText.String()
+		if len(str) > 0 {
+			usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
+		} else {
+			// 空补全,不需要使用量
+			usage = &dto.Usage{}
+		}
+	}
+
 	response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
 	response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
 	err := handleFinalStream(c, info, response)
 	err := handleFinalStream(c, info, response)
 	if err != nil {
 	if err != nil {