package helper import ( "fmt" "io" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "time" "github.com/QuantumNous/new-api/constant" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func init() { gin.SetMode(gin.TestMode) } func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) { t.Helper() oldTimeout := constant.StreamingTimeout constant.StreamingTimeout = 30 t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) resp := &http.Response{ Body: io.NopCloser(body), } info := &relaycommon.RelayInfo{ ChannelMeta: &relaycommon.ChannelMeta{}, } return c, resp, info } func buildSSEBody(n int) string { var b strings.Builder for i := 0; i < n; i++ { fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i) } b.WriteString("data: [DONE]\n") return b.String() } type slowReader struct { r io.Reader delay time.Duration } func (s *slowReader) Read(p []byte) (int, error) { time.Sleep(s.delay) return s.r.Read(p) } // ---------- Basic correctness ---------- func TestStreamScannerHandler_NilInputs(t *testing.T) { t.Parallel() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} StreamScannerHandler(c, nil, info, func(data string, sr *StreamResult) {}) StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil) } func TestStreamScannerHandler_EmptyBody(t *testing.T) { t.Parallel() c, resp, info := setupStreamTest(t, strings.NewReader("")) var called atomic.Bool StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { called.Store(true) }) assert.False(t, called.Load(), "handler should not be called for empty body") } func TestStreamScannerHandler_1000Chunks(t *testing.T) { t.Parallel() const numChunks = 1000 body := buildSSEBody(numChunks) c, resp, info := setupStreamTest(t, strings.NewReader(body)) var count atomic.Int64 StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) }) assert.Equal(t, int64(numChunks), count.Load()) assert.Equal(t, numChunks, info.ReceivedResponseCount) } func TestStreamScannerHandler_10000Chunks(t *testing.T) { t.Parallel() const numChunks = 10000 body := buildSSEBody(numChunks) c, resp, info := setupStreamTest(t, strings.NewReader(body)) var count atomic.Int64 start := time.Now() StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) }) elapsed := time.Since(start) assert.Equal(t, int64(numChunks), count.Load()) assert.Equal(t, numChunks, info.ReceivedResponseCount) t.Logf("10000 chunks processed in %v", elapsed) } func TestStreamScannerHandler_OrderPreserved(t *testing.T) { t.Parallel() const numChunks = 500 body := buildSSEBody(numChunks) c, resp, info := setupStreamTest(t, strings.NewReader(body)) var mu sync.Mutex received := make([]string, 0, numChunks) StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { mu.Lock() received = append(received, data) mu.Unlock() }) require.Equal(t, numChunks, len(received)) for i := 0; i < numChunks; i++ { expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i) assert.Equal(t, expected, received[i], "chunk %d out of order", i) } } func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) { t.Parallel() body := buildSSEBody(50) + "data: should_not_appear\n" c, resp, info := setupStreamTest(t, strings.NewReader(body)) var count atomic.Int64 StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) }) assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed") } func TestStreamScannerHandler_StopStopsStream(t *testing.T) { t.Parallel() const numChunks = 200 body := buildSSEBody(numChunks) c, resp, info := setupStreamTest(t, strings.NewReader(body)) const stopAt int64 = 50 var count atomic.Int64 StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { n := count.Add(1) if n >= stopAt { sr.Stop(fmt.Errorf("fatal at %d", n)) } }) assert.Equal(t, stopAt, count.Load()) require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) } func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) { t.Parallel() var b strings.Builder b.WriteString(": comment line\n") b.WriteString("event: message\n") b.WriteString("id: 12345\n") b.WriteString("retry: 5000\n") for i := 0; i < 100; i++ { fmt.Fprintf(&b, "data: payload_%d\n", i) b.WriteString(": interleaved comment\n") } b.WriteString("data: [DONE]\n") c, resp, info := setupStreamTest(t, strings.NewReader(b.String())) var count atomic.Int64 StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) }) assert.Equal(t, int64(100), count.Load()) } func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) { t.Parallel() body := "data: {\"trimmed\":true} \ndata: [DONE]\n" c, resp, info := setupStreamTest(t, strings.NewReader(body)) var got string StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { got = data }) assert.Equal(t, "{\"trimmed\":true}", got) } // ---------- Decoupling ---------- func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) { t.Parallel() const numChunks = 50 const upstreamDelay = 10 * time.Millisecond const handlerDelay = 20 * time.Millisecond pr, pw := io.Pipe() go func() { defer pw.Close() for i := 0; i < numChunks; i++ { fmt.Fprintf(pw, "data: {\"id\":%d}\n", i) time.Sleep(upstreamDelay) } fmt.Fprint(pw, "data: [DONE]\n") }() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) oldTimeout := constant.StreamingTimeout constant.StreamingTimeout = 30 t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) resp := &http.Response{Body: pr} info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} var count atomic.Int64 start := time.Now() done := make(chan struct{}) go func() { StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { time.Sleep(handlerDelay) count.Add(1) }) close(done) }() select { case <-done: case <-time.After(15 * time.Second): t.Fatal("StreamScannerHandler did not complete in time") } elapsed := time.Since(start) assert.Equal(t, int64(numChunks), count.Load()) coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay) t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime) assert.Less(t, elapsed, coupledTime*85/100, "decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime) } func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) { t.Parallel() const numChunks = 50 body := buildSSEBody(numChunks) reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond} c, resp, info := setupStreamTest(t, reader) var count atomic.Int64 start := time.Now() done := make(chan struct{}) go func() { StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) }) close(done) }() select { case <-done: case <-time.After(15 * time.Second): t.Fatal("timed out with slow upstream") } elapsed := time.Since(start) assert.Equal(t, int64(numChunks), count.Load()) t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed) } // ---------- Ping tests ---------- func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) { t.Parallel() setting := operation_setting.GetGeneralSetting() oldEnabled := setting.PingIntervalEnabled oldSeconds := setting.PingIntervalSeconds setting.PingIntervalEnabled = true setting.PingIntervalSeconds = 1 t.Cleanup(func() { setting.PingIntervalEnabled = oldEnabled setting.PingIntervalSeconds = oldSeconds }) pr, pw := io.Pipe() go func() { defer pw.Close() for i := 0; i < 7; i++ { fmt.Fprintf(pw, "data: chunk_%d\n", i) time.Sleep(500 * time.Millisecond) } fmt.Fprint(pw, "data: [DONE]\n") }() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) oldTimeout := constant.StreamingTimeout constant.StreamingTimeout = 30 t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) resp := &http.Response{Body: pr} info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} var count atomic.Int64 done := make(chan struct{}) go func() { StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) }) close(done) }() select { case <-done: case <-time.After(15 * time.Second): t.Fatal("timed out waiting for stream to finish") } assert.Equal(t, int64(7), count.Load()) body := recorder.Body.String() pingCount := strings.Count(body, ": PING") t.Logf("received %d pings in response body", pingCount) assert.GreaterOrEqual(t, pingCount, 2, "expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount) } func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) { t.Parallel() setting := operation_setting.GetGeneralSetting() oldEnabled := setting.PingIntervalEnabled oldSeconds := setting.PingIntervalSeconds setting.PingIntervalEnabled = true setting.PingIntervalSeconds = 1 t.Cleanup(func() { setting.PingIntervalEnabled = oldEnabled setting.PingIntervalSeconds = oldSeconds }) pr, pw := io.Pipe() go func() { defer pw.Close() for i := 0; i < 5; i++ { fmt.Fprintf(pw, "data: chunk_%d\n", i) time.Sleep(500 * time.Millisecond) } fmt.Fprint(pw, "data: [DONE]\n") }() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) oldTimeout := constant.StreamingTimeout constant.StreamingTimeout = 30 t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) resp := &http.Response{Body: pr} info := &relaycommon.RelayInfo{ DisablePing: true, ChannelMeta: &relaycommon.ChannelMeta{}, } var count atomic.Int64 done := make(chan struct{}) go func() { StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) }) close(done) }() select { case <-done: case <-time.After(15 * time.Second): t.Fatal("timed out") } assert.Equal(t, int64(5), count.Load()) body := recorder.Body.String() pingCount := strings.Count(body, ": PING") assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true") } // ---------- StreamStatus integration ---------- func TestStreamScannerHandler_StreamStatus_DoneReason(t *testing.T) { t.Parallel() body := buildSSEBody(10) c, resp, info := setupStreamTest(t, strings.NewReader(body)) StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) assert.Nil(t, info.StreamStatus.EndError) assert.True(t, info.StreamStatus.IsNormalEnd()) assert.False(t, info.StreamStatus.HasErrors()) } func TestStreamScannerHandler_StreamStatus_EOFWithoutDone(t *testing.T) { t.Parallel() var b strings.Builder for i := 0; i < 5; i++ { fmt.Fprintf(&b, "data: {\"id\":%d}\n", i) } c, resp, info := setupStreamTest(t, strings.NewReader(b.String())) StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason) assert.True(t, info.StreamStatus.IsNormalEnd()) } func TestStreamScannerHandler_StreamStatus_HandlerStop(t *testing.T) { t.Parallel() body := buildSSEBody(100) c, resp, info := setupStreamTest(t, strings.NewReader(body)) var count atomic.Int64 StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { n := count.Add(1) if n >= 10 { sr.Stop(fmt.Errorf("stop at 10")) } }) require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) assert.True(t, info.StreamStatus.HasErrors()) } func TestStreamScannerHandler_StreamStatus_HandlerDone(t *testing.T) { t.Parallel() body := buildSSEBody(20) c, resp, info := setupStreamTest(t, strings.NewReader(body)) var count atomic.Int64 StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { n := count.Add(1) if n >= 5 { sr.Done() } }) assert.Equal(t, int64(5), count.Load()) require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) assert.False(t, info.StreamStatus.HasErrors()) } func TestStreamScannerHandler_StreamStatus_Timeout(t *testing.T) { // Not parallel: modifies global constant.StreamingTimeout oldTimeout := constant.StreamingTimeout constant.StreamingTimeout = 2 t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) pr, pw := io.Pipe() go func() { fmt.Fprint(pw, "data: {\"id\":1}\n") time.Sleep(10 * time.Second) pw.Close() }() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) resp := &http.Response{Body: pr} info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} done := make(chan struct{}) go func() { StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) close(done) }() select { case <-done: case <-time.After(15 * time.Second): t.Fatal("timed out waiting for stream timeout") } require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonTimeout, info.StreamStatus.EndReason) assert.False(t, info.StreamStatus.IsNormalEnd()) } func TestStreamScannerHandler_StreamStatus_SoftErrors(t *testing.T) { t.Parallel() body := buildSSEBody(10) c, resp, info := setupStreamTest(t, strings.NewReader(body)) StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { sr.Error(fmt.Errorf("soft error for chunk")) }) require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) assert.True(t, info.StreamStatus.HasErrors()) assert.Equal(t, 10, info.StreamStatus.TotalErrorCount()) } func TestStreamScannerHandler_StreamStatus_MultipleErrorsPerChunk(t *testing.T) { t.Parallel() body := buildSSEBody(5) c, resp, info := setupStreamTest(t, strings.NewReader(body)) StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { sr.Error(fmt.Errorf("error A")) sr.Error(fmt.Errorf("error B")) }) require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) assert.Equal(t, 10, info.StreamStatus.TotalErrorCount()) } func TestStreamScannerHandler_StreamStatus_ErrorThenStop(t *testing.T) { t.Parallel() // Use a large body without [DONE] to avoid race between scanner's [DONE] // and handler's Stop on the sync.Once EndReason. var b strings.Builder for i := 0; i < 100; i++ { fmt.Fprintf(&b, "data: {\"id\":%d}\n", i) } c, resp, info := setupStreamTest(t, strings.NewReader(b.String())) var count atomic.Int64 StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) sr.Error(fmt.Errorf("soft error")) sr.Stop(fmt.Errorf("fatal")) }) assert.Equal(t, int64(1), count.Load()) require.NotNil(t, info.StreamStatus) assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason) assert.Equal(t, 2, info.StreamStatus.TotalErrorCount()) } func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) { t.Parallel() body := buildSSEBody(1) c, resp, info := setupStreamTest(t, strings.NewReader(body)) assert.Nil(t, info.StreamStatus) StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) assert.NotNil(t, info.StreamStatus) } func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) { t.Parallel() body := buildSSEBody(5) c, resp, info := setupStreamTest(t, strings.NewReader(body)) info.StreamStatus = relaycommon.NewStreamStatus() info.StreamStatus.RecordError("pre-existing error") StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {}) assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason) assert.Equal(t, 1, info.StreamStatus.TotalErrorCount()) } func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) { t.Parallel() setting := operation_setting.GetGeneralSetting() oldEnabled := setting.PingIntervalEnabled oldSeconds := setting.PingIntervalSeconds setting.PingIntervalEnabled = true setting.PingIntervalSeconds = 1 t.Cleanup(func() { setting.PingIntervalEnabled = oldEnabled setting.PingIntervalSeconds = oldSeconds }) pr, pw := io.Pipe() go func() { defer pw.Close() for i := 0; i < 10; i++ { fmt.Fprintf(pw, "data: chunk_%d\n", i) time.Sleep(500 * time.Millisecond) } fmt.Fprint(pw, "data: [DONE]\n") }() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) oldTimeout := constant.StreamingTimeout constant.StreamingTimeout = 30 t.Cleanup(func() { constant.StreamingTimeout = oldTimeout }) resp := &http.Response{Body: pr} info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}} var count atomic.Int64 done := make(chan struct{}) go func() { StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) { count.Add(1) }) close(done) }() select { case <-done: case <-time.After(15 * time.Second): t.Fatal("timed out") } assert.Equal(t, int64(10), count.Load()) body := recorder.Body.String() pingCount := strings.Count(body, ": PING") t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount) assert.GreaterOrEqual(t, pingCount, 3, "expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount) }