relay_task.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. package relay
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/relay/channel"
  15. "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
  16. relaycommon "github.com/QuantumNous/new-api/relay/common"
  17. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  18. "github.com/QuantumNous/new-api/relay/helper"
  19. "github.com/QuantumNous/new-api/service"
  20. "github.com/gin-gonic/gin"
  21. )
  22. type TaskSubmitResult struct {
  23. UpstreamTaskID string
  24. TaskData []byte
  25. Platform constant.TaskPlatform
  26. ModelName string
  27. Quota int
  28. //PerCallPrice types.PriceData
  29. }
  30. // ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
  31. // 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
  32. // (通过 info.LockedChannel,重试时复用同一渠道并轮换 key),
  33. // 以及提取 OtherRatios(时长、分辨率)。
  34. // 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
  35. func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
  36. // 检测 remix action
  37. path := c.Request.URL.Path
  38. if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
  39. info.Action = constant.TaskActionRemix
  40. }
  41. if info.Action == constant.TaskActionRemix {
  42. videoID := c.Param("video_id")
  43. if strings.TrimSpace(videoID) == "" {
  44. return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
  45. }
  46. info.OriginTaskID = videoID
  47. }
  48. if info.OriginTaskID == "" {
  49. return nil
  50. }
  51. // 查找原始任务
  52. originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
  53. if err != nil {
  54. return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  55. }
  56. if !exist {
  57. return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  58. }
  59. // 从原始任务推导模型名称
  60. if info.OriginModelName == "" {
  61. if originTask.Properties.OriginModelName != "" {
  62. info.OriginModelName = originTask.Properties.OriginModelName
  63. } else if originTask.Properties.UpstreamModelName != "" {
  64. info.OriginModelName = originTask.Properties.UpstreamModelName
  65. } else {
  66. var taskData map[string]interface{}
  67. _ = common.Unmarshal(originTask.Data, &taskData)
  68. if m, ok := taskData["model"].(string); ok && m != "" {
  69. info.OriginModelName = m
  70. }
  71. }
  72. }
  73. // 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key)
  74. ch, err := model.GetChannelById(originTask.ChannelId, true)
  75. if err != nil {
  76. return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  77. }
  78. if ch.Status != common.ChannelStatusEnabled {
  79. return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
  80. }
  81. info.LockedChannel = ch
  82. if originTask.ChannelId != info.ChannelId {
  83. key, _, newAPIError := ch.GetNextEnabledKey()
  84. if newAPIError != nil {
  85. return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
  86. }
  87. common.SetContextKey(c, constant.ContextKeyChannelKey, key)
  88. common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
  89. common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
  90. common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
  91. info.ChannelBaseUrl = ch.GetBaseURL()
  92. info.ChannelId = originTask.ChannelId
  93. info.ChannelType = ch.Type
  94. info.ApiKey = key
  95. }
  96. // 提取 remix 参数(时长、分辨率 → OtherRatios)
  97. if info.Action == constant.TaskActionRemix {
  98. if originTask.PrivateData.BillingContext != nil {
  99. // 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在)
  100. for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
  101. info.PriceData.AddOtherRatio(s, f)
  102. }
  103. } else {
  104. // 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在)
  105. var taskData map[string]interface{}
  106. _ = common.Unmarshal(originTask.Data, &taskData)
  107. secondsStr, _ := taskData["seconds"].(string)
  108. seconds, _ := strconv.Atoi(secondsStr)
  109. if seconds <= 0 {
  110. seconds = 4
  111. }
  112. sizeStr, _ := taskData["size"].(string)
  113. if info.PriceData.OtherRatios == nil {
  114. info.PriceData.OtherRatios = map[string]float64{}
  115. }
  116. info.PriceData.OtherRatios["seconds"] = float64(seconds)
  117. info.PriceData.OtherRatios["size"] = 1
  118. if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
  119. info.PriceData.OtherRatios["size"] = 1.666667
  120. }
  121. }
  122. }
  123. return nil
  124. }
  125. // RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
  126. // 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
  127. // 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
  128. // 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
  129. // 控制器负责 defer Refund 和成功后 Settle。
  130. func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
  131. info.InitChannelMeta(c)
  132. // 1. 确定 platform → 创建适配器 → 验证请求
  133. platform := constant.TaskPlatform(c.GetString("platform"))
  134. if platform == "" {
  135. platform = GetTaskPlatform(c)
  136. }
  137. adaptor := GetTaskAdaptor(platform)
  138. if adaptor == nil {
  139. return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  140. }
  141. adaptor.Init(info)
  142. if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
  143. return nil, taskErr
  144. }
  145. // 2. 确定模型名称
  146. modelName := info.OriginModelName
  147. if modelName == "" {
  148. modelName = service.CoverTaskActionToModelName(platform, info.Action)
  149. }
  150. // 3. 预生成公开 task ID(仅首次)
  151. if info.PublicTaskID == "" {
  152. info.PublicTaskID = model.GenerateTaskID()
  153. }
  154. // 4. 价格计算:基础模型价格
  155. info.OriginModelName = modelName
  156. info.PriceData = helper.ModelPriceHelperPerCall(c, info)
  157. // 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
  158. // 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
  159. // ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
  160. if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
  161. for k, v := range estimatedRatios {
  162. info.PriceData.AddOtherRatio(k, v)
  163. }
  164. }
  165. // 6. 将 OtherRatios 应用到基础额度
  166. if !common.StringsContains(constant.TaskPricePatches, modelName) {
  167. for _, ra := range info.PriceData.OtherRatios {
  168. if ra != 1.0 {
  169. info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
  170. }
  171. }
  172. }
  173. // 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
  174. if info.Billing == nil && !info.PriceData.FreeModel {
  175. info.ForcePreConsume = true
  176. if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
  177. return nil, service.TaskErrorFromAPIError(apiErr)
  178. }
  179. }
  180. // 8. 构建请求体
  181. requestBody, err := adaptor.BuildRequestBody(c, info)
  182. if err != nil {
  183. return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  184. }
  185. // 9. 发送请求
  186. resp, err := adaptor.DoRequest(c, info, requestBody)
  187. if err != nil {
  188. return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  189. }
  190. if resp != nil && resp.StatusCode != http.StatusOK {
  191. responseBody, _ := io.ReadAll(resp.Body)
  192. return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  193. }
  194. // 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
  195. otherRatios := info.PriceData.OtherRatios
  196. if otherRatios == nil {
  197. otherRatios = map[string]float64{}
  198. }
  199. ratiosJSON, _ := common.Marshal(otherRatios)
  200. c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
  201. // 11. 解析响应
  202. upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
  203. if taskErr != nil {
  204. return nil, taskErr
  205. }
  206. // 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
  207. finalQuota := info.PriceData.Quota
  208. if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
  209. // 基于调整后的 ratios 重新计算 quota
  210. finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
  211. info.PriceData.OtherRatios = adjustedRatios
  212. info.PriceData.Quota = finalQuota
  213. }
  214. return &TaskSubmitResult{
  215. UpstreamTaskID: upstreamTaskID,
  216. TaskData: taskData,
  217. Platform: platform,
  218. ModelName: modelName,
  219. Quota: finalQuota,
  220. }, nil
  221. }
  222. // recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
  223. // 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
  224. func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
  225. // 从 PriceData 获取不含 OtherRatios 的基础价格
  226. baseQuota := info.PriceData.Quota
  227. // 先除掉原有的 OtherRatios 恢复基础额度
  228. for _, ra := range info.PriceData.OtherRatios {
  229. if ra != 1.0 && ra > 0 {
  230. baseQuota = int(float64(baseQuota) / ra)
  231. }
  232. }
  233. // 应用新的 ratios
  234. result := float64(baseQuota)
  235. for _, ra := range ratios {
  236. if ra != 1.0 {
  237. result *= ra
  238. }
  239. }
  240. return int(result)
  241. }
  242. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  243. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  244. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  245. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  246. }
  247. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  248. respBuilder, ok := fetchRespBuilders[relayMode]
  249. if !ok {
  250. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  251. }
  252. respBody, taskErr := respBuilder(c)
  253. if taskErr != nil {
  254. return taskErr
  255. }
  256. if len(respBody) == 0 {
  257. respBody = []byte("{\"code\":\"success\",\"data\":null}")
  258. }
  259. c.Writer.Header().Set("Content-Type", "application/json")
  260. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  261. if err != nil {
  262. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  263. return
  264. }
  265. return
  266. }
  267. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  268. userId := c.GetInt("id")
  269. var condition = struct {
  270. IDs []any `json:"ids"`
  271. Action string `json:"action"`
  272. }{}
  273. err := c.BindJSON(&condition)
  274. if err != nil {
  275. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  276. return
  277. }
  278. var tasks []any
  279. if len(condition.IDs) > 0 {
  280. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  281. if err != nil {
  282. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  283. return
  284. }
  285. for _, task := range taskModels {
  286. tasks = append(tasks, TaskModel2Dto(task))
  287. }
  288. } else {
  289. tasks = make([]any, 0)
  290. }
  291. respBody, err = common.Marshal(dto.TaskResponse[[]any]{
  292. Code: "success",
  293. Data: tasks,
  294. })
  295. return
  296. }
  297. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  298. taskId := c.Param("id")
  299. userId := c.GetInt("id")
  300. originTask, exist, err := model.GetByTaskId(userId, taskId)
  301. if err != nil {
  302. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  303. return
  304. }
  305. if !exist {
  306. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  307. return
  308. }
  309. respBody, err = common.Marshal(dto.TaskResponse[any]{
  310. Code: "success",
  311. Data: TaskModel2Dto(originTask),
  312. })
  313. return
  314. }
  315. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  316. taskId := c.Param("task_id")
  317. if taskId == "" {
  318. taskId = c.GetString("task_id")
  319. }
  320. userId := c.GetInt("id")
  321. originTask, exist, err := model.GetByTaskId(userId, taskId)
  322. if err != nil {
  323. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  324. return
  325. }
  326. if !exist {
  327. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  328. return
  329. }
  330. isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
  331. // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
  332. if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
  333. respBody = realtimeResp
  334. return
  335. }
  336. // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
  337. if isOpenAIVideoAPI {
  338. adaptor := GetTaskAdaptor(originTask.Platform)
  339. if adaptor == nil {
  340. taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
  341. return
  342. }
  343. if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
  344. openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
  345. if err != nil {
  346. taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
  347. return
  348. }
  349. respBody = openAIVideoData
  350. return
  351. }
  352. taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
  353. return
  354. }
  355. // 通用 TaskDto 格式
  356. respBody, err = common.Marshal(dto.TaskResponse[any]{
  357. Code: "success",
  358. Data: TaskModel2Dto(originTask),
  359. })
  360. if err != nil {
  361. taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
  362. }
  363. return
  364. }
  365. // tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
  366. // 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
  367. // 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
  368. func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
  369. channelModel, err := model.GetChannelById(task.ChannelId, true)
  370. if err != nil {
  371. return nil
  372. }
  373. if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
  374. return nil
  375. }
  376. baseURL := constant.ChannelBaseURLs[channelModel.Type]
  377. if channelModel.GetBaseURL() != "" {
  378. baseURL = channelModel.GetBaseURL()
  379. }
  380. proxy := channelModel.GetSetting().Proxy
  381. adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
  382. if adaptor == nil {
  383. return nil
  384. }
  385. resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
  386. "task_id": task.GetUpstreamTaskID(),
  387. "action": task.Action,
  388. }, proxy)
  389. if err != nil || resp == nil {
  390. return nil
  391. }
  392. defer resp.Body.Close()
  393. body, err := io.ReadAll(resp.Body)
  394. if err != nil {
  395. return nil
  396. }
  397. ti, err := adaptor.ParseTaskResult(body)
  398. if err != nil || ti == nil {
  399. return nil
  400. }
  401. // 将上游最新状态更新到 task
  402. if ti.Status != "" {
  403. task.Status = model.TaskStatus(ti.Status)
  404. }
  405. if ti.Progress != "" {
  406. task.Progress = ti.Progress
  407. }
  408. if strings.HasPrefix(ti.Url, "data:") {
  409. // data: URI — kept in Data, not ResultURL
  410. } else if ti.Url != "" {
  411. task.PrivateData.ResultURL = ti.Url
  412. } else if task.Status == model.TaskStatusSuccess {
  413. // No URL from adaptor — construct proxy URL using public task ID
  414. task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
  415. }
  416. _ = task.Update()
  417. // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
  418. if isOpenAIVideoAPI {
  419. return nil
  420. }
  421. // 非 OpenAI Video API: 构建自定义格式响应
  422. format := detectVideoFormat(body)
  423. out := map[string]any{
  424. "error": nil,
  425. "format": format,
  426. "metadata": nil,
  427. "status": mapTaskStatusToSimple(task.Status),
  428. "task_id": task.TaskID,
  429. "url": task.GetResultURL(),
  430. }
  431. respBody, _ := common.Marshal(dto.TaskResponse[any]{
  432. Code: "success",
  433. Data: out,
  434. })
  435. return respBody
  436. }
  437. // detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
  438. func detectVideoFormat(rawBody []byte) string {
  439. var raw map[string]any
  440. if err := common.Unmarshal(rawBody, &raw); err != nil {
  441. return "mp4"
  442. }
  443. respObj, ok := raw["response"].(map[string]any)
  444. if !ok {
  445. return "mp4"
  446. }
  447. vids, ok := respObj["videos"].([]any)
  448. if !ok || len(vids) == 0 {
  449. return "mp4"
  450. }
  451. v0, ok := vids[0].(map[string]any)
  452. if !ok {
  453. return "mp4"
  454. }
  455. mt, ok := v0["mimeType"].(string)
  456. if !ok || mt == "" || strings.Contains(mt, "mp4") {
  457. return "mp4"
  458. }
  459. return mt
  460. }
  461. // mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
  462. func mapTaskStatusToSimple(status model.TaskStatus) string {
  463. switch status {
  464. case model.TaskStatusSuccess:
  465. return "succeeded"
  466. case model.TaskStatusFailure:
  467. return "failed"
  468. case model.TaskStatusQueued, model.TaskStatusSubmitted:
  469. return "queued"
  470. default:
  471. return "processing"
  472. }
  473. }
  474. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  475. return &dto.TaskDto{
  476. ID: task.ID,
  477. CreatedAt: task.CreatedAt,
  478. UpdatedAt: task.UpdatedAt,
  479. TaskID: task.TaskID,
  480. Platform: string(task.Platform),
  481. UserId: task.UserId,
  482. Group: task.Group,
  483. ChannelId: task.ChannelId,
  484. Quota: task.Quota,
  485. Action: task.Action,
  486. Status: string(task.Status),
  487. FailReason: task.FailReason,
  488. ResultURL: task.GetResultURL(),
  489. SubmitTime: task.SubmitTime,
  490. StartTime: task.StartTime,
  491. FinishTime: task.FinishTime,
  492. Progress: task.Progress,
  493. Properties: task.Properties,
  494. Username: task.Username,
  495. Data: task.Data,
  496. }
  497. }