responses_handler.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/relay/helper"
  12. "one-api/service"
  13. "one-api/setting/model_setting"
  14. "one-api/types"
  15. "strings"
  16. "github.com/gin-gonic/gin"
  17. )
  18. func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
  19. info.InitChannelMeta(c)
  20. request, ok := info.Request.(*dto.OpenAIResponsesRequest)
  21. if !ok {
  22. common.FatalLog(fmt.Sprintf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request))
  23. }
  24. err := helper.ModelMappedHelper(c, info, request)
  25. if err != nil {
  26. return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
  27. }
  28. adaptor := GetAdaptor(info.ApiType)
  29. if adaptor == nil {
  30. return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
  31. }
  32. adaptor.Init(info)
  33. var requestBody io.Reader
  34. if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
  35. body, err := common.GetRequestBody(c)
  36. if err != nil {
  37. return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
  38. }
  39. requestBody = bytes.NewBuffer(body)
  40. } else {
  41. convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request)
  42. if err != nil {
  43. return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
  44. }
  45. jsonData, err := json.Marshal(convertedRequest)
  46. if err != nil {
  47. return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
  48. }
  49. // apply param override
  50. if len(info.ParamOverride) > 0 {
  51. reqMap := make(map[string]interface{})
  52. err = json.Unmarshal(jsonData, &reqMap)
  53. if err != nil {
  54. return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
  55. }
  56. for key, value := range info.ParamOverride {
  57. reqMap[key] = value
  58. }
  59. jsonData, err = json.Marshal(reqMap)
  60. if err != nil {
  61. return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
  62. }
  63. }
  64. if common.DebugEnabled {
  65. println("requestBody: ", string(jsonData))
  66. }
  67. requestBody = bytes.NewBuffer(jsonData)
  68. }
  69. var httpResp *http.Response
  70. resp, err := adaptor.DoRequest(c, info, requestBody)
  71. if err != nil {
  72. return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
  73. }
  74. statusCodeMappingStr := c.GetString("status_code_mapping")
  75. if resp != nil {
  76. httpResp = resp.(*http.Response)
  77. if httpResp.StatusCode != http.StatusOK {
  78. newAPIError = service.RelayErrorHandler(httpResp, false)
  79. // reset status code 重置状态码
  80. service.ResetStatusCode(newAPIError, statusCodeMappingStr)
  81. return newAPIError
  82. }
  83. }
  84. usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
  85. if newAPIError != nil {
  86. // reset status code 重置状态码
  87. service.ResetStatusCode(newAPIError, statusCodeMappingStr)
  88. return newAPIError
  89. }
  90. if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
  91. service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
  92. } else {
  93. postConsumeQuota(c, info, usage.(*dto.Usage), "")
  94. }
  95. return nil
  96. }