image.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. package ali
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "mime/multipart"
  9. "net/http"
  10. "strings"
  11. "time"
  12. "github.com/QuantumNous/new-api/common"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/logger"
  15. relaycommon "github.com/QuantumNous/new-api/relay/common"
  16. "github.com/QuantumNous/new-api/service"
  17. "github.com/QuantumNous/new-api/types"
  18. "github.com/gin-gonic/gin"
  19. )
  20. func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
  21. var imageRequest AliImageRequest
  22. imageRequest.Model = request.Model
  23. imageRequest.ResponseFormat = request.ResponseFormat
  24. logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
  25. logger.LogDebug(context.Background(), "oaiImage2Ali request isSync: "+fmt.Sprintf("%v", isSync))
  26. if request.Extra != nil {
  27. if val, ok := request.Extra["parameters"]; ok {
  28. err := common.Unmarshal(val, &imageRequest.Parameters)
  29. if err != nil {
  30. return nil, fmt.Errorf("invalid parameters field: %w", err)
  31. }
  32. } else {
  33. // 兼容没有parameters字段的情况,从openai标准字段中提取参数
  34. imageRequest.Parameters = AliImageParameters{
  35. Size: strings.Replace(request.Size, "x", "*", -1),
  36. N: int(request.N),
  37. Watermark: request.Watermark,
  38. }
  39. }
  40. if val, ok := request.Extra["input"]; ok {
  41. err := common.Unmarshal(val, &imageRequest.Input)
  42. if err != nil {
  43. return nil, fmt.Errorf("invalid input field: %w", err)
  44. }
  45. }
  46. }
  47. if strings.Contains(request.Model, "z-image") {
  48. // z-image 开启prompt_extend后,按2倍计费
  49. if imageRequest.Parameters.PromptExtendValue() {
  50. info.PriceData.AddOtherRatio("prompt_extend", 2)
  51. }
  52. }
  53. // 检查n参数
  54. if imageRequest.Parameters.N != 0 {
  55. info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
  56. }
  57. // 同步图片模型和异步图片模型请求格式不一样
  58. if isSync {
  59. if imageRequest.Input == nil {
  60. imageRequest.Input = AliImageInput{
  61. Messages: []AliMessage{
  62. {
  63. Role: "user",
  64. Content: []AliMediaContent{
  65. {
  66. Text: request.Prompt,
  67. },
  68. },
  69. },
  70. },
  71. }
  72. }
  73. } else {
  74. if imageRequest.Input == nil {
  75. imageRequest.Input = AliImageInput{
  76. Prompt: request.Prompt,
  77. }
  78. }
  79. }
  80. return &imageRequest, nil
  81. }
  82. func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
  83. mf := c.Request.MultipartForm
  84. if mf == nil {
  85. if _, err := c.MultipartForm(); err != nil {
  86. return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
  87. }
  88. mf = c.Request.MultipartForm
  89. }
  90. var imageFiles []*multipart.FileHeader
  91. var exists bool
  92. // First check for standard "image" field
  93. if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
  94. // If not found, check for "image[]" field
  95. if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
  96. // If still not found, iterate through all fields to find any that start with "image["
  97. foundArrayImages := false
  98. for fieldName, files := range mf.File {
  99. if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
  100. foundArrayImages = true
  101. imageFiles = append(imageFiles, files...)
  102. }
  103. }
  104. // If no image fields found at all
  105. if !foundArrayImages && (len(imageFiles) == 0) {
  106. return nil, errors.New("image is required")
  107. }
  108. }
  109. }
  110. if len(imageFiles) == 0 {
  111. return nil, errors.New("image is required")
  112. }
  113. //if len(imageFiles) > 1 {
  114. // return nil, errors.New("only one image is supported for qwen edit")
  115. //}
  116. // 获取base64编码的图片
  117. var imageBase64s []string
  118. for _, file := range imageFiles {
  119. image, err := file.Open()
  120. if err != nil {
  121. return nil, errors.New("failed to open image file")
  122. }
  123. // 读取文件内容
  124. imageData, err := io.ReadAll(image)
  125. if err != nil {
  126. return nil, errors.New("failed to read image file")
  127. }
  128. // 获取MIME类型
  129. mimeType := http.DetectContentType(imageData)
  130. // 编码为base64
  131. base64Data := base64.StdEncoding.EncodeToString(imageData)
  132. // 构造data URL格式
  133. dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
  134. imageBase64s = append(imageBase64s, dataURL)
  135. image.Close()
  136. }
  137. return imageBase64s, nil
  138. }
  139. func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
  140. var imageRequest AliImageRequest
  141. imageRequest.Model = request.Model
  142. imageRequest.ResponseFormat = request.ResponseFormat
  143. imageBase64s, err := getImageBase64sFromForm(c, "image")
  144. if err != nil {
  145. return nil, fmt.Errorf("get image base64s from form failed: %w", err)
  146. }
  147. //dto.MediaContent{}
  148. mediaContents := make([]AliMediaContent, len(imageBase64s))
  149. for i, b64 := range imageBase64s {
  150. mediaContents[i] = AliMediaContent{
  151. Image: b64,
  152. }
  153. }
  154. mediaContents = append(mediaContents, AliMediaContent{
  155. Text: request.Prompt,
  156. })
  157. imageRequest.Input = AliImageInput{
  158. Messages: []AliMessage{
  159. {
  160. Role: "user",
  161. Content: mediaContents,
  162. },
  163. },
  164. }
  165. imageRequest.Parameters = AliImageParameters{
  166. Watermark: request.Watermark,
  167. }
  168. return &imageRequest, nil
  169. }
  170. func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
  171. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
  172. var aliResponse AliResponse
  173. req, err := http.NewRequest("GET", url, nil)
  174. if err != nil {
  175. return &aliResponse, err, nil
  176. }
  177. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  178. client := &http.Client{}
  179. resp, err := client.Do(req)
  180. if err != nil {
  181. common.SysLog("updateTask client.Do err: " + err.Error())
  182. return &aliResponse, err, nil
  183. }
  184. defer resp.Body.Close()
  185. responseBody, err := io.ReadAll(resp.Body)
  186. var response AliResponse
  187. err = common.Unmarshal(responseBody, &response)
  188. if err != nil {
  189. common.SysLog("updateTask NewDecoder err: " + err.Error())
  190. return &aliResponse, err, nil
  191. }
  192. return &response, nil, responseBody
  193. }
  194. func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
  195. waitSeconds := 10
  196. step := 0
  197. maxStep := 20
  198. var taskResponse AliResponse
  199. var responseBody []byte
  200. time.Sleep(time.Duration(5) * time.Second)
  201. for {
  202. logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
  203. step++
  204. rsp, err, body := updateTask(info, taskID)
  205. responseBody = body
  206. if err != nil {
  207. logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
  208. time.Sleep(time.Duration(waitSeconds) * time.Second)
  209. continue
  210. }
  211. if rsp.Output.TaskStatus == "" {
  212. return &taskResponse, responseBody, nil
  213. }
  214. switch rsp.Output.TaskStatus {
  215. case "FAILED":
  216. fallthrough
  217. case "CANCELED":
  218. fallthrough
  219. case "SUCCEEDED":
  220. fallthrough
  221. case "UNKNOWN":
  222. return rsp, responseBody, nil
  223. }
  224. if step >= maxStep {
  225. break
  226. }
  227. time.Sleep(time.Duration(waitSeconds) * time.Second)
  228. }
  229. return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
  230. }
  231. func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
  232. imageResponse := dto.ImageResponse{
  233. Created: info.StartTime.Unix(),
  234. }
  235. if len(response.Output.Results) > 0 {
  236. imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
  237. } else if len(response.Output.Choices) > 0 {
  238. imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
  239. }
  240. imageResponse.Metadata = originBody
  241. return &imageResponse
  242. }
  243. func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  244. responseFormat := c.GetString("response_format")
  245. var aliTaskResponse AliResponse
  246. responseBody, err := io.ReadAll(resp.Body)
  247. if err != nil {
  248. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  249. }
  250. service.CloseResponseBodyGracefully(resp)
  251. err = common.Unmarshal(responseBody, &aliTaskResponse)
  252. if err != nil {
  253. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  254. }
  255. if aliTaskResponse.Message != "" {
  256. logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
  257. return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
  258. }
  259. var (
  260. aliResponse *AliResponse
  261. originRespBody []byte
  262. )
  263. if a.IsSyncImageModel {
  264. aliResponse = &aliTaskResponse
  265. originRespBody = responseBody
  266. } else {
  267. // 异步图片模型需要轮询任务结果
  268. aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
  269. if err != nil {
  270. return types.NewError(err, types.ErrorCodeBadResponse), nil
  271. }
  272. if aliResponse.Output.TaskStatus != "SUCCEEDED" {
  273. return types.WithOpenAIError(types.OpenAIError{
  274. Message: aliResponse.Output.Message,
  275. Type: "ali_error",
  276. Param: "",
  277. Code: aliResponse.Output.Code,
  278. }, resp.StatusCode), nil
  279. }
  280. }
  281. //logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
  282. if a.IsSyncImageModel {
  283. logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
  284. } else {
  285. logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
  286. }
  287. imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
  288. // 可能生成多张图片,修正计费数量n
  289. if aliResponse.Usage.ImageCount != 0 {
  290. info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
  291. } else if len(imageResponses.Data) != 0 {
  292. info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
  293. }
  294. jsonResponse, err := common.Marshal(imageResponses)
  295. if err != nil {
  296. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  297. }
  298. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  299. return nil, &dto.Usage{}
  300. }