image.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package ali
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. "one-api/logger"
  11. relaycommon "one-api/relay/common"
  12. "one-api/service"
  13. "one-api/types"
  14. "strings"
  15. "time"
  16. "github.com/gin-gonic/gin"
  17. )
  18. func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
  19. var imageRequest AliImageRequest
  20. imageRequest.Input.Prompt = request.Prompt
  21. imageRequest.Model = request.Model
  22. imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
  23. imageRequest.Parameters.N = int(request.N)
  24. imageRequest.ResponseFormat = request.ResponseFormat
  25. return &imageRequest
  26. }
  27. func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
  28. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
  29. var aliResponse AliResponse
  30. req, err := http.NewRequest("GET", url, nil)
  31. if err != nil {
  32. return &aliResponse, err, nil
  33. }
  34. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  35. client := &http.Client{}
  36. resp, err := client.Do(req)
  37. if err != nil {
  38. common.SysLog("updateTask client.Do err: " + err.Error())
  39. return &aliResponse, err, nil
  40. }
  41. defer resp.Body.Close()
  42. responseBody, err := io.ReadAll(resp.Body)
  43. var response AliResponse
  44. err = json.Unmarshal(responseBody, &response)
  45. if err != nil {
  46. common.SysLog("updateTask NewDecoder err: " + err.Error())
  47. return &aliResponse, err, nil
  48. }
  49. return &response, nil, responseBody
  50. }
  51. func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
  52. waitSeconds := 3
  53. step := 0
  54. maxStep := 20
  55. var taskResponse AliResponse
  56. var responseBody []byte
  57. for {
  58. step++
  59. rsp, err, body := updateTask(info, taskID)
  60. responseBody = body
  61. if err != nil {
  62. return &taskResponse, responseBody, err
  63. }
  64. if rsp.Output.TaskStatus == "" {
  65. return &taskResponse, responseBody, nil
  66. }
  67. switch rsp.Output.TaskStatus {
  68. case "FAILED":
  69. fallthrough
  70. case "CANCELED":
  71. fallthrough
  72. case "SUCCEEDED":
  73. fallthrough
  74. case "UNKNOWN":
  75. return rsp, responseBody, nil
  76. }
  77. if step >= maxStep {
  78. break
  79. }
  80. time.Sleep(time.Duration(waitSeconds) * time.Second)
  81. }
  82. return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
  83. }
  84. func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
  85. imageResponse := dto.ImageResponse{
  86. Created: info.StartTime.Unix(),
  87. }
  88. for _, data := range response.Output.Results {
  89. var b64Json string
  90. if responseFormat == "b64_json" {
  91. _, b64, err := service.GetImageFromUrl(data.Url)
  92. if err != nil {
  93. logger.LogError(c, "get_image_data_failed: "+err.Error())
  94. continue
  95. }
  96. b64Json = b64
  97. } else {
  98. b64Json = data.B64Image
  99. }
  100. imageResponse.Data = append(imageResponse.Data, dto.ImageData{
  101. Url: data.Url,
  102. B64Json: b64Json,
  103. RevisedPrompt: "",
  104. })
  105. }
  106. return &imageResponse
  107. }
  108. func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  109. responseFormat := c.GetString("response_format")
  110. var aliTaskResponse AliResponse
  111. responseBody, err := io.ReadAll(resp.Body)
  112. if err != nil {
  113. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  114. }
  115. service.CloseResponseBodyGracefully(resp)
  116. err = json.Unmarshal(responseBody, &aliTaskResponse)
  117. if err != nil {
  118. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  119. }
  120. if aliTaskResponse.Message != "" {
  121. logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
  122. return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
  123. }
  124. aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
  125. if err != nil {
  126. return types.NewError(err, types.ErrorCodeBadResponse), nil
  127. }
  128. if aliResponse.Output.TaskStatus != "SUCCEEDED" {
  129. return types.WithOpenAIError(types.OpenAIError{
  130. Message: aliResponse.Output.Message,
  131. Type: "ali_error",
  132. Param: "",
  133. Code: aliResponse.Output.Code,
  134. }, resp.StatusCode), nil
  135. }
  136. fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
  137. jsonResponse, err := json.Marshal(fullTextResponse)
  138. if err != nil {
  139. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  140. }
  141. c.Writer.Header().Set("Content-Type", "application/json")
  142. c.Writer.WriteHeader(resp.StatusCode)
  143. c.Writer.Write(jsonResponse)
  144. return nil, &dto.Usage{}
  145. }