| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- package channel
- import (
- "net/http"
- "net/http/httptest"
- "testing"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
- )
- func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- ctx.Request.Header.Set("X-Trace-Id", "trace-123")
- info := &relaycommon.RelayInfo{
- IsChannelTest: true,
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "*": "",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Empty(t, headers)
- }
- func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- ctx.Request.Header.Set("X-Trace-Id", "trace-123")
- info := &relaycommon.RelayInfo{
- IsChannelTest: true,
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "X-Upstream-Trace": "{client_header:X-Trace-Id}",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- _, ok := headers["X-Upstream-Trace"]
- require.False(t, ok)
- }
- func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- ctx.Request.Header.Set("X-Trace-Id", "trace-123")
- info := &relaycommon.RelayInfo{
- IsChannelTest: false,
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "X-Upstream-Trace": "{client_header:X-Trace-Id}",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
- }
- func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- info := &relaycommon.RelayInfo{
- IsChannelTest: false,
- UseRuntimeHeadersOverride: true,
- RuntimeHeadersOverride: map[string]any{
- "X-Static": "runtime-value",
- "X-Runtime": "runtime-only",
- },
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "X-Static": "legacy-value",
- "X-Legacy": "legacy-only",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Equal(t, "runtime-value", headers["X-Static"])
- require.Equal(t, "runtime-only", headers["X-Runtime"])
- _, ok := headers["X-Legacy"]
- require.False(t, ok)
- }
- func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- ctx.Request.Header.Set("X-Trace-Id", "trace-123")
- ctx.Request.Header.Set("Accept-Encoding", "gzip")
- info := &relaycommon.RelayInfo{
- IsChannelTest: false,
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "*": "",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Equal(t, "trace-123", headers["X-Trace-Id"])
- _, hasAcceptEncoding := headers["Accept-Encoding"]
- require.False(t, hasAcceptEncoding)
- }
|