Procházet zdrojové kódy

refactor: token counter logic

CaIon před 8 měsíci
rodič
revize
a9e5d99ea3

+ 1 - 1
relay/audio_handler.go

@@ -66,7 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	promptTokens := 0
 	preConsumedTokens := common.PreConsumedQuota
 	if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
-		promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
+		promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
 		}

+ 3 - 3
relay/channel/claude/relay-claude.go

@@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
 
 	if requestMode == RequestModeCompletion {
-		claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+		claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
 	} else {
 		if claudeInfo.Usage.PromptTokens == 0 {
 			//上游出错
@@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
 			if common.DebugEnabled {
 				common.SysError("claude response usage is not complete, maybe upstream error")
 			}
-			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+			claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
 		}
 	}
 
@@ -618,7 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		}
 	}
 	if requestMode == RequestModeCompletion {
-		completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
+		completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
 		}

+ 3 - 3
relay/channel/cloudflare/relay_cloudflare.go

@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	if err := scanner.Err(); err != nil {
 		common.LogError(c, "error_scanning_stream_response: "+err.Error())
 	}
-	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	if info.ShouldIncludeUsage {
 		response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
 		err := helper.ObjectData(c, response)
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
 	for _, choice := range response.Choices {
 		responseText += choice.Message.StringContent()
 	}
-	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	response.Usage = *usage
 	response.Id = helper.GetResponseID(c)
 	jsonResponse, err := json.Marshal(response)
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
 
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
+	usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 
 	return nil, usage

+ 1 - 1
relay/channel/cohere/relay-cohere.go

@@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		}
 	})
 	if usage.PromptTokens == 0 {
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	}
 	return nil, usage
 }

+ 1 - 1
relay/channel/coze/relay-coze.go

@@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
 
 	if usage.TotalTokens == 0 {
 		usage.PromptTokens = info.PromptTokens
-		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
+		usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText)
 		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	}
 

+ 1 - 1
relay/channel/dify/relay-dify.go

@@ -250,7 +250,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	}
 	if usage.TotalTokens == 0 {
 		usage.PromptTokens = info.PromptTokens
-		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
+		usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText)
 		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	}
 	usage.CompletionTokens += nodeToken

+ 5 - 12
relay/channel/gemini/relay-gemini-native.go

@@ -9,6 +9,7 @@ import (
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/service"
+	"strings"
 
 	"github.com/gin-gonic/gin"
 )
@@ -75,8 +76,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
 
 	helper.SetEventStreamHeaders(c)
 
-	// 本地统计的completion tokens
-	localCompletionTokens := 0
+	responseText := strings.Builder{}
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse GeminiChatResponse
@@ -92,12 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
 				if part.InlineData != nil && part.InlineData.MimeType != "" {
 					imageCount++
 				}
-				// 本地统计completion tokens
-				textTokens, err := service.CountTextToken(part.Text, info.UpstreamModelName)
-				if err != nil {
-					common.LogError(c, "error counting text token: "+err.Error())
+				if part.Text != "" {
+					responseText.WriteString(part.Text)
 				}
-				localCompletionTokens += textTokens
 			}
 		}
 
@@ -133,13 +130,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
 
 	// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
 	if usage.CompletionTokens == 0 {
-		usage.CompletionTokens = localCompletionTokens
-		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+		usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
 	}
 
-	// 计算最终使用量
-	// usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-
 	// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
 	//helper.Done(c)
 

+ 9 - 9
relay/channel/openai/relay-openai.go

@@ -8,7 +8,6 @@ import (
 	"math"
 	"mime/multipart"
 	"net/http"
-	"path/filepath"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
@@ -16,6 +15,7 @@ import (
 	"one-api/relay/helper"
 	"one-api/service"
 	"os"
+	"path/filepath"
 	"strings"
 
 	"github.com/bytedance/gopkg/util/gopool"
@@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	}
 
 	if !containStreamUsage {
-		usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
 		usage.CompletionTokens += toolCount * 7
 	} else {
 		if info.ChannelType == common.ChannelTypeDeepSeek {
@@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 			StatusCode: resp.StatusCode,
 		}, nil
 	}
-	
+
 	forceFormat := false
 	if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
 		forceFormat = forceFmt
@@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
 		completionTokens := 0
 		for _, choice := range simpleResponse.Choices {
-			ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
+			ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
 			completionTokens += ctkm
 		}
 		simpleResponse.Usage = dto.Usage{
@@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	// the status code has been judged before, if there is a body reading failure,
 	// it should be regarded as a non-recoverable error, so it should not return err for external retry.
-	// Analogous to nginx's load balancing, it will only retry if it can't be requested or 
-	// if the upstream returns a specific status code, once the upstream has already written the header, 
-	// the subsequent failure of the response body should be regarded as a non-recoverable error, 
+	// Analogous to nginx's load balancing, it will only retry if it can't be requested or
+	// if the upstream returns a specific status code, once the upstream has already written the header,
+	// the subsequent failure of the response body should be regarded as a non-recoverable error,
 	// and can be terminated directly.
 	defer resp.Body.Close()
 	usage := &dto.Usage{}
@@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) {
 	if err = c.ShouldBind(&reqBody); err != nil {
 		return 0, errors.WithStack(err)
 	}
-  ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
+	ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
 	reqFp, err := reqBody.File.Open()
 	if err != nil {
 		return 0, errors.WithStack(err)
 	}
-  defer reqFp.Close()
+	defer reqFp.Close()
 
 	tmpFp, err := os.CreateTemp("", "audio-*"+ext)
 	if err != nil {

+ 1 - 1
relay/channel/openai/relay_responses.go

@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
 		tempStr := responseTextBuilder.String()
 		if len(tempStr) > 0 {
 			// 非正常结束,使用输出文本的 token 数量
-			completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
+			completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
 			usage.CompletionTokens = completionTokens
 		}
 	}

+ 1 - 1
relay/channel/palm/adaptor.go

@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		var responseText string
 		err, responseText = palmStreamHandler(c, resp)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 		err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 1 - 1
relay/channel/palm/relay-palm.go

@@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 		}, nil
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
+	completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,

+ 1 - 1
relay/channel/tencent/adaptor.go

@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 		var responseText string
 		err, responseText = tencentStreamHandler(c, resp)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 		err, usage = tencentHandler(c, resp)
 	}

+ 1 - 1
relay/channel/xai/text.go

@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	})
 
 	if !containStreamUsage {
-		usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
 		usage.CompletionTokens += toolCount * 7
 	}
 

+ 1 - 1
relay/embedding_handler.go

@@ -15,7 +15,7 @@ import (
 )
 
 func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
-	token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
+	token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
 	return token
 }
 

+ 4 - 4
relay/gemini_handler.go

@@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
 	return sensitiveWords, err
 }
 
-func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
+func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
 	// 计算输入 token 数量
 	var inputTexts []string
 	for _, content := range req.Contents {
@@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
 	}
 
 	inputText := strings.Join(inputTexts, "\n")
-	inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
+	inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
 	info.PromptTokens = inputTokens
-	return inputTokens, err
+	return inputTokens
 }
 
 func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		promptTokens := value.(int)
 		relayInfo.SetPromptTokens(promptTokens)
 	} else {
-		promptTokens, err := getGeminiInputTokens(req, relayInfo)
+		promptTokens := getGeminiInputTokens(req, relayInfo)
 		if err != nil {
 			return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
 		}

+ 3 - 3
relay/relay-text.go

@@ -251,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
 	case relayconstant.RelayModeChatCompletions:
 		promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
 	case relayconstant.RelayModeCompletions:
-		promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
+		promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
 	case relayconstant.RelayModeModerations:
-		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
+		promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
 	case relayconstant.RelayModeEmbeddings:
-		promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
+		promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
 	default:
 		err = errors.New("unknown relay mode")
 		promptTokens = 0

+ 3 - 5
relay/rerank_handler.go

@@ -14,12 +14,10 @@ import (
 )
 
 func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
-	token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
+	token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
 	for _, document := range rerankRequest.Documents {
-		tkm, err := service.CountTokenInput(document, rerankRequest.Model)
-		if err == nil {
-			token += tkm
-		}
+		tkm := service.CountTokenInput(document, rerankRequest.Model)
+		token += tkm
 	}
 	return token
 }

+ 4 - 4
relay/responses_handler.go

@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
 	return sensitiveWords, err
 }
 
-func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
-	inputTokens, err := service.CountTokenInput(req.Input, req.Model)
+func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
+	inputTokens := service.CountTokenInput(req.Input, req.Model)
 	info.PromptTokens = inputTokens
-	return inputTokens, err
+	return inputTokens
 }
 
 func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -72,7 +72,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
 		promptTokens := value.(int)
 		relayInfo.SetPromptTokens(promptTokens)
 	} else {
-		promptTokens, err := getInputTokens(req, relayInfo)
+		promptTokens := getInputTokens(req, relayInfo)
 		if err != nil {
 			return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
 		}

+ 17 - 27
service/token_counter.go

@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
 				countStr += fmt.Sprintf("%v", tool.Function.Parameters)
 			}
 		}
-		toolTokens, err := CountTokenInput(countStr, request.Model)
+		toolTokens := CountTokenInput(countStr, request.Model)
 		if err != nil {
 			return 0, err
 		}
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
 
 	// Count tokens in system message
 	if request.System != "" {
-		systemTokens, err := CountTokenInput(request.System, model)
+		systemTokens := CountTokenInput(request.System, model)
 		if err != nil {
 			return 0, err
 		}
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
 	switch request.Type {
 	case dto.RealtimeEventTypeSessionUpdate:
 		if request.Session != nil {
-			msgTokens, err := CountTextToken(request.Session.Instructions, model)
-			if err != nil {
-				return 0, 0, err
-			}
+			msgTokens := CountTextToken(request.Session.Instructions, model)
 			textToken += msgTokens
 		}
 	case dto.RealtimeEventResponseAudioDelta:
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
 		audioToken += atk
 	case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
 		// count text token
-		tkm, err := CountTextToken(request.Delta, model)
-		if err != nil {
-			return 0, 0, fmt.Errorf("error counting text token: %v", err)
-		}
+		tkm := CountTextToken(request.Delta, model)
 		textToken += tkm
 	case dto.RealtimeEventInputAudioBufferAppend:
 		// count audio token
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
 			case "message":
 				for _, content := range request.Item.Content {
 					if content.Type == "input_text" {
-						tokens, err := CountTextToken(content.Text, model)
-						if err != nil {
-							return 0, 0, err
-						}
+						tokens := CountTextToken(content.Text, model)
 						textToken += tokens
 					}
 				}
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
 		if !info.IsFirstRequest {
 			if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
 				for _, tool := range info.RealtimeTools {
-					toolTokens, err := CountTokenInput(tool, model)
-					if err != nil {
-						return 0, 0, err
-					}
+					toolTokens := CountTokenInput(tool, model)
 					textToken += 8
 					textToken += toolTokens
 				}
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
 	return tokenNum, nil
 }
 
-func CountTokenInput(input any, model string) (int, error) {
+func CountTokenInput(input any, model string) int {
 	switch v := input.(type) {
 	case string:
 		return CountTextToken(v, model)
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
 func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
 	tokens := 0
 	for _, message := range messages {
-		tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
+		tkm := CountTokenInput(message.Delta.GetContentString(), model)
 		tokens += tkm
 		if message.Delta.ToolCalls != nil {
 			for _, tool := range message.Delta.ToolCalls {
-				tkm, _ := CountTokenInput(tool.Function.Name, model)
+				tkm := CountTokenInput(tool.Function.Name, model)
 				tokens += tkm
-				tkm, _ = CountTokenInput(tool.Function.Arguments, model)
+				tkm = CountTokenInput(tool.Function.Arguments, model)
 				tokens += tkm
 			}
 		}
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
 	return tokens
 }
 
-func CountTTSToken(text string, model string) (int, error) {
+func CountTTSToken(text string, model string) int {
 	if strings.HasPrefix(model, "tts") {
-		return utf8.RuneCountInString(text), nil
+		return utf8.RuneCountInString(text)
 	} else {
 		return CountTextToken(text, model)
 	}
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
 //}
 
 // CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTextToken(text string, model string) (int, error) {
-	var err error
+func CountTextToken(text string, model string) int {
+	if text == "" {
+		return 0
+	}
 	tokenEncoder := getTokenEncoder(model)
-	return getTokenNum(tokenEncoder, text), err
+	return getTokenNum(tokenEncoder, text)
 }

+ 3 - 3
service/usage_helpr.go

@@ -16,13 +16,13 @@ import (
 //	return 0, errors.New("unknown relay mode")
 //}
 
-func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
 	usage := &dto.Usage{}
 	usage.PromptTokens = promptTokens
-	ctkm, err := CountTextToken(responseText, modeName)
+	ctkm := CountTextToken(responseText, modeName)
 	usage.CompletionTokens = ctkm
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
-	return usage, err
+	return usage
 }
 
 func ValidUsage(usage *dto.Usage) bool {