image.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. package ali
  2. import (
  3. "encoding/base64"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "mime/multipart"
  8. "net/http"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/logger"
  14. relaycommon "github.com/QuantumNous/new-api/relay/common"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/QuantumNous/new-api/types"
  17. "github.com/gin-gonic/gin"
  18. "github.com/samber/lo"
  19. )
  20. func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
  21. var imageRequest AliImageRequest
  22. imageRequest.Model = request.Model
  23. imageRequest.ResponseFormat = request.ResponseFormat
  24. if request.Extra != nil {
  25. if val, ok := request.Extra["parameters"]; ok {
  26. err := common.Unmarshal(val, &imageRequest.Parameters)
  27. if err != nil {
  28. return nil, fmt.Errorf("invalid parameters field: %w", err)
  29. }
  30. } else {
  31. // 兼容没有parameters字段的情况,从openai标准字段中提取参数
  32. imageRequest.Parameters = AliImageParameters{
  33. Size: strings.Replace(request.Size, "x", "*", -1),
  34. N: int(lo.FromPtrOr(request.N, uint(1))),
  35. Watermark: request.Watermark,
  36. }
  37. }
  38. if val, ok := request.Extra["input"]; ok {
  39. err := common.Unmarshal(val, &imageRequest.Input)
  40. if err != nil {
  41. return nil, fmt.Errorf("invalid input field: %w", err)
  42. }
  43. }
  44. }
  45. if strings.Contains(request.Model, "z-image") {
  46. // z-image 开启prompt_extend后,按2倍计费
  47. if imageRequest.Parameters.PromptExtendValue() {
  48. info.PriceData.AddOtherRatio("prompt_extend", 2)
  49. }
  50. }
  51. // 检查n参数
  52. if imageRequest.Parameters.N != 0 {
  53. info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
  54. }
  55. // 同步图片模型和异步图片模型请求格式不一样
  56. if isSync {
  57. if imageRequest.Input == nil {
  58. imageRequest.Input = AliImageInput{
  59. Messages: []AliMessage{
  60. {
  61. Role: "user",
  62. Content: []AliMediaContent{
  63. {
  64. Text: request.Prompt,
  65. },
  66. },
  67. },
  68. },
  69. }
  70. }
  71. } else {
  72. if imageRequest.Input == nil {
  73. imageRequest.Input = AliImageInput{
  74. Prompt: request.Prompt,
  75. }
  76. }
  77. }
  78. return &imageRequest, nil
  79. }
  80. func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
  81. mf := c.Request.MultipartForm
  82. if mf == nil {
  83. if _, err := c.MultipartForm(); err != nil {
  84. return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
  85. }
  86. mf = c.Request.MultipartForm
  87. }
  88. var imageFiles []*multipart.FileHeader
  89. var exists bool
  90. // First check for standard "image" field
  91. if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
  92. // If not found, check for "image[]" field
  93. if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
  94. // If still not found, iterate through all fields to find any that start with "image["
  95. foundArrayImages := false
  96. for fieldName, files := range mf.File {
  97. if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
  98. foundArrayImages = true
  99. imageFiles = append(imageFiles, files...)
  100. }
  101. }
  102. // If no image fields found at all
  103. if !foundArrayImages && (len(imageFiles) == 0) {
  104. return nil, errors.New("image is required")
  105. }
  106. }
  107. }
  108. if len(imageFiles) == 0 {
  109. return nil, errors.New("image is required")
  110. }
  111. //if len(imageFiles) > 1 {
  112. // return nil, errors.New("only one image is supported for qwen edit")
  113. //}
  114. // 获取base64编码的图片
  115. var imageBase64s []string
  116. for _, file := range imageFiles {
  117. image, err := file.Open()
  118. if err != nil {
  119. return nil, errors.New("failed to open image file")
  120. }
  121. // 读取文件内容
  122. imageData, err := io.ReadAll(image)
  123. if err != nil {
  124. return nil, errors.New("failed to read image file")
  125. }
  126. // 获取MIME类型
  127. mimeType := http.DetectContentType(imageData)
  128. // 编码为base64
  129. base64Data := base64.StdEncoding.EncodeToString(imageData)
  130. // 构造data URL格式
  131. dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
  132. imageBase64s = append(imageBase64s, dataURL)
  133. image.Close()
  134. }
  135. return imageBase64s, nil
  136. }
  137. func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
  138. var imageRequest AliImageRequest
  139. imageRequest.Model = request.Model
  140. imageRequest.ResponseFormat = request.ResponseFormat
  141. imageBase64s, err := getImageBase64sFromForm(c, "image")
  142. if err != nil {
  143. return nil, fmt.Errorf("get image base64s from form failed: %w", err)
  144. }
  145. //dto.MediaContent{}
  146. mediaContents := make([]AliMediaContent, len(imageBase64s))
  147. for i, b64 := range imageBase64s {
  148. mediaContents[i] = AliMediaContent{
  149. Image: b64,
  150. }
  151. }
  152. mediaContents = append(mediaContents, AliMediaContent{
  153. Text: request.Prompt,
  154. })
  155. imageRequest.Input = AliImageInput{
  156. Messages: []AliMessage{
  157. {
  158. Role: "user",
  159. Content: mediaContents,
  160. },
  161. },
  162. }
  163. imageRequest.Parameters = AliImageParameters{
  164. Watermark: request.Watermark,
  165. }
  166. return &imageRequest, nil
  167. }
  168. func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
  169. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
  170. var aliResponse AliResponse
  171. req, err := http.NewRequest("GET", url, nil)
  172. if err != nil {
  173. return &aliResponse, err, nil
  174. }
  175. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  176. client := &http.Client{}
  177. resp, err := client.Do(req)
  178. if err != nil {
  179. common.SysLog("updateTask client.Do err: " + err.Error())
  180. return &aliResponse, err, nil
  181. }
  182. defer resp.Body.Close()
  183. responseBody, err := io.ReadAll(resp.Body)
  184. var response AliResponse
  185. err = common.Unmarshal(responseBody, &response)
  186. if err != nil {
  187. common.SysLog("updateTask NewDecoder err: " + err.Error())
  188. return &aliResponse, err, nil
  189. }
  190. return &response, nil, responseBody
  191. }
  192. func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
  193. waitSeconds := 10
  194. step := 0
  195. maxStep := 20
  196. var taskResponse AliResponse
  197. var responseBody []byte
  198. time.Sleep(time.Duration(5) * time.Second)
  199. for {
  200. logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
  201. step++
  202. rsp, err, body := updateTask(info, taskID)
  203. responseBody = body
  204. if err != nil {
  205. logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
  206. time.Sleep(time.Duration(waitSeconds) * time.Second)
  207. continue
  208. }
  209. if rsp.Output.TaskStatus == "" {
  210. return &taskResponse, responseBody, nil
  211. }
  212. switch rsp.Output.TaskStatus {
  213. case "FAILED":
  214. fallthrough
  215. case "CANCELED":
  216. fallthrough
  217. case "SUCCEEDED":
  218. fallthrough
  219. case "UNKNOWN":
  220. return rsp, responseBody, nil
  221. }
  222. if step >= maxStep {
  223. break
  224. }
  225. time.Sleep(time.Duration(waitSeconds) * time.Second)
  226. }
  227. return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
  228. }
  229. func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
  230. imageResponse := dto.ImageResponse{
  231. Created: info.StartTime.Unix(),
  232. }
  233. if len(response.Output.Results) > 0 {
  234. imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
  235. } else if len(response.Output.Choices) > 0 {
  236. imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
  237. }
  238. imageResponse.Metadata = originBody
  239. return &imageResponse
  240. }
  241. func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  242. responseFormat := c.GetString("response_format")
  243. var aliTaskResponse AliResponse
  244. responseBody, err := io.ReadAll(resp.Body)
  245. if err != nil {
  246. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  247. }
  248. service.CloseResponseBodyGracefully(resp)
  249. err = common.Unmarshal(responseBody, &aliTaskResponse)
  250. if err != nil {
  251. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  252. }
  253. if aliTaskResponse.Message != "" {
  254. logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
  255. return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
  256. }
  257. var (
  258. aliResponse *AliResponse
  259. originRespBody []byte
  260. )
  261. if a.IsSyncImageModel {
  262. aliResponse = &aliTaskResponse
  263. originRespBody = responseBody
  264. } else {
  265. // 异步图片模型需要轮询任务结果
  266. aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
  267. if err != nil {
  268. return types.NewError(err, types.ErrorCodeBadResponse), nil
  269. }
  270. if aliResponse.Output.TaskStatus != "SUCCEEDED" {
  271. return types.WithOpenAIError(types.OpenAIError{
  272. Message: aliResponse.Output.Message,
  273. Type: "ali_error",
  274. Param: "",
  275. Code: aliResponse.Output.Code,
  276. }, resp.StatusCode), nil
  277. }
  278. }
  279. //logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
  280. if a.IsSyncImageModel {
  281. logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
  282. } else {
  283. logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
  284. }
  285. imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
  286. // 可能生成多张图片,修正计费数量n
  287. if aliResponse.Usage.ImageCount != 0 {
  288. info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
  289. } else if len(imageResponses.Data) != 0 {
  290. info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
  291. }
  292. jsonResponse, err := common.Marshal(imageResponses)
  293. if err != nil {
  294. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  295. }
  296. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  297. return nil, &dto.Usage{}
  298. }