Procházet zdrojové kódy

refactor: extract FlushWriter function for improved stream flushing

CaIon před 6 měsíci
rodič
revize
c18414cbe4
2 změnil soubory, kde provedl 17 přidání a 32 odebrání
  1. 4 11
      relay/channel/openai/helper.go
  2. 13 21
      relay/helper/common.go

+ 4 - 11
relay/channel/openai/helper.go

@@ -2,9 +2,6 @@ package openai
 
 import (
 	"encoding/json"
-	"errors"
-	"github.com/samber/lo"
-	"net/http"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/logger"
@@ -15,6 +12,8 @@ import (
 	"one-api/types"
 	"strings"
 
+	"github.com/samber/lo"
+
 	"github.com/gin-gonic/gin"
 )
 
@@ -71,11 +70,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
 
 	// send gemini format response
 	c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	} else {
-		return errors.New("streaming error: flusher not found")
-	}
+	_ = helper.FlushWriter(c)
 	return nil
 }
 
@@ -253,9 +248,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
 
 		// 发送最终的 Gemini 响应
 		c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
-		if flusher, ok := c.Writer.(http.Flusher); ok {
-			flusher.Flush()
-		}
+		_ = helper.FlushWriter(c)
 	}
 }
 

+ 13 - 21
relay/helper/common.go

@@ -14,6 +14,14 @@ import (
 	"github.com/gorilla/websocket"
 )
 
+func FlushWriter(c *gin.Context) error {
+	if flusher, ok := c.Writer.(http.Flusher); ok {
+		flusher.Flush()
+		return nil
+	}
+	return errors.New("streaming error: flusher not found")
+}
+
 func SetEventStreamHeaders(c *gin.Context) {
 	// 检查是否已经设置过头部
 	if _, exists := c.Get("event_stream_headers_set"); exists {
@@ -38,49 +46,33 @@ func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
 		c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
 		c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
 	}
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	} else {
-		return errors.New("streaming error: flusher not found")
-	}
+	_ = FlushWriter(c)
 	return nil
 }
 
 func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
 	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
 	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	}
+	_ = FlushWriter(c)
 }
 
 func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
 	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
 	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	}
+	_ = FlushWriter(c)
 }
 
 func StringData(c *gin.Context, str string) error {
 	//str = strings.TrimPrefix(str, "data: ")
 	//str = strings.TrimSuffix(str, "\r")
 	c.Render(-1, common.CustomEvent{Data: "data: " + str})
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	} else {
-		return errors.New("streaming error: flusher not found")
-	}
+	_ = FlushWriter(c)
 	return nil
 }
 
 func PingData(c *gin.Context) error {
 	c.Writer.Write([]byte(": PING\n\n"))
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	} else {
-		return errors.New("streaming error: flusher not found")
-	}
+	_ = FlushWriter(c)
 	return nil
 }