adaptor.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. package ali
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/relay/channel"
  14. relaycommon "github.com/QuantumNous/new-api/relay/common"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/gin-gonic/gin"
  17. "github.com/pkg/errors"
  18. )
  19. // ============================
  20. // Request / Response structures
  21. // ============================
  22. // AliVideoRequest 阿里通义万相视频生成请求
  23. type AliVideoRequest struct {
  24. Model string `json:"model"`
  25. Input AliVideoInput `json:"input"`
  26. Parameters *AliVideoParameters `json:"parameters,omitempty"`
  27. }
  28. // AliVideoInput 视频输入参数
  29. type AliVideoInput struct {
  30. Prompt string `json:"prompt,omitempty"` // 文本提示词
  31. ImgURL string `json:"img_url,omitempty"` // 首帧图像URL或Base64(图生视频)
  32. FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL(首尾帧生视频)
  33. LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL(首尾帧生视频)
  34. AudioURL string `json:"audio_url,omitempty"` // 音频URL(wan2.5支持)
  35. NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
  36. Template string `json:"template,omitempty"` // 视频特效模板
  37. }
  38. // AliVideoParameters 视频参数
  39. type AliVideoParameters struct {
  40. Resolution string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P(图生视频、首尾帧生视频)
  41. Size string `json:"size,omitempty"` // 尺寸: 如 "832*480"(文生视频)
  42. Duration int `json:"duration,omitempty"` // 时长: 3-10秒
  43. PromptExtend bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
  44. Watermark bool `json:"watermark,omitempty"` // 是否添加水印
  45. Audio *bool `json:"audio,omitempty"` // 是否添加音频(wan2.5)
  46. Seed int `json:"seed,omitempty"` // 随机数种子
  47. }
  48. // AliVideoResponse 阿里通义万相响应
  49. type AliVideoResponse struct {
  50. Output AliVideoOutput `json:"output"`
  51. RequestID string `json:"request_id"`
  52. Code string `json:"code,omitempty"`
  53. Message string `json:"message,omitempty"`
  54. Usage *AliUsage `json:"usage,omitempty"`
  55. }
  56. // AliVideoOutput 输出信息
  57. type AliVideoOutput struct {
  58. TaskID string `json:"task_id"`
  59. TaskStatus string `json:"task_status"`
  60. SubmitTime string `json:"submit_time,omitempty"`
  61. ScheduledTime string `json:"scheduled_time,omitempty"`
  62. EndTime string `json:"end_time,omitempty"`
  63. OrigPrompt string `json:"orig_prompt,omitempty"`
  64. ActualPrompt string `json:"actual_prompt,omitempty"`
  65. VideoURL string `json:"video_url,omitempty"`
  66. Code string `json:"code,omitempty"`
  67. Message string `json:"message,omitempty"`
  68. }
  69. // AliUsage 使用统计
  70. type AliUsage struct {
  71. Duration int `json:"duration,omitempty"`
  72. VideoCount int `json:"video_count,omitempty"`
  73. SR int `json:"SR,omitempty"`
  74. }
  75. type AliMetadata struct {
  76. // Input 相关
  77. AudioURL string `json:"audio_url,omitempty"` // 音频URL
  78. ImgURL string `json:"img_url,omitempty"` // 图片URL(图生视频)
  79. FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL(首尾帧生视频)
  80. LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL(首尾帧生视频)
  81. NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
  82. Template string `json:"template,omitempty"` // 视频特效模板
  83. // Parameters 相关
  84. Resolution *string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P
  85. Size *string `json:"size,omitempty"` // 尺寸: 如 "832*480"
  86. Duration *int `json:"duration,omitempty"` // 时长
  87. PromptExtend *bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
  88. Watermark *bool `json:"watermark,omitempty"` // 是否添加水印
  89. Audio *bool `json:"audio,omitempty"` // 是否添加音频
  90. Seed *int `json:"seed,omitempty"` // 随机数种子
  91. }
  92. // ============================
  93. // Adaptor implementation
  94. // ============================
  95. type TaskAdaptor struct {
  96. ChannelType int
  97. apiKey string
  98. baseURL string
  99. aliReq *AliVideoRequest
  100. }
  101. func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
  102. a.ChannelType = info.ChannelType
  103. a.baseURL = info.ChannelBaseUrl
  104. a.apiKey = info.ApiKey
  105. }
  106. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  107. // 阿里通义万相支持 JSON 格式,不使用 multipart
  108. var taskReq relaycommon.TaskSubmitReq
  109. if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
  110. return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
  111. }
  112. aliReq, err := a.convertToAliRequest(info, taskReq)
  113. if err != nil {
  114. return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
  115. }
  116. a.aliReq = aliReq
  117. logger.LogJson(c, "ali video request body", aliReq)
  118. return relaycommon.ValidateMultipartDirect(c, info)
  119. }
  120. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
  121. return fmt.Sprintf("%s/api/v1/services/aigc/video-generation/video-synthesis", a.baseURL), nil
  122. }
  123. // BuildRequestHeader sets required headers for Ali API
  124. func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
  125. req.Header.Set("Authorization", "Bearer "+a.apiKey)
  126. req.Header.Set("Content-Type", "application/json")
  127. req.Header.Set("X-DashScope-Async", "enable") // 阿里异步任务必须设置
  128. return nil
  129. }
  130. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
  131. bodyBytes, err := common.Marshal(a.aliReq)
  132. if err != nil {
  133. return nil, errors.Wrap(err, "marshal_ali_request_failed")
  134. }
  135. return bytes.NewReader(bodyBytes), nil
  136. }
  137. func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
  138. otherRatios := map[string]map[string]float64{
  139. "wan2.5-i2v-preview": {
  140. "480P": 1,
  141. "720P": 2,
  142. "1080P": 1 / 0.3,
  143. },
  144. "wan2.2-i2v-plus": {
  145. "480P": 1,
  146. "1080P": 0.7 / 0.14,
  147. },
  148. "wan2.2-kf2v-flash": {
  149. "480P": 1,
  150. "720P": 2,
  151. "1080P": 4.8,
  152. },
  153. "wan2.2-i2v-flash": {
  154. "480P": 1,
  155. "720P": 2,
  156. },
  157. "wan2.2-s2v": {
  158. "480P": 1,
  159. "720P": 0.9 / 0.5,
  160. },
  161. }
  162. aliReq := &AliVideoRequest{
  163. Model: req.Model,
  164. Input: AliVideoInput{
  165. Prompt: req.Prompt,
  166. ImgURL: req.InputReference,
  167. },
  168. Parameters: &AliVideoParameters{
  169. PromptExtend: true, // 默认开启智能改写
  170. Watermark: false,
  171. },
  172. }
  173. // 处理分辨率映射
  174. if req.Size != "" {
  175. resolution := strings.ToUpper(req.Size)
  176. // 支持 480p, 720p, 1080p 或 480P, 720P, 1080P
  177. if !strings.HasSuffix(resolution, "P") {
  178. resolution = resolution + "P"
  179. }
  180. aliReq.Parameters.Resolution = resolution
  181. } else {
  182. // 根据模型设置默认分辨率
  183. if strings.HasPrefix(req.Model, "wan2.5") {
  184. aliReq.Parameters.Resolution = "1080P"
  185. } else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
  186. aliReq.Parameters.Resolution = "720P"
  187. } else if strings.HasPrefix(req.Model, "wan2.2-i2v-plus") {
  188. aliReq.Parameters.Resolution = "1080P"
  189. } else {
  190. aliReq.Parameters.Resolution = "720P"
  191. }
  192. }
  193. // 处理时长
  194. if req.Duration > 0 {
  195. aliReq.Parameters.Duration = req.Duration
  196. } else if req.Seconds != "" {
  197. seconds, err := strconv.Atoi(req.Seconds)
  198. if err != nil {
  199. return nil, errors.Wrap(err, "convert seconds to int failed")
  200. } else {
  201. aliReq.Parameters.Duration = seconds
  202. }
  203. } else {
  204. aliReq.Parameters.Duration = 5 // 默认5秒
  205. }
  206. // 从 metadata 中提取额外参数
  207. if req.Metadata != nil {
  208. if metadataBytes, err := common.Marshal(req.Metadata); err == nil {
  209. err = common.Unmarshal(metadataBytes, aliReq)
  210. if err != nil {
  211. return nil, errors.Wrap(err, "unmarshal metadata failed")
  212. }
  213. } else {
  214. return nil, errors.Wrap(err, "marshal metadata failed")
  215. }
  216. }
  217. if aliReq.Model != req.Model {
  218. return nil, errors.New("can't change model with metadata")
  219. }
  220. info.PriceData.OtherRatios = map[string]float64{
  221. "seconds": float64(aliReq.Parameters.Duration),
  222. }
  223. if otherRatio, ok := otherRatios[req.Model]; ok {
  224. if ratio, ok := otherRatio[aliReq.Parameters.Resolution]; ok {
  225. info.PriceData.OtherRatios[fmt.Sprintf("resolution-%s", aliReq.Parameters.Resolution)] = ratio
  226. }
  227. }
  228. // println(fmt.Sprintf("other ratios: %v", info.PriceData.OtherRatios))
  229. return aliReq, nil
  230. }
  231. // DoRequest delegates to common helper
  232. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  233. return channel.DoTaskApiRequest(a, c, info, requestBody)
  234. }
  235. // DoResponse handles upstream response
  236. func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
  237. responseBody, err := io.ReadAll(resp.Body)
  238. if err != nil {
  239. taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  240. return
  241. }
  242. _ = resp.Body.Close()
  243. // 解析阿里响应
  244. var aliResp AliVideoResponse
  245. if err := common.Unmarshal(responseBody, &aliResp); err != nil {
  246. taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
  247. return
  248. }
  249. // 检查错误
  250. if aliResp.Code != "" {
  251. taskErr = service.TaskErrorWrapper(fmt.Errorf("%s: %s", aliResp.Code, aliResp.Message), "ali_api_error", resp.StatusCode)
  252. return
  253. }
  254. if aliResp.Output.TaskID == "" {
  255. taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
  256. return
  257. }
  258. // 转换为 OpenAI 格式响应
  259. openAIResp := dto.NewOpenAIVideo()
  260. openAIResp.ID = aliResp.Output.TaskID
  261. openAIResp.Model = c.GetString("model")
  262. if openAIResp.Model == "" && info != nil {
  263. openAIResp.Model = info.OriginModelName
  264. }
  265. openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
  266. openAIResp.CreatedAt = common.GetTimestamp()
  267. // 返回 OpenAI 格式
  268. c.JSON(http.StatusOK, openAIResp)
  269. return aliResp.Output.TaskID, responseBody, nil
  270. }
  271. // FetchTask 查询任务状态
  272. func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
  273. taskID, ok := body["task_id"].(string)
  274. if !ok {
  275. return nil, fmt.Errorf("invalid task_id")
  276. }
  277. uri := fmt.Sprintf("%s/api/v1/tasks/%s", baseUrl, taskID)
  278. req, err := http.NewRequest(http.MethodGet, uri, nil)
  279. if err != nil {
  280. return nil, err
  281. }
  282. req.Header.Set("Authorization", "Bearer "+key)
  283. return service.GetHttpClient().Do(req)
  284. }
  285. func (a *TaskAdaptor) GetModelList() []string {
  286. return ModelList
  287. }
  288. func (a *TaskAdaptor) GetChannelName() string {
  289. return ChannelName
  290. }
  291. // ParseTaskResult 解析任务结果
  292. func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
  293. var aliResp AliVideoResponse
  294. if err := common.Unmarshal(respBody, &aliResp); err != nil {
  295. return nil, errors.Wrap(err, "unmarshal task result failed")
  296. }
  297. taskResult := relaycommon.TaskInfo{
  298. Code: 0,
  299. }
  300. // 状态映射
  301. switch aliResp.Output.TaskStatus {
  302. case "PENDING":
  303. taskResult.Status = model.TaskStatusQueued
  304. case "RUNNING":
  305. taskResult.Status = model.TaskStatusInProgress
  306. case "SUCCEEDED":
  307. taskResult.Status = model.TaskStatusSuccess
  308. // 阿里直接返回视频URL,不需要额外的代理端点
  309. taskResult.Url = aliResp.Output.VideoURL
  310. case "FAILED", "CANCELED", "UNKNOWN":
  311. taskResult.Status = model.TaskStatusFailure
  312. if aliResp.Message != "" {
  313. taskResult.Reason = aliResp.Message
  314. } else if aliResp.Output.Message != "" {
  315. taskResult.Reason = fmt.Sprintf("task failed, code: %s , message: %s", aliResp.Output.Code, aliResp.Output.Message)
  316. } else {
  317. taskResult.Reason = "task failed"
  318. }
  319. default:
  320. taskResult.Status = model.TaskStatusQueued
  321. }
  322. return &taskResult, nil
  323. }
  324. func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
  325. var aliResp AliVideoResponse
  326. if err := common.Unmarshal(task.Data, &aliResp); err != nil {
  327. return nil, errors.Wrap(err, "unmarshal ali response failed")
  328. }
  329. openAIResp := dto.NewOpenAIVideo()
  330. openAIResp.ID = task.TaskID
  331. openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
  332. openAIResp.Model = task.Properties.OriginModelName
  333. openAIResp.SetProgressStr(task.Progress)
  334. openAIResp.CreatedAt = task.CreatedAt
  335. openAIResp.CompletedAt = task.UpdatedAt
  336. // 设置视频URL(核心字段)
  337. openAIResp.SetMetadata("url", aliResp.Output.VideoURL)
  338. // 错误处理
  339. if aliResp.Code != "" {
  340. openAIResp.Error = &dto.OpenAIVideoError{
  341. Code: aliResp.Code,
  342. Message: aliResp.Message,
  343. }
  344. } else if aliResp.Output.Code != "" {
  345. openAIResp.Error = &dto.OpenAIVideoError{
  346. Code: aliResp.Output.Code,
  347. Message: aliResp.Output.Message,
  348. }
  349. }
  350. return common.Marshal(openAIResp)
  351. }
  352. func convertAliStatus(aliStatus string) string {
  353. switch aliStatus {
  354. case "PENDING":
  355. return dto.VideoStatusQueued
  356. case "RUNNING":
  357. return dto.VideoStatusInProgress
  358. case "SUCCEEDED":
  359. return dto.VideoStatusCompleted
  360. case "FAILED", "CANCELED", "UNKNOWN":
  361. return dto.VideoStatusFailed
  362. default:
  363. return dto.VideoStatusUnknown
  364. }
  365. }