adaptor.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package hailuo
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/model"
  12. "github.com/gin-gonic/gin"
  13. "github.com/pkg/errors"
  14. "github.com/QuantumNous/new-api/constant"
  15. "github.com/QuantumNous/new-api/dto"
  16. "github.com/QuantumNous/new-api/relay/channel"
  17. relaycommon "github.com/QuantumNous/new-api/relay/common"
  18. "github.com/QuantumNous/new-api/service"
  19. )
  20. // https://platform.minimaxi.com/docs/api-reference/video-generation-intro
  21. type TaskAdaptor struct {
  22. ChannelType int
  23. apiKey string
  24. baseURL string
  25. }
  26. func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
  27. a.ChannelType = info.ChannelType
  28. a.baseURL = info.ChannelBaseUrl
  29. a.apiKey = info.ApiKey
  30. }
  31. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  32. return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
  33. }
  34. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
  35. return fmt.Sprintf("%s%s", a.baseURL, TextToVideoEndpoint), nil
  36. }
  37. func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
  38. req.Header.Set("Content-Type", "application/json")
  39. req.Header.Set("Accept", "application/json")
  40. req.Header.Set("Authorization", "Bearer "+a.apiKey)
  41. return nil
  42. }
  43. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
  44. v, exists := c.Get("task_request")
  45. if !exists {
  46. return nil, fmt.Errorf("request not found in context")
  47. }
  48. req, ok := v.(relaycommon.TaskSubmitReq)
  49. if !ok {
  50. return nil, fmt.Errorf("invalid request type in context")
  51. }
  52. body, err := a.convertToRequestPayload(&req)
  53. if err != nil {
  54. return nil, errors.Wrap(err, "convert request payload failed")
  55. }
  56. data, err := common.Marshal(body)
  57. if err != nil {
  58. return nil, err
  59. }
  60. return bytes.NewReader(data), nil
  61. }
  62. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  63. return channel.DoTaskApiRequest(a, c, info, requestBody)
  64. }
  65. func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
  66. responseBody, err := io.ReadAll(resp.Body)
  67. if err != nil {
  68. taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  69. return
  70. }
  71. _ = resp.Body.Close()
  72. var hResp VideoResponse
  73. if err := common.Unmarshal(responseBody, &hResp); err != nil {
  74. taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
  75. return
  76. }
  77. if hResp.BaseResp.StatusCode != StatusSuccess {
  78. taskErr = service.TaskErrorWrapper(
  79. fmt.Errorf("hailuo api error: %s", hResp.BaseResp.StatusMsg),
  80. strconv.Itoa(hResp.BaseResp.StatusCode),
  81. http.StatusBadRequest,
  82. )
  83. return
  84. }
  85. ov := dto.NewOpenAIVideo()
  86. ov.ID = info.PublicTaskID
  87. ov.TaskID = info.PublicTaskID
  88. ov.CreatedAt = time.Now().Unix()
  89. ov.Model = info.OriginModelName
  90. c.JSON(http.StatusOK, ov)
  91. return hResp.TaskID, responseBody, nil
  92. }
  93. func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
  94. taskID, ok := body["task_id"].(string)
  95. if !ok {
  96. return nil, fmt.Errorf("invalid task_id")
  97. }
  98. uri := fmt.Sprintf("%s%s?task_id=%s", baseUrl, QueryTaskEndpoint, taskID)
  99. req, err := http.NewRequest(http.MethodGet, uri, nil)
  100. if err != nil {
  101. return nil, err
  102. }
  103. req.Header.Set("Accept", "application/json")
  104. req.Header.Set("Authorization", "Bearer "+key)
  105. client, err := service.GetHttpClientWithProxy(proxy)
  106. if err != nil {
  107. return nil, fmt.Errorf("new proxy http client failed: %w", err)
  108. }
  109. return client.Do(req)
  110. }
  111. func (a *TaskAdaptor) GetModelList() []string {
  112. return ModelList
  113. }
  114. func (a *TaskAdaptor) GetChannelName() string {
  115. return ChannelName
  116. }
  117. func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
  118. modelConfig := GetModelConfig(req.Model)
  119. duration := DefaultDuration
  120. if req.Duration > 0 {
  121. duration = req.Duration
  122. }
  123. resolution := modelConfig.DefaultResolution
  124. if req.Size != "" {
  125. resolution = a.parseResolutionFromSize(req.Size, modelConfig)
  126. }
  127. videoRequest := &VideoRequest{
  128. Model: req.Model,
  129. Prompt: req.Prompt,
  130. Duration: &duration,
  131. Resolution: resolution,
  132. }
  133. if err := req.UnmarshalMetadata(&videoRequest); err != nil {
  134. return nil, errors.Wrap(err, "unmarshal metadata to video request failed")
  135. }
  136. return videoRequest, nil
  137. }
  138. func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConfig) string {
  139. switch {
  140. case strings.Contains(size, "1080"):
  141. return Resolution1080P
  142. case strings.Contains(size, "768"):
  143. return Resolution768P
  144. case strings.Contains(size, "720"):
  145. return Resolution720P
  146. case strings.Contains(size, "512"):
  147. return Resolution512P
  148. default:
  149. return modelConfig.DefaultResolution
  150. }
  151. }
  152. func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
  153. resTask := QueryTaskResponse{}
  154. if err := common.Unmarshal(respBody, &resTask); err != nil {
  155. return nil, errors.Wrap(err, "unmarshal task result failed")
  156. }
  157. taskResult := relaycommon.TaskInfo{}
  158. if resTask.BaseResp.StatusCode == StatusSuccess {
  159. taskResult.Code = 0
  160. } else {
  161. taskResult.Code = resTask.BaseResp.StatusCode
  162. taskResult.Reason = resTask.BaseResp.StatusMsg
  163. taskResult.Status = model.TaskStatusFailure
  164. taskResult.Progress = "100%"
  165. }
  166. switch resTask.Status {
  167. case TaskStatusPreparing, TaskStatusQueueing, TaskStatusProcessing:
  168. taskResult.Status = model.TaskStatusInProgress
  169. taskResult.Progress = "30%"
  170. if resTask.Status == TaskStatusProcessing {
  171. taskResult.Progress = "50%"
  172. }
  173. case TaskStatusSuccess:
  174. taskResult.Status = model.TaskStatusSuccess
  175. taskResult.Progress = "100%"
  176. taskResult.Url = a.buildVideoURL(resTask.TaskID, resTask.FileID)
  177. case TaskStatusFailed:
  178. taskResult.Status = model.TaskStatusFailure
  179. taskResult.Progress = "100%"
  180. if taskResult.Reason == "" {
  181. taskResult.Reason = "task failed"
  182. }
  183. default:
  184. taskResult.Status = model.TaskStatusInProgress
  185. taskResult.Progress = "30%"
  186. }
  187. return &taskResult, nil
  188. }
  189. func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
  190. var hailuoResp QueryTaskResponse
  191. if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil {
  192. return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
  193. }
  194. openAIVideo := originTask.ToOpenAIVideo()
  195. if hailuoResp.BaseResp.StatusCode != StatusSuccess {
  196. openAIVideo.Error = &dto.OpenAIVideoError{
  197. Message: hailuoResp.BaseResp.StatusMsg,
  198. Code: strconv.Itoa(hailuoResp.BaseResp.StatusCode),
  199. }
  200. }
  201. jsonData, err := common.Marshal(openAIVideo)
  202. if err != nil {
  203. return nil, errors.Wrap(err, "marshal openai video failed")
  204. }
  205. return jsonData, nil
  206. }
  207. func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
  208. if a.apiKey == "" || a.baseURL == "" {
  209. return ""
  210. }
  211. url := fmt.Sprintf("%s/v1/files/retrieve?file_id=%s", a.baseURL, fileID)
  212. req, err := http.NewRequest(http.MethodGet, url, nil)
  213. if err != nil {
  214. return ""
  215. }
  216. req.Header.Set("Accept", "application/json")
  217. req.Header.Set("Authorization", "Bearer "+a.apiKey)
  218. resp, err := service.GetHttpClient().Do(req)
  219. if err != nil {
  220. return ""
  221. }
  222. defer resp.Body.Close()
  223. responseBody, err := io.ReadAll(resp.Body)
  224. if err != nil {
  225. return ""
  226. }
  227. var retrieveResp RetrieveFileResponse
  228. if err := common.Unmarshal(responseBody, &retrieveResp); err != nil {
  229. return ""
  230. }
  231. if retrieveResp.BaseResp.StatusCode != StatusSuccess {
  232. return ""
  233. }
  234. return retrieveResp.File.DownloadURL
  235. }
  236. func contains(slice []string, item string) bool {
  237. for _, s := range slice {
  238. if s == item {
  239. return true
  240. }
  241. }
  242. return false
  243. }
  244. func containsInt(slice []int, item int) bool {
  245. for _, s := range slice {
  246. if s == item {
  247. return true
  248. }
  249. }
  250. return false
  251. }