stream_scanner.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package helper
  2. import (
  3. "bufio"
  4. "context"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/constant"
  9. relaycommon "one-api/relay/common"
  10. "strings"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. )
  14. func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
  15. if resp == nil {
  16. return
  17. }
  18. defer resp.Body.Close()
  19. streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
  20. if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
  21. // twice timeout for thinking model
  22. streamingTimeout *= 2
  23. }
  24. var (
  25. stopChan = make(chan bool, 2)
  26. scanner = bufio.NewScanner(resp.Body)
  27. ticker = time.NewTicker(streamingTimeout)
  28. )
  29. defer func() {
  30. ticker.Stop()
  31. close(stopChan)
  32. }()
  33. scanner.Split(bufio.ScanLines)
  34. SetEventStreamHeaders(c)
  35. ctx, cancel := context.WithCancel(context.Background())
  36. defer cancel()
  37. ctx = context.WithValue(ctx, "stop_chan", stopChan)
  38. common.RelayCtxGo(ctx, func() {
  39. for scanner.Scan() {
  40. ticker.Reset(streamingTimeout)
  41. data := scanner.Text()
  42. if common.DebugEnabled {
  43. println(data)
  44. }
  45. if len(data) < 6 {
  46. continue
  47. }
  48. if data[:5] != "data:" && data[:6] != "[DONE]" {
  49. continue
  50. }
  51. data = data[5:]
  52. data = strings.TrimLeft(data, " ")
  53. data = strings.TrimSuffix(data, "\"")
  54. if !strings.HasPrefix(data, "[DONE]") {
  55. info.SetFirstResponseTime()
  56. success := dataHandler(data)
  57. if !success {
  58. break
  59. }
  60. }
  61. }
  62. if err := scanner.Err(); err != nil {
  63. if err != io.EOF {
  64. common.LogError(c, "scanner error: "+err.Error())
  65. }
  66. }
  67. common.SafeSendBool(stopChan, true)
  68. })
  69. select {
  70. case <-ticker.C:
  71. // 超时处理逻辑
  72. common.LogError(c, "streaming timeout")
  73. case <-stopChan:
  74. // 正常结束
  75. }
  76. }