| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- package channel
- import (
- "net/http"
- "net/http/httptest"
- "testing"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
- )
- func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- ctx.Request.Header.Set("X-Trace-Id", "trace-123")
- info := &relaycommon.RelayInfo{
- IsChannelTest: true,
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "*": "",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Empty(t, headers)
- }
- func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- ctx.Request.Header.Set("X-Trace-Id", "trace-123")
- info := &relaycommon.RelayInfo{
- IsChannelTest: true,
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "X-Upstream-Trace": "{client_header:X-Trace-Id}",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- _, ok := headers["X-Upstream-Trace"]
- require.False(t, ok)
- }
- func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- ctx.Request.Header.Set("X-Trace-Id", "trace-123")
- info := &relaycommon.RelayInfo{
- IsChannelTest: false,
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "X-Upstream-Trace": "{client_header:X-Trace-Id}",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
- }
- func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- info := &relaycommon.RelayInfo{
- IsChannelTest: false,
- UseRuntimeHeadersOverride: true,
- RuntimeHeadersOverride: map[string]any{
- "X-Static": "runtime-value",
- "X-Runtime": "runtime-only",
- },
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "X-Static": "legacy-value",
- "X-Legacy": "legacy-only",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Equal(t, "runtime-value", headers["X-Static"])
- require.Equal(t, "runtime-only", headers["X-Runtime"])
- _, ok := headers["X-Legacy"]
- require.False(t, ok)
- }
- func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
- ctx.Request.Header.Set("X-Trace-Id", "trace-123")
- ctx.Request.Header.Set("Accept-Encoding", "gzip")
- info := &relaycommon.RelayInfo{
- IsChannelTest: false,
- ChannelMeta: &relaycommon.ChannelMeta{
- HeadersOverride: map[string]any{
- "*": "",
- },
- },
- }
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Equal(t, "trace-123", headers["X-Trace-Id"])
- _, hasAcceptEncoding := headers["Accept-Encoding"]
- require.False(t, hasAcceptEncoding)
- }
- func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) {
- t.Parallel()
- gin.SetMode(gin.TestMode)
- recorder := httptest.NewRecorder()
- ctx, _ := gin.CreateTestContext(recorder)
- ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
- ctx.Request.Header.Set("Originator", "Codex CLI")
- ctx.Request.Header.Set("Session_id", "sess-123")
- info := &relaycommon.RelayInfo{
- IsChannelTest: false,
- RequestHeaders: map[string]string{
- "Originator": "Codex CLI",
- "Session_id": "sess-123",
- },
- ChannelMeta: &relaycommon.ChannelMeta{
- ParamOverride: map[string]any{
- "operations": []any{
- map[string]any{
- "mode": "pass_headers",
- "value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"},
- },
- },
- },
- HeadersOverride: map[string]any{
- "X-Static": "legacy-value",
- },
- },
- }
- _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info)
- require.NoError(t, err)
- require.True(t, info.UseRuntimeHeadersOverride)
- require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["Originator"])
- require.Equal(t, "sess-123", info.RuntimeHeadersOverride["Session_id"])
- _, exists := info.RuntimeHeadersOverride["X-Codex-Beta-Features"]
- require.False(t, exists)
- headers, err := processHeaderOverride(info, ctx)
- require.NoError(t, err)
- require.Equal(t, "Codex CLI", headers["Originator"])
- require.Equal(t, "sess-123", headers["Session_id"])
- _, exists = headers["X-Codex-Beta-Features"]
- require.False(t, exists)
- upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil)
- applyHeaderOverrideToRequest(upstreamReq, headers)
- require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator"))
- require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id"))
- require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features"))
- }
|