|
|
@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
|
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
|
|
}
|
|
|
}
|
|
|
- toolTokens, err := CountTokenInput(countStr, request.Model)
|
|
|
+ toolTokens := CountTokenInput(countStr, request.Model)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
|
|
|
|
|
// Count tokens in system message
|
|
|
if request.System != "" {
|
|
|
- systemTokens, err := CountTokenInput(request.System, model)
|
|
|
+ systemTokens := CountTokenInput(request.System, model)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|
|
switch request.Type {
|
|
|
case dto.RealtimeEventTypeSessionUpdate:
|
|
|
if request.Session != nil {
|
|
|
- msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
|
|
- if err != nil {
|
|
|
- return 0, 0, err
|
|
|
- }
|
|
|
+ msgTokens := CountTextToken(request.Session.Instructions, model)
|
|
|
textToken += msgTokens
|
|
|
}
|
|
|
case dto.RealtimeEventResponseAudioDelta:
|
|
|
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|
|
audioToken += atk
|
|
|
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
|
|
// count text token
|
|
|
- tkm, err := CountTextToken(request.Delta, model)
|
|
|
- if err != nil {
|
|
|
- return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
|
|
- }
|
|
|
+ tkm := CountTextToken(request.Delta, model)
|
|
|
textToken += tkm
|
|
|
case dto.RealtimeEventInputAudioBufferAppend:
|
|
|
// count audio token
|
|
|
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|
|
case "message":
|
|
|
for _, content := range request.Item.Content {
|
|
|
if content.Type == "input_text" {
|
|
|
- tokens, err := CountTextToken(content.Text, model)
|
|
|
- if err != nil {
|
|
|
- return 0, 0, err
|
|
|
- }
|
|
|
+ tokens := CountTextToken(content.Text, model)
|
|
|
textToken += tokens
|
|
|
}
|
|
|
}
|
|
|
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|
|
if !info.IsFirstRequest {
|
|
|
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
|
|
for _, tool := range info.RealtimeTools {
|
|
|
- toolTokens, err := CountTokenInput(tool, model)
|
|
|
- if err != nil {
|
|
|
- return 0, 0, err
|
|
|
- }
|
|
|
+ toolTokens := CountTokenInput(tool, model)
|
|
|
textToken += 8
|
|
|
textToken += toolTokens
|
|
|
}
|
|
|
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
|
|
return tokenNum, nil
|
|
|
}
|
|
|
|
|
|
-func CountTokenInput(input any, model string) (int, error) {
|
|
|
+func CountTokenInput(input any, model string) int {
|
|
|
switch v := input.(type) {
|
|
|
case string:
|
|
|
return CountTextToken(v, model)
|
|
|
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
|
|
|
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
|
|
tokens := 0
|
|
|
for _, message := range messages {
|
|
|
- tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
|
|
+ tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
|
|
tokens += tkm
|
|
|
if message.Delta.ToolCalls != nil {
|
|
|
for _, tool := range message.Delta.ToolCalls {
|
|
|
- tkm, _ := CountTokenInput(tool.Function.Name, model)
|
|
|
+ tkm := CountTokenInput(tool.Function.Name, model)
|
|
|
tokens += tkm
|
|
|
- tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
|
|
+ tkm = CountTokenInput(tool.Function.Arguments, model)
|
|
|
tokens += tkm
|
|
|
}
|
|
|
}
|
|
|
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
|
|
return tokens
|
|
|
}
|
|
|
|
|
|
-func CountTTSToken(text string, model string) (int, error) {
|
|
|
+func CountTTSToken(text string, model string) int {
|
|
|
if strings.HasPrefix(model, "tts") {
|
|
|
- return utf8.RuneCountInString(text), nil
|
|
|
+ return utf8.RuneCountInString(text)
|
|
|
} else {
|
|
|
return CountTextToken(text, model)
|
|
|
}
|
|
|
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
|
|
//}
|
|
|
|
|
|
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
|
|
-func CountTextToken(text string, model string) (int, error) {
|
|
|
- var err error
|
|
|
+func CountTextToken(text string, model string) int {
|
|
|
+ if text == "" {
|
|
|
+ return 0
|
|
|
+ }
|
|
|
tokenEncoder := getTokenEncoder(model)
|
|
|
- return getTokenNum(tokenEncoder, text), err
|
|
|
+ return getTokenNum(tokenEncoder, text)
|
|
|
}
|