stream_scanner.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package helper
  2. import (
  3. "bufio"
  4. "context"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/constant"
  9. relaycommon "one-api/relay/common"
  10. "strings"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. )
  14. func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
  15. streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
  16. if strings.HasPrefix(info.UpstreamModelName, "o1") || strings.HasPrefix(info.UpstreamModelName, "o3") {
  17. // twice timeout for thinking model
  18. streamingTimeout *= 2
  19. }
  20. var (
  21. stopChan = make(chan bool, 2)
  22. scanner = bufio.NewScanner(resp.Body)
  23. ticker = time.NewTicker(streamingTimeout)
  24. )
  25. defer func() {
  26. ticker.Stop()
  27. close(stopChan)
  28. }()
  29. scanner.Split(bufio.ScanLines)
  30. SetEventStreamHeaders(c)
  31. ctx, cancel := context.WithCancel(context.Background())
  32. defer cancel()
  33. ctx = context.WithValue(ctx, "stop_chan", stopChan)
  34. common.RelayCtxGo(ctx, func() {
  35. for scanner.Scan() {
  36. ticker.Reset(streamingTimeout)
  37. data := scanner.Text()
  38. if common.DebugEnabled {
  39. println(data)
  40. }
  41. if len(data) < 6 {
  42. continue
  43. }
  44. if data[:5] != "data:" && data[:6] != "[DONE]" {
  45. continue
  46. }
  47. data = data[5:]
  48. data = strings.TrimLeft(data, " ")
  49. data = strings.TrimSuffix(data, "\"")
  50. if !strings.HasPrefix(data, "[DONE]") {
  51. info.SetFirstResponseTime()
  52. success := dataHandler(data)
  53. if !success {
  54. break
  55. }
  56. }
  57. }
  58. if err := scanner.Err(); err != nil {
  59. if err != io.EOF {
  60. common.LogError(c, "scanner error: "+err.Error())
  61. }
  62. }
  63. common.SafeSendBool(stopChan, true)
  64. })
  65. select {
  66. case <-ticker.C:
  67. // 超时处理逻辑
  68. common.LogError(c, "streaming timeout")
  69. case <-stopChan:
  70. // 正常结束
  71. }
  72. }