stream_scanner.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package helper
  2. import (
  3. "bufio"
  4. "context"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/constant"
  9. relaycommon "one-api/relay/common"
  10. "one-api/setting/operation_setting"
  11. "strings"
  12. "sync"
  13. "time"
  14. "github.com/bytedance/gopkg/util/gopool"
  15. "github.com/gin-gonic/gin"
  16. )
  17. const (
  18. InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
  19. MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
  20. DefaultPingInterval = 10 * time.Second
  21. )
  22. func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
  23. if resp == nil || dataHandler == nil {
  24. return
  25. }
  26. defer resp.Body.Close()
  27. streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
  28. if strings.HasPrefix(info.UpstreamModelName, "o") {
  29. // twice timeout for thinking model
  30. streamingTimeout *= 2
  31. }
  32. var (
  33. stopChan = make(chan bool, 2)
  34. scanner = bufio.NewScanner(resp.Body)
  35. ticker = time.NewTicker(streamingTimeout)
  36. pingTicker *time.Ticker
  37. writeMutex sync.Mutex // Mutex to protect concurrent writes
  38. )
  39. generalSettings := operation_setting.GetGeneralSetting()
  40. pingEnabled := generalSettings.PingIntervalEnabled
  41. pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
  42. if pingInterval <= 0 {
  43. pingInterval = DefaultPingInterval
  44. }
  45. if pingEnabled {
  46. pingTicker = time.NewTicker(pingInterval)
  47. }
  48. defer func() {
  49. ticker.Stop()
  50. if pingTicker != nil {
  51. pingTicker.Stop()
  52. }
  53. close(stopChan)
  54. }()
  55. scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
  56. scanner.Split(bufio.ScanLines)
  57. SetEventStreamHeaders(c)
  58. ctx, cancel := context.WithCancel(context.Background())
  59. defer cancel()
  60. ctx = context.WithValue(ctx, "stop_chan", stopChan)
  61. // Handle ping data sending
  62. if pingEnabled && pingTicker != nil {
  63. gopool.Go(func() {
  64. for {
  65. select {
  66. case <-pingTicker.C:
  67. writeMutex.Lock() // Lock before writing
  68. err := PingData(c)
  69. writeMutex.Unlock() // Unlock after writing
  70. if err != nil {
  71. common.LogError(c, "ping data error: "+err.Error())
  72. common.SafeSendBool(stopChan, true)
  73. return
  74. }
  75. if common.DebugEnabled {
  76. println("ping data sent")
  77. }
  78. case <-ctx.Done():
  79. if common.DebugEnabled {
  80. println("ping data goroutine stopped")
  81. }
  82. return
  83. }
  84. }
  85. })
  86. }
  87. common.RelayCtxGo(ctx, func() {
  88. for scanner.Scan() {
  89. ticker.Reset(streamingTimeout)
  90. data := scanner.Text()
  91. if common.DebugEnabled {
  92. println(data)
  93. }
  94. if len(data) < 6 {
  95. continue
  96. }
  97. if data[:5] != "data:" && data[:6] != "[DONE]" {
  98. continue
  99. }
  100. data = data[5:]
  101. data = strings.TrimLeft(data, " ")
  102. data = strings.TrimSuffix(data, "\r")
  103. if !strings.HasPrefix(data, "[DONE]") {
  104. info.SetFirstResponseTime()
  105. writeMutex.Lock() // Lock before writing
  106. success := dataHandler(data)
  107. writeMutex.Unlock() // Unlock after writing
  108. if !success {
  109. break
  110. }
  111. }
  112. }
  113. if err := scanner.Err(); err != nil {
  114. if err != io.EOF {
  115. common.LogError(c, "scanner error: "+err.Error())
  116. }
  117. }
  118. common.SafeSendBool(stopChan, true)
  119. })
  120. select {
  121. case <-ticker.C:
  122. // 超时处理逻辑
  123. common.LogError(c, "streaming timeout")
  124. common.SafeSendBool(stopChan, true)
  125. case <-stopChan:
  126. // 正常结束
  127. common.LogInfo(c, "streaming finished")
  128. }
  129. }