stream_scanner.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package helper
  2. import (
  3. "bufio"
  4. "context"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/constant"
  9. relaycommon "one-api/relay/common"
  10. "strings"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. )
  14. const (
  15. InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
  16. MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
  17. )
  18. func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
  19. if resp == nil {
  20. return
  21. }
  22. defer resp.Body.Close()
  23. streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
  24. if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
  25. // twice timeout for thinking model
  26. streamingTimeout *= 2
  27. }
  28. var (
  29. stopChan = make(chan bool, 2)
  30. scanner = bufio.NewScanner(resp.Body)
  31. ticker = time.NewTicker(streamingTimeout)
  32. )
  33. defer func() {
  34. ticker.Stop()
  35. close(stopChan)
  36. }()
  37. scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
  38. scanner.Split(bufio.ScanLines)
  39. SetEventStreamHeaders(c)
  40. ctx, cancel := context.WithCancel(context.Background())
  41. defer cancel()
  42. ctx = context.WithValue(ctx, "stop_chan", stopChan)
  43. common.RelayCtxGo(ctx, func() {
  44. for scanner.Scan() {
  45. ticker.Reset(streamingTimeout)
  46. data := scanner.Text()
  47. if common.DebugEnabled {
  48. println(data)
  49. }
  50. if len(data) < 6 {
  51. continue
  52. }
  53. if data[:5] != "data:" && data[:6] != "[DONE]" {
  54. continue
  55. }
  56. data = data[5:]
  57. data = strings.TrimLeft(data, " ")
  58. data = strings.TrimSuffix(data, "\"")
  59. if !strings.HasPrefix(data, "[DONE]") {
  60. info.SetFirstResponseTime()
  61. success := dataHandler(data)
  62. if !success {
  63. break
  64. }
  65. }
  66. }
  67. if err := scanner.Err(); err != nil {
  68. if err != io.EOF {
  69. common.LogError(c, "scanner error: "+err.Error())
  70. }
  71. }
  72. common.SafeSendBool(stopChan, true)
  73. })
  74. select {
  75. case <-ticker.C:
  76. // 超时处理逻辑
  77. common.LogError(c, "streaming timeout")
  78. case <-stopChan:
  79. // 正常结束
  80. }
  81. }