Browse Source

fix: 流式请求ping

creamlike1024 9 months ago
parent
commit
c51a30b862
1 changed files with 66 additions and 48 deletions
  1. 66 48
      relay/channel/api_request.go

+ 66 - 48
relay/channel/api_request.go

@@ -104,6 +104,65 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	return targetConn, nil
 }
 
+func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
+	pingerCtx, stopPinger := context.WithCancel(context.Background())
+
+	gopool.Go(func() {
+		defer func() {
+			if common2.DebugEnabled {
+				println("SSE ping goroutine stopped.")
+			}
+		}()
+
+		if pingInterval <= 0 {
+			pingInterval = helper.DefaultPingInterval
+		}
+
+		ticker := time.NewTicker(pingInterval)
+		// 退出时清理 ticker
+		defer ticker.Stop()
+
+		var pingMutex sync.Mutex
+		if common2.DebugEnabled {
+			println("SSE ping goroutine started")
+		}
+
+		for {
+			select {
+			// 发送 ping 数据
+			case <-ticker.C:
+				if err := sendPingData(c, &pingMutex); err != nil {
+					return
+				}
+			// 收到退出信号
+			case <-pingerCtx.Done():
+				return
+			// request 结束
+			case <-c.Request.Context().Done():
+				return
+			}
+		}
+	})
+
+	return stopPinger
+}
+
+func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
+	mutex.Lock()
+	defer mutex.Unlock()
+
+	err := helper.PingData(c)
+	if err != nil {
+		common2.LogError(c, "SSE ping error: "+err.Error())
+		return err
+	}
+
+	if common2.DebugEnabled {
+		println("SSE ping data sent.")
+	}
+	return nil
+}
+
 func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
 	var client *http.Client
 	var err error
@@ -115,69 +174,28 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
 	} else {
 		client = service.GetHttpClient()
 	}
-	// 流式请求 ping 保活
-	var stopPinger func()
-	generalSettings := operation_setting.GetGeneralSetting()
-	pingEnabled := generalSettings.PingIntervalEnabled
-	var pingerWg sync.WaitGroup
+
 	if info.IsStream {
 		helper.SetEventStreamHeaders(c)
 
-		if pingEnabled {
+		// 处理流式请求的 ping 保活
+		generalSettings := operation_setting.GetGeneralSetting()
+		if generalSettings.PingIntervalEnabled {
 			pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
-			var pingerCtx context.Context
-			pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
-			// 退出时清理 pingerCtx 防止泄露
+			stopPinger := startPingKeepAlive(c, pingInterval)
 			defer stopPinger()
-			pingerWg.Add(1)
-			gopool.Go(func() {
-				defer pingerWg.Done()
-				if pingInterval <= 0 {
-					pingInterval = helper.DefaultPingInterval
-				}
-
-				ticker := time.NewTicker(pingInterval)
-				defer ticker.Stop()
-				var pingMutex sync.Mutex
-				if common2.DebugEnabled {
-					println("SSE ping goroutine started")
-				}
-
-				for {
-					select {
-					case <-ticker.C:
-						pingMutex.Lock()
-						err2 := helper.PingData(c)
-						pingMutex.Unlock()
-						if err2 != nil {
-							common2.LogError(c, "SSE ping error: "+err.Error())
-							return
-						}
-						if common2.DebugEnabled {
-							println("SSE ping data sent.")
-						}
-					case <-pingerCtx.Done():
-						if common2.DebugEnabled {
-							println("SSE ping goroutine stopped.")
-						}
-						return
-					}
-				}
-			})
 		}
 	}
 
 	resp, err := client.Do(req)
-	// request结束后等待 ping goroutine 完成
-	if info.IsStream && pingEnabled {
-		pingerWg.Wait()
-	}
+
 	if err != nil {
 		return nil, err
 	}
 	if resp == nil {
 		return nil, errors.New("resp is nil")
 	}
+
 	_ = req.Body.Close()
 	_ = c.Request.Body.Close()
 	return resp, nil