relay_aws_test.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. package aws
  2. import (
  3. "bytes"
  4. "net/http"
  5. "net/http/httptest"
  6. "testing"
  7. "github.com/QuantumNous/new-api/common"
  8. relaycommon "github.com/QuantumNous/new-api/relay/common"
  9. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  10. "github.com/gin-gonic/gin"
  11. "github.com/stretchr/testify/require"
  12. )
  13. func TestDoAwsClientRequest_AppliesRuntimeHeaderOverrideToAnthropicBeta(t *testing.T) {
  14. t.Parallel()
  15. gin.SetMode(gin.TestMode)
  16. recorder := httptest.NewRecorder()
  17. ctx, _ := gin.CreateTestContext(recorder)
  18. ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
  19. info := &relaycommon.RelayInfo{
  20. OriginModelName: "claude-3-5-sonnet-20240620",
  21. IsStream: false,
  22. UseRuntimeHeadersOverride: true,
  23. RuntimeHeadersOverride: map[string]any{
  24. "anthropic-beta": "computer-use-2025-01-24",
  25. },
  26. ChannelMeta: &relaycommon.ChannelMeta{
  27. ApiKey: "access-key|secret-key|us-east-1",
  28. UpstreamModelName: "claude-3-5-sonnet-20240620",
  29. },
  30. }
  31. requestBody := bytes.NewBufferString(`{"messages":[{"role":"user","content":"hello"}],"max_tokens":128}`)
  32. adaptor := &Adaptor{}
  33. _, err := doAwsClientRequest(ctx, info, adaptor, requestBody)
  34. require.NoError(t, err)
  35. awsReq, ok := adaptor.AwsReq.(*bedrockruntime.InvokeModelInput)
  36. require.True(t, ok)
  37. var payload map[string]any
  38. require.NoError(t, common.Unmarshal(awsReq.Body, &payload))
  39. anthropicBeta, exists := payload["anthropic_beta"]
  40. require.True(t, exists)
  41. values, ok := anthropicBeta.([]any)
  42. require.True(t, ok)
  43. require.Equal(t, []any{"computer-use-2025-01-24"}, values)
  44. }