stream_scanner_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. package helper
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "strings"
  8. "sync"
  9. "sync/atomic"
  10. "testing"
  11. "time"
  12. "github.com/QuantumNous/new-api/constant"
  13. relaycommon "github.com/QuantumNous/new-api/relay/common"
  14. "github.com/QuantumNous/new-api/setting/operation_setting"
  15. "github.com/gin-gonic/gin"
  16. "github.com/stretchr/testify/assert"
  17. "github.com/stretchr/testify/require"
  18. )
  19. func init() {
  20. gin.SetMode(gin.TestMode)
  21. }
  22. func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) {
  23. t.Helper()
  24. oldTimeout := constant.StreamingTimeout
  25. constant.StreamingTimeout = 30
  26. t.Cleanup(func() {
  27. constant.StreamingTimeout = oldTimeout
  28. })
  29. recorder := httptest.NewRecorder()
  30. c, _ := gin.CreateTestContext(recorder)
  31. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  32. resp := &http.Response{
  33. Body: io.NopCloser(body),
  34. }
  35. info := &relaycommon.RelayInfo{
  36. ChannelMeta: &relaycommon.ChannelMeta{},
  37. }
  38. return c, resp, info
  39. }
  40. func buildSSEBody(n int) string {
  41. var b strings.Builder
  42. for i := 0; i < n; i++ {
  43. fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i)
  44. }
  45. b.WriteString("data: [DONE]\n")
  46. return b.String()
  47. }
  48. type slowReader struct {
  49. r io.Reader
  50. delay time.Duration
  51. }
  52. func (s *slowReader) Read(p []byte) (int, error) {
  53. time.Sleep(s.delay)
  54. return s.r.Read(p)
  55. }
  56. // ---------- Basic correctness ----------
  57. func TestStreamScannerHandler_NilInputs(t *testing.T) {
  58. t.Parallel()
  59. recorder := httptest.NewRecorder()
  60. c, _ := gin.CreateTestContext(recorder)
  61. c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
  62. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  63. StreamScannerHandler(c, nil, info, func(data string, sr *StreamResult) {})
  64. StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
  65. }
  66. func TestStreamScannerHandler_EmptyBody(t *testing.T) {
  67. t.Parallel()
  68. c, resp, info := setupStreamTest(t, strings.NewReader(""))
  69. var called atomic.Bool
  70. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  71. called.Store(true)
  72. })
  73. assert.False(t, called.Load(), "handler should not be called for empty body")
  74. }
  75. func TestStreamScannerHandler_1000Chunks(t *testing.T) {
  76. t.Parallel()
  77. const numChunks = 1000
  78. body := buildSSEBody(numChunks)
  79. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  80. var count atomic.Int64
  81. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  82. count.Add(1)
  83. })
  84. assert.Equal(t, int64(numChunks), count.Load())
  85. assert.Equal(t, numChunks, info.ReceivedResponseCount)
  86. }
  87. func TestStreamScannerHandler_10000Chunks(t *testing.T) {
  88. t.Parallel()
  89. const numChunks = 10000
  90. body := buildSSEBody(numChunks)
  91. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  92. var count atomic.Int64
  93. start := time.Now()
  94. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  95. count.Add(1)
  96. })
  97. elapsed := time.Since(start)
  98. assert.Equal(t, int64(numChunks), count.Load())
  99. assert.Equal(t, numChunks, info.ReceivedResponseCount)
  100. t.Logf("10000 chunks processed in %v", elapsed)
  101. }
  102. func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
  103. t.Parallel()
  104. const numChunks = 500
  105. body := buildSSEBody(numChunks)
  106. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  107. var mu sync.Mutex
  108. received := make([]string, 0, numChunks)
  109. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  110. mu.Lock()
  111. received = append(received, data)
  112. mu.Unlock()
  113. })
  114. require.Equal(t, numChunks, len(received))
  115. for i := 0; i < numChunks; i++ {
  116. expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i)
  117. assert.Equal(t, expected, received[i], "chunk %d out of order", i)
  118. }
  119. }
  120. func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
  121. t.Parallel()
  122. body := buildSSEBody(50) + "data: should_not_appear\n"
  123. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  124. var count atomic.Int64
  125. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  126. count.Add(1)
  127. })
  128. assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
  129. }
  130. func TestStreamScannerHandler_StopStopsStream(t *testing.T) {
  131. t.Parallel()
  132. const numChunks = 200
  133. body := buildSSEBody(numChunks)
  134. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  135. const stopAt int64 = 50
  136. var count atomic.Int64
  137. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  138. n := count.Add(1)
  139. if n >= stopAt {
  140. sr.Stop(fmt.Errorf("fatal at %d", n))
  141. }
  142. })
  143. assert.Equal(t, stopAt, count.Load())
  144. require.NotNil(t, info.StreamStatus)
  145. assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
  146. }
  147. func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
  148. t.Parallel()
  149. var b strings.Builder
  150. b.WriteString(": comment line\n")
  151. b.WriteString("event: message\n")
  152. b.WriteString("id: 12345\n")
  153. b.WriteString("retry: 5000\n")
  154. for i := 0; i < 100; i++ {
  155. fmt.Fprintf(&b, "data: payload_%d\n", i)
  156. b.WriteString(": interleaved comment\n")
  157. }
  158. b.WriteString("data: [DONE]\n")
  159. c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
  160. var count atomic.Int64
  161. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  162. count.Add(1)
  163. })
  164. assert.Equal(t, int64(100), count.Load())
  165. }
  166. func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
  167. t.Parallel()
  168. body := "data: {\"trimmed\":true} \ndata: [DONE]\n"
  169. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  170. var got string
  171. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  172. got = data
  173. })
  174. assert.Equal(t, "{\"trimmed\":true}", got)
  175. }
  176. // ---------- Decoupling ----------
  177. func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
  178. t.Parallel()
  179. const numChunks = 50
  180. const upstreamDelay = 10 * time.Millisecond
  181. const handlerDelay = 20 * time.Millisecond
  182. pr, pw := io.Pipe()
  183. go func() {
  184. defer pw.Close()
  185. for i := 0; i < numChunks; i++ {
  186. fmt.Fprintf(pw, "data: {\"id\":%d}\n", i)
  187. time.Sleep(upstreamDelay)
  188. }
  189. fmt.Fprint(pw, "data: [DONE]\n")
  190. }()
  191. recorder := httptest.NewRecorder()
  192. c, _ := gin.CreateTestContext(recorder)
  193. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  194. oldTimeout := constant.StreamingTimeout
  195. constant.StreamingTimeout = 30
  196. t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
  197. resp := &http.Response{Body: pr}
  198. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  199. var count atomic.Int64
  200. start := time.Now()
  201. done := make(chan struct{})
  202. go func() {
  203. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  204. time.Sleep(handlerDelay)
  205. count.Add(1)
  206. })
  207. close(done)
  208. }()
  209. select {
  210. case <-done:
  211. case <-time.After(15 * time.Second):
  212. t.Fatal("StreamScannerHandler did not complete in time")
  213. }
  214. elapsed := time.Since(start)
  215. assert.Equal(t, int64(numChunks), count.Load())
  216. coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
  217. t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
  218. assert.Less(t, elapsed, coupledTime*85/100,
  219. "decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
  220. }
  221. func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
  222. t.Parallel()
  223. const numChunks = 50
  224. body := buildSSEBody(numChunks)
  225. reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond}
  226. c, resp, info := setupStreamTest(t, reader)
  227. var count atomic.Int64
  228. start := time.Now()
  229. done := make(chan struct{})
  230. go func() {
  231. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  232. count.Add(1)
  233. })
  234. close(done)
  235. }()
  236. select {
  237. case <-done:
  238. case <-time.After(15 * time.Second):
  239. t.Fatal("timed out with slow upstream")
  240. }
  241. elapsed := time.Since(start)
  242. assert.Equal(t, int64(numChunks), count.Load())
  243. t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed)
  244. }
  245. // ---------- Ping tests ----------
  246. func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
  247. t.Parallel()
  248. setting := operation_setting.GetGeneralSetting()
  249. oldEnabled := setting.PingIntervalEnabled
  250. oldSeconds := setting.PingIntervalSeconds
  251. setting.PingIntervalEnabled = true
  252. setting.PingIntervalSeconds = 1
  253. t.Cleanup(func() {
  254. setting.PingIntervalEnabled = oldEnabled
  255. setting.PingIntervalSeconds = oldSeconds
  256. })
  257. pr, pw := io.Pipe()
  258. go func() {
  259. defer pw.Close()
  260. for i := 0; i < 7; i++ {
  261. fmt.Fprintf(pw, "data: chunk_%d\n", i)
  262. time.Sleep(500 * time.Millisecond)
  263. }
  264. fmt.Fprint(pw, "data: [DONE]\n")
  265. }()
  266. recorder := httptest.NewRecorder()
  267. c, _ := gin.CreateTestContext(recorder)
  268. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  269. oldTimeout := constant.StreamingTimeout
  270. constant.StreamingTimeout = 30
  271. t.Cleanup(func() {
  272. constant.StreamingTimeout = oldTimeout
  273. })
  274. resp := &http.Response{Body: pr}
  275. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  276. var count atomic.Int64
  277. done := make(chan struct{})
  278. go func() {
  279. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  280. count.Add(1)
  281. })
  282. close(done)
  283. }()
  284. select {
  285. case <-done:
  286. case <-time.After(15 * time.Second):
  287. t.Fatal("timed out waiting for stream to finish")
  288. }
  289. assert.Equal(t, int64(7), count.Load())
  290. body := recorder.Body.String()
  291. pingCount := strings.Count(body, ": PING")
  292. t.Logf("received %d pings in response body", pingCount)
  293. assert.GreaterOrEqual(t, pingCount, 2,
  294. "expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount)
  295. }
  296. func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
  297. t.Parallel()
  298. setting := operation_setting.GetGeneralSetting()
  299. oldEnabled := setting.PingIntervalEnabled
  300. oldSeconds := setting.PingIntervalSeconds
  301. setting.PingIntervalEnabled = true
  302. setting.PingIntervalSeconds = 1
  303. t.Cleanup(func() {
  304. setting.PingIntervalEnabled = oldEnabled
  305. setting.PingIntervalSeconds = oldSeconds
  306. })
  307. pr, pw := io.Pipe()
  308. go func() {
  309. defer pw.Close()
  310. for i := 0; i < 5; i++ {
  311. fmt.Fprintf(pw, "data: chunk_%d\n", i)
  312. time.Sleep(500 * time.Millisecond)
  313. }
  314. fmt.Fprint(pw, "data: [DONE]\n")
  315. }()
  316. recorder := httptest.NewRecorder()
  317. c, _ := gin.CreateTestContext(recorder)
  318. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  319. oldTimeout := constant.StreamingTimeout
  320. constant.StreamingTimeout = 30
  321. t.Cleanup(func() {
  322. constant.StreamingTimeout = oldTimeout
  323. })
  324. resp := &http.Response{Body: pr}
  325. info := &relaycommon.RelayInfo{
  326. DisablePing: true,
  327. ChannelMeta: &relaycommon.ChannelMeta{},
  328. }
  329. var count atomic.Int64
  330. done := make(chan struct{})
  331. go func() {
  332. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  333. count.Add(1)
  334. })
  335. close(done)
  336. }()
  337. select {
  338. case <-done:
  339. case <-time.After(15 * time.Second):
  340. t.Fatal("timed out")
  341. }
  342. assert.Equal(t, int64(5), count.Load())
  343. body := recorder.Body.String()
  344. pingCount := strings.Count(body, ": PING")
  345. assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
  346. }
  347. // ---------- StreamStatus integration ----------
  348. func TestStreamScannerHandler_StreamStatus_DoneReason(t *testing.T) {
  349. t.Parallel()
  350. body := buildSSEBody(10)
  351. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  352. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
  353. require.NotNil(t, info.StreamStatus)
  354. assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
  355. assert.Nil(t, info.StreamStatus.EndError)
  356. assert.True(t, info.StreamStatus.IsNormalEnd())
  357. assert.False(t, info.StreamStatus.HasErrors())
  358. }
  359. func TestStreamScannerHandler_StreamStatus_EOFWithoutDone(t *testing.T) {
  360. t.Parallel()
  361. var b strings.Builder
  362. for i := 0; i < 5; i++ {
  363. fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
  364. }
  365. c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
  366. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
  367. require.NotNil(t, info.StreamStatus)
  368. assert.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
  369. assert.True(t, info.StreamStatus.IsNormalEnd())
  370. }
  371. func TestStreamScannerHandler_StreamStatus_HandlerStop(t *testing.T) {
  372. t.Parallel()
  373. body := buildSSEBody(100)
  374. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  375. var count atomic.Int64
  376. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  377. n := count.Add(1)
  378. if n >= 10 {
  379. sr.Stop(fmt.Errorf("stop at 10"))
  380. }
  381. })
  382. require.NotNil(t, info.StreamStatus)
  383. assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
  384. assert.True(t, info.StreamStatus.HasErrors())
  385. }
  386. func TestStreamScannerHandler_StreamStatus_HandlerDone(t *testing.T) {
  387. t.Parallel()
  388. body := buildSSEBody(20)
  389. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  390. var count atomic.Int64
  391. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  392. n := count.Add(1)
  393. if n >= 5 {
  394. sr.Done()
  395. }
  396. })
  397. assert.Equal(t, int64(5), count.Load())
  398. require.NotNil(t, info.StreamStatus)
  399. assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
  400. assert.False(t, info.StreamStatus.HasErrors())
  401. }
  402. func TestStreamScannerHandler_StreamStatus_Timeout(t *testing.T) {
  403. // Not parallel: modifies global constant.StreamingTimeout
  404. oldTimeout := constant.StreamingTimeout
  405. constant.StreamingTimeout = 2
  406. t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
  407. pr, pw := io.Pipe()
  408. go func() {
  409. fmt.Fprint(pw, "data: {\"id\":1}\n")
  410. time.Sleep(10 * time.Second)
  411. pw.Close()
  412. }()
  413. recorder := httptest.NewRecorder()
  414. c, _ := gin.CreateTestContext(recorder)
  415. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  416. resp := &http.Response{Body: pr}
  417. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  418. done := make(chan struct{})
  419. go func() {
  420. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
  421. close(done)
  422. }()
  423. select {
  424. case <-done:
  425. case <-time.After(15 * time.Second):
  426. t.Fatal("timed out waiting for stream timeout")
  427. }
  428. require.NotNil(t, info.StreamStatus)
  429. assert.Equal(t, relaycommon.StreamEndReasonTimeout, info.StreamStatus.EndReason)
  430. assert.False(t, info.StreamStatus.IsNormalEnd())
  431. }
  432. func TestStreamScannerHandler_StreamStatus_SoftErrors(t *testing.T) {
  433. t.Parallel()
  434. body := buildSSEBody(10)
  435. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  436. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  437. sr.Error(fmt.Errorf("soft error for chunk"))
  438. })
  439. require.NotNil(t, info.StreamStatus)
  440. assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
  441. assert.True(t, info.StreamStatus.HasErrors())
  442. assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
  443. }
  444. func TestStreamScannerHandler_StreamStatus_MultipleErrorsPerChunk(t *testing.T) {
  445. t.Parallel()
  446. body := buildSSEBody(5)
  447. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  448. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  449. sr.Error(fmt.Errorf("error A"))
  450. sr.Error(fmt.Errorf("error B"))
  451. })
  452. require.NotNil(t, info.StreamStatus)
  453. assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
  454. assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
  455. }
  456. func TestStreamScannerHandler_StreamStatus_ErrorThenStop(t *testing.T) {
  457. t.Parallel()
  458. // Use a large body without [DONE] to avoid race between scanner's [DONE]
  459. // and handler's Stop on the sync.Once EndReason.
  460. var b strings.Builder
  461. for i := 0; i < 100; i++ {
  462. fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
  463. }
  464. c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
  465. var count atomic.Int64
  466. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  467. count.Add(1)
  468. sr.Error(fmt.Errorf("soft error"))
  469. sr.Stop(fmt.Errorf("fatal"))
  470. })
  471. assert.Equal(t, int64(1), count.Load())
  472. require.NotNil(t, info.StreamStatus)
  473. assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
  474. assert.Equal(t, 2, info.StreamStatus.TotalErrorCount())
  475. }
  476. func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
  477. t.Parallel()
  478. body := buildSSEBody(1)
  479. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  480. assert.Nil(t, info.StreamStatus)
  481. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
  482. assert.NotNil(t, info.StreamStatus)
  483. }
  484. func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
  485. t.Parallel()
  486. body := buildSSEBody(5)
  487. c, resp, info := setupStreamTest(t, strings.NewReader(body))
  488. info.StreamStatus = relaycommon.NewStreamStatus()
  489. info.StreamStatus.RecordError("pre-existing error")
  490. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
  491. assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
  492. assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
  493. }
  494. func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
  495. t.Parallel()
  496. setting := operation_setting.GetGeneralSetting()
  497. oldEnabled := setting.PingIntervalEnabled
  498. oldSeconds := setting.PingIntervalSeconds
  499. setting.PingIntervalEnabled = true
  500. setting.PingIntervalSeconds = 1
  501. t.Cleanup(func() {
  502. setting.PingIntervalEnabled = oldEnabled
  503. setting.PingIntervalSeconds = oldSeconds
  504. })
  505. pr, pw := io.Pipe()
  506. go func() {
  507. defer pw.Close()
  508. for i := 0; i < 10; i++ {
  509. fmt.Fprintf(pw, "data: chunk_%d\n", i)
  510. time.Sleep(500 * time.Millisecond)
  511. }
  512. fmt.Fprint(pw, "data: [DONE]\n")
  513. }()
  514. recorder := httptest.NewRecorder()
  515. c, _ := gin.CreateTestContext(recorder)
  516. c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  517. oldTimeout := constant.StreamingTimeout
  518. constant.StreamingTimeout = 30
  519. t.Cleanup(func() {
  520. constant.StreamingTimeout = oldTimeout
  521. })
  522. resp := &http.Response{Body: pr}
  523. info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
  524. var count atomic.Int64
  525. done := make(chan struct{})
  526. go func() {
  527. StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
  528. count.Add(1)
  529. })
  530. close(done)
  531. }()
  532. select {
  533. case <-done:
  534. case <-time.After(15 * time.Second):
  535. t.Fatal("timed out")
  536. }
  537. assert.Equal(t, int64(10), count.Load())
  538. body := recorder.Body.String()
  539. pingCount := strings.Count(body, ": PING")
  540. t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount)
  541. assert.GreaterOrEqual(t, pingCount, 3,
  542. "expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount)
  543. }