|
|
@@ -56,8 +56,6 @@ func buildSSEBody(n int) string {
|
|
|
return b.String()
|
|
|
}
|
|
|
|
|
|
-// slowReader wraps a reader and injects a delay before each Read call,
|
|
|
-// simulating a slow upstream that trickles data.
|
|
|
type slowReader struct {
|
|
|
r io.Reader
|
|
|
delay time.Duration
|
|
|
@@ -79,7 +77,7 @@ func TestStreamScannerHandler_NilInputs(t *testing.T) {
|
|
|
|
|
|
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
|
|
|
|
|
- StreamScannerHandler(c, nil, info, func(data string) bool { return true })
|
|
|
+ StreamScannerHandler(c, nil, info, func(data string, sr *StreamResult) {})
|
|
|
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
|
|
}
|
|
|
|
|
|
@@ -89,9 +87,8 @@ func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
|
|
c, resp, info := setupStreamTest(t, strings.NewReader(""))
|
|
|
|
|
|
var called atomic.Bool
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
called.Store(true)
|
|
|
- return true
|
|
|
})
|
|
|
|
|
|
assert.False(t, called.Load(), "handler should not be called for empty body")
|
|
|
@@ -105,9 +102,8 @@ func TestStreamScannerHandler_1000Chunks(t *testing.T) {
|
|
|
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
|
|
|
var count atomic.Int64
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
|
|
|
assert.Equal(t, int64(numChunks), count.Load())
|
|
|
@@ -124,9 +120,8 @@ func TestStreamScannerHandler_10000Chunks(t *testing.T) {
|
|
|
var count atomic.Int64
|
|
|
start := time.Now()
|
|
|
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
|
|
|
elapsed := time.Since(start)
|
|
|
@@ -145,11 +140,10 @@ func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
|
|
|
var mu sync.Mutex
|
|
|
received := make([]string, 0, numChunks)
|
|
|
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
mu.Lock()
|
|
|
received = append(received, data)
|
|
|
mu.Unlock()
|
|
|
- return true
|
|
|
})
|
|
|
|
|
|
require.Equal(t, numChunks, len(received))
|
|
|
@@ -166,31 +160,32 @@ func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
|
|
|
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
|
|
|
var count atomic.Int64
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
|
|
|
assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
|
|
|
}
|
|
|
|
|
|
-func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
|
|
|
+func TestStreamScannerHandler_StopStopsStream(t *testing.T) {
|
|
|
t.Parallel()
|
|
|
|
|
|
const numChunks = 200
|
|
|
body := buildSSEBody(numChunks)
|
|
|
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
|
|
|
- const failAt = 50
|
|
|
+ const stopAt int64 = 50
|
|
|
var count atomic.Int64
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
n := count.Add(1)
|
|
|
- return n < failAt
|
|
|
+ if n >= stopAt {
|
|
|
+ sr.Stop(fmt.Errorf("fatal at %d", n))
|
|
|
+ }
|
|
|
})
|
|
|
|
|
|
- // The worker stops at failAt; the scanner may have read ahead,
|
|
|
- // but the handler should not be called beyond failAt.
|
|
|
- assert.Equal(t, int64(failAt), count.Load())
|
|
|
+ assert.Equal(t, stopAt, count.Load())
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
|
|
}
|
|
|
|
|
|
func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
|
|
@@ -210,9 +205,8 @@ func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
|
|
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
|
|
|
|
|
var count atomic.Int64
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
|
|
|
assert.Equal(t, int64(100), count.Load())
|
|
|
@@ -225,25 +219,18 @@ func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
|
|
|
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
|
|
|
var got string
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
got = data
|
|
|
- return true
|
|
|
})
|
|
|
|
|
|
assert.Equal(t, "{\"trimmed\":true}", got)
|
|
|
}
|
|
|
|
|
|
-// ---------- Decoupling: scanner not blocked by slow handler ----------
|
|
|
+// ---------- Decoupling ----------
|
|
|
|
|
|
func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
|
|
t.Parallel()
|
|
|
|
|
|
- // Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
|
|
|
- // If the scanner were synchronously coupled to the handler, total time would be
|
|
|
- // ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
|
|
|
- // With decoupling, total time should be closer to
|
|
|
- // ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
|
|
|
- // because the scanner reads ahead into the buffer while the handler processes.
|
|
|
const numChunks = 50
|
|
|
const upstreamDelay = 10 * time.Millisecond
|
|
|
const handlerDelay = 20 * time.Millisecond
|
|
|
@@ -273,10 +260,9 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
|
|
start := time.Now()
|
|
|
done := make(chan struct{})
|
|
|
go func() {
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
time.Sleep(handlerDelay)
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
close(done)
|
|
|
}()
|
|
|
@@ -293,7 +279,6 @@ func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
|
|
coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
|
|
|
t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
|
|
|
|
|
|
- // If decoupled, elapsed should be well under the coupled estimate.
|
|
|
assert.Less(t, elapsed, coupledTime*85/100,
|
|
|
"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
|
|
|
}
|
|
|
@@ -311,9 +296,8 @@ func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
|
|
|
|
|
|
done := make(chan struct{})
|
|
|
go func() {
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
close(done)
|
|
|
}()
|
|
|
@@ -344,8 +328,6 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
|
|
setting.PingIntervalSeconds = oldSeconds
|
|
|
})
|
|
|
|
|
|
- // Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
|
|
|
- // The ping interval is 1s, so we should see at least 2 pings.
|
|
|
pr, pw := io.Pipe()
|
|
|
go func() {
|
|
|
defer pw.Close()
|
|
|
@@ -372,9 +354,8 @@ func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
|
|
var count atomic.Int64
|
|
|
done := make(chan struct{})
|
|
|
go func() {
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
close(done)
|
|
|
}()
|
|
|
@@ -436,9 +417,8 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
|
|
var count atomic.Int64
|
|
|
done := make(chan struct{})
|
|
|
go func() {
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
close(done)
|
|
|
}()
|
|
|
@@ -456,6 +436,199 @@ func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
|
|
assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
|
|
|
}
|
|
|
|
|
|
+// ---------- StreamStatus integration ----------
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_DoneReason(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ body := buildSSEBody(10)
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
+
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
|
|
+
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
|
|
+ assert.Nil(t, info.StreamStatus.EndError)
|
|
|
+ assert.True(t, info.StreamStatus.IsNormalEnd())
|
|
|
+ assert.False(t, info.StreamStatus.HasErrors())
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_EOFWithoutDone(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ var b strings.Builder
|
|
|
+ for i := 0; i < 5; i++ {
|
|
|
+ fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
|
|
|
+ }
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
|
|
+
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
|
|
+
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonEOF, info.StreamStatus.EndReason)
|
|
|
+ assert.True(t, info.StreamStatus.IsNormalEnd())
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_HandlerStop(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ body := buildSSEBody(100)
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
+
|
|
|
+ var count atomic.Int64
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
+ n := count.Add(1)
|
|
|
+ if n >= 10 {
|
|
|
+ sr.Stop(fmt.Errorf("stop at 10"))
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
|
|
+ assert.True(t, info.StreamStatus.HasErrors())
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_HandlerDone(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ body := buildSSEBody(20)
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
+
|
|
|
+ var count atomic.Int64
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
+ n := count.Add(1)
|
|
|
+ if n >= 5 {
|
|
|
+ sr.Done()
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ assert.Equal(t, int64(5), count.Load())
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
|
|
+ assert.False(t, info.StreamStatus.HasErrors())
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_Timeout(t *testing.T) {
|
|
|
+ // Not parallel: modifies global constant.StreamingTimeout
|
|
|
+ oldTimeout := constant.StreamingTimeout
|
|
|
+ constant.StreamingTimeout = 2
|
|
|
+ t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
|
|
+
|
|
|
+ pr, pw := io.Pipe()
|
|
|
+ go func() {
|
|
|
+ fmt.Fprint(pw, "data: {\"id\":1}\n")
|
|
|
+ time.Sleep(10 * time.Second)
|
|
|
+ pw.Close()
|
|
|
+ }()
|
|
|
+
|
|
|
+ recorder := httptest.NewRecorder()
|
|
|
+ c, _ := gin.CreateTestContext(recorder)
|
|
|
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
|
+
|
|
|
+ resp := &http.Response{Body: pr}
|
|
|
+ info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
|
|
+
|
|
|
+ done := make(chan struct{})
|
|
|
+ go func() {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
|
|
+ close(done)
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-done:
|
|
|
+ case <-time.After(15 * time.Second):
|
|
|
+ t.Fatal("timed out waiting for stream timeout")
|
|
|
+ }
|
|
|
+
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonTimeout, info.StreamStatus.EndReason)
|
|
|
+ assert.False(t, info.StreamStatus.IsNormalEnd())
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_SoftErrors(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ body := buildSSEBody(10)
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
+
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
+ sr.Error(fmt.Errorf("soft error for chunk"))
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
|
|
+ assert.True(t, info.StreamStatus.HasErrors())
|
|
|
+ assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_MultipleErrorsPerChunk(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ body := buildSSEBody(5)
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
+
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
+ sr.Error(fmt.Errorf("error A"))
|
|
|
+ sr.Error(fmt.Errorf("error B"))
|
|
|
+ })
|
|
|
+
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
|
|
+ assert.Equal(t, 10, info.StreamStatus.TotalErrorCount())
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_ErrorThenStop(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ // Use a large body without [DONE] to avoid race between scanner's [DONE]
|
|
|
+ // and handler's Stop on the sync.Once EndReason.
|
|
|
+ var b strings.Builder
|
|
|
+ for i := 0; i < 100; i++ {
|
|
|
+ fmt.Fprintf(&b, "data: {\"id\":%d}\n", i)
|
|
|
+ }
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
|
|
+
|
|
|
+ var count atomic.Int64
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
+ count.Add(1)
|
|
|
+ sr.Error(fmt.Errorf("soft error"))
|
|
|
+ sr.Stop(fmt.Errorf("fatal"))
|
|
|
+ })
|
|
|
+
|
|
|
+ assert.Equal(t, int64(1), count.Load())
|
|
|
+ require.NotNil(t, info.StreamStatus)
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonHandlerStop, info.StreamStatus.EndReason)
|
|
|
+ assert.Equal(t, 2, info.StreamStatus.TotalErrorCount())
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_InitializedIfNil(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ body := buildSSEBody(1)
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
+
|
|
|
+ assert.Nil(t, info.StreamStatus)
|
|
|
+
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
|
|
+
|
|
|
+ assert.NotNil(t, info.StreamStatus)
|
|
|
+}
|
|
|
+
|
|
|
+func TestStreamScannerHandler_StreamStatus_PreInitialized(t *testing.T) {
|
|
|
+ t.Parallel()
|
|
|
+
|
|
|
+ body := buildSSEBody(5)
|
|
|
+ c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
|
|
+
|
|
|
+ info.StreamStatus = relaycommon.NewStreamStatus()
|
|
|
+ info.StreamStatus.RecordError("pre-existing error")
|
|
|
+
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {})
|
|
|
+
|
|
|
+ assert.Equal(t, relaycommon.StreamEndReasonDone, info.StreamStatus.EndReason)
|
|
|
+ assert.Equal(t, 1, info.StreamStatus.TotalErrorCount())
|
|
|
+}
|
|
|
+
|
|
|
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
|
|
t.Parallel()
|
|
|
|
|
|
@@ -469,9 +642,6 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
|
|
setting.PingIntervalSeconds = oldSeconds
|
|
|
})
|
|
|
|
|
|
- // Slow upstream + slow handler. Total stream takes ~5 seconds.
|
|
|
- // The ping goroutine stays alive as long as the scanner is reading,
|
|
|
- // so pings should fire between data writes.
|
|
|
pr, pw := io.Pipe()
|
|
|
go func() {
|
|
|
defer pw.Close()
|
|
|
@@ -498,9 +668,8 @@ func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
|
|
var count atomic.Int64
|
|
|
done := make(chan struct{})
|
|
|
go func() {
|
|
|
- StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ StreamScannerHandler(c, resp, info, func(data string, sr *StreamResult) {
|
|
|
count.Add(1)
|
|
|
- return true
|
|
|
})
|
|
|
close(done)
|
|
|
}()
|