Sfoglia il codice sorgente

Update dto

(cherry picked from commit 030187ff75c64c40017cda2fa98ef2b3c01f0bd5)
1808837298@qq.com 1 anno fa
parent
commit
e3c85572d4
6 ha cambiato i file con 182 aggiunte e 26 eliminazioni
  1. 57 0
      controller/relay.go
  2. 59 0
      dto/realtime.go
  3. 4 0
      middleware/distributor.go
  4. 4 0
      relay/constant/relay_mode.go
  5. 34 25
      router/relay-router.go
  6. 24 1
      service/relay.go

+ 57 - 0
controller/relay.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 	"io"
 	"log"
 	"net/http"
@@ -134,6 +135,62 @@ func Relay(c *gin.Context) {
 	}
 }
 
+var upgrader = websocket.Upgrader{
+	CheckOrigin: func(r *http.Request) bool {
+		return true // 允许跨域
+	},
+}
+
+func WssRelay(c *gin.Context) {
+	// 将 HTTP 连接升级为 WebSocket 连接
+	ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
+	if err != nil {
+		openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
+		service.WssError(c, ws, openaiErr.Error)
+		return
+	}
+	relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+	requestId := c.GetString(common.RequestIdKey)
+	group := c.GetString("group")
+	//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
+	originalModel := c.GetString("original_model")
+	var openaiErr *dto.OpenAIErrorWithStatusCode
+
+	for i := 0; i <= common.RetryTimes; i++ {
+		channel, err := getChannel(c, group, originalModel, i)
+		if err != nil {
+			common.LogError(c, err.Error())
+			openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
+			break
+		}
+
+		openaiErr = relayRequest(c, relayMode, channel)
+
+		if openaiErr == nil {
+			return // 成功处理请求,直接返回
+		}
+
+		go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
+
+		if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
+			break
+		}
+	}
+	useChannel := c.GetStringSlice("use_channel")
+	if len(useChannel) > 1 {
+		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+		common.LogInfo(c, retryLogStr)
+	}
+
+	if openaiErr != nil {
+		if openaiErr.StatusCode == http.StatusTooManyRequests {
+			openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
+		}
+		openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
+		service.WssError(c, ws, openaiErr.Error)
+	}
+}
+
 func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
 	addUsedChannel(c, channel.Id)
 	requestBody, _ := common.GetRequestBody(c)

+ 59 - 0
dto/realtime.go

@@ -0,0 +1,59 @@
+package dto
+
+const (
+	RealtimeEventTypeError              = "error"
+	RealtimeEventTypeSessionUpdate      = "session.update"
+	RealtimeEventTypeConversationCreate = "conversation.item.create"
+	RealtimeEventTypeResponseCreate     = "response.create"
+)
+
+type RealtimeEvent struct {
+	EventId string `json:"event_id"`
+	Type    string `json:"type"`
+	//PreviousItemId string `json:"previous_item_id"`
+	Session *RealtimeSession `json:"session,omitempty"`
+	Item    *RealtimeItem    `json:"item,omitempty"`
+	Error   *OpenAIError     `json:"error,omitempty"`
+}
+
+type RealtimeSession struct {
+	Modalities              []string                `json:"modalities"`
+	Instructions            string                  `json:"instructions"`
+	Voice                   string                  `json:"voice"`
+	InputAudioFormat        string                  `json:"input_audio_format"`
+	OutputAudioFormat       string                  `json:"output_audio_format"`
+	InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
+	TurnDetection           interface{}             `json:"turn_detection"`
+	Tools                   []RealTimeTool          `json:"tools"`
+	ToolChoice              string                  `json:"tool_choice"`
+	Temperature             float64                 `json:"temperature"`
+	MaxResponseOutputTokens int                     `json:"max_response_output_tokens"`
+}
+
+type InputAudioTranscription struct {
+	Model string `json:"model"`
+}
+
+type RealTimeTool struct {
+	Type        string `json:"type"`
+	Name        string `json:"name"`
+	Description string `json:"description"`
+	Parameters  any    `json:"parameters"`
+}
+
+type RealtimeItem struct {
+	Id        string          `json:"id"`
+	Type      string          `json:"type"`
+	Status    string          `json:"status"`
+	Role      string          `json:"role"`
+	Content   RealtimeContent `json:"content"`
+	Name      *string         `json:"name,omitempty"`
+	ToolCalls any             `json:"tool_calls,omitempty"`
+	CallId    string          `json:"call_id,omitempty"`
+}
+type RealtimeContent struct {
+	Type       string `json:"type"`
+	Text       string `json:"text,omitempty"`
+	Audio      string `json:"audio,omitempty"` // Base64-encoded audio bytes.
+	Transcript string `json:"transcript,omitempty"`
+}

+ 4 - 0
middleware/distributor.go

@@ -170,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
 		return nil, false, errors.New("无效的请求, " + err.Error())
 	}
+	if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
+		//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
+		modelRequest.Model = c.Query("model")
+	}
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 		if modelRequest.Model == "" {
 			modelRequest.Model = "text-moderation-stable"

+ 4 - 0
relay/constant/relay_mode.go

@@ -38,6 +38,8 @@ const (
 	RelayModeSunoSubmit
 
 	RelayModeRerank
+
+	RelayModeRealtime
 )
 
 func Path2RelayMode(path string) int {
@@ -64,6 +66,8 @@ func Path2RelayMode(path string) int {
 		relayMode = RelayModeAudioTranslation
 	} else if strings.HasPrefix(path, "/v1/rerank") {
 		relayMode = RelayModeRerank
+	} else if strings.HasPrefix(path, "/v1/realtime") {
+		relayMode = RelayModeRealtime
 	}
 	return relayMode
 }

+ 34 - 25
router/relay-router.go

@@ -22,32 +22,41 @@ func SetRelayRouter(router *gin.Engine) {
 		playgroundRouter.POST("/chat/completions", controller.Playground)
 	}
 	relayV1Router := router.Group("/v1")
-	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
+	relayV1Router.Use(middleware.TokenAuth())
 	{
-		relayV1Router.POST("/completions", controller.Relay)
-		relayV1Router.POST("/chat/completions", controller.Relay)
-		relayV1Router.POST("/edits", controller.Relay)
-		relayV1Router.POST("/images/generations", controller.Relay)
-		relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
-		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
-		relayV1Router.POST("/embeddings", controller.Relay)
-		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
-		relayV1Router.POST("/audio/transcriptions", controller.Relay)
-		relayV1Router.POST("/audio/translations", controller.Relay)
-		relayV1Router.POST("/audio/speech", controller.Relay)
-		relayV1Router.GET("/files", controller.RelayNotImplemented)
-		relayV1Router.POST("/files", controller.RelayNotImplemented)
-		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
-		relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
-		relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
-		relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
-		relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
-		relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
-		relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
-		relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
-		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
-		relayV1Router.POST("/moderations", controller.Relay)
-		relayV1Router.POST("/rerank", controller.Relay)
+		// WebSocket 路由
+		wsRouter := relayV1Router.Group("")
+		wsRouter.Use(middleware.Distribute())
+		wsRouter.GET("/realtime", controller.WssRelay)
+	}
+	{
+		//http router
+		httpRouter := relayV1Router.Group("")
+		httpRouter.Use(middleware.Distribute())
+		httpRouter.POST("/completions", controller.Relay)
+		httpRouter.POST("/chat/completions", controller.Relay)
+		httpRouter.POST("/edits", controller.Relay)
+		httpRouter.POST("/images/generations", controller.Relay)
+		httpRouter.POST("/images/edits", controller.RelayNotImplemented)
+		httpRouter.POST("/images/variations", controller.RelayNotImplemented)
+		httpRouter.POST("/embeddings", controller.Relay)
+		httpRouter.POST("/engines/:model/embeddings", controller.Relay)
+		httpRouter.POST("/audio/transcriptions", controller.Relay)
+		httpRouter.POST("/audio/translations", controller.Relay)
+		httpRouter.POST("/audio/speech", controller.Relay)
+		httpRouter.GET("/files", controller.RelayNotImplemented)
+		httpRouter.POST("/files", controller.RelayNotImplemented)
+		httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
+		httpRouter.GET("/files/:id", controller.RelayNotImplemented)
+		httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
+		httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
+		httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
+		httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
+		httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
+		httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
+		httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
+		httpRouter.POST("/moderations", controller.Relay)
+		httpRouter.POST("/rerank", controller.Relay)
 	}
 
 	relayMjRouter := router.Group("/mj")

+ 24 - 1
service/relay.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
@@ -42,11 +43,33 @@ func Done(c *gin.Context) {
 	_ = StringData(c, "[DONE]")
 }
 
+func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
+	jsonData, err := json.Marshal(object)
+	if err != nil {
+		return fmt.Errorf("error marshalling object: %w", err)
+	}
+	return ws.WriteMessage(1, jsonData)
+}
+
+func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
+	errorObj := &dto.RealtimeEvent{
+		Type:    "error",
+		EventId: GetLocalRealtimeID(c),
+		Error:   &openaiError,
+	}
+	_ = WssObject(c, ws, errorObj)
+}
+
 func GetResponseID(c *gin.Context) string {
-	logID := c.GetString("X-Oneapi-Request-Id")
+	logID := c.GetString(common.RequestIdKey)
 	return fmt.Sprintf("chatcmpl-%s", logID)
 }
 
+func GetLocalRealtimeID(c *gin.Context) string {
+	logID := c.GetString(common.RequestIdKey)
+	return fmt.Sprintf("evt_%s", logID)
+}
+
 func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
 	return &dto.ChatCompletionsStreamResponse{
 		Id:                id,