Procházet zdrojové kódy

Merge pull request #551 from Calcium-Ion/realtime

feat: support openai realtime api
Calcium-Ion před 1 rokem
rodič
revize
a859ff5985
48 změnil soubory, kde provedl 1266 přidání a 198 odebrání
  1. 28 0
      common/model-ratio.go
  2. 10 5
      controller/channel-test.go
  3. 78 0
      controller/relay.go
  4. 97 0
      dto/realtime.go
  5. 19 0
      middleware/auth.go
  6. 4 0
      middleware/distributor.go
  7. 3 3
      relay/channel/adapter.go
  8. 6 6
      relay/channel/ali/adaptor.go
  9. 30 6
      relay/channel/api_request.go
  10. 3 3
      relay/channel/aws/adaptor.go
  11. 4 4
      relay/channel/baidu/adaptor.go
  12. 5 5
      relay/channel/claude/adaptor.go
  13. 1 1
      relay/channel/claude/relay-claude.go
  14. 4 4
      relay/channel/cloudflare/adaptor.go
  15. 1 1
      relay/channel/cloudflare/relay_cloudflare.go
  16. 4 4
      relay/channel/cohere/adaptor.go
  17. 4 4
      relay/channel/dify/adaptor.go
  18. 1 1
      relay/channel/dify/relay-dify.go
  19. 4 4
      relay/channel/gemini/adaptor.go
  20. 4 4
      relay/channel/jina/adaptor.go
  21. 4 4
      relay/channel/mistral/adaptor.go
  22. 3 3
      relay/channel/ollama/adaptor.go
  23. 39 8
      relay/channel/openai/adaptor.go
  24. 210 2
      relay/channel/openai/relay-openai.go
  25. 4 4
      relay/channel/palm/adaptor.go
  26. 1 1
      relay/channel/palm/relay-palm.go
  27. 4 4
      relay/channel/perplexity/adaptor.go
  28. 4 4
      relay/channel/siliconflow/adaptor.go
  29. 7 7
      relay/channel/tencent/adaptor.go
  30. 4 4
      relay/channel/vertex/adaptor.go
  31. 3 3
      relay/channel/xunfei/adaptor.go
  32. 4 4
      relay/channel/zhipu/adaptor.go
  33. 4 4
      relay/channel/zhipu_4v/adaptor.go
  34. 18 0
      relay/common/relay_info.go
  35. 4 0
      relay/constant/relay_mode.go
  36. 14 8
      relay/relay-audio.go
  37. 7 5
      relay/relay-image.go
  38. 13 9
      relay/relay-text.go
  39. 14 7
      relay/relay_rerank.go
  40. 159 0
      relay/websocket.go
  41. 34 25
      router/relay-router.go
  42. 31 0
      service/audio.go
  43. 11 0
      service/log.go
  44. 140 0
      service/quota.go
  45. 38 1
      service/relay.go
  46. 101 6
      service/token_counter.go
  47. 1 1
      service/usage_helpr.go
  48. 80 29
      web/src/components/LogsTable.js

+ 28 - 0
common/model-ratio.go

@@ -421,6 +421,34 @@ func GetCompletionRatio(name string) float64 {
 	return 1
 	return 1
 }
 }
 
 
+func GetAudioRatio(name string) float64 {
+	if strings.HasPrefix(name, "gpt-4o-realtime") {
+		return 20
+	}
+	return 20
+}
+
+func GetAudioCompletionRatio(name string) float64 {
+	if strings.HasPrefix(name, "gpt-4o-realtime") {
+		return 10
+	}
+	return 2
+}
+
+//func GetAudioPricePerMinute(name string) float64 {
+//	if strings.HasPrefix(name, "gpt-4o-realtime") {
+//		return 0.06
+//	}
+//	return 0.06
+//}
+//
+//func GetAudioCompletionPricePerMinute(name string) float64 {
+//	if strings.HasPrefix(name, "gpt-4o-realtime") {
+//		return 0.24
+//	}
+//	return 0.24
+//}
+
 func GetCompletionRatioMap() map[string]float64 {
 func GetCompletionRatioMap() map[string]float64 {
 	if CompletionRatio == nil {
 	if CompletionRatio == nil {
 		CompletionRatio = defaultCompletionRatio
 		CompletionRatio = defaultCompletionRatio

+ 10 - 5
controller/channel-test.go

@@ -102,17 +102,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	if err != nil {
 	if err != nil {
 		return err, nil
 		return err, nil
 	}
 	}
-	if resp != nil && resp.StatusCode != http.StatusOK {
-		err := service.RelayErrorHandler(resp)
-		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
+	var httpResp *http.Response
+	if resp != nil {
+		httpResp = resp.(*http.Response)
+		if httpResp.StatusCode != http.StatusOK {
+			err := service.RelayErrorHandler(httpResp)
+			return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
+		}
 	}
 	}
-	usage, respErr := adaptor.DoResponse(c, resp, meta)
+	usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
 	if respErr != nil {
 	if respErr != nil {
 		return fmt.Errorf("%s", respErr.Error.Message), respErr
 		return fmt.Errorf("%s", respErr.Error.Message), respErr
 	}
 	}
-	if usage == nil {
+	if usageA == nil {
 		return errors.New("usage is nil"), nil
 		return errors.New("usage is nil"), nil
 	}
 	}
+	usage := usageA.(*dto.Usage)
 	result := w.Result()
 	result := w.Result()
 	respBody, err := io.ReadAll(result.Body)
 	respBody, err := io.ReadAll(result.Body)
 	if err != nil {
 	if err != nil {

+ 78 - 0
controller/relay.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 	"io"
 	"io"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
@@ -38,6 +39,15 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	return err
 	return err
 }
 }
 
 
+func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
+	var err *dto.OpenAIErrorWithStatusCode
+	switch relayMode {
+	default:
+		err = relay.TextHelper(c)
+	}
+	return err
+}
+
 func Playground(c *gin.Context) {
 func Playground(c *gin.Context) {
 	var openaiErr *dto.OpenAIErrorWithStatusCode
 	var openaiErr *dto.OpenAIErrorWithStatusCode
 
 
@@ -134,6 +144,67 @@ func Relay(c *gin.Context) {
 	}
 	}
 }
 }
 
 
+var upgrader = websocket.Upgrader{
+	Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
+	CheckOrigin: func(r *http.Request) bool {
+		return true // 允许跨域
+	},
+}
+
+func WssRelay(c *gin.Context) {
+	// 将 HTTP 连接升级为 WebSocket 连接
+
+	ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
+	defer ws.Close()
+
+	if err != nil {
+		openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
+		service.WssError(c, ws, openaiErr.Error)
+		return
+	}
+
+	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+	requestId := c.GetString(common.RequestIdKey)
+	group := c.GetString("group")
+	//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
+	originalModel := c.GetString("original_model")
+	var openaiErr *dto.OpenAIErrorWithStatusCode
+
+	for i := 0; i <= common.RetryTimes; i++ {
+		channel, err := getChannel(c, group, originalModel, i)
+		if err != nil {
+			common.LogError(c, err.Error())
+			openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
+			break
+		}
+
+		openaiErr = wssRequest(c, ws, relayMode, channel)
+
+		if openaiErr == nil {
+			return // 成功处理请求,直接返回
+		}
+
+		go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
+
+		if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
+			break
+		}
+	}
+	useChannel := c.GetStringSlice("use_channel")
+	if len(useChannel) > 1 {
+		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+		common.LogInfo(c, retryLogStr)
+	}
+
+	if openaiErr != nil {
+		if openaiErr.StatusCode == http.StatusTooManyRequests {
+			openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
+		}
+		openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
+		service.WssError(c, ws, openaiErr.Error)
+	}
+}
+
 func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
 func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
 	addUsedChannel(c, channel.Id)
 	addUsedChannel(c, channel.Id)
 	requestBody, _ := common.GetRequestBody(c)
 	requestBody, _ := common.GetRequestBody(c)
@@ -141,6 +212,13 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op
 	return relayHandler(c, relayMode)
 	return relayHandler(c, relayMode)
 }
 }
 
 
+func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
+	addUsedChannel(c, channel.Id)
+	requestBody, _ := common.GetRequestBody(c)
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	return relay.WssHelper(c, ws)
+}
+
 func addUsedChannel(c *gin.Context, channelId int) {
 func addUsedChannel(c *gin.Context, channelId int) {
 	useChannel := c.GetStringSlice("use_channel")
 	useChannel := c.GetStringSlice("use_channel")
 	useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
 	useChannel = append(useChannel, fmt.Sprintf("%d", channelId))

+ 97 - 0
dto/realtime.go

@@ -0,0 +1,97 @@
+package dto
+
+const (
+	RealtimeEventTypeError              = "error"
+	RealtimeEventTypeSessionUpdate      = "session.update"
+	RealtimeEventTypeConversationCreate = "conversation.item.create"
+	RealtimeEventTypeResponseCreate     = "response.create"
+	RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
+)
+
+const (
+	RealtimeEventTypeResponseDone                   = "response.done"
+	RealtimeEventTypeSessionUpdated                 = "session.updated"
+	RealtimeEventTypeSessionCreated                 = "session.created"
+	RealtimeEventResponseAudioDelta                 = "response.audio.delta"
+	RealtimeEventResponseAudioTranscriptionDelta    = "response.audio_transcript.delta"
+	RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
+	RealtimeEventResponseFunctionCallArgumentsDone  = "response.function_call_arguments.done"
+	RealtimeEventConversationItemCreated            = "conversation.item.created"
+)
+
+type RealtimeEvent struct {
+	EventId string `json:"event_id"`
+	Type    string `json:"type"`
+	//PreviousItemId string `json:"previous_item_id"`
+	Session  *RealtimeSession  `json:"session,omitempty"`
+	Item     *RealtimeItem     `json:"item,omitempty"`
+	Error    *OpenAIError      `json:"error,omitempty"`
+	Response *RealtimeResponse `json:"response,omitempty"`
+	Delta    string            `json:"delta,omitempty"`
+	Audio    string            `json:"audio,omitempty"`
+}
+
+type RealtimeResponse struct {
+	Usage *RealtimeUsage `json:"usage"`
+}
+
+type RealtimeUsage struct {
+	TotalTokens        int                `json:"total_tokens"`
+	InputTokens        int                `json:"input_tokens"`
+	OutputTokens       int                `json:"output_tokens"`
+	InputTokenDetails  InputTokenDetails  `json:"input_token_details"`
+	OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
+}
+
+type InputTokenDetails struct {
+	CachedTokens int `json:"cached_tokens"`
+	TextTokens   int `json:"text_tokens"`
+	AudioTokens  int `json:"audio_tokens"`
+}
+
+type OutputTokenDetails struct {
+	TextTokens  int `json:"text_tokens"`
+	AudioTokens int `json:"audio_tokens"`
+}
+
+type RealtimeSession struct {
+	Modalities              []string                `json:"modalities"`
+	Instructions            string                  `json:"instructions"`
+	Voice                   string                  `json:"voice"`
+	InputAudioFormat        string                  `json:"input_audio_format"`
+	OutputAudioFormat       string                  `json:"output_audio_format"`
+	InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
+	TurnDetection           interface{}             `json:"turn_detection"`
+	Tools                   []RealTimeTool          `json:"tools"`
+	ToolChoice              string                  `json:"tool_choice"`
+	Temperature             float64                 `json:"temperature"`
+	//MaxResponseOutputTokens int                     `json:"max_response_output_tokens"`
+}
+
+type InputAudioTranscription struct {
+	Model string `json:"model"`
+}
+
+type RealTimeTool struct {
+	Type        string `json:"type"`
+	Name        string `json:"name"`
+	Description string `json:"description"`
+	Parameters  any    `json:"parameters"`
+}
+
+type RealtimeItem struct {
+	Id        string            `json:"id"`
+	Type      string            `json:"type"`
+	Status    string            `json:"status"`
+	Role      string            `json:"role"`
+	Content   []RealtimeContent `json:"content"`
+	Name      *string           `json:"name,omitempty"`
+	ToolCalls any               `json:"tool_calls,omitempty"`
+	CallId    string            `json:"call_id,omitempty"`
+}
+type RealtimeContent struct {
+	Type       string `json:"type"`
+	Text       string `json:"text,omitempty"`
+	Audio      string `json:"audio,omitempty"` // Base64-encoded audio bytes.
+	Transcript string `json:"transcript,omitempty"`
+}

+ 19 - 0
middleware/auth.go

@@ -155,8 +155,27 @@ func RootAuth() func(c *gin.Context) {
 	}
 	}
 }
 }
 
 
+func WssAuth(c *gin.Context) {
+
+}
+
 func TokenAuth() func(c *gin.Context) {
 func TokenAuth() func(c *gin.Context) {
 	return func(c *gin.Context) {
 	return func(c *gin.Context) {
+		// 先检测是否为ws
+		if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
+			// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
+			// read sk from Sec-WebSocket-Protocol
+			key := c.Request.Header.Get("Sec-WebSocket-Protocol")
+			parts := strings.Split(key, ",")
+			for _, part := range parts {
+				part = strings.TrimSpace(part)
+				if strings.HasPrefix(part, "openai-insecure-api-key") {
+					key = strings.TrimPrefix(part, "openai-insecure-api-key.")
+					break
+				}
+			}
+			c.Request.Header.Set("Authorization", "Bearer "+key)
+		}
 		key := c.Request.Header.Get("Authorization")
 		key := c.Request.Header.Get("Authorization")
 		parts := make([]string, 0)
 		parts := make([]string, 0)
 		key = strings.TrimPrefix(key, "Bearer ")
 		key = strings.TrimPrefix(key, "Bearer ")

+ 4 - 0
middleware/distributor.go

@@ -170,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
 		abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
 		return nil, false, errors.New("无效的请求, " + err.Error())
 		return nil, false, errors.New("无效的请求, " + err.Error())
 	}
 	}
+	if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
+		//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
+		modelRequest.Model = c.Query("model")
+	}
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 		if modelRequest.Model == "" {
 		if modelRequest.Model == "" {
 			modelRequest.Model = "text-moderation-stable"
 			modelRequest.Model = "text-moderation-stable"

+ 3 - 3
relay/channel/adapter.go

@@ -12,13 +12,13 @@ type Adaptor interface {
 	// Init IsStream bool
 	// Init IsStream bool
 	Init(info *relaycommon.RelayInfo)
 	Init(info *relaycommon.RelayInfo)
 	GetRequestURL(info *relaycommon.RelayInfo) (string, error)
 	GetRequestURL(info *relaycommon.RelayInfo) (string, error)
-	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
+	SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
 	ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
 	ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
 	ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
 	ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
 	ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
 	ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
 	ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
 	ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
-	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
-	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
+	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
+	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
 	GetModelList() []string
 	GetModelList() []string
 	GetChannelName() string
 	GetChannelName() string
 }
 }

+ 6 - 6
relay/channel/ali/adaptor.go

@@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fullRequestURL, nil
 	return fullRequestURL, nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
 	if info.IsStream {
 	if info.IsStream {
-		req.Header.Set("X-DashScope-SSE", "enable")
+		req.Set("X-DashScope-SSE", "enable")
 	}
 	}
 	if c.GetString("plugin") != "" {
 	if c.GetString("plugin") != "" {
-		req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
+		req.Set("X-DashScope-Plugin", c.GetString("plugin"))
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 	return nil, errors.New("not implemented")
 	return nil, errors.New("not implemented")
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	switch info.RelayMode {
 	switch info.RelayMode {
 	case constant.RelayModeImagesGenerations:
 	case constant.RelayModeImagesGenerations:
 		err, usage = aliImageHandler(c, resp, info)
 		err, usage = aliImageHandler(c, resp, info)

+ 30 - 6
relay/channel/api_request.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/relay/common"
 	"one-api/relay/common"
@@ -11,14 +12,16 @@ import (
 	"one-api/service"
 	"one-api/service"
 )
 )
 
 
-func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
+func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
 	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
 	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
 		// multipart/form-data
 		// multipart/form-data
+	} else if info.RelayMode == constant.RelayModeRealtime {
+		// websocket
 	} else {
 	} else {
-		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
-		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+		req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+		req.Set("Accept", c.Request.Header.Get("Accept"))
 		if info.IsStream && c.Request.Header.Get("Accept") == "" {
 		if info.IsStream && c.Request.Header.Get("Accept") == "" {
-			req.Header.Set("Accept", "text/event-stream")
+			req.Set("Accept", "text/event-stream")
 		}
 		}
 	}
 	}
 }
 }
@@ -32,7 +35,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("new request failed: %w", err)
 		return nil, fmt.Errorf("new request failed: %w", err)
 	}
 	}
-	err = a.SetupRequestHeader(c, req, info)
+	err = a.SetupRequestHeader(c, &req.Header, info)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("setup request header failed: %w", err)
 		return nil, fmt.Errorf("setup request header failed: %w", err)
 	}
 	}
@@ -55,7 +58,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
 	// set form data
 	// set form data
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 
 
-	err = a.SetupRequestHeader(c, req, info)
+	err = a.SetupRequestHeader(c, &req.Header, info)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("setup request header failed: %w", err)
 		return nil, fmt.Errorf("setup request header failed: %w", err)
 	}
 	}
@@ -66,6 +69,27 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
 	return resp, nil
 	return resp, nil
 }
 }
 
 
+func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
+	fullRequestURL, err := a.GetRequestURL(info)
+	if err != nil {
+		return nil, fmt.Errorf("get request url failed: %w", err)
+	}
+	targetHeader := http.Header{}
+	err = a.SetupRequestHeader(c, &targetHeader, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
+	if err != nil {
+		return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
+	}
+	// send request body
+	//all, err := io.ReadAll(requestBody)
+	//err = service.WssString(c, targetConn, string(all))
+	return targetConn, nil
+}
+
 func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
 func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
 	resp, err := service.GetHttpClient().Do(req)
 	resp, err := service.GetHttpClient().Do(req)
 	if err != nil {
 	if err != nil {

+ 3 - 3
relay/channel/aws/adaptor.go

@@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return "", nil
 	return "", nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	return nil
 	return nil
 }
 }
 
 
@@ -59,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
 		err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
 	} else {
 	} else {

+ 4 - 4
relay/channel/baidu/adaptor.go

@@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fullRequestURL, nil
 	return fullRequestURL, nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
 	return nil
 	return nil
 }
 }
 
 
@@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = baiduStreamHandler(c, resp)
 		err, usage = baiduStreamHandler(c, resp)
 	} else {
 	} else {

+ 5 - 5
relay/channel/claude/adaptor.go

@@ -47,14 +47,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	}
 	}
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("x-api-key", info.ApiKey)
+	req.Set("x-api-key", info.ApiKey)
 	anthropicVersion := c.Request.Header.Get("anthropic-version")
 	anthropicVersion := c.Request.Header.Get("anthropic-version")
 	if anthropicVersion == "" {
 	if anthropicVersion == "" {
 		anthropicVersion = "2023-06-01"
 		anthropicVersion = "2023-06-01"
 	}
 	}
-	req.Header.Set("anthropic-version", anthropicVersion)
+	req.Set("anthropic-version", anthropicVersion)
 	return nil
 	return nil
 }
 }
 
 
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
 		err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
 	} else {
 	} else {

+ 1 - 1
relay/channel/claude/relay-claude.go

@@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
 		}, nil
 		}, nil
 	}
 	}
 	fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
 	fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
-	completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
+	completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
 	}
 	}

+ 4 - 4
relay/channel/cloudflare/adaptor.go

@@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	}
 	}
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 	return nil
 	return nil
 }
 }
 
 
@@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	}
 	}
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 	return nil, errors.New("not implemented")
 	return nil, errors.New("not implemented")
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	switch info.RelayMode {
 	switch info.RelayMode {
 	case constant.RelayModeEmbeddings:
 	case constant.RelayModeEmbeddings:
 		fallthrough
 		fallthrough

+ 1 - 1
relay/channel/cloudflare/relay_cloudflare.go

@@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
 
 
 	usage := &dto.Usage{}
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
 	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
+	usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 
 
 	return nil, usage
 	return nil, usage

+ 4 - 4
relay/channel/cohere/adaptor.go

@@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	}
 	}
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 	return nil
 	return nil
 }
 }
 
 
@@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	return requestOpenAI2Cohere(*request), nil
 	return requestOpenAI2Cohere(*request), nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
@@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return requestConvertRerank2Cohere(request), nil
 	return requestConvertRerank2Cohere(request), nil
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
 	if info.RelayMode == constant.RelayModeRerank {
 		err, usage = cohereRerankHandler(c, resp, info)
 		err, usage = cohereRerankHandler(c, resp, info)
 	} else {
 	} else {

+ 4 - 4
relay/channel/dify/adaptor.go

@@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
 	return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
 	return nil
 	return nil
 }
 }
 
 
@@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = difyStreamHandler(c, resp, info)
 		err, usage = difyStreamHandler(c, resp, info)
 	} else {
 	} else {

+ 1 - 1
relay/channel/dify/relay-dify.go

@@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	}
 	}
 	if usage.TotalTokens == 0 {
 	if usage.TotalTokens == 0 {
 		usage.PromptTokens = info.PromptTokens
 		usage.PromptTokens = info.PromptTokens
-		usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
+		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
 		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	}
 	}
 	return nil, usage
 	return nil, usage

+ 4 - 4
relay/channel/gemini/adaptor.go

@@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
 	return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("x-goog-api-key", info.ApiKey)
+	req.Set("x-goog-api-key", info.ApiKey)
 	return nil
 	return nil
 }
 }
 
 
@@ -64,11 +64,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = GeminiChatStreamHandler(c, resp, info)
 		err, usage = GeminiChatStreamHandler(c, resp, info)
 	} else {
 	} else {

+ 4 - 4
relay/channel/jina/adaptor.go

@@ -37,9 +37,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return "", errors.New("invalid relay mode")
 	return "", errors.New("invalid relay mode")
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 	return nil
 	return nil
 }
 }
 
 
@@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	return request, nil
 	return request, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
@@ -55,7 +55,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return request, nil
 	return request, nil
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
 	if info.RelayMode == constant.RelayModeRerank {
 		err, usage = jinaRerankHandler(c, resp)
 		err, usage = jinaRerankHandler(c, resp)
 	} else if info.RelayMode == constant.RelayModeEmbeddings {
 	} else if info.RelayMode == constant.RelayModeEmbeddings {

+ 4 - 4
relay/channel/mistral/adaptor.go

@@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
 	return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
 	return nil
 	return nil
 }
 }
 
 
@@ -50,11 +50,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
 	} else {

+ 3 - 3
relay/channel/ollama/adaptor.go

@@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	}
 	}
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
 	return nil
 	return nil
 }
 }
@@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
 	} else {

+ 39 - 8
relay/channel/openai/adaptor.go

@@ -31,6 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 }
 
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	if info.RelayMode == constant.RelayModeRealtime {
+		// trim https
+		baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
+		baseUrl = strings.TrimPrefix(baseUrl, "http://")
+		baseUrl = "wss://" + baseUrl
+		info.BaseUrl = baseUrl
+	}
 	switch info.ChannelType {
 	switch info.ChannelType {
 	case common.ChannelTypeAzure:
 	case common.ChannelTypeAzure:
 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
@@ -40,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		model_ := info.UpstreamModelName
 		model_ := info.UpstreamModelName
 		model_ = strings.Replace(model_, ".", "", -1)
 		model_ = strings.Replace(model_, ".", "", -1)
 		// https://github.com/songquanpeng/one-api/issues/67
 		// https://github.com/songquanpeng/one-api/issues/67
-
 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+		if info.RelayMode == constant.RelayModeRealtime {
+			requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
+		}
 		return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
 		return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
 	case common.ChannelTypeMiniMax:
 	case common.ChannelTypeMiniMax:
 		return minimax.GetRequestURL(info)
 		return minimax.GetRequestURL(info)
@@ -54,16 +63,34 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	}
 	}
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
-	channel.SetupApiRequestHeader(info, c, req)
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, header)
 	if info.ChannelType == common.ChannelTypeAzure {
 	if info.ChannelType == common.ChannelTypeAzure {
-		req.Header.Set("api-key", info.ApiKey)
+		header.Set("api-key", info.ApiKey)
 		return nil
 		return nil
 	}
 	}
 	if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
 	if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
-		req.Header.Set("OpenAI-Organization", info.Organization)
+		header.Set("OpenAI-Organization", info.Organization)
+	}
+	if info.RelayMode == constant.RelayModeRealtime {
+		swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
+		if swp != "" {
+			items := []string{
+				"realtime",
+				"openai-insecure-api-key." + info.ApiKey,
+				"openai-beta.realtime-v1",
+			}
+			header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
+			//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
+			//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
+			//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
+		} else {
+			header.Set("openai-beta", "realtime=v1")
+			header.Set("Authorization", "Bearer "+info.ApiKey)
+		}
+	} else {
+		header.Set("Authorization", "Bearer "+info.ApiKey)
 	}
 	}
-	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
 	//if info.ChannelType == common.ChannelTypeOpenRouter {
 	//if info.ChannelType == common.ChannelTypeOpenRouter {
 	//	req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
 	//	req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
 	//	req.Header.Set("X-Title", "One API")
 	//	req.Header.Set("X-Title", "One API")
@@ -131,16 +158,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 	return request, nil
 	return request, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
 	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
 		return channel.DoFormRequest(a, c, info, requestBody)
 		return channel.DoFormRequest(a, c, info, requestBody)
+	} else if info.RelayMode == constant.RelayModeRealtime {
+		return channel.DoWssRequest(a, c, info, requestBody)
 	} else {
 	} else {
 		return channel.DoApiRequest(a, c, info, requestBody)
 		return channel.DoApiRequest(a, c, info, requestBody)
 	}
 	}
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	switch info.RelayMode {
 	switch info.RelayMode {
+	case constant.RelayModeRealtime:
+		err, usage = OpenaiRealtimeHandler(c, info)
 	case constant.RelayModeAudioSpeech:
 	case constant.RelayModeAudioSpeech:
 		err, usage = OpenaiTTSHandler(c, resp, info)
 		err, usage = OpenaiTTSHandler(c, resp, info)
 	case constant.RelayModeAudioTranslation:
 	case constant.RelayModeAudioTranslation:

+ 210 - 2
relay/channel/openai/relay-openai.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"fmt"
 	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
@@ -231,7 +232,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
 		completionTokens := 0
 		completionTokens := 0
 		for _, choice := range simpleResponse.Choices {
 		for _, choice := range simpleResponse.Choices {
-			ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
+			ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
 			completionTokens += ctkm
 			completionTokens += ctkm
 		}
 		}
 		simpleResponse.Usage = dto.Usage{
 		simpleResponse.Usage = dto.Usage{
@@ -324,7 +325,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 
 
 	usage := &dto.Usage{}
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
 	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
+	usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return nil, usage
 	return nil, usage
 }
 }
@@ -373,3 +374,210 @@ func getTextFromJSON(body []byte) (string, error) {
 	}
 	}
 	return whisperResponse.Text, nil
 	return whisperResponse.Text, nil
 }
 }
+
+func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
+	info.IsStream = true
+	clientConn := info.ClientWs
+	targetConn := info.TargetWs
+
+	clientClosed := make(chan struct{})
+	targetClosed := make(chan struct{})
+	sendChan := make(chan []byte, 100)
+	receiveChan := make(chan []byte, 100)
+	errChan := make(chan error, 2)
+
+	usage := &dto.RealtimeUsage{}
+	localUsage := &dto.RealtimeUsage{}
+	sumUsage := &dto.RealtimeUsage{}
+
+	gopool.Go(func() {
+		for {
+			select {
+			case <-c.Done():
+				return
+			default:
+				_, message, err := clientConn.ReadMessage()
+				if err != nil {
+					if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
+						errChan <- fmt.Errorf("error reading from client: %v", err)
+					}
+					close(clientClosed)
+					return
+				}
+
+				realtimeEvent := &dto.RealtimeEvent{}
+				err = json.Unmarshal(message, realtimeEvent)
+				if err != nil {
+					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
+					return
+				}
+
+				if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
+					if realtimeEvent.Session != nil {
+						if realtimeEvent.Session.Tools != nil {
+							info.RealtimeTools = realtimeEvent.Session.Tools
+						}
+					}
+				}
+
+				textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+				if err != nil {
+					errChan <- fmt.Errorf("error counting text token: %v", err)
+					return
+				}
+				common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+				localUsage.TotalTokens += textToken + audioToken
+				localUsage.InputTokens += textToken + audioToken
+				localUsage.InputTokenDetails.TextTokens += textToken
+				localUsage.InputTokenDetails.AudioTokens += audioToken
+
+				err = service.WssString(c, targetConn, string(message))
+				if err != nil {
+					errChan <- fmt.Errorf("error writing to target: %v", err)
+					return
+				}
+
+				select {
+				case sendChan <- message:
+				default:
+				}
+			}
+		}
+	})
+
+	gopool.Go(func() {
+		for {
+			select {
+			case <-c.Done():
+				return
+			default:
+				_, message, err := targetConn.ReadMessage()
+				if err != nil {
+					if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
+						errChan <- fmt.Errorf("error reading from target: %v", err)
+					}
+					close(targetClosed)
+					return
+				}
+				info.SetFirstResponseTime()
+				realtimeEvent := &dto.RealtimeEvent{}
+				err = json.Unmarshal(message, realtimeEvent)
+				if err != nil {
+					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
+					return
+				}
+
+				if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
+					realtimeUsage := realtimeEvent.Response.Usage
+					if realtimeUsage != nil {
+						usage.TotalTokens += realtimeUsage.TotalTokens
+						usage.InputTokens += realtimeUsage.InputTokens
+						usage.OutputTokens += realtimeUsage.OutputTokens
+						usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
+						usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
+						usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
+						usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
+						usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
+						err := preConsumeUsage(c, info, usage, sumUsage)
+						if err != nil {
+							errChan <- fmt.Errorf("error consume usage: %v", err)
+							return
+						}
+						// 本次计费完成,清除
+						usage = &dto.RealtimeUsage{}
+
+						localUsage = &dto.RealtimeUsage{}
+					} else {
+						textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+						if err != nil {
+							errChan <- fmt.Errorf("error counting text token: %v", err)
+							return
+						}
+						common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+						localUsage.TotalTokens += textToken + audioToken
+						info.IsFirstRequest = false
+						localUsage.InputTokens += textToken + audioToken
+						localUsage.InputTokenDetails.TextTokens += textToken
+						localUsage.InputTokenDetails.AudioTokens += audioToken
+						err = preConsumeUsage(c, info, localUsage, sumUsage)
+						if err != nil {
+							errChan <- fmt.Errorf("error consume usage: %v", err)
+							return
+						}
+						// 本次计费完成,清除
+						localUsage = &dto.RealtimeUsage{}
+						// print now usage
+					}
+					//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
+					//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+					//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+
+				} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
+					realtimeSession := realtimeEvent.Session
+					if realtimeSession != nil {
+						// update audio format
+						info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
+						info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
+					}
+				} else {
+					textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+					if err != nil {
+						errChan <- fmt.Errorf("error counting text token: %v", err)
+						return
+					}
+					common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+					localUsage.TotalTokens += textToken + audioToken
+					localUsage.OutputTokens += textToken + audioToken
+					localUsage.OutputTokenDetails.TextTokens += textToken
+					localUsage.OutputTokenDetails.AudioTokens += audioToken
+				}
+
+				err = service.WssString(c, clientConn, string(message))
+				if err != nil {
+					errChan <- fmt.Errorf("error writing to client: %v", err)
+					return
+				}
+
+				select {
+				case receiveChan <- message:
+				default:
+				}
+			}
+		}
+	})
+
+	select {
+	case <-clientClosed:
+	case <-targetClosed:
+	case err := <-errChan:
+		//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
+		common.LogError(c, "realtime error: "+err.Error())
+	case <-c.Done():
+	}
+
+	if usage.TotalTokens != 0 {
+		_ = preConsumeUsage(c, info, usage, sumUsage)
+	}
+
+	if localUsage.TotalTokens != 0 {
+		_ = preConsumeUsage(c, info, localUsage, sumUsage)
+	}
+
+	// check usage total tokens, if 0, use local usage
+
+	return nil, sumUsage
+}
+
+func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
+	totalUsage.TotalTokens += usage.TotalTokens
+	totalUsage.InputTokens += usage.InputTokens
+	totalUsage.OutputTokens += usage.OutputTokens
+	totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
+	totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
+	totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
+	totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
+	totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
+	// clear usage
+	err := service.PreWssConsumeQuota(ctx, info, usage)
+	return err
+}

+ 4 - 4
relay/channel/palm/adaptor.go

@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
 	return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("x-goog-api-key", info.ApiKey)
+	req.Set("x-goog-api-key", info.ApiKey)
 	return nil
 	return nil
 }
 }
 
 
@@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		var responseText string
 		var responseText string
 		err, responseText = palmStreamHandler(c, resp)
 		err, responseText = palmStreamHandler(c, resp)

+ 1 - 1
relay/channel/palm/relay-palm.go

@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 		}, nil
 		}, nil
 	}
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
+	completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
 	usage := dto.Usage{
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,
 		CompletionTokens: completionTokens,

+ 4 - 4
relay/channel/perplexity/adaptor.go

@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
 	return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
 	return nil
 	return nil
 }
 }
 
 
@@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
 	} else {

+ 4 - 4
relay/channel/siliconflow/adaptor.go

@@ -40,9 +40,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return "", errors.New("invalid relay mode")
 	return "", errors.New("invalid relay mode")
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+	req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
 	return nil
 	return nil
 }
 }
 
 
@@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	return request, nil
 	return request, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
@@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return request, nil
 	return request, nil
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	switch info.RelayMode {
 	switch info.RelayMode {
 	case constant.RelayModeRerank:
 	case constant.RelayModeRerank:
 		err, usage = siliconflowRerankHandler(c, resp)
 		err, usage = siliconflowRerankHandler(c, resp)

+ 7 - 7
relay/channel/tencent/adaptor.go

@@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fmt.Sprintf("%s/", info.BaseUrl), nil
 	return fmt.Sprintf("%s/", info.BaseUrl), nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
-	req.Header.Set("Authorization", a.Sign)
-	req.Header.Set("X-TC-Action", a.Action)
-	req.Header.Set("X-TC-Version", a.Version)
-	req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
+	req.Set("Authorization", a.Sign)
+	req.Set("X-TC-Action", a.Action)
+	req.Set("X-TC-Version", a.Version)
+	req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
 	return nil
 	return nil
 }
 }
 
 
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		var responseText string
 		var responseText string
 		err, responseText = tencentStreamHandler(c, resp)
 		err, responseText = tencentStreamHandler(c, resp)

+ 4 - 4
relay/channel/vertex/adaptor.go

@@ -107,13 +107,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return "", errors.New("unsupported request mode")
 	return "", errors.New("unsupported request mode")
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
 	accessToken, err := getAccessToken(a, info)
 	accessToken, err := getAccessToken(a, info)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	req.Header.Set("Authorization", "Bearer "+accessToken)
+	req.Set("Authorization", "Bearer "+accessToken)
 	return nil
 	return nil
 }
 }
 
 
@@ -148,11 +148,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		switch a.RequestMode {
 		switch a.RequestMode {
 		case RequestModeClaude:
 		case RequestModeClaude:

+ 3 - 3
relay/channel/xunfei/adaptor.go

@@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return "", nil
 	return "", nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
 	return nil
 	return nil
 }
 }
@@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	// xunfei's request is not http request, so we don't need to do anything here
 	// xunfei's request is not http request, so we don't need to do anything here
 	dummyResp := &http.Response{}
 	dummyResp := &http.Response{}
 	dummyResp.StatusCode = http.StatusOK
 	dummyResp.StatusCode = http.StatusOK
 	return dummyResp, nil
 	return dummyResp, nil
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	splits := strings.Split(info.ApiKey, "|")
 	splits := strings.Split(info.ApiKey, "|")
 	if len(splits) != 3 {
 	if len(splits) != 3 {
 		return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 		return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)

+ 4 - 4
relay/channel/zhipu/adaptor.go

@@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
 	return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
 	token := getZhipuToken(info.ApiKey)
 	token := getZhipuToken(info.ApiKey)
-	req.Header.Set("Authorization", token)
+	req.Set("Authorization", token)
 	return nil
 	return nil
 }
 }
 
 
@@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = zhipuStreamHandler(c, resp)
 		err, usage = zhipuStreamHandler(c, resp)
 	} else {
 	} else {

+ 4 - 4
relay/channel/zhipu_4v/adaptor.go

@@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
 	return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
 }
 }
 
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
 	channel.SetupApiRequestHeader(info, c, req)
 	channel.SetupApiRequestHeader(info, c, req)
 	token := getZhipuToken(info.ApiKey)
 	token := getZhipuToken(info.ApiKey)
-	req.Header.Set("Authorization", token)
+	req.Set("Authorization", token)
 	return nil
 	return nil
 }
 }
 
 
@@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 	return channel.DoApiRequest(a, c, info, requestBody)
 }
 }
 
 
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 	if info.IsStream {
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 		err, usage = openai.OaiStreamHandler(c, resp, info)
 	} else {
 	} else {

+ 18 - 0
relay/common/relay_info.go

@@ -2,7 +2,9 @@ package common
 
 
 import (
 import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 	"one-api/common"
 	"one-api/common"
+	"one-api/dto"
 	"one-api/relay/constant"
 	"one-api/relay/constant"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -21,6 +23,7 @@ type RelayInfo struct {
 	ApiType              int
 	ApiType              int
 	IsStream             bool
 	IsStream             bool
 	IsPlayground         bool
 	IsPlayground         bool
+	UsePrice             bool
 	RelayMode            int
 	RelayMode            int
 	UpstreamModelName    string
 	UpstreamModelName    string
 	OriginModelName      string
 	OriginModelName      string
@@ -32,6 +35,21 @@ type RelayInfo struct {
 	BaseUrl              string
 	BaseUrl              string
 	SupportStreamOptions bool
 	SupportStreamOptions bool
 	ShouldIncludeUsage   bool
 	ShouldIncludeUsage   bool
+	ClientWs             *websocket.Conn
+	TargetWs             *websocket.Conn
+	InputAudioFormat     string
+	OutputAudioFormat    string
+	RealtimeTools        []dto.RealTimeTool
+	IsFirstRequest       bool
+}
+
+func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
+	info := GenRelayInfo(c)
+	info.ClientWs = ws
+	info.InputAudioFormat = "pcm16"
+	info.OutputAudioFormat = "pcm16"
+	info.IsFirstRequest = true
+	return info
 }
 }
 
 
 func GenRelayInfo(c *gin.Context) *RelayInfo {
 func GenRelayInfo(c *gin.Context) *RelayInfo {

+ 4 - 0
relay/constant/relay_mode.go

@@ -38,6 +38,8 @@ const (
 	RelayModeSunoSubmit
 	RelayModeSunoSubmit
 
 
 	RelayModeRerank
 	RelayModeRerank
+
+	RelayModeRealtime
 )
 )
 
 
 func Path2RelayMode(path string) int {
 func Path2RelayMode(path string) int {
@@ -64,6 +66,8 @@ func Path2RelayMode(path string) int {
 		relayMode = RelayModeAudioTranslation
 		relayMode = RelayModeAudioTranslation
 	} else if strings.HasPrefix(path, "/v1/rerank") {
 	} else if strings.HasPrefix(path, "/v1/rerank") {
 		relayMode = RelayModeRerank
 		relayMode = RelayModeRerank
+	} else if strings.HasPrefix(path, "/v1/realtime") {
+		relayMode = RelayModeRealtime
 	}
 	}
 	return relayMode
 	return relayMode
 }
 }

+ 14 - 8
relay/relay-audio.go

@@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 	return audioRequest, nil
 	return audioRequest, nil
 }
 }
 
 
-func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
+func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	relayInfo := relaycommon.GenRelayInfo(c)
 	relayInfo := relaycommon.GenRelayInfo(c)
 	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
 	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
 
 
@@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	promptTokens := 0
 	promptTokens := 0
 	preConsumedTokens := common.PreConsumedQuota
 	preConsumedTokens := common.PreConsumedQuota
 	if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
 	if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
-		promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
+		promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
 		if err != nil {
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
 			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
 		}
 		}
@@ -92,6 +92,11 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 			return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 			return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 		}
 	}
 	}
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
 
 
 	// map model name
 	// map model name
 	modelMapping := c.GetString("model_mapping")
 	modelMapping := c.GetString("model_mapping")
@@ -122,19 +127,20 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
 	}
-
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	statusCodeMappingStr := c.GetString("status_code_mapping")
+
+	var httpResp *http.Response
 	if resp != nil {
 	if resp != nil {
-		if resp.StatusCode != http.StatusOK {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-			openaiErr := service.RelayErrorHandler(resp)
+		httpResp = resp.(*http.Response)
+		if httpResp.StatusCode != http.StatusOK {
+			openaiErr = service.RelayErrorHandler(httpResp)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			return openaiErr
 			return openaiErr
 		}
 		}
 	}
 	}
 
 
-	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
+	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	if openaiErr != nil {
 	if openaiErr != nil {
 		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 		// reset status code 重置状态码
 		// reset status code 重置状态码
@@ -142,7 +148,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		return openaiErr
 		return openaiErr
 	}
 	}
 
 
-	postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
+	postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
 
 
 	return nil
 	return nil
 }
 }

+ 7 - 5
relay/relay-image.go

@@ -149,22 +149,24 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	requestBody = bytes.NewBuffer(jsonData)
 	requestBody = bytes.NewBuffer(jsonData)
 
 
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	statusCodeMappingStr := c.GetString("status_code_mapping")
+
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
 	}
-
+	var httpResp *http.Response
 	if resp != nil {
 	if resp != nil {
-		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
-		if resp.StatusCode != http.StatusOK {
-			openaiErr := service.RelayErrorHandler(resp)
+		httpResp = resp.(*http.Response)
+		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+		if httpResp.StatusCode != http.StatusOK {
+			openaiErr := service.RelayErrorHandler(httpResp)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			return openaiErr
 			return openaiErr
 		}
 		}
 	}
 	}
 
 
-	_, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
+	_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	if openaiErr != nil {
 	if openaiErr != nil {
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)

+ 13 - 9
relay/relay-text.go

@@ -64,7 +64,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
 	return textRequest, nil
 	return textRequest, nil
 }
 }
 
 
-func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
+func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 
 
 	relayInfo := relaycommon.GenRelayInfo(c)
 	relayInfo := relaycommon.GenRelayInfo(c)
 
 
@@ -131,7 +131,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	if openaiErr != nil {
 	if openaiErr != nil {
 		return openaiErr
 		return openaiErr
 	}
 	}
-
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
 	includeUsage := false
 	includeUsage := false
 	// 判断用户是否需要返回使用情况
 	// 判断用户是否需要返回使用情况
 	if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
 	if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
@@ -180,30 +184,30 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	}
 	}
 
 
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	statusCodeMappingStr := c.GetString("status_code_mapping")
+	var httpResp *http.Response
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
 	}
 
 
 	if resp != nil {
 	if resp != nil {
-		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
-		if resp.StatusCode != http.StatusOK {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-			openaiErr := service.RelayErrorHandler(resp)
+		httpResp = resp.(*http.Response)
+		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+		if httpResp.StatusCode != http.StatusOK {
+			openaiErr = service.RelayErrorHandler(httpResp)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			return openaiErr
 			return openaiErr
 		}
 		}
 	}
 	}
 
 
-	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
+	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	if openaiErr != nil {
 	if openaiErr != nil {
-		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 		return openaiErr
 	}
 	}
-	postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+	postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
 	return nil
 	return nil
 }
 }
 
 

+ 14 - 7
relay/relay_rerank.go

@@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
 	return token
 	return token
 }
 }
 
 
-func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
+func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	relayInfo := relaycommon.GenRelayInfo(c)
 	relayInfo := relaycommon.GenRelayInfo(c)
 
 
 	var rerankRequest *dto.RerankRequest
 	var rerankRequest *dto.RerankRequest
@@ -79,6 +79,12 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	if openaiErr != nil {
 	if openaiErr != nil {
 		return openaiErr
 		return openaiErr
 	}
 	}
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
+
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
@@ -99,23 +105,24 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
 	}
+
+	var httpResp *http.Response
 	if resp != nil {
 	if resp != nil {
-		if resp.StatusCode != http.StatusOK {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-			openaiErr := service.RelayErrorHandler(resp)
+		httpResp = resp.(*http.Response)
+		if httpResp.StatusCode != http.StatusOK {
+			openaiErr = service.RelayErrorHandler(httpResp)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 			return openaiErr
 			return openaiErr
 		}
 		}
 	}
 	}
 
 
-	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
+	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	if openaiErr != nil {
 	if openaiErr != nil {
-		returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 		return openaiErr
 	}
 	}
-	postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
+	postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
 	return nil
 	return nil
 }
 }

+ 159 - 0
relay/websocket.go

@@ -0,0 +1,159 @@
+package relay
+
+import (
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
+//	_, p, err := ws.ReadMessage()
+//	if err != nil {
+//		return nil, err
+//	}
+//	realtimeEvent := &dto.RealtimeEvent{}
+//	err = json.Unmarshal(p, realtimeEvent)
+//	if err != nil {
+//		return nil, err
+//	}
+//	// save the original request
+//	if realtimeEvent.Session == nil {
+//		return nil, errors.New("session object is nil")
+//	}
+//	c.Set("first_wss_request", p)
+//	return realtimeEvent, nil
+//}
+
+func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
+	relayInfo := relaycommon.GenRelayInfoWs(c, ws)
+
+	// get & validate textRequest 获取并验证文本请求
+	//realtimeEvent, err := getAndValidateWssRequest(c, ws)
+	//if err != nil {
+	//	common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
+	//	return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
+	//}
+
+	// map model name
+	modelMapping := c.GetString("model_mapping")
+	//isModelMapped := false
+	if modelMapping != "" && modelMapping != "{}" {
+		modelMap := make(map[string]string)
+		err := json.Unmarshal([]byte(modelMapping), &modelMap)
+		if err != nil {
+			return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+		}
+		if modelMap[relayInfo.OriginModelName] != "" {
+			relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
+			// set upstream model name
+			//isModelMapped = true
+		}
+	}
+	//relayInfo.UpstreamModelName = textRequest.Model
+	modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
+	groupRatio := common.GetGroupRatio(relayInfo.Group)
+
+	var preConsumedQuota int
+	var ratio float64
+	var modelRatio float64
+	//err := service.SensitiveWordsCheck(textRequest)
+
+	//if constant.ShouldCheckPromptSensitive() {
+	//	err = checkRequestSensitive(textRequest, relayInfo)
+	//	if err != nil {
+	//		return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
+	//	}
+	//}
+
+	//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
+	//// count messages token error 计算promptTokens错误
+	//if err != nil {
+	//	return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
+	//}
+	//
+	if !getModelPriceSuccess {
+		preConsumedTokens := common.PreConsumedQuota
+		//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
+		//	preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
+		//}
+		modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName)
+		ratio = modelRatio * groupRatio
+		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+	} else {
+		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+		relayInfo.UsePrice = true
+	}
+
+	// pre-consume quota 预消耗配额
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+	if openaiErr != nil {
+		return openaiErr
+	}
+
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
+
+	adaptor := GetAdaptor(relayInfo.ApiType)
+	if adaptor == nil {
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+	}
+	adaptor.Init(relayInfo)
+	//var requestBody io.Reader
+	//firstWssRequest, _ := c.Get("first_wss_request")
+	//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
+
+	statusCodeMappingStr := c.GetString("status_code_mapping")
+	resp, err := adaptor.DoRequest(c, relayInfo, nil)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+	}
+
+	if resp != nil {
+		relayInfo.TargetWs = resp.(*websocket.Conn)
+		defer relayInfo.TargetWs.Close()
+	}
+
+	usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
+	if openaiErr != nil {
+		// reset status code 重置状态码
+		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+		return openaiErr
+	}
+	service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+	return nil
+}
+
+//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
+//	var promptTokens int
+//	var err error
+//	switch info.RelayMode {
+//	default:
+//		promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
+//	}
+//	info.PromptTokens = promptTokens
+//	return promptTokens, err
+//}
+
+//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
+//	var err error
+//	switch info.RelayMode {
+//	case relayconstant.RelayModeChatCompletions:
+//		err = service.CheckSensitiveMessages(textRequest.Messages)
+//	case relayconstant.RelayModeCompletions:
+//		err = service.CheckSensitiveInput(textRequest.Prompt)
+//	case relayconstant.RelayModeModerations:
+//		err = service.CheckSensitiveInput(textRequest.Input)
+//	case relayconstant.RelayModeEmbeddings:
+//		err = service.CheckSensitiveInput(textRequest.Input)
+//	}
+//	return err
+//}

+ 34 - 25
router/relay-router.go

@@ -22,32 +22,41 @@ func SetRelayRouter(router *gin.Engine) {
 		playgroundRouter.POST("/chat/completions", controller.Playground)
 		playgroundRouter.POST("/chat/completions", controller.Playground)
 	}
 	}
 	relayV1Router := router.Group("/v1")
 	relayV1Router := router.Group("/v1")
-	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
+	relayV1Router.Use(middleware.TokenAuth())
 	{
 	{
-		relayV1Router.POST("/completions", controller.Relay)
-		relayV1Router.POST("/chat/completions", controller.Relay)
-		relayV1Router.POST("/edits", controller.Relay)
-		relayV1Router.POST("/images/generations", controller.Relay)
-		relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
-		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
-		relayV1Router.POST("/embeddings", controller.Relay)
-		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
-		relayV1Router.POST("/audio/transcriptions", controller.Relay)
-		relayV1Router.POST("/audio/translations", controller.Relay)
-		relayV1Router.POST("/audio/speech", controller.Relay)
-		relayV1Router.GET("/files", controller.RelayNotImplemented)
-		relayV1Router.POST("/files", controller.RelayNotImplemented)
-		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
-		relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
-		relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
-		relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
-		relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
-		relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
-		relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
-		relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
-		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
-		relayV1Router.POST("/moderations", controller.Relay)
-		relayV1Router.POST("/rerank", controller.Relay)
+		// WebSocket 路由
+		wsRouter := relayV1Router.Group("")
+		wsRouter.Use(middleware.Distribute())
+		wsRouter.GET("/realtime", controller.WssRelay)
+	}
+	{
+		//http router
+		httpRouter := relayV1Router.Group("")
+		httpRouter.Use(middleware.Distribute())
+		httpRouter.POST("/completions", controller.Relay)
+		httpRouter.POST("/chat/completions", controller.Relay)
+		httpRouter.POST("/edits", controller.Relay)
+		httpRouter.POST("/images/generations", controller.Relay)
+		httpRouter.POST("/images/edits", controller.RelayNotImplemented)
+		httpRouter.POST("/images/variations", controller.RelayNotImplemented)
+		httpRouter.POST("/embeddings", controller.Relay)
+		httpRouter.POST("/engines/:model/embeddings", controller.Relay)
+		httpRouter.POST("/audio/transcriptions", controller.Relay)
+		httpRouter.POST("/audio/translations", controller.Relay)
+		httpRouter.POST("/audio/speech", controller.Relay)
+		httpRouter.GET("/files", controller.RelayNotImplemented)
+		httpRouter.POST("/files", controller.RelayNotImplemented)
+		httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
+		httpRouter.GET("/files/:id", controller.RelayNotImplemented)
+		httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
+		httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
+		httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
+		httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
+		httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
+		httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
+		httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
+		httpRouter.POST("/moderations", controller.Relay)
+		httpRouter.POST("/rerank", controller.Relay)
 	}
 	}
 
 
 	relayMjRouter := router.Group("/mj")
 	relayMjRouter := router.Group("/mj")

+ 31 - 0
service/audio.go

@@ -0,0 +1,31 @@
+package service
+
+import (
+	"encoding/base64"
+	"fmt"
+)
+
+func parseAudio(audioBase64 string, format string) (duration float64, err error) {
+	audioData, err := base64.StdEncoding.DecodeString(audioBase64)
+	if err != nil {
+		return 0, fmt.Errorf("base64 decode error: %v", err)
+	}
+
+	var samplesCount int
+	var sampleRate int
+
+	switch format {
+	case "pcm16":
+		samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
+		sampleRate = 24000                // 24kHz
+	case "g711_ulaw", "g711_alaw":
+		samplesCount = len(audioData) // 8位 = 1字节每样本
+		sampleRate = 8000             // 8kHz
+	default:
+		samplesCount = len(audioData) // 8位 = 1字节每样本
+		sampleRate = 8000             // 8kHz
+	}
+
+	duration = float64(samplesCount) / float64(sampleRate)
+	return duration, nil
+}

+ 11 - 0
service/log.go

@@ -2,6 +2,7 @@ package service
 
 
 import (
 import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
+	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 )
 )
 
 
@@ -17,3 +18,13 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
 	other["admin_info"] = adminInfo
 	other["admin_info"] = adminInfo
 	return other
 	return other
 }
 }
+
+func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
+	info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
+	info["ws"] = true
+	info["audio_input"] = usage.InputTokenDetails.AudioTokens
+	info["audio_output"] = usage.OutputTokenDetails.AudioTokens
+	info["text_input"] = usage.InputTokenDetails.TextTokens
+	info["text_output"] = usage.OutputTokenDetails.TextTokens
+	return info
+}

+ 140 - 0
service/quota.go

@@ -0,0 +1,140 @@
+package service
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"math"
+	"one-api/common"
+	"one-api/dto"
+	"one-api/model"
+	relaycommon "one-api/relay/common"
+	"strings"
+	"time"
+)
+
+func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
+	if relayInfo.UsePrice {
+		return nil
+	}
+	userQuota, err := model.GetUserQuota(relayInfo.UserId)
+	if err != nil {
+		return err
+	}
+	modelName := relayInfo.UpstreamModelName
+	textInputTokens := usage.InputTokenDetails.TextTokens
+	textOutTokens := usage.OutputTokenDetails.TextTokens
+	audioInputTokens := usage.InputTokenDetails.AudioTokens
+	audioOutTokens := usage.OutputTokenDetails.AudioTokens
+
+	completionRatio := common.GetCompletionRatio(modelName)
+	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
+	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
+	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	modelRatio := common.GetModelRatio(modelName)
+
+	ratio := groupRatio * modelRatio
+
+	quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
+	quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
+
+	quota = int(math.Round(float64(quota) * ratio))
+	if ratio != 0 && quota <= 0 {
+		quota = 1
+	}
+
+	if userQuota < quota {
+		return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
+	}
+
+	err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
+	if err != nil {
+		return err
+	}
+	common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
+	err = model.CacheUpdateUserQuota(relayInfo.UserId)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
+	usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
+	groupRatio float64,
+	modelPrice float64, usePrice bool, extraContent string) {
+
+	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
+	textInputTokens := usage.InputTokenDetails.TextTokens
+	textOutTokens := usage.OutputTokenDetails.TextTokens
+
+	audioInputTokens := usage.InputTokenDetails.AudioTokens
+	audioOutTokens := usage.OutputTokenDetails.AudioTokens
+
+	tokenName := ctx.GetString("token_name")
+	completionRatio := common.GetCompletionRatio(modelName)
+	audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
+	audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
+
+	quota := 0
+	if !usePrice {
+		quota = int(math.Round(float64(textInputTokens)*ratio + float64(textOutTokens)*ratio*completionRatio))
+		quota += int(math.Round(float64(audioInputTokens)*ratio*audioRatio + float64(audioOutTokens)*ratio*audioRatio*audioCompletionRatio))
+		if ratio != 0 && quota <= 0 {
+			quota = 1
+		}
+	} else {
+		quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+	}
+	totalTokens := usage.TotalTokens
+	var logContent string
+	if !usePrice {
+		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
+	} else {
+		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
+	}
+
+	// record all the consume log even if quota is 0
+	if totalTokens == 0 {
+		// in this case, must be some error happened
+		// we cannot just return, because we may have to return the pre-consumed quota
+		quota = 0
+		logContent += fmt.Sprintf("(可能是上游超时)")
+		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+	} else {
+		//if sensitiveResp != nil {
+		//	logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
+		//}
+		//quotaDelta := quota - preConsumedQuota
+		//if quotaDelta != 0 {
+		//	err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
+		//	if err != nil {
+		//		common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+		//	}
+		//}
+
+		//err := model.CacheUpdateUserQuota(relayInfo.UserId)
+		//if err != nil {
+		//	common.LogError(ctx, "error update user quota cache: "+err.Error())
+		//}
+		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+	}
+
+	logModel := modelName
+	if strings.HasPrefix(logModel, "gpt-4-gizmo") {
+		logModel = "gpt-4-gizmo-*"
+		logContent += fmt.Sprintf(",模型 %s", modelName)
+	}
+	if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
+		logModel = "gpt-4o-gizmo-*"
+		logContent += fmt.Sprintf(",模型 %s", modelName)
+	}
+	if extraContent != "" {
+		logContent += ", " + extraContent
+	}
+	other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice)
+	model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
+		tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
+}

+ 38 - 1
service/relay.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
@@ -42,11 +43,47 @@ func Done(c *gin.Context) {
 	_ = StringData(c, "[DONE]")
 	_ = StringData(c, "[DONE]")
 }
 }
 
 
+func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
+	if ws == nil {
+		common.LogError(c, "websocket connection is nil")
+		return errors.New("websocket connection is nil")
+	}
+	//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
+	return ws.WriteMessage(1, []byte(str))
+}
+
+func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
+	jsonData, err := json.Marshal(object)
+	if err != nil {
+		return fmt.Errorf("error marshalling object: %w", err)
+	}
+	if ws == nil {
+		common.LogError(c, "websocket connection is nil")
+		return errors.New("websocket connection is nil")
+	}
+	//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
+	return ws.WriteMessage(1, jsonData)
+}
+
+func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
+	errorObj := &dto.RealtimeEvent{
+		Type:    "error",
+		EventId: GetLocalRealtimeID(c),
+		Error:   &openaiError,
+	}
+	_ = WssObject(c, ws, errorObj)
+}
+
 func GetResponseID(c *gin.Context) string {
 func GetResponseID(c *gin.Context) string {
-	logID := c.GetString("X-Oneapi-Request-Id")
+	logID := c.GetString(common.RequestIdKey)
 	return fmt.Sprintf("chatcmpl-%s", logID)
 	return fmt.Sprintf("chatcmpl-%s", logID)
 }
 }
 
 
+func GetLocalRealtimeID(c *gin.Context) string {
+	logID := c.GetString(common.RequestIdKey)
+	return fmt.Sprintf("evt_%s", logID)
+}
+
 func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
 func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
 	return &dto.ChatCompletionsStreamResponse{
 	return &dto.ChatCompletionsStreamResponse{
 		Id:                id,
 		Id:                id,

+ 101 - 6
service/token_counter.go

@@ -11,6 +11,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	relaycommon "one-api/relay/common"
 	"strings"
 	"strings"
 	"unicode/utf8"
 	"unicode/utf8"
 )
 )
@@ -191,6 +192,72 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
 	return tkm, nil
 	return tkm, nil
 }
 }
 
 
+func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
+	audioToken := 0
+	textToken := 0
+	switch request.Type {
+	case dto.RealtimeEventTypeSessionUpdate:
+		if request.Session != nil {
+			msgTokens, err := CountTextToken(request.Session.Instructions, model)
+			if err != nil {
+				return 0, 0, err
+			}
+			textToken += msgTokens
+		}
+	case dto.RealtimeEventResponseAudioDelta:
+		// count audio token
+		atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
+		if err != nil {
+			return 0, 0, fmt.Errorf("error counting audio token: %v", err)
+		}
+		audioToken += atk
+	case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
+		// count text token
+		tkm, err := CountTextToken(request.Delta, model)
+		if err != nil {
+			return 0, 0, fmt.Errorf("error counting text token: %v", err)
+		}
+		textToken += tkm
+	case dto.RealtimeEventInputAudioBufferAppend:
+		// count audio token
+		atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
+		if err != nil {
+			return 0, 0, fmt.Errorf("error counting audio token: %v", err)
+		}
+		audioToken += atk
+	case dto.RealtimeEventConversationItemCreated:
+		if request.Item != nil {
+			switch request.Item.Type {
+			case "message":
+				for _, content := range request.Item.Content {
+					if content.Type == "input_text" {
+						tokens, err := CountTextToken(content.Text, model)
+						if err != nil {
+							return 0, 0, err
+						}
+						textToken += tokens
+					}
+				}
+			}
+		}
+	case dto.RealtimeEventTypeResponseDone:
+		// count tools token
+		if !info.IsFirstRequest {
+			if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
+				for _, tool := range info.RealtimeTools {
+					toolTokens, err := CountTokenInput(tool, model)
+					if err != nil {
+						return 0, 0, err
+					}
+					textToken += 8
+					textToken += toolTokens
+				}
+			}
+		}
+	}
+	return textToken, audioToken, nil
+}
+
 func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
 func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
 	//recover when panic
 	//recover when panic
 	tokenEncoder := getTokenEncoder(model)
 	tokenEncoder := getTokenEncoder(model)
@@ -248,13 +315,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
 func CountTokenInput(input any, model string) (int, error) {
 func CountTokenInput(input any, model string) (int, error) {
 	switch v := input.(type) {
 	switch v := input.(type) {
 	case string:
 	case string:
-		return CountTokenText(v, model)
+		return CountTextToken(v, model)
 	case []string:
 	case []string:
 		text := ""
 		text := ""
 		for _, s := range v {
 		for _, s := range v {
 			text += s
 			text += s
 		}
 		}
-		return CountTokenText(text, model)
+		return CountTextToken(text, model)
 	}
 	}
 	return CountTokenInput(fmt.Sprintf("%v", input), model)
 	return CountTokenInput(fmt.Sprintf("%v", input), model)
 }
 }
@@ -276,16 +343,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
 	return tokens
 	return tokens
 }
 }
 
 
-func CountAudioToken(text string, model string) (int, error) {
+func CountTTSToken(text string, model string) (int, error) {
 	if strings.HasPrefix(model, "tts") {
 	if strings.HasPrefix(model, "tts") {
 		return utf8.RuneCountInString(text), nil
 		return utf8.RuneCountInString(text), nil
 	} else {
 	} else {
-		return CountTokenText(text, model)
+		return CountTextToken(text, model)
 	}
 	}
 }
 }
 
 
-// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTokenText(text string, model string) (int, error) {
+func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
+	if audioBase64 == "" {
+		return 0, nil
+	}
+	duration, err := parseAudio(audioBase64, audioFormat)
+	if err != nil {
+		return 0, err
+	}
+	return int(duration / 60 * 100 / 0.06), nil
+}
+
+func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
+	if audioBase64 == "" {
+		return 0, nil
+	}
+	duration, err := parseAudio(audioBase64, audioFormat)
+	if err != nil {
+		return 0, err
+	}
+	return int(duration / 60 * 200 / 0.24), nil
+}
+
+//func CountAudioToken(sec float64, audioType string) {
+//	if audioType == "input" {
+//
+//	}
+//}
+
+// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
+func CountTextToken(text string, model string) (int, error) {
 	var err error
 	var err error
 	tokenEncoder := getTokenEncoder(model)
 	tokenEncoder := getTokenEncoder(model)
 	return getTokenNum(tokenEncoder, text), err
 	return getTokenNum(tokenEncoder, text), err

+ 1 - 1
service/usage_helpr.go

@@ -19,7 +19,7 @@ import (
 func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
 func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
 	usage := &dto.Usage{}
 	usage := &dto.Usage{}
 	usage.PromptTokens = promptTokens
 	usage.PromptTokens = promptTokens
-	ctkm, err := CountTokenText(responseText, modeName)
+	ctkm, err := CountTextToken(responseText, modeName)
 	usage.CompletionTokens = ctkm
 	usage.CompletionTokens = ctkm
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return usage, err
 	return usage, err

+ 80 - 29
web/src/components/LogsTable.js

@@ -11,7 +11,7 @@ import {
 
 
 import {
 import {
   Avatar,
   Avatar,
-  Button,
+  Button, Descriptions,
   Form,
   Form,
   Layout,
   Layout,
   Modal,
   Modal,
@@ -20,7 +20,7 @@ import {
   Spin,
   Spin,
   Table,
   Table,
   Tag,
   Tag,
-  Tooltip,
+  Tooltip
 } from '@douyinfe/semi-ui';
 } from '@douyinfe/semi-ui';
 import { ITEMS_PER_PAGE } from '../constants';
 import { ITEMS_PER_PAGE } from '../constants';
 import {
 import {
@@ -336,33 +336,33 @@ const LogsTable = () => {
         );
         );
       },
       },
     },
     },
-    {
-      title: '重试',
-      dataIndex: 'retry',
-      className: isAdmin() ? 'tableShow' : 'tableHiddle',
-      render: (text, record, index) => {
-        let content = '渠道:' + record.channel;
-        if (record.other !== '') {
-          let other = JSON.parse(record.other);
-          if (other === null) {
-            return <></>;
-          }
-          if (other.admin_info !== undefined) {
-            if (
-              other.admin_info.use_channel !== null &&
-              other.admin_info.use_channel !== undefined &&
-              other.admin_info.use_channel !== ''
-            ) {
-              // channel id array
-              let useChannel = other.admin_info.use_channel;
-              let useChannelStr = useChannel.join('->');
-              content = `渠道:${useChannelStr}`;
-            }
-          }
-        }
-        return isAdminUser ? <div>{content}</div> : <></>;
-      },
-    },
+    // {
+    //   title: '重试',
+    //   dataIndex: 'retry',
+    //   className: isAdmin() ? 'tableShow' : 'tableHiddle',
+    //   render: (text, record, index) => {
+    //     let content = '渠道:' + record.channel;
+    //     if (record.other !== '') {
+    //       let other = JSON.parse(record.other);
+    //       if (other === null) {
+    //         return <></>;
+    //       }
+    //       if (other.admin_info !== undefined) {
+    //         if (
+    //           other.admin_info.use_channel !== null &&
+    //           other.admin_info.use_channel !== undefined &&
+    //           other.admin_info.use_channel !== ''
+    //         ) {
+    //           // channel id array
+    //           let useChannel = other.admin_info.use_channel;
+    //           let useChannelStr = useChannel.join('->');
+    //           content = `渠道:${useChannelStr}`;
+    //         }
+    //       }
+    //     }
+    //     return isAdminUser ? <div>{content}</div> : <></>;
+    //   },
+    // },
     {
     {
       title: '详情',
       title: '详情',
       dataIndex: 'content',
       dataIndex: 'content',
@@ -409,6 +409,7 @@ const LogsTable = () => {
   ];
   ];
 
 
   const [logs, setLogs] = useState([]);
   const [logs, setLogs] = useState([]);
+  const [expandData, setExpandData] = useState({});
   const [showStat, setShowStat] = useState(false);
   const [showStat, setShowStat] = useState(false);
   const [loading, setLoading] = useState(false);
   const [loading, setLoading] = useState(false);
   const [loadingStat, setLoadingStat] = useState(false);
   const [loadingStat, setLoadingStat] = useState(false);
@@ -512,10 +513,54 @@ const LogsTable = () => {
   };
   };
 
 
   const setLogsFormat = (logs) => {
   const setLogsFormat = (logs) => {
+    let expandDatesLocal = {};
     for (let i = 0; i < logs.length; i++) {
     for (let i = 0; i < logs.length; i++) {
       logs[i].timestamp2string = timestamp2string(logs[i].created_at);
       logs[i].timestamp2string = timestamp2string(logs[i].created_at);
       logs[i].key = '' + logs[i].id;
       logs[i].key = '' + logs[i].id;
+      let other = getLogOther(logs[i].other);
+      let expandDataLocal = [];
+      if (isAdmin()) {
+        let content = '渠道:' + logs[i].channel;
+        if (other.admin_info !== undefined) {
+          if (
+            other.admin_info.use_channel !== null &&
+            other.admin_info.use_channel !== undefined &&
+            other.admin_info.use_channel !== ''
+          ) {
+            // channel id array
+            let useChannel = other.admin_info.use_channel;
+            let useChannelStr = useChannel.join('->');
+            content = `渠道:${useChannelStr}`;
+          }
+        }
+        expandDataLocal.push({
+          key: '重试',
+          value: content,
+        })
+      }
+      if (other.ws) {
+        expandDataLocal.push({
+          key: '语音输入',
+          value: other.audio_input,
+        });
+        expandDataLocal.push({
+          key: '语音输出',
+          value: other.audio_output,
+        });
+        expandDataLocal.push({
+          key: '文字输入',
+          value: other.text_input,
+        });
+        expandDataLocal.push({
+          key: '文字输出',
+          value: other.text_output,
+        });
+      }
+      expandDatesLocal[logs[i].key] = expandDataLocal;
     }
     }
+    console.log(expandDatesLocal);
+    setExpandData(expandDatesLocal);
+
     setLogs(logs);
     setLogs(logs);
   };
   };
 
 
@@ -588,6 +633,10 @@ const LogsTable = () => {
     handleEyeClick();
     handleEyeClick();
   }, []);
   }, []);
 
 
+  const expandRowRender = (record, index) => {
+    return <Descriptions align="justify" data={expandData[record.key]} />;
+  };
+
   return (
   return (
     <>
     <>
       <Layout>
       <Layout>
@@ -686,7 +735,9 @@ const LogsTable = () => {
         <Table
         <Table
           style={{ marginTop: 5 }}
           style={{ marginTop: 5 }}
           columns={columns}
           columns={columns}
+          expandedRowRender={expandRowRender}
           dataSource={logs}
           dataSource={logs}
+          rowKey="key"
           pagination={{
           pagination={{
             currentPage: activePage,
             currentPage: activePage,
             pageSize: pageSize,
             pageSize: pageSize,