stream_scanner_test.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. package helper
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "strings"
  8. "sync"
  9. "sync/atomic"
  10. "testing"
  11. "time"
  12. "github.com/QuantumNous/new-api/constant"
  13. relaycommon "github.com/QuantumNous/new-api/relay/common"
  14. "github.com/QuantumNous/new-api/setting/operation_setting"
  15. "github.com/gin-gonic/gin"
  16. "github.com/stretchr/testify/assert"
  17. "github.com/stretchr/testify/require"
  18. )
  19. func init() {
  20. gin.SetMode(gin.TestMode)
  21. }
  22. func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) {
  23. t.Helper()
  24. oldTimeout := constant.StreamingTimeout
  25. constant.StreamingTimeout = 30
  26. t.Cleanup(func() {
  27. constant.StreamingTimeout = oldTimeout
  28. })
  29. recorder := httptest.NewRecorder()
  30. c, _ := gin.CreateTestContext(recorder)
  31. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  32. resp := &http.Response{
  33. Body: io.NopCloser(body),
  34. }
  35. info := &relaycommon.RelayInfo{
  36. ChannelMeta: &relaycommon.ChannelMeta{},
  37. }
  38. return c, resp, info
  39. }
  40. func buildSSEBody(n int) string {
  41. var b strings.Builder
  42. for i := 0; i < n; i++ {
  43. fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i)
  44. }
  45. b.WriteString("data: [DONE]\n")
  46. return b.String()
  47. }
  48. // slowReader wraps a reader and injects a delay before each Read call,
  49. // simulating a slow upstream that trickles data.
  50. type slowReader struct {
  51. r io.Reader
  52. delay time.Duration
  53. }
  54. func (s *slowReader) Read(p []byte) (int, error) {
  55. time.Sleep(s.delay)
  56. return s.r.Read(p)
  57. }
  58. // ---------- Basic correctness ----------
  59. func TestStreamScannerHandler_NilInputs(t *testing.T) {
  60. t.Parallel()
  61. recorder := httptest.NewRecorder()
  62. c, _ := gin.CreateTestContext(recorder)
  63. c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
  64. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  65. StreamScannerHandler(c, nil, info, func(data string) bool { return true })
  66. StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
  67. }
  68. func TestStreamScannerHandler_EmptyBody(t *testing.T) {
  69. t.Parallel()
  70. c, resp, info := setupStreamTest(t, strings.NewReader(""))
  71. var called atomic.Bool
  72. StreamScannerHandler(c, resp, info, func(data string) bool {
  73. called.Store(true)
  74. return true
  75. })
  76. assert.False(t, called.Load(), "handler should not be called for empty body")
  77. }
  78. func TestStreamScannerHandler_1000Chunks(t *testing.T) {
  79. t.Parallel()
  80. const numChunks = 1000
  81. body := buildSSEBody(numChunks)
  82. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  83. var count atomic.Int64
  84. StreamScannerHandler(c, resp, info, func(data string) bool {
  85. count.Add(1)
  86. return true
  87. })
  88. assert.Equal(t, int64(numChunks), count.Load())
  89. assert.Equal(t, numChunks, info.ReceivedResponseCount)
  90. }
  91. func TestStreamScannerHandler_10000Chunks(t *testing.T) {
  92. t.Parallel()
  93. const numChunks = 10000
  94. body := buildSSEBody(numChunks)
  95. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  96. var count atomic.Int64
  97. start := time.Now()
  98. StreamScannerHandler(c, resp, info, func(data string) bool {
  99. count.Add(1)
  100. return true
  101. })
  102. elapsed := time.Since(start)
  103. assert.Equal(t, int64(numChunks), count.Load())
  104. assert.Equal(t, numChunks, info.ReceivedResponseCount)
  105. t.Logf("10000 chunks processed in %v", elapsed)
  106. }
  107. func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
  108. t.Parallel()
  109. const numChunks = 500
  110. body := buildSSEBody(numChunks)
  111. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  112. var mu sync.Mutex
  113. received := make([]string, 0, numChunks)
  114. StreamScannerHandler(c, resp, info, func(data string) bool {
  115. mu.Lock()
  116. received = append(received, data)
  117. mu.Unlock()
  118. return true
  119. })
  120. require.Equal(t, numChunks, len(received))
  121. for i := 0; i < numChunks; i++ {
  122. expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i)
  123. assert.Equal(t, expected, received[i], "chunk %d out of order", i)
  124. }
  125. }
  126. func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
  127. t.Parallel()
  128. body := buildSSEBody(50) + "data: should_not_appear\n"
  129. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  130. var count atomic.Int64
  131. StreamScannerHandler(c, resp, info, func(data string) bool {
  132. count.Add(1)
  133. return true
  134. })
  135. assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
  136. }
  137. func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
  138. t.Parallel()
  139. const numChunks = 200
  140. body := buildSSEBody(numChunks)
  141. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  142. const failAt = 50
  143. var count atomic.Int64
  144. StreamScannerHandler(c, resp, info, func(data string) bool {
  145. n := count.Add(1)
  146. return n < failAt
  147. })
  148. // The worker stops at failAt; the scanner may have read ahead,
  149. // but the handler should not be called beyond failAt.
  150. assert.Equal(t, int64(failAt), count.Load())
  151. }
  152. func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
  153. t.Parallel()
  154. var b strings.Builder
  155. b.WriteString(": comment line\n")
  156. b.WriteString("event: message\n")
  157. b.WriteString("id: 12345\n")
  158. b.WriteString("retry: 5000\n")
  159. for i := 0; i < 100; i++ {
  160. fmt.Fprintf(&b, "data: payload_%d\n", i)
  161. b.WriteString(": interleaved comment\n")
  162. }
  163. b.WriteString("data: [DONE]\n")
  164. c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
  165. var count atomic.Int64
  166. StreamScannerHandler(c, resp, info, func(data string) bool {
  167. count.Add(1)
  168. return true
  169. })
  170. assert.Equal(t, int64(100), count.Load())
  171. }
  172. func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
  173. t.Parallel()
  174. body := "data: {\"trimmed\":true} \ndata: [DONE]\n"
  175. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  176. var got string
  177. StreamScannerHandler(c, resp, info, func(data string) bool {
  178. got = data
  179. return true
  180. })
  181. assert.Equal(t, "{\"trimmed\":true}", got)
  182. }
  183. // ---------- Decoupling: scanner not blocked by slow handler ----------
  184. func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
  185. t.Parallel()
  186. // Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
  187. // If the scanner were synchronously coupled to the handler, total time would be
  188. // ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
  189. // With decoupling, total time should be closer to
  190. // ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
  191. // because the scanner reads ahead into the buffer while the handler processes.
  192. const numChunks = 50
  193. const upstreamDelay = 10 * time.Millisecond
  194. const handlerDelay = 20 * time.Millisecond
  195. pr, pw := io.Pipe()
  196. go func() {
  197. defer pw.Close()
  198. for i := 0; i < numChunks; i++ {
  199. fmt.Fprintf(pw, "data: {\"id\":%d}\n", i)
  200. time.Sleep(upstreamDelay)
  201. }
  202. fmt.Fprint(pw, "data: [DONE]\n")
  203. }()
  204. recorder := httptest.NewRecorder()
  205. c, _ := gin.CreateTestContext(recorder)
  206. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  207. oldTimeout := constant.StreamingTimeout
  208. constant.StreamingTimeout = 30
  209. t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
  210. resp := &http.Response{Body: pr}
  211. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  212. var count atomic.Int64
  213. start := time.Now()
  214. done := make(chan struct{})
  215. go func() {
  216. StreamScannerHandler(c, resp, info, func(data string) bool {
  217. time.Sleep(handlerDelay)
  218. count.Add(1)
  219. return true
  220. })
  221. close(done)
  222. }()
  223. select {
  224. case <-done:
  225. case <-time.After(15 * time.Second):
  226. t.Fatal("StreamScannerHandler did not complete in time")
  227. }
  228. elapsed := time.Since(start)
  229. assert.Equal(t, int64(numChunks), count.Load())
  230. coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
  231. t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
  232. // If decoupled, elapsed should be well under the coupled estimate.
  233. assert.Less(t, elapsed, coupledTime*85/100,
  234. "decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
  235. }
  236. func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
  237. t.Parallel()
  238. const numChunks = 50
  239. body := buildSSEBody(numChunks)
  240. reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond}
  241. c, resp, info := setupStreamTest(t, reader)
  242. var count atomic.Int64
  243. start := time.Now()
  244. done := make(chan struct{})
  245. go func() {
  246. StreamScannerHandler(c, resp, info, func(data string) bool {
  247. count.Add(1)
  248. return true
  249. })
  250. close(done)
  251. }()
  252. select {
  253. case <-done:
  254. case <-time.After(15 * time.Second):
  255. t.Fatal("timed out with slow upstream")
  256. }
  257. elapsed := time.Since(start)
  258. assert.Equal(t, int64(numChunks), count.Load())
  259. t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed)
  260. }
  261. // ---------- Ping tests ----------
  262. func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
  263. t.Parallel()
  264. setting := operation_setting.GetGeneralSetting()
  265. oldEnabled := setting.PingIntervalEnabled
  266. oldSeconds := setting.PingIntervalSeconds
  267. setting.PingIntervalEnabled = true
  268. setting.PingIntervalSeconds = 1
  269. t.Cleanup(func() {
  270. setting.PingIntervalEnabled = oldEnabled
  271. setting.PingIntervalSeconds = oldSeconds
  272. })
  273. // Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
  274. // The ping interval is 1s, so we should see at least 2 pings.
  275. pr, pw := io.Pipe()
  276. go func() {
  277. defer pw.Close()
  278. for i := 0; i < 7; i++ {
  279. fmt.Fprintf(pw, "data: chunk_%d\n", i)
  280. time.Sleep(500 * time.Millisecond)
  281. }
  282. fmt.Fprint(pw, "data: [DONE]\n")
  283. }()
  284. recorder := httptest.NewRecorder()
  285. c, _ := gin.CreateTestContext(recorder)
  286. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  287. oldTimeout := constant.StreamingTimeout
  288. constant.StreamingTimeout = 30
  289. t.Cleanup(func() {
  290. constant.StreamingTimeout = oldTimeout
  291. })
  292. resp := &http.Response{Body: pr}
  293. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  294. var count atomic.Int64
  295. done := make(chan struct{})
  296. go func() {
  297. StreamScannerHandler(c, resp, info, func(data string) bool {
  298. count.Add(1)
  299. return true
  300. })
  301. close(done)
  302. }()
  303. select {
  304. case <-done:
  305. case <-time.After(15 * time.Second):
  306. t.Fatal("timed out waiting for stream to finish")
  307. }
  308. assert.Equal(t, int64(7), count.Load())
  309. body := recorder.Body.String()
  310. pingCount := strings.Count(body, ": PING")
  311. t.Logf("received %d pings in response body", pingCount)
  312. assert.GreaterOrEqual(t, pingCount, 2,
  313. "expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount)
  314. }
  315. func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
  316. t.Parallel()
  317. setting := operation_setting.GetGeneralSetting()
  318. oldEnabled := setting.PingIntervalEnabled
  319. oldSeconds := setting.PingIntervalSeconds
  320. setting.PingIntervalEnabled = true
  321. setting.PingIntervalSeconds = 1
  322. t.Cleanup(func() {
  323. setting.PingIntervalEnabled = oldEnabled
  324. setting.PingIntervalSeconds = oldSeconds
  325. })
  326. pr, pw := io.Pipe()
  327. go func() {
  328. defer pw.Close()
  329. for i := 0; i < 5; i++ {
  330. fmt.Fprintf(pw, "data: chunk_%d\n", i)
  331. time.Sleep(500 * time.Millisecond)
  332. }
  333. fmt.Fprint(pw, "data: [DONE]\n")
  334. }()
  335. recorder := httptest.NewRecorder()
  336. c, _ := gin.CreateTestContext(recorder)
  337. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  338. oldTimeout := constant.StreamingTimeout
  339. constant.StreamingTimeout = 30
  340. t.Cleanup(func() {
  341. constant.StreamingTimeout = oldTimeout
  342. })
  343. resp := &http.Response{Body: pr}
  344. info := &relaycommon.RelayInfo{
  345. DisablePing: true,
  346. ChannelMeta: &relaycommon.ChannelMeta{},
  347. }
  348. var count atomic.Int64
  349. done := make(chan struct{})
  350. go func() {
  351. StreamScannerHandler(c, resp, info, func(data string) bool {
  352. count.Add(1)
  353. return true
  354. })
  355. close(done)
  356. }()
  357. select {
  358. case <-done:
  359. case <-time.After(15 * time.Second):
  360. t.Fatal("timed out")
  361. }
  362. assert.Equal(t, int64(5), count.Load())
  363. body := recorder.Body.String()
  364. pingCount := strings.Count(body, ": PING")
  365. assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
  366. }
  367. func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
  368. t.Parallel()
  369. setting := operation_setting.GetGeneralSetting()
  370. oldEnabled := setting.PingIntervalEnabled
  371. oldSeconds := setting.PingIntervalSeconds
  372. setting.PingIntervalEnabled = true
  373. setting.PingIntervalSeconds = 1
  374. t.Cleanup(func() {
  375. setting.PingIntervalEnabled = oldEnabled
  376. setting.PingIntervalSeconds = oldSeconds
  377. })
  378. // Slow upstream + slow handler. Total stream takes ~5 seconds.
  379. // The ping goroutine stays alive as long as the scanner is reading,
  380. // so pings should fire between data writes.
  381. pr, pw := io.Pipe()
  382. go func() {
  383. defer pw.Close()
  384. for i := 0; i < 10; i++ {
  385. fmt.Fprintf(pw, "data: chunk_%d\n", i)
  386. time.Sleep(500 * time.Millisecond)
  387. }
  388. fmt.Fprint(pw, "data: [DONE]\n")
  389. }()
  390. recorder := httptest.NewRecorder()
  391. c, _ := gin.CreateTestContext(recorder)
  392. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  393. oldTimeout := constant.StreamingTimeout
  394. constant.StreamingTimeout = 30
  395. t.Cleanup(func() {
  396. constant.StreamingTimeout = oldTimeout
  397. })
  398. resp := &http.Response{Body: pr}
  399. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  400. var count atomic.Int64
  401. done := make(chan struct{})
  402. go func() {
  403. StreamScannerHandler(c, resp, info, func(data string) bool {
  404. count.Add(1)
  405. return true
  406. })
  407. close(done)
  408. }()
  409. select {
  410. case <-done:
  411. case <-time.After(15 * time.Second):
  412. t.Fatal("timed out")
  413. }
  414. assert.Equal(t, int64(10), count.Load())
  415. body := recorder.Body.String()
  416. pingCount := strings.Count(body, ": PING")
  417. t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount)
  418. assert.GreaterOrEqual(t, pingCount, 3,
  419. "expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount)
  420. }