stream_scanner.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. package helper
  2. import (
  3. "bufio"
  4. "context"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/logger"
  14. relaycommon "github.com/QuantumNous/new-api/relay/common"
  15. "github.com/QuantumNous/new-api/setting/operation_setting"
  16. "github.com/bytedance/gopkg/util/gopool"
  17. "github.com/gin-gonic/gin"
  18. )
  19. const (
  20. InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
  21. DefaultMaxScannerBufferSize = 64 << 20 // 64MB (64*1024*1024) default SSE buffer size
  22. DefaultPingInterval = 10 * time.Second
  23. )
  24. func getScannerBufferSize() int {
  25. if constant.StreamScannerMaxBufferMB > 0 {
  26. return constant.StreamScannerMaxBufferMB << 20
  27. }
  28. return DefaultMaxScannerBufferSize
  29. }
  30. func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string, sr *StreamResult)) {
  31. if resp == nil || dataHandler == nil {
  32. return
  33. }
  34. // 无条件新建 StreamStatus
  35. info.StreamStatus = relaycommon.NewStreamStatus()
  36. // 确保响应体总是被关闭
  37. defer func() {
  38. if resp.Body != nil {
  39. resp.Body.Close()
  40. }
  41. }()
  42. streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
  43. var (
  44. stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
  45. scanner = bufio.NewScanner(resp.Body)
  46. ticker = time.NewTicker(streamingTimeout)
  47. pingTicker *time.Ticker
  48. writeMutex sync.Mutex // Mutex to protect concurrent writes
  49. wg sync.WaitGroup // 用于等待所有 goroutine 退出
  50. )
  51. generalSettings := operation_setting.GetGeneralSetting()
  52. pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
  53. pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
  54. if pingInterval <= 0 {
  55. pingInterval = DefaultPingInterval
  56. }
  57. if pingEnabled {
  58. pingTicker = time.NewTicker(pingInterval)
  59. }
  60. if common.DebugEnabled {
  61. // print timeout and ping interval for debugging
  62. println("relay timeout seconds:", common.RelayTimeout)
  63. println("relay max idle conns:", common.RelayMaxIdleConns)
  64. println("relay max idle conns per host:", common.RelayMaxIdleConnsPerHost)
  65. println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
  66. println("ping interval seconds:", int64(pingInterval.Seconds()))
  67. }
  68. // 改进资源清理,确保所有 goroutine 正确退出
  69. defer func() {
  70. // 通知所有 goroutine 停止
  71. common.SafeSendBool(stopChan, true)
  72. ticker.Stop()
  73. if pingTicker != nil {
  74. pingTicker.Stop()
  75. }
  76. // 等待所有 goroutine 退出,最多等待5秒
  77. done := make(chan struct{})
  78. gopool.Go(func() {
  79. wg.Wait()
  80. close(done)
  81. })
  82. select {
  83. case <-done:
  84. case <-time.After(5 * time.Second):
  85. logger.LogError(c, "timeout waiting for goroutines to exit")
  86. }
  87. close(stopChan)
  88. }()
  89. scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize())
  90. scanner.Split(bufio.ScanLines)
  91. SetEventStreamHeaders(c)
  92. ctx, cancel := context.WithCancel(context.Background())
  93. defer cancel()
  94. ctx = context.WithValue(ctx, "stop_chan", stopChan)
  95. // Handle ping data sending with improved error handling
  96. if pingEnabled && pingTicker != nil {
  97. wg.Add(1)
  98. gopool.Go(func() {
  99. defer func() {
  100. wg.Done()
  101. if r := recover(); r != nil {
  102. logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
  103. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("ping panic: %v", r))
  104. common.SafeSendBool(stopChan, true)
  105. }
  106. if common.DebugEnabled {
  107. println("ping goroutine exited")
  108. }
  109. }()
  110. // 添加超时保护,防止 goroutine 无限运行
  111. maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
  112. pingTimeout := time.NewTimer(maxPingDuration)
  113. defer pingTimeout.Stop()
  114. for {
  115. select {
  116. case <-pingTicker.C:
  117. // 使用超时机制防止写操作阻塞
  118. done := make(chan error, 1)
  119. gopool.Go(func() {
  120. writeMutex.Lock()
  121. defer writeMutex.Unlock()
  122. done <- PingData(c)
  123. })
  124. select {
  125. case err := <-done:
  126. if err != nil {
  127. logger.LogError(c, "ping data error: "+err.Error())
  128. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, err)
  129. return
  130. }
  131. if common.DebugEnabled {
  132. println("ping data sent")
  133. }
  134. case <-time.After(10 * time.Second):
  135. logger.LogError(c, "ping data send timeout")
  136. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPingFail, fmt.Errorf("ping send timeout"))
  137. return
  138. case <-ctx.Done():
  139. return
  140. case <-stopChan:
  141. return
  142. }
  143. case <-ctx.Done():
  144. return
  145. case <-stopChan:
  146. return
  147. case <-c.Request.Context().Done():
  148. // 监听客户端断开连接
  149. return
  150. case <-pingTimeout.C:
  151. logger.LogError(c, "ping goroutine max duration reached")
  152. return
  153. }
  154. }
  155. })
  156. }
  157. dataChan := make(chan string, 10)
  158. wg.Add(1)
  159. gopool.Go(func() {
  160. defer func() {
  161. wg.Done()
  162. if r := recover(); r != nil {
  163. logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
  164. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("handler panic: %v", r))
  165. }
  166. common.SafeSendBool(stopChan, true)
  167. }()
  168. sr := newStreamResult(info.StreamStatus)
  169. for data := range dataChan {
  170. sr.reset()
  171. writeMutex.Lock()
  172. dataHandler(data, sr)
  173. writeMutex.Unlock()
  174. if sr.IsStopped() {
  175. return
  176. }
  177. }
  178. })
  179. // Scanner goroutine with improved error handling
  180. wg.Add(1)
  181. common.RelayCtxGo(ctx, func() {
  182. defer func() {
  183. close(dataChan)
  184. wg.Done()
  185. if r := recover(); r != nil {
  186. logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
  187. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonPanic, fmt.Errorf("scanner panic: %v", r))
  188. }
  189. common.SafeSendBool(stopChan, true)
  190. if common.DebugEnabled {
  191. println("scanner goroutine exited")
  192. }
  193. }()
  194. for scanner.Scan() {
  195. // 检查是否需要停止
  196. select {
  197. case <-stopChan:
  198. return
  199. case <-ctx.Done():
  200. return
  201. case <-c.Request.Context().Done():
  202. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
  203. return
  204. default:
  205. }
  206. ticker.Reset(streamingTimeout)
  207. data := scanner.Text()
  208. if common.DebugEnabled {
  209. println(data)
  210. }
  211. if len(data) < 6 {
  212. continue
  213. }
  214. if data[:5] != "data:" && data[:6] != "[DONE]" {
  215. continue
  216. }
  217. data = data[5:]
  218. data = strings.TrimSpace(data)
  219. if data == "" {
  220. continue
  221. }
  222. if !strings.HasPrefix(data, "[DONE]") {
  223. info.SetFirstResponseTime()
  224. info.ReceivedResponseCount++
  225. select {
  226. case dataChan <- data:
  227. case <-ctx.Done():
  228. return
  229. case <-stopChan:
  230. return
  231. }
  232. } else {
  233. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonDone, nil)
  234. if common.DebugEnabled {
  235. println("received [DONE], stopping scanner")
  236. }
  237. return
  238. }
  239. }
  240. if err := scanner.Err(); err != nil {
  241. if err != io.EOF {
  242. logger.LogError(c, "scanner error: "+err.Error())
  243. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonScannerErr, err)
  244. }
  245. }
  246. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonEOF, nil)
  247. })
  248. // 主循环等待完成或超时
  249. select {
  250. case <-ticker.C:
  251. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonTimeout, nil)
  252. case <-stopChan:
  253. // EndReason already set by the goroutine that triggered stopChan
  254. case <-c.Request.Context().Done():
  255. info.StreamStatus.SetEndReason(relaycommon.StreamEndReasonClientGone, c.Request.Context().Err())
  256. }
  257. if info.StreamStatus.IsNormalEnd() && !info.StreamStatus.HasErrors() {
  258. logger.LogInfo(c, fmt.Sprintf("stream ended: %s", info.StreamStatus.Summary()))
  259. } else {
  260. logger.LogError(c, fmt.Sprintf("stream ended: %s, received=%d", info.StreamStatus.Summary(), info.ReceivedResponseCount))
  261. }
  262. }