image.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. package ali
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "mime/multipart"
  9. "net/http"
  10. "strings"
  11. "time"
  12. "github.com/QuantumNous/new-api/common"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/logger"
  15. relaycommon "github.com/QuantumNous/new-api/relay/common"
  16. "github.com/QuantumNous/new-api/service"
  17. "github.com/QuantumNous/new-api/types"
  18. "github.com/gin-gonic/gin"
  19. )
  20. func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
  21. var imageRequest AliImageRequest
  22. imageRequest.Model = request.Model
  23. imageRequest.ResponseFormat = request.ResponseFormat
  24. logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
  25. if request.Extra != nil {
  26. if val, ok := request.Extra["parameters"]; ok {
  27. err := common.Unmarshal(val, &imageRequest.Parameters)
  28. if err != nil {
  29. return nil, fmt.Errorf("invalid parameters field: %w", err)
  30. }
  31. }
  32. if val, ok := request.Extra["input"]; ok {
  33. err := common.Unmarshal(val, &imageRequest.Input)
  34. if err != nil {
  35. return nil, fmt.Errorf("invalid input field: %w", err)
  36. }
  37. }
  38. }
  39. if imageRequest.Parameters == nil {
  40. imageRequest.Parameters = AliImageParameters{
  41. Size: strings.Replace(request.Size, "x", "*", -1),
  42. N: int(request.N),
  43. Watermark: request.Watermark,
  44. }
  45. }
  46. if imageRequest.Input == nil {
  47. imageRequest.Input = AliImageInput{
  48. Prompt: request.Prompt,
  49. }
  50. }
  51. return &imageRequest, nil
  52. }
  53. func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
  54. mf := c.Request.MultipartForm
  55. if mf == nil {
  56. if _, err := c.MultipartForm(); err != nil {
  57. return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
  58. }
  59. mf = c.Request.MultipartForm
  60. }
  61. var imageFiles []*multipart.FileHeader
  62. var exists bool
  63. // First check for standard "image" field
  64. if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
  65. // If not found, check for "image[]" field
  66. if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
  67. // If still not found, iterate through all fields to find any that start with "image["
  68. foundArrayImages := false
  69. for fieldName, files := range mf.File {
  70. if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
  71. foundArrayImages = true
  72. imageFiles = append(imageFiles, files...)
  73. }
  74. }
  75. // If no image fields found at all
  76. if !foundArrayImages && (len(imageFiles) == 0) {
  77. return nil, errors.New("image is required")
  78. }
  79. }
  80. }
  81. if len(imageFiles) == 0 {
  82. return nil, errors.New("image is required")
  83. }
  84. //if len(imageFiles) > 1 {
  85. // return nil, errors.New("only one image is supported for qwen edit")
  86. //}
  87. // 获取base64编码的图片
  88. var imageBase64s []string
  89. for _, file := range imageFiles {
  90. image, err := file.Open()
  91. if err != nil {
  92. return nil, errors.New("failed to open image file")
  93. }
  94. // 读取文件内容
  95. imageData, err := io.ReadAll(image)
  96. if err != nil {
  97. return nil, errors.New("failed to read image file")
  98. }
  99. // 获取MIME类型
  100. mimeType := http.DetectContentType(imageData)
  101. // 编码为base64
  102. base64Data := base64.StdEncoding.EncodeToString(imageData)
  103. // 构造data URL格式
  104. dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
  105. imageBase64s = append(imageBase64s, dataURL)
  106. image.Close()
  107. }
  108. return imageBase64s, nil
  109. }
  110. func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
  111. var imageRequest AliImageRequest
  112. imageRequest.Model = request.Model
  113. imageRequest.ResponseFormat = request.ResponseFormat
  114. imageBase64s, err := getImageBase64sFromForm(c, "image")
  115. if err != nil {
  116. return nil, fmt.Errorf("get image base64s from form failed: %w", err)
  117. }
  118. //dto.MediaContent{}
  119. mediaContents := make([]AliMediaContent, len(imageBase64s))
  120. for i, b64 := range imageBase64s {
  121. mediaContents[i] = AliMediaContent{
  122. Image: b64,
  123. }
  124. }
  125. mediaContents = append(mediaContents, AliMediaContent{
  126. Text: request.Prompt,
  127. })
  128. imageRequest.Input = AliImageInput{
  129. Messages: []AliMessage{
  130. {
  131. Role: "user",
  132. Content: mediaContents,
  133. },
  134. },
  135. }
  136. imageRequest.Parameters = AliImageParameters{
  137. Watermark: request.Watermark,
  138. }
  139. return &imageRequest, nil
  140. }
  141. func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
  142. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
  143. var aliResponse AliResponse
  144. req, err := http.NewRequest("GET", url, nil)
  145. if err != nil {
  146. return &aliResponse, err, nil
  147. }
  148. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  149. client := &http.Client{}
  150. resp, err := client.Do(req)
  151. if err != nil {
  152. common.SysLog("updateTask client.Do err: " + err.Error())
  153. return &aliResponse, err, nil
  154. }
  155. defer resp.Body.Close()
  156. responseBody, err := io.ReadAll(resp.Body)
  157. var response AliResponse
  158. err = common.Unmarshal(responseBody, &response)
  159. if err != nil {
  160. common.SysLog("updateTask NewDecoder err: " + err.Error())
  161. return &aliResponse, err, nil
  162. }
  163. return &response, nil, responseBody
  164. }
  165. func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
  166. waitSeconds := 10
  167. step := 0
  168. maxStep := 20
  169. var taskResponse AliResponse
  170. var responseBody []byte
  171. for {
  172. logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
  173. step++
  174. rsp, err, body := updateTask(info, taskID)
  175. responseBody = body
  176. if err != nil {
  177. logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
  178. time.Sleep(time.Duration(waitSeconds) * time.Second)
  179. continue
  180. }
  181. if rsp.Output.TaskStatus == "" {
  182. return &taskResponse, responseBody, nil
  183. }
  184. switch rsp.Output.TaskStatus {
  185. case "FAILED":
  186. fallthrough
  187. case "CANCELED":
  188. fallthrough
  189. case "SUCCEEDED":
  190. fallthrough
  191. case "UNKNOWN":
  192. return rsp, responseBody, nil
  193. }
  194. if step >= maxStep {
  195. break
  196. }
  197. time.Sleep(time.Duration(waitSeconds) * time.Second)
  198. }
  199. return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
  200. }
  201. func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
  202. imageResponse := dto.ImageResponse{
  203. Created: info.StartTime.Unix(),
  204. }
  205. for _, data := range response.Output.Results {
  206. var b64Json string
  207. if responseFormat == "b64_json" {
  208. _, b64, err := service.GetImageFromUrl(data.Url)
  209. if err != nil {
  210. logger.LogError(c, "get_image_data_failed: "+err.Error())
  211. continue
  212. }
  213. b64Json = b64
  214. } else {
  215. b64Json = data.B64Image
  216. }
  217. imageResponse.Data = append(imageResponse.Data, dto.ImageData{
  218. Url: data.Url,
  219. B64Json: b64Json,
  220. RevisedPrompt: "",
  221. })
  222. }
  223. var mapResponse map[string]any
  224. _ = common.Unmarshal(originBody, &mapResponse)
  225. imageResponse.Extra = mapResponse
  226. return &imageResponse
  227. }
  228. func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  229. responseFormat := c.GetString("response_format")
  230. var aliTaskResponse AliResponse
  231. responseBody, err := io.ReadAll(resp.Body)
  232. if err != nil {
  233. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  234. }
  235. service.CloseResponseBodyGracefully(resp)
  236. err = common.Unmarshal(responseBody, &aliTaskResponse)
  237. if err != nil {
  238. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  239. }
  240. if aliTaskResponse.Message != "" {
  241. logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
  242. return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
  243. }
  244. aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
  245. if err != nil {
  246. return types.NewError(err, types.ErrorCodeBadResponse), nil
  247. }
  248. if aliResponse.Output.TaskStatus != "SUCCEEDED" {
  249. return types.WithOpenAIError(types.OpenAIError{
  250. Message: aliResponse.Output.Message,
  251. Type: "ali_error",
  252. Param: "",
  253. Code: aliResponse.Output.Code,
  254. }, resp.StatusCode), nil
  255. }
  256. fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
  257. jsonResponse, err := common.Marshal(fullTextResponse)
  258. if err != nil {
  259. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  260. }
  261. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  262. return nil, &dto.Usage{}
  263. }
  264. func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  265. var aliResponse AliResponse
  266. responseBody, err := io.ReadAll(resp.Body)
  267. if err != nil {
  268. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  269. }
  270. service.CloseResponseBodyGracefully(resp)
  271. err = common.Unmarshal(responseBody, &aliResponse)
  272. if err != nil {
  273. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  274. }
  275. if aliResponse.Message != "" {
  276. logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
  277. return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
  278. }
  279. var fullTextResponse dto.ImageResponse
  280. if len(aliResponse.Output.Choices) > 0 {
  281. fullTextResponse = dto.ImageResponse{
  282. Created: info.StartTime.Unix(),
  283. Data: []dto.ImageData{
  284. {
  285. Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
  286. B64Json: "",
  287. },
  288. },
  289. }
  290. }
  291. var mapResponse map[string]any
  292. _ = common.Unmarshal(responseBody, &mapResponse)
  293. fullTextResponse.Extra = mapResponse
  294. jsonResponse, err := common.Marshal(fullTextResponse)
  295. if err != nil {
  296. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  297. }
  298. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  299. return nil, &dto.Usage{}
  300. }