api_request_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package channel
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. relaycommon "github.com/QuantumNous/new-api/relay/common"
  7. "github.com/gin-gonic/gin"
  8. "github.com/stretchr/testify/require"
  9. )
  10. func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
  11. t.Parallel()
  12. gin.SetMode(gin.TestMode)
  13. recorder := httptest.NewRecorder()
  14. ctx, _ := gin.CreateTestContext(recorder)
  15. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  16. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  17. info := &relaycommon.RelayInfo{
  18. IsChannelTest: true,
  19. ChannelMeta: &relaycommon.ChannelMeta{
  20. HeadersOverride: map[string]any{
  21. "*": "",
  22. },
  23. },
  24. }
  25. headers, err := processHeaderOverride(info, ctx)
  26. require.NoError(t, err)
  27. require.Empty(t, headers)
  28. }
  29. func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
  30. t.Parallel()
  31. gin.SetMode(gin.TestMode)
  32. recorder := httptest.NewRecorder()
  33. ctx, _ := gin.CreateTestContext(recorder)
  34. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  35. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  36. info := &relaycommon.RelayInfo{
  37. IsChannelTest: true,
  38. ChannelMeta: &relaycommon.ChannelMeta{
  39. HeadersOverride: map[string]any{
  40. "X-Upstream-Trace": "{client_header:X-Trace-Id}",
  41. },
  42. },
  43. }
  44. headers, err := processHeaderOverride(info, ctx)
  45. require.NoError(t, err)
  46. _, ok := headers["X-Upstream-Trace"]
  47. require.False(t, ok)
  48. }
  49. func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
  50. t.Parallel()
  51. gin.SetMode(gin.TestMode)
  52. recorder := httptest.NewRecorder()
  53. ctx, _ := gin.CreateTestContext(recorder)
  54. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  55. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  56. info := &relaycommon.RelayInfo{
  57. IsChannelTest: false,
  58. ChannelMeta: &relaycommon.ChannelMeta{
  59. HeadersOverride: map[string]any{
  60. "X-Upstream-Trace": "{client_header:X-Trace-Id}",
  61. },
  62. },
  63. }
  64. headers, err := processHeaderOverride(info, ctx)
  65. require.NoError(t, err)
  66. require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
  67. }
  68. func TestProcessHeaderOverride_RuntimeOverrideMergesWithChannelOverride(t *testing.T) {
  69. t.Parallel()
  70. gin.SetMode(gin.TestMode)
  71. recorder := httptest.NewRecorder()
  72. ctx, _ := gin.CreateTestContext(recorder)
  73. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  74. info := &relaycommon.RelayInfo{
  75. IsChannelTest: false,
  76. UseRuntimeHeadersOverride: true,
  77. RuntimeHeadersOverride: map[string]any{
  78. "X-Static": "runtime-value",
  79. "X-Runtime": "runtime-only",
  80. },
  81. ChannelMeta: &relaycommon.ChannelMeta{
  82. HeadersOverride: map[string]any{
  83. "X-Static": "legacy-value",
  84. "X-Legacy": "legacy-only",
  85. },
  86. },
  87. }
  88. headers, err := processHeaderOverride(info, ctx)
  89. require.NoError(t, err)
  90. require.Equal(t, "runtime-value", headers["X-Static"])
  91. require.Equal(t, "runtime-only", headers["X-Runtime"])
  92. require.Equal(t, "legacy-only", headers["X-Legacy"])
  93. }
  94. func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
  95. t.Parallel()
  96. gin.SetMode(gin.TestMode)
  97. recorder := httptest.NewRecorder()
  98. ctx, _ := gin.CreateTestContext(recorder)
  99. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  100. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  101. ctx.Request.Header.Set("Accept-Encoding", "gzip")
  102. info := &relaycommon.RelayInfo{
  103. IsChannelTest: false,
  104. ChannelMeta: &relaycommon.ChannelMeta{
  105. HeadersOverride: map[string]any{
  106. "*": "",
  107. },
  108. },
  109. }
  110. headers, err := processHeaderOverride(info, ctx)
  111. require.NoError(t, err)
  112. require.Equal(t, "trace-123", headers["X-Trace-Id"])
  113. _, hasAcceptEncoding := headers["Accept-Encoding"]
  114. require.False(t, hasAcceptEncoding)
  115. }
  116. func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) {
  117. t.Parallel()
  118. gin.SetMode(gin.TestMode)
  119. recorder := httptest.NewRecorder()
  120. ctx, _ := gin.CreateTestContext(recorder)
  121. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
  122. ctx.Request.Header.Set("Originator", "Codex CLI")
  123. ctx.Request.Header.Set("Session_id", "sess-123")
  124. info := &relaycommon.RelayInfo{
  125. IsChannelTest: false,
  126. RequestHeaders: map[string]string{
  127. "Originator": "Codex CLI",
  128. "Session_id": "sess-123",
  129. },
  130. ChannelMeta: &relaycommon.ChannelMeta{
  131. ParamOverride: map[string]any{
  132. "operations": []any{
  133. map[string]any{
  134. "mode": "pass_headers",
  135. "value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"},
  136. },
  137. },
  138. },
  139. HeadersOverride: map[string]any{
  140. "X-Static": "legacy-value",
  141. },
  142. },
  143. }
  144. _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info)
  145. require.NoError(t, err)
  146. require.True(t, info.UseRuntimeHeadersOverride)
  147. require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["Originator"])
  148. require.Equal(t, "sess-123", info.RuntimeHeadersOverride["Session_id"])
  149. _, exists := info.RuntimeHeadersOverride["X-Codex-Beta-Features"]
  150. require.False(t, exists)
  151. headers, err := processHeaderOverride(info, ctx)
  152. require.NoError(t, err)
  153. require.Equal(t, "Codex CLI", headers["Originator"])
  154. require.Equal(t, "sess-123", headers["Session_id"])
  155. _, exists = headers["X-Codex-Beta-Features"]
  156. require.False(t, exists)
  157. upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil)
  158. applyHeaderOverrideToRequest(upstreamReq, headers)
  159. require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator"))
  160. require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id"))
  161. require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features"))
  162. }