api_request_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package channel
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. relaycommon "github.com/QuantumNous/new-api/relay/common"
  7. "github.com/gin-gonic/gin"
  8. "github.com/stretchr/testify/require"
  9. )
  10. func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) {
  11. t.Parallel()
  12. gin.SetMode(gin.TestMode)
  13. recorder := httptest.NewRecorder()
  14. ctx, _ := gin.CreateTestContext(recorder)
  15. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  16. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  17. info := &relaycommon.RelayInfo{
  18. IsChannelTest: true,
  19. ChannelMeta: &relaycommon.ChannelMeta{
  20. HeadersOverride: map[string]any{
  21. "*": "",
  22. },
  23. },
  24. }
  25. headers, err := processHeaderOverride(info, ctx)
  26. require.NoError(t, err)
  27. require.Empty(t, headers)
  28. }
  29. func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) {
  30. t.Parallel()
  31. gin.SetMode(gin.TestMode)
  32. recorder := httptest.NewRecorder()
  33. ctx, _ := gin.CreateTestContext(recorder)
  34. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  35. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  36. info := &relaycommon.RelayInfo{
  37. IsChannelTest: true,
  38. ChannelMeta: &relaycommon.ChannelMeta{
  39. HeadersOverride: map[string]any{
  40. "X-Upstream-Trace": "{client_header:X-Trace-Id}",
  41. },
  42. },
  43. }
  44. headers, err := processHeaderOverride(info, ctx)
  45. require.NoError(t, err)
  46. _, ok := headers["X-Upstream-Trace"]
  47. require.False(t, ok)
  48. }
  49. func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) {
  50. t.Parallel()
  51. gin.SetMode(gin.TestMode)
  52. recorder := httptest.NewRecorder()
  53. ctx, _ := gin.CreateTestContext(recorder)
  54. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  55. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  56. info := &relaycommon.RelayInfo{
  57. IsChannelTest: false,
  58. ChannelMeta: &relaycommon.ChannelMeta{
  59. HeadersOverride: map[string]any{
  60. "X-Upstream-Trace": "{client_header:X-Trace-Id}",
  61. },
  62. },
  63. }
  64. headers, err := processHeaderOverride(info, ctx)
  65. require.NoError(t, err)
  66. require.Equal(t, "trace-123", headers["X-Upstream-Trace"])
  67. }
  68. func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
  69. t.Parallel()
  70. gin.SetMode(gin.TestMode)
  71. recorder := httptest.NewRecorder()
  72. ctx, _ := gin.CreateTestContext(recorder)
  73. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  74. ctx.Request.Header.Set("X-Trace-Id", "trace-123")
  75. ctx.Request.Header.Set("Accept-Encoding", "gzip")
  76. info := &relaycommon.RelayInfo{
  77. IsChannelTest: false,
  78. ChannelMeta: &relaycommon.ChannelMeta{
  79. HeadersOverride: map[string]any{
  80. "*": "",
  81. },
  82. },
  83. }
  84. headers, err := processHeaderOverride(info, ctx)
  85. require.NoError(t, err)
  86. require.Equal(t, "trace-123", headers["X-Trace-Id"])
  87. _, hasAcceptEncoding := headers["Accept-Encoding"]
  88. require.False(t, hasAcceptEncoding)
  89. }