Просмотр исходного кода

feat: azure realtime

(cherry picked from commit 75ff3d98f06103dc2df1f8817bd3fcbf433e0f20)
1808837298@qq.com 1 год назад
Родитель
Сommit
8de79382f0
2 измененных файлов с 30 добавлено и 11 удалено
  1. 15 11
      relay/channel/openai/relay-openai.go
  2. 15 0
      service/token_counter.go

+ 15 - 11
relay/channel/openai/relay-openai.go

@@ -478,6 +478,18 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 						usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
 						usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
 						usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
+					} else {
+						textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+						if err != nil {
+							errChan <- fmt.Errorf("error counting text token: %v", err)
+							return
+						}
+						log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
+						localUsage.TotalTokens += textToken + audioToken
+						info.IsFirstRequest = false
+						localUsage.InputTokens += textToken + audioToken
+						localUsage.InputTokenDetails.TextTokens += textToken
+						localUsage.InputTokenDetails.AudioTokens += audioToken
 					}
 				} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
 					realtimeSession := realtimeEvent.Session
@@ -494,17 +506,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 					}
 					log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
 					localUsage.TotalTokens += textToken + audioToken
-
-					if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
-						info.IsFirstRequest = false
-						localUsage.InputTokens += textToken + audioToken
-						localUsage.InputTokenDetails.TextTokens += textToken
-						localUsage.InputTokenDetails.AudioTokens += audioToken
-					} else {
-						localUsage.OutputTokens += textToken + audioToken
-						localUsage.OutputTokenDetails.TextTokens += textToken
-						localUsage.OutputTokenDetails.AudioTokens += audioToken
-					}
+					localUsage.OutputTokens += textToken + audioToken
+					localUsage.OutputTokenDetails.TextTokens += textToken
+					localUsage.OutputTokenDetails.AudioTokens += audioToken
 				}
 
 				err = service.WssString(c, clientConn, string(message))

+ 15 - 0
service/token_counter.go

@@ -225,6 +225,21 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
 			return 0, 0, fmt.Errorf("error counting audio token: %v", err)
 		}
 		audioToken += atk
+	case dto.RealtimeEventConversationItemCreated:
+		if request.Item != nil {
+			switch request.Item.Type {
+			case "message":
+				for _, content := range request.Item.Content {
+					if content.Type == "input_text" {
+						tokens, err := CountTextToken(content.Text, model)
+						if err != nil {
+							return 0, 0, err
+						}
+						textToken += tokens
+					}
+				}
+			}
+		}
 	case dto.RealtimeEventTypeResponseDone:
 		// count tools token
 		if !info.IsFirstRequest {