adaptor.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. package ali
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/QuantumNous/new-api/relay/channel"
  14. "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
  15. relaycommon "github.com/QuantumNous/new-api/relay/common"
  16. "github.com/QuantumNous/new-api/service"
  17. "github.com/samber/lo"
  18. "github.com/gin-gonic/gin"
  19. "github.com/pkg/errors"
  20. )
  21. // ============================
  22. // Request / Response structures
  23. // ============================
  24. // AliVideoRequest 阿里通义万相视频生成请求
  25. type AliVideoRequest struct {
  26. Model string `json:"model"`
  27. Input AliVideoInput `json:"input"`
  28. Parameters *AliVideoParameters `json:"parameters,omitempty"`
  29. }
  30. // AliVideoInput 视频输入参数
  31. type AliVideoInput struct {
  32. Prompt string `json:"prompt,omitempty"` // 文本提示词
  33. ImgURL string `json:"img_url,omitempty"` // 首帧图像URL或Base64(图生视频)
  34. FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL(首尾帧生视频)
  35. LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL(首尾帧生视频)
  36. AudioURL string `json:"audio_url,omitempty"` // 音频URL(wan2.5支持)
  37. NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
  38. Template string `json:"template,omitempty"` // 视频特效模板
  39. }
  40. // AliVideoParameters 视频参数
  41. type AliVideoParameters struct {
  42. Resolution string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P(图生视频、首尾帧生视频)
  43. Size string `json:"size,omitempty"` // 尺寸: 如 "832*480"(文生视频)
  44. Duration int `json:"duration,omitempty"` // 时长: 3-10秒
  45. PromptExtend bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
  46. Watermark bool `json:"watermark,omitempty"` // 是否添加水印
  47. Audio *bool `json:"audio,omitempty"` // 是否添加音频(wan2.5)
  48. Seed int `json:"seed,omitempty"` // 随机数种子
  49. }
  50. // AliVideoResponse 阿里通义万相响应
  51. type AliVideoResponse struct {
  52. Output AliVideoOutput `json:"output"`
  53. RequestID string `json:"request_id"`
  54. Code string `json:"code,omitempty"`
  55. Message string `json:"message,omitempty"`
  56. Usage *AliUsage `json:"usage,omitempty"`
  57. }
  58. // AliVideoOutput 输出信息
  59. type AliVideoOutput struct {
  60. TaskID string `json:"task_id"`
  61. TaskStatus string `json:"task_status"`
  62. SubmitTime string `json:"submit_time,omitempty"`
  63. ScheduledTime string `json:"scheduled_time,omitempty"`
  64. EndTime string `json:"end_time,omitempty"`
  65. OrigPrompt string `json:"orig_prompt,omitempty"`
  66. ActualPrompt string `json:"actual_prompt,omitempty"`
  67. VideoURL string `json:"video_url,omitempty"`
  68. Code string `json:"code,omitempty"`
  69. Message string `json:"message,omitempty"`
  70. }
  71. // AliUsage 使用统计
  72. type AliUsage struct {
  73. Duration int `json:"duration,omitempty"`
  74. VideoCount int `json:"video_count,omitempty"`
  75. SR int `json:"SR,omitempty"`
  76. }
  77. type AliMetadata struct {
  78. // Input 相关
  79. AudioURL string `json:"audio_url,omitempty"` // 音频URL
  80. ImgURL string `json:"img_url,omitempty"` // 图片URL(图生视频)
  81. FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL(首尾帧生视频)
  82. LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL(首尾帧生视频)
  83. NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
  84. Template string `json:"template,omitempty"` // 视频特效模板
  85. // Parameters 相关
  86. Resolution *string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P
  87. Size *string `json:"size,omitempty"` // 尺寸: 如 "832*480"
  88. Duration *int `json:"duration,omitempty"` // 时长
  89. PromptExtend *bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
  90. Watermark *bool `json:"watermark,omitempty"` // 是否添加水印
  91. Audio *bool `json:"audio,omitempty"` // 是否添加音频
  92. Seed *int `json:"seed,omitempty"` // 随机数种子
  93. }
  94. // ============================
  95. // Adaptor implementation
  96. // ============================
  97. type TaskAdaptor struct {
  98. taskcommon.BaseBilling
  99. ChannelType int
  100. apiKey string
  101. baseURL string
  102. }
  103. func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
  104. a.ChannelType = info.ChannelType
  105. a.baseURL = info.ChannelBaseUrl
  106. a.apiKey = info.ApiKey
  107. }
  108. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  109. // ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
  110. return relaycommon.ValidateMultipartDirect(c, info)
  111. }
  112. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
  113. return fmt.Sprintf("%s/api/v1/services/aigc/video-generation/video-synthesis", a.baseURL), nil
  114. }
  115. // BuildRequestHeader sets required headers for Ali API
  116. func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
  117. req.Header.Set("Authorization", "Bearer "+a.apiKey)
  118. req.Header.Set("Content-Type", "application/json")
  119. req.Header.Set("X-DashScope-Async", "enable") // 阿里异步任务必须设置
  120. return nil
  121. }
  122. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
  123. taskReq, err := relaycommon.GetTaskRequest(c)
  124. if err != nil {
  125. return nil, errors.Wrap(err, "get_task_request_failed")
  126. }
  127. aliReq, err := a.convertToAliRequest(info, taskReq)
  128. if err != nil {
  129. return nil, errors.Wrap(err, "convert_to_ali_request_failed")
  130. }
  131. logger.LogJson(c, "ali video request body", aliReq)
  132. bodyBytes, err := common.Marshal(aliReq)
  133. if err != nil {
  134. return nil, errors.Wrap(err, "marshal_ali_request_failed")
  135. }
  136. return bytes.NewReader(bodyBytes), nil
  137. }
  138. var (
  139. size480p = []string{
  140. "832*480",
  141. "480*832",
  142. "624*624",
  143. }
  144. size720p = []string{
  145. "1280*720",
  146. "720*1280",
  147. "960*960",
  148. "1088*832",
  149. "832*1088",
  150. }
  151. size1080p = []string{
  152. "1920*1080",
  153. "1080*1920",
  154. "1440*1440",
  155. "1632*1248",
  156. "1248*1632",
  157. }
  158. )
  159. func sizeToResolution(size string) (string, error) {
  160. if lo.Contains(size480p, size) {
  161. return "480P", nil
  162. } else if lo.Contains(size720p, size) {
  163. return "720P", nil
  164. } else if lo.Contains(size1080p, size) {
  165. return "1080P", nil
  166. }
  167. return "", fmt.Errorf("invalid size: %s", size)
  168. }
  169. func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) {
  170. otherRatios := make(map[string]float64)
  171. aliRatios := map[string]map[string]float64{
  172. "wan2.6-i2v": {
  173. "720P": 1,
  174. "1080P": 1 / 0.6,
  175. },
  176. "wan2.5-t2v-preview": {
  177. "480P": 1,
  178. "720P": 2,
  179. "1080P": 1 / 0.3,
  180. },
  181. "wan2.2-t2v-plus": {
  182. "480P": 1,
  183. "1080P": 0.7 / 0.14,
  184. },
  185. "wan2.5-i2v-preview": {
  186. "480P": 1,
  187. "720P": 2,
  188. "1080P": 1 / 0.3,
  189. },
  190. "wan2.2-i2v-plus": {
  191. "480P": 1,
  192. "1080P": 0.7 / 0.14,
  193. },
  194. "wan2.2-kf2v-flash": {
  195. "480P": 1,
  196. "720P": 2,
  197. "1080P": 4.8,
  198. },
  199. "wan2.2-i2v-flash": {
  200. "480P": 1,
  201. "720P": 2,
  202. },
  203. "wan2.2-s2v": {
  204. "480P": 1,
  205. "720P": 0.9 / 0.5,
  206. },
  207. }
  208. var resolution string
  209. // size match
  210. if aliReq.Parameters.Size != "" {
  211. toResolution, err := sizeToResolution(aliReq.Parameters.Size)
  212. if err != nil {
  213. return nil, err
  214. }
  215. resolution = toResolution
  216. } else {
  217. resolution = strings.ToUpper(aliReq.Parameters.Resolution)
  218. if !strings.HasSuffix(resolution, "P") {
  219. resolution = resolution + "P"
  220. }
  221. }
  222. if otherRatio, ok := aliRatios[aliReq.Model]; ok {
  223. if ratio, ok := otherRatio[resolution]; ok {
  224. otherRatios[fmt.Sprintf("resolution-%s", resolution)] = ratio
  225. }
  226. }
  227. return otherRatios, nil
  228. }
  229. func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
  230. upstreamModel := req.Model
  231. if info.IsModelMapped {
  232. upstreamModel = info.UpstreamModelName
  233. }
  234. aliReq := &AliVideoRequest{
  235. Model: upstreamModel,
  236. Input: AliVideoInput{
  237. Prompt: req.Prompt,
  238. ImgURL: req.InputReference,
  239. },
  240. Parameters: &AliVideoParameters{
  241. PromptExtend: true, // 默认开启智能改写
  242. Watermark: false,
  243. },
  244. }
  245. // 处理分辨率映射
  246. if req.Size != "" {
  247. // text to video size must be contained *
  248. if strings.Contains(req.Model, "t2v") && !strings.Contains(req.Size, "*") {
  249. return nil, fmt.Errorf("invalid size: %s, example: %s", req.Size, "1920*1080")
  250. }
  251. if strings.Contains(req.Size, "*") {
  252. aliReq.Parameters.Size = req.Size
  253. } else {
  254. resolution := strings.ToUpper(req.Size)
  255. // 支持 480p, 720p, 1080p 或 480P, 720P, 1080P
  256. if !strings.HasSuffix(resolution, "P") {
  257. resolution = resolution + "P"
  258. }
  259. aliReq.Parameters.Resolution = resolution
  260. }
  261. } else {
  262. // 根据模型设置默认分辨率
  263. if strings.Contains(req.Model, "t2v") { // image to video
  264. if strings.HasPrefix(req.Model, "wan2.5") {
  265. aliReq.Parameters.Size = "1920*1080"
  266. } else if strings.HasPrefix(req.Model, "wan2.2") {
  267. aliReq.Parameters.Size = "1920*1080"
  268. } else {
  269. aliReq.Parameters.Size = "1280*720"
  270. }
  271. } else {
  272. if strings.HasPrefix(req.Model, "wan2.6") {
  273. aliReq.Parameters.Resolution = "1080P"
  274. } else if strings.HasPrefix(req.Model, "wan2.5") {
  275. aliReq.Parameters.Resolution = "1080P"
  276. } else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
  277. aliReq.Parameters.Resolution = "720P"
  278. } else if strings.HasPrefix(req.Model, "wan2.2-i2v-plus") {
  279. aliReq.Parameters.Resolution = "1080P"
  280. } else {
  281. aliReq.Parameters.Resolution = "720P"
  282. }
  283. }
  284. }
  285. // 处理时长
  286. if req.Duration > 0 {
  287. aliReq.Parameters.Duration = req.Duration
  288. } else if req.Seconds != "" {
  289. seconds, err := strconv.Atoi(req.Seconds)
  290. if err != nil {
  291. return nil, errors.Wrap(err, "convert seconds to int failed")
  292. } else {
  293. aliReq.Parameters.Duration = seconds
  294. }
  295. } else {
  296. aliReq.Parameters.Duration = 5 // 默认5秒
  297. }
  298. // 从 metadata 中提取额外参数
  299. if req.Metadata != nil {
  300. if metadataBytes, err := common.Marshal(req.Metadata); err == nil {
  301. err = common.Unmarshal(metadataBytes, aliReq)
  302. if err != nil {
  303. return nil, errors.Wrap(err, "unmarshal metadata failed")
  304. }
  305. } else {
  306. return nil, errors.Wrap(err, "marshal metadata failed")
  307. }
  308. }
  309. if aliReq.Model != upstreamModel {
  310. return nil, errors.New("can't change model with metadata")
  311. }
  312. return aliReq, nil
  313. }
  314. // EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
  315. // 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
  316. func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
  317. taskReq, err := relaycommon.GetTaskRequest(c)
  318. if err != nil {
  319. return nil
  320. }
  321. aliReq, err := a.convertToAliRequest(info, taskReq)
  322. if err != nil {
  323. return nil
  324. }
  325. otherRatios := map[string]float64{
  326. "seconds": float64(aliReq.Parameters.Duration),
  327. }
  328. ratios, err := ProcessAliOtherRatios(aliReq)
  329. if err != nil {
  330. return otherRatios
  331. }
  332. for k, v := range ratios {
  333. otherRatios[k] = v
  334. }
  335. return otherRatios
  336. }
  337. // DoRequest delegates to common helper
  338. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  339. return channel.DoTaskApiRequest(a, c, info, requestBody)
  340. }
  341. // DoResponse handles upstream response
  342. func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
  343. responseBody, err := io.ReadAll(resp.Body)
  344. if err != nil {
  345. taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  346. return
  347. }
  348. _ = resp.Body.Close()
  349. // 解析阿里响应
  350. var aliResp AliVideoResponse
  351. if err := common.Unmarshal(responseBody, &aliResp); err != nil {
  352. taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
  353. return
  354. }
  355. // 检查错误
  356. if aliResp.Code != "" {
  357. taskErr = service.TaskErrorWrapper(fmt.Errorf("%s: %s", aliResp.Code, aliResp.Message), "ali_api_error", resp.StatusCode)
  358. return
  359. }
  360. if aliResp.Output.TaskID == "" {
  361. taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
  362. return
  363. }
  364. // 转换为 OpenAI 格式响应
  365. openAIResp := dto.NewOpenAIVideo()
  366. openAIResp.ID = info.PublicTaskID
  367. openAIResp.TaskID = info.PublicTaskID
  368. openAIResp.Model = c.GetString("model")
  369. if openAIResp.Model == "" && info != nil {
  370. openAIResp.Model = info.OriginModelName
  371. }
  372. openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
  373. openAIResp.CreatedAt = common.GetTimestamp()
  374. // 返回 OpenAI 格式
  375. c.JSON(http.StatusOK, openAIResp)
  376. return aliResp.Output.TaskID, responseBody, nil
  377. }
  378. // FetchTask 查询任务状态
  379. func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
  380. taskID, ok := body["task_id"].(string)
  381. if !ok {
  382. return nil, fmt.Errorf("invalid task_id")
  383. }
  384. uri := fmt.Sprintf("%s/api/v1/tasks/%s", baseUrl, taskID)
  385. req, err := http.NewRequest(http.MethodGet, uri, nil)
  386. if err != nil {
  387. return nil, err
  388. }
  389. req.Header.Set("Authorization", "Bearer "+key)
  390. client, err := service.GetHttpClientWithProxy(proxy)
  391. if err != nil {
  392. return nil, fmt.Errorf("new proxy http client failed: %w", err)
  393. }
  394. return client.Do(req)
  395. }
  396. func (a *TaskAdaptor) GetModelList() []string {
  397. return ModelList
  398. }
  399. func (a *TaskAdaptor) GetChannelName() string {
  400. return ChannelName
  401. }
  402. // ParseTaskResult 解析任务结果
  403. func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
  404. var aliResp AliVideoResponse
  405. if err := common.Unmarshal(respBody, &aliResp); err != nil {
  406. return nil, errors.Wrap(err, "unmarshal task result failed")
  407. }
  408. taskResult := relaycommon.TaskInfo{
  409. Code: 0,
  410. }
  411. // 状态映射
  412. switch aliResp.Output.TaskStatus {
  413. case "PENDING":
  414. taskResult.Status = model.TaskStatusQueued
  415. case "RUNNING":
  416. taskResult.Status = model.TaskStatusInProgress
  417. case "SUCCEEDED":
  418. taskResult.Status = model.TaskStatusSuccess
  419. // 阿里直接返回视频URL,不需要额外的代理端点
  420. taskResult.Url = aliResp.Output.VideoURL
  421. case "FAILED", "CANCELED", "UNKNOWN":
  422. taskResult.Status = model.TaskStatusFailure
  423. if aliResp.Message != "" {
  424. taskResult.Reason = aliResp.Message
  425. } else if aliResp.Output.Message != "" {
  426. taskResult.Reason = fmt.Sprintf("task failed, code: %s , message: %s", aliResp.Output.Code, aliResp.Output.Message)
  427. } else {
  428. taskResult.Reason = "task failed"
  429. }
  430. default:
  431. taskResult.Status = model.TaskStatusQueued
  432. }
  433. return &taskResult, nil
  434. }
  435. func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
  436. var aliResp AliVideoResponse
  437. if err := common.Unmarshal(task.Data, &aliResp); err != nil {
  438. return nil, errors.Wrap(err, "unmarshal ali response failed")
  439. }
  440. openAIResp := dto.NewOpenAIVideo()
  441. openAIResp.ID = task.TaskID
  442. openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
  443. openAIResp.Model = task.Properties.OriginModelName
  444. openAIResp.SetProgressStr(task.Progress)
  445. openAIResp.CreatedAt = task.CreatedAt
  446. openAIResp.CompletedAt = task.UpdatedAt
  447. // 设置视频URL(核心字段)
  448. openAIResp.SetMetadata("url", aliResp.Output.VideoURL)
  449. // 错误处理
  450. if aliResp.Code != "" {
  451. openAIResp.Error = &dto.OpenAIVideoError{
  452. Code: aliResp.Code,
  453. Message: aliResp.Message,
  454. }
  455. } else if aliResp.Output.Code != "" {
  456. openAIResp.Error = &dto.OpenAIVideoError{
  457. Code: aliResp.Output.Code,
  458. Message: aliResp.Output.Message,
  459. }
  460. }
  461. return common.Marshal(openAIResp)
  462. }
  463. func convertAliStatus(aliStatus string) string {
  464. switch aliStatus {
  465. case "PENDING":
  466. return dto.VideoStatusQueued
  467. case "RUNNING":
  468. return dto.VideoStatusInProgress
  469. case "SUCCEEDED":
  470. return dto.VideoStatusCompleted
  471. case "FAILED", "CANCELED", "UNKNOWN":
  472. return dto.VideoStatusFailed
  473. default:
  474. return dto.VideoStatusUnknown
  475. }
  476. }