api_request_test.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package channel
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. relaycommon "github.com/QuantumNous/new-api/relay/common"
  7. "github.com/gin-gonic/gin"
  8. "github.com/stretchr/testify/require"
  9. )
  10. func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
  11. t.Parallel()
  12. gin.SetMode(gin.TestMode)
  13. recorder := httptest.NewRecorder()
  14. ctx, _ := gin.CreateTestContext(recorder)
  15. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  16. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  17. info := &relaycommon.RelayInfo{
  18. IsChannelTest: true,
  19. ChannelMeta: &relaycommon.ChannelMeta{
  20. HeadersOverride: map[string]any{
  21. "*": "",
  22. },
  23. },
  24. }
  25. headers, err := processHeaderOverride(info, ctx)
  26. require.NoError(t, err)
  27. require.Empty(t, headers)
  28. }
  29. func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
  30. t.Parallel()
  31. gin.SetMode(gin.TestMode)
  32. recorder := httptest.NewRecorder()
  33. ctx, _ := gin.CreateTestContext(recorder)
  34. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  35. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  36. info := &relaycommon.RelayInfo{
  37. IsChannelTest: true,
  38. ChannelMeta: &relaycommon.ChannelMeta{
  39. HeadersOverride: map[string]any{
  40. "X-Upstream-Trace": "{client_header:X-Trace-Id}",
  41. },
  42. },
  43. }
  44. headers, err := processHeaderOverride(info, ctx)
  45. require.NoError(t, err)
  46. _, ok := headers["x-upstream-trace"]
  47. require.False(t, ok)
  48. }
  49. func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
  50. t.Parallel()
  51. gin.SetMode(gin.TestMode)
  52. recorder := httptest.NewRecorder()
  53. ctx, _ := gin.CreateTestContext(recorder)
  54. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  55. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  56. info := &relaycommon.RelayInfo{
  57. IsChannelTest: false,
  58. ChannelMeta: &relaycommon.ChannelMeta{
  59. HeadersOverride: map[string]any{
  60. "X-Upstream-Trace": "{client_header:X-Trace-Id}",
  61. },
  62. },
  63. }
  64. headers, err := processHeaderOverride(info, ctx)
  65. require.NoError(t, err)
  66. require.Equal(t, "trace-123", headers["x-upstream-trace"])
  67. }
  68. func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) {
  69. t.Parallel()
  70. gin.SetMode(gin.TestMode)
  71. recorder := httptest.NewRecorder()
  72. ctx, _ := gin.CreateTestContext(recorder)
  73. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  74. info := &relaycommon.RelayInfo{
  75. IsChannelTest: false,
  76. UseRuntimeHeadersOverride: true,
  77. RuntimeHeadersOverride: map[string]any{
  78. "x-static": "runtime-value",
  79. "x-runtime": "runtime-only",
  80. },
  81. ChannelMeta: &relaycommon.ChannelMeta{
  82. HeadersOverride: map[string]any{
  83. "X-Static": "legacy-value",
  84. "X-Legacy": "legacy-only",
  85. },
  86. },
  87. }
  88. headers, err := processHeaderOverride(info, ctx)
  89. require.NoError(t, err)
  90. require.Equal(t, "runtime-value", headers["x-static"])
  91. require.Equal(t, "runtime-only", headers["x-runtime"])
  92. _, exists := headers["x-legacy"]
  93. require.False(t, exists)
  94. }
  95. func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
  96. t.Parallel()
  97. gin.SetMode(gin.TestMode)
  98. recorder := httptest.NewRecorder()
  99. ctx, _ := gin.CreateTestContext(recorder)
  100. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  101. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  102. ctx.Request.Header.Set("Accept-Encoding", "gzip")
  103. info := &relaycommon.RelayInfo{
  104. IsChannelTest: false,
  105. ChannelMeta: &relaycommon.ChannelMeta{
  106. HeadersOverride: map[string]any{
  107. "*": "",
  108. },
  109. },
  110. }
  111. headers, err := processHeaderOverride(info, ctx)
  112. require.NoError(t, err)
  113. require.Equal(t, "trace-123", headers["x-trace-id"])
  114. _, hasAcceptEncoding := headers["accept-encoding"]
  115. require.False(t, hasAcceptEncoding)
  116. }
  117. func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) {
  118. t.Parallel()
  119. gin.SetMode(gin.TestMode)
  120. recorder := httptest.NewRecorder()
  121. ctx, _ := gin.CreateTestContext(recorder)
  122. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
  123. ctx.Request.Header.Set("Originator", "Codex CLI")
  124. ctx.Request.Header.Set("Session_id", "sess-123")
  125. info := &relaycommon.RelayInfo{
  126. IsChannelTest: false,
  127. RequestHeaders: map[string]string{
  128. "Originator": "Codex CLI",
  129. "Session_id": "sess-123",
  130. },
  131. ChannelMeta: &relaycommon.ChannelMeta{
  132. ParamOverride: map[string]any{
  133. "operations": []any{
  134. map[string]any{
  135. "mode": "pass_headers",
  136. "value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"},
  137. },
  138. },
  139. },
  140. HeadersOverride: map[string]any{
  141. "X-Static": "legacy-value",
  142. },
  143. },
  144. }
  145. _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info)
  146. require.NoError(t, err)
  147. require.True(t, info.UseRuntimeHeadersOverride)
  148. require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"])
  149. require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"])
  150. _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"]
  151. require.False(t, exists)
  152. require.Equal(t, "legacy-value", info.RuntimeHeadersOverride["x-static"])
  153. headers, err := processHeaderOverride(info, ctx)
  154. require.NoError(t, err)
  155. require.Equal(t, "Codex CLI", headers["originator"])
  156. require.Equal(t, "sess-123", headers["session_id"])
  157. _, exists = headers["x-codex-beta-features"]
  158. require.False(t, exists)
  159. upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil)
  160. applyHeaderOverrideToRequest(upstreamReq, headers)
  161. require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator"))
  162. require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id"))
  163. require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features"))
  164. }