image.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. package ali
  2. import (
  3. "encoding/base64"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "mime/multipart"
  8. "net/http"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/logger"
  14. relaycommon "github.com/QuantumNous/new-api/relay/common"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/QuantumNous/new-api/types"
  17. "github.com/gin-gonic/gin"
  18. )
  19. func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
  20. var imageRequest AliImageRequest
  21. imageRequest.Model = request.Model
  22. imageRequest.ResponseFormat = request.ResponseFormat
  23. if request.Extra != nil {
  24. if val, ok := request.Extra["parameters"]; ok {
  25. err := common.Unmarshal(val, &imageRequest.Parameters)
  26. if err != nil {
  27. return nil, fmt.Errorf("invalid parameters field: %w", err)
  28. }
  29. } else {
  30. // 兼容没有parameters字段的情况,从openai标准字段中提取参数
  31. imageRequest.Parameters = AliImageParameters{
  32. Size: strings.Replace(request.Size, "x", "*", -1),
  33. N: int(request.N),
  34. Watermark: request.Watermark,
  35. }
  36. }
  37. if val, ok := request.Extra["input"]; ok {
  38. err := common.Unmarshal(val, &imageRequest.Input)
  39. if err != nil {
  40. return nil, fmt.Errorf("invalid input field: %w", err)
  41. }
  42. }
  43. }
  44. if strings.Contains(request.Model, "z-image") {
  45. // z-image 开启prompt_extend后,按2倍计费
  46. if imageRequest.Parameters.PromptExtendValue() {
  47. info.PriceData.AddOtherRatio("prompt_extend", 2)
  48. }
  49. }
  50. // 检查n参数
  51. if imageRequest.Parameters.N != 0 {
  52. info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
  53. }
  54. // 同步图片模型和异步图片模型请求格式不一样
  55. if isSync {
  56. if imageRequest.Input == nil {
  57. imageRequest.Input = AliImageInput{
  58. Messages: []AliMessage{
  59. {
  60. Role: "user",
  61. Content: []AliMediaContent{
  62. {
  63. Text: request.Prompt,
  64. },
  65. },
  66. },
  67. },
  68. }
  69. }
  70. } else {
  71. if imageRequest.Input == nil {
  72. imageRequest.Input = AliImageInput{
  73. Prompt: request.Prompt,
  74. }
  75. }
  76. }
  77. return &imageRequest, nil
  78. }
  79. func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
  80. mf := c.Request.MultipartForm
  81. if mf == nil {
  82. if _, err := c.MultipartForm(); err != nil {
  83. return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
  84. }
  85. mf = c.Request.MultipartForm
  86. }
  87. var imageFiles []*multipart.FileHeader
  88. var exists bool
  89. // First check for standard "image" field
  90. if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
  91. // If not found, check for "image[]" field
  92. if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
  93. // If still not found, iterate through all fields to find any that start with "image["
  94. foundArrayImages := false
  95. for fieldName, files := range mf.File {
  96. if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
  97. foundArrayImages = true
  98. imageFiles = append(imageFiles, files...)
  99. }
  100. }
  101. // If no image fields found at all
  102. if !foundArrayImages && (len(imageFiles) == 0) {
  103. return nil, errors.New("image is required")
  104. }
  105. }
  106. }
  107. if len(imageFiles) == 0 {
  108. return nil, errors.New("image is required")
  109. }
  110. //if len(imageFiles) > 1 {
  111. // return nil, errors.New("only one image is supported for qwen edit")
  112. //}
  113. // 获取base64编码的图片
  114. var imageBase64s []string
  115. for _, file := range imageFiles {
  116. image, err := file.Open()
  117. if err != nil {
  118. return nil, errors.New("failed to open image file")
  119. }
  120. // 读取文件内容
  121. imageData, err := io.ReadAll(image)
  122. if err != nil {
  123. return nil, errors.New("failed to read image file")
  124. }
  125. // 获取MIME类型
  126. mimeType := http.DetectContentType(imageData)
  127. // 编码为base64
  128. base64Data := base64.StdEncoding.EncodeToString(imageData)
  129. // 构造data URL格式
  130. dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
  131. imageBase64s = append(imageBase64s, dataURL)
  132. image.Close()
  133. }
  134. return imageBase64s, nil
  135. }
  136. func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
  137. var imageRequest AliImageRequest
  138. imageRequest.Model = request.Model
  139. imageRequest.ResponseFormat = request.ResponseFormat
  140. imageBase64s, err := getImageBase64sFromForm(c, "image")
  141. if err != nil {
  142. return nil, fmt.Errorf("get image base64s from form failed: %w", err)
  143. }
  144. //dto.MediaContent{}
  145. mediaContents := make([]AliMediaContent, len(imageBase64s))
  146. for i, b64 := range imageBase64s {
  147. mediaContents[i] = AliMediaContent{
  148. Image: b64,
  149. }
  150. }
  151. mediaContents = append(mediaContents, AliMediaContent{
  152. Text: request.Prompt,
  153. })
  154. imageRequest.Input = AliImageInput{
  155. Messages: []AliMessage{
  156. {
  157. Role: "user",
  158. Content: mediaContents,
  159. },
  160. },
  161. }
  162. imageRequest.Parameters = AliImageParameters{
  163. Watermark: request.Watermark,
  164. }
  165. return &imageRequest, nil
  166. }
  167. func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
  168. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
  169. var aliResponse AliResponse
  170. req, err := http.NewRequest("GET", url, nil)
  171. if err != nil {
  172. return &aliResponse, err, nil
  173. }
  174. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  175. client := &http.Client{}
  176. resp, err := client.Do(req)
  177. if err != nil {
  178. common.SysLog("updateTask client.Do err: " + err.Error())
  179. return &aliResponse, err, nil
  180. }
  181. defer resp.Body.Close()
  182. responseBody, err := io.ReadAll(resp.Body)
  183. var response AliResponse
  184. err = common.Unmarshal(responseBody, &response)
  185. if err != nil {
  186. common.SysLog("updateTask NewDecoder err: " + err.Error())
  187. return &aliResponse, err, nil
  188. }
  189. return &response, nil, responseBody
  190. }
  191. func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
  192. waitSeconds := 10
  193. step := 0
  194. maxStep := 20
  195. var taskResponse AliResponse
  196. var responseBody []byte
  197. time.Sleep(time.Duration(5) * time.Second)
  198. for {
  199. logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
  200. step++
  201. rsp, err, body := updateTask(info, taskID)
  202. responseBody = body
  203. if err != nil {
  204. logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
  205. time.Sleep(time.Duration(waitSeconds) * time.Second)
  206. continue
  207. }
  208. if rsp.Output.TaskStatus == "" {
  209. return &taskResponse, responseBody, nil
  210. }
  211. switch rsp.Output.TaskStatus {
  212. case "FAILED":
  213. fallthrough
  214. case "CANCELED":
  215. fallthrough
  216. case "SUCCEEDED":
  217. fallthrough
  218. case "UNKNOWN":
  219. return rsp, responseBody, nil
  220. }
  221. if step >= maxStep {
  222. break
  223. }
  224. time.Sleep(time.Duration(waitSeconds) * time.Second)
  225. }
  226. return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
  227. }
  228. func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
  229. imageResponse := dto.ImageResponse{
  230. Created: info.StartTime.Unix(),
  231. }
  232. if len(response.Output.Results) > 0 {
  233. imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
  234. } else if len(response.Output.Choices) > 0 {
  235. imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
  236. }
  237. imageResponse.Metadata = originBody
  238. return &imageResponse
  239. }
  240. func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
  241. responseFormat := c.GetString("response_format")
  242. var aliTaskResponse AliResponse
  243. responseBody, err := io.ReadAll(resp.Body)
  244. if err != nil {
  245. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  246. }
  247. service.CloseResponseBodyGracefully(resp)
  248. err = common.Unmarshal(responseBody, &aliTaskResponse)
  249. if err != nil {
  250. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  251. }
  252. if aliTaskResponse.Message != "" {
  253. logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
  254. return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
  255. }
  256. var (
  257. aliResponse *AliResponse
  258. originRespBody []byte
  259. )
  260. if a.IsSyncImageModel {
  261. aliResponse = &aliTaskResponse
  262. originRespBody = responseBody
  263. } else {
  264. // 异步图片模型需要轮询任务结果
  265. aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
  266. if err != nil {
  267. return types.NewError(err, types.ErrorCodeBadResponse), nil
  268. }
  269. if aliResponse.Output.TaskStatus != "SUCCEEDED" {
  270. return types.WithOpenAIError(types.OpenAIError{
  271. Message: aliResponse.Output.Message,
  272. Type: "ali_error",
  273. Param: "",
  274. Code: aliResponse.Output.Code,
  275. }, resp.StatusCode), nil
  276. }
  277. }
  278. //logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
  279. if a.IsSyncImageModel {
  280. logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
  281. } else {
  282. logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
  283. }
  284. imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
  285. // 可能生成多张图片,修正计费数量n
  286. if aliResponse.Usage.ImageCount != 0 {
  287. info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
  288. } else if len(imageResponses.Data) != 0 {
  289. info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
  290. }
  291. jsonResponse, err := common.Marshal(imageResponses)
  292. if err != nil {
  293. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  294. }
  295. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  296. return nil, &dto.Usage{}
  297. }