billing_expr_request_test.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. package helper
  2. import (
  3. "bytes"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "testing"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/dto"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/gin-gonic/gin"
  12. "github.com/samber/lo"
  13. "github.com/stretchr/testify/require"
  14. "github.com/tidwall/gjson"
  15. )
  16. func TestResolveIncomingBillingExprRequestInput(t *testing.T) {
  17. gin.SetMode(gin.TestMode)
  18. recorder := httptest.NewRecorder()
  19. ctx, _ := gin.CreateTestContext(recorder)
  20. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
  21. ctx.Request.Header.Set("Content-Type", "application/json")
  22. body := []byte(`{"service_tier":"fast"}`)
  23. ctx.Request.Body = io.NopCloser(bytes.NewReader(body))
  24. ctx.Set(common.KeyRequestBody, body)
  25. info := &relaycommon.RelayInfo{
  26. RequestHeaders: map[string]string{"Content-Type": "application/json"},
  27. }
  28. input, err := ResolveIncomingBillingExprRequestInput(ctx, info)
  29. require.NoError(t, err)
  30. require.Equal(t, body, input.Body)
  31. require.Equal(t, "application/json", input.Headers["Content-Type"])
  32. }
  33. func TestBuildBillingExprRequestInputFromRequest(t *testing.T) {
  34. request := &dto.GeneralOpenAIRequest{
  35. Model: "gemini-3.1-pro-preview",
  36. Stream: lo.ToPtr(true),
  37. Messages: []dto.Message{
  38. {
  39. Role: "user",
  40. Content: "hi",
  41. },
  42. },
  43. MaxTokens: lo.ToPtr(uint(3000)),
  44. }
  45. input, err := BuildBillingExprRequestInputFromRequest(request, map[string]string{
  46. "Content-Type": "application/json",
  47. "X-Test": "1",
  48. })
  49. require.NoError(t, err)
  50. require.Equal(t, "application/json", input.Headers["Content-Type"])
  51. require.Equal(t, "1", input.Headers["X-Test"])
  52. require.True(t, gjson.GetBytes(input.Body, "stream").Bool())
  53. require.Equal(t, "user", gjson.GetBytes(input.Body, "messages.0.role").String())
  54. require.Equal(t, float64(3000), gjson.GetBytes(input.Body, "max_tokens").Float())
  55. }