Explorar el Código

fix: gemini usage (close #354)

CalciumIon hace 1 año
padre
commit
4e7e206290

+ 1 - 4
relay/channel/gemini/adaptor.go

@@ -9,7 +9,6 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
-	"one-api/service"
 )
 
 type Adaptor struct {
@@ -69,9 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		var responseText string
-		err, responseText = geminiChatStreamHandler(c, resp, info)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		err, usage = geminiChatStreamHandler(c, resp, info)
 	} else {
 		err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 7 - 0
relay/channel/gemini/dto.go

@@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
 type GeminiChatResponse struct {
 	Candidates     []GeminiChatCandidate    `json:"candidates"`
 	PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
+	UsageMetadata  GeminiUsageMetadata      `json:"usageMetadata"`
+}
+
+type GeminiUsageMetadata struct {
+	PromptTokenCount     int `json:"promptTokenCount"`
+	CandidatesTokenCount int `json:"candidatesTokenCount"`
+	TotalTokenCount      int `json:"totalTokenCount"`
 }

+ 36 - 12
relay/channel/gemini/relay-gemini.go

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"log"
 	"net/http"
 	"one-api/common"
 	"one-api/constant"
@@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
 	return &response
 }
 
-func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
+func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseText := ""
+	responseJson := ""
+	id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+	createAt := common.GetTimestamp()
+	var usage = &dto.Usage{}
 	dataChan := make(chan string, 5)
 	stopChan := make(chan bool, 2)
 	scanner := bufio.NewScanner(resp.Body)
@@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 	go func() {
 		for scanner.Scan() {
 			data := scanner.Text()
+			responseJson += data
 			data = strings.TrimSpace(data)
 			if !strings.HasPrefix(data, "\"text\": \"") {
 				continue
@@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 			var choice dto.ChatCompletionsStreamResponseChoice
 			choice.Delta.SetContentString(dummy.Content)
 			response := dto.ChatCompletionsStreamResponse{
-				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+				Id:      id,
 				Object:  "chat.completion.chunk",
-				Created: common.GetTimestamp(),
-				Model:   "gemini-pro",
+				Created: createAt,
+				Model:   info.UpstreamModelName,
 				Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
 			}
 			jsonResponse, err := json.Marshal(response)
@@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 			return true
 		case <-stopChan:
-			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 			return false
 		}
 	})
-	err := resp.Body.Close()
+	var geminiChatResponses []GeminiChatResponse
+	err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		log.Printf("cannot get gemini usage: %s", err.Error())
+		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	} else {
+		for _, response := range geminiChatResponses {
+			usage.PromptTokens = response.UsageMetadata.PromptTokenCount
+			usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
+		}
+		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+	}
+	if info.ShouldIncludeUsage {
+		response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
+		err := service.ObjectData(c, response)
+		if err != nil {
+			common.SysError("send final response failed: " + err.Error())
+		}
+	}
+	service.Done(c)
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
 	}
-	return nil, responseText
+	return nil, usage
 }
 
 func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 		}, nil
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
-	completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
 	usage := dto.Usage{
-		PromptTokens:     promptTokens,
-		CompletionTokens: completionTokens,
-		TotalTokens:      promptTokens + completionTokens,
+		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
+		CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
+		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
 	}
 	fullTextResponse.Usage = usage
 	jsonResponse, err := json.Marshal(fullTextResponse)

+ 2 - 1
relay/common/relay_info.go

@@ -67,7 +67,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	if info.ChannelType == common.ChannelTypeAzure {
 		info.ApiVersion = GetAPIVersion(c)
 	}
-	if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || info.ChannelType == common.ChannelTypeAws {
+	if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
+		info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
 		info.SupportStreamOptions = true
 	}
 	return info