relay_task.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. package relay
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/relay/channel"
  15. "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
  16. relaycommon "github.com/QuantumNous/new-api/relay/common"
  17. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  18. "github.com/QuantumNous/new-api/relay/helper"
  19. "github.com/QuantumNous/new-api/service"
  20. "github.com/gin-gonic/gin"
  21. )
  22. type TaskSubmitResult struct {
  23. UpstreamTaskID string
  24. TaskData []byte
  25. Platform constant.TaskPlatform
  26. ModelName string
  27. Quota int
  28. //PerCallPrice types.PriceData
  29. }
  30. // ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
  31. // 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道(并通过
  32. // specific_channel_id 禁止重试),以及提取 OtherRatios(时长、分辨率)。
  33. // 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
  34. func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
  35. // 检测 remix action
  36. path := c.Request.URL.Path
  37. if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
  38. info.Action = constant.TaskActionRemix
  39. }
  40. if info.Action == constant.TaskActionRemix {
  41. videoID := c.Param("video_id")
  42. if strings.TrimSpace(videoID) == "" {
  43. return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
  44. }
  45. info.OriginTaskID = videoID
  46. }
  47. if info.OriginTaskID == "" {
  48. return nil
  49. }
  50. // 查找原始任务
  51. originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
  52. if err != nil {
  53. return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  54. }
  55. if !exist {
  56. return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  57. }
  58. // 从原始任务推导模型名称
  59. if info.OriginModelName == "" {
  60. if originTask.Properties.OriginModelName != "" {
  61. info.OriginModelName = originTask.Properties.OriginModelName
  62. } else if originTask.Properties.UpstreamModelName != "" {
  63. info.OriginModelName = originTask.Properties.UpstreamModelName
  64. } else {
  65. var taskData map[string]interface{}
  66. _ = common.Unmarshal(originTask.Data, &taskData)
  67. if m, ok := taskData["model"].(string); ok && m != "" {
  68. info.OriginModelName = m
  69. }
  70. }
  71. }
  72. // 锁定到原始任务的渠道(如果与当前选中的不同)
  73. if originTask.ChannelId != info.ChannelId {
  74. ch, err := model.GetChannelById(originTask.ChannelId, true)
  75. if err != nil {
  76. return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  77. }
  78. if ch.Status != common.ChannelStatusEnabled {
  79. return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
  80. }
  81. key, _, newAPIError := ch.GetNextEnabledKey()
  82. if newAPIError != nil {
  83. return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
  84. }
  85. common.SetContextKey(c, constant.ContextKeyChannelKey, key)
  86. common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
  87. common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
  88. common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
  89. info.ChannelBaseUrl = ch.GetBaseURL()
  90. info.ChannelId = originTask.ChannelId
  91. info.ChannelType = ch.Type
  92. info.ApiKey = key
  93. }
  94. // 渠道已锁定到原始任务 → 禁止重试切换到其他渠道
  95. c.Set("specific_channel_id", fmt.Sprintf("%d", originTask.ChannelId))
  96. // 提取 remix 参数(时长、分辨率 → OtherRatios)
  97. if info.Action == constant.TaskActionRemix {
  98. var taskData map[string]interface{}
  99. _ = common.Unmarshal(originTask.Data, &taskData)
  100. secondsStr, _ := taskData["seconds"].(string)
  101. seconds, _ := strconv.Atoi(secondsStr)
  102. if seconds <= 0 {
  103. seconds = 4
  104. }
  105. sizeStr, _ := taskData["size"].(string)
  106. if info.PriceData.OtherRatios == nil {
  107. info.PriceData.OtherRatios = map[string]float64{}
  108. }
  109. info.PriceData.OtherRatios["seconds"] = float64(seconds)
  110. info.PriceData.OtherRatios["size"] = 1
  111. if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
  112. info.PriceData.OtherRatios["size"] = 1.666667
  113. }
  114. }
  115. return nil
  116. }
  117. // RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
  118. // 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 →
  119. // 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。
  120. // 控制器负责 defer Refund 和成功后 Settle。
  121. func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
  122. info.InitChannelMeta(c)
  123. // 1. 确定 platform → 创建适配器 → 验证请求
  124. platform := constant.TaskPlatform(c.GetString("platform"))
  125. if platform == "" {
  126. platform = GetTaskPlatform(c)
  127. }
  128. adaptor := GetTaskAdaptor(platform)
  129. if adaptor == nil {
  130. return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  131. }
  132. adaptor.Init(info)
  133. if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
  134. return nil, taskErr
  135. }
  136. // 2. 确定模型名称
  137. modelName := info.OriginModelName
  138. if modelName == "" {
  139. modelName = service.CoverTaskActionToModelName(platform, info.Action)
  140. }
  141. // 3. 预生成公开 task ID(仅首次)
  142. if info.PublicTaskID == "" {
  143. info.PublicTaskID = model.GenerateTaskID()
  144. }
  145. // 4. 价格计算
  146. info.OriginModelName = modelName
  147. info.PriceData = helper.ModelPriceHelperPerCall(c, info)
  148. if !common.StringsContains(constant.TaskPricePatches, modelName) {
  149. for _, ra := range info.PriceData.OtherRatios {
  150. if ra != 1.0 {
  151. info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
  152. }
  153. }
  154. }
  155. // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
  156. if info.Billing == nil && !info.PriceData.FreeModel {
  157. info.ForcePreConsume = true
  158. if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
  159. return nil, service.TaskErrorFromAPIError(apiErr)
  160. }
  161. }
  162. // 6. 构建请求体
  163. requestBody, err := adaptor.BuildRequestBody(c, info)
  164. if err != nil {
  165. return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  166. }
  167. // 7. 发送请求
  168. resp, err := adaptor.DoRequest(c, info, requestBody)
  169. if err != nil {
  170. return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  171. }
  172. if resp != nil && resp.StatusCode != http.StatusOK {
  173. responseBody, _ := io.ReadAll(resp.Body)
  174. return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  175. }
  176. // 8. 解析响应
  177. upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
  178. if taskErr != nil {
  179. return nil, taskErr
  180. }
  181. return &TaskSubmitResult{
  182. UpstreamTaskID: upstreamTaskID,
  183. TaskData: taskData,
  184. Platform: platform,
  185. ModelName: modelName,
  186. }, nil
  187. }
  188. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  189. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  190. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  191. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  192. }
  193. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  194. respBuilder, ok := fetchRespBuilders[relayMode]
  195. if !ok {
  196. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  197. }
  198. respBody, taskErr := respBuilder(c)
  199. if taskErr != nil {
  200. return taskErr
  201. }
  202. if len(respBody) == 0 {
  203. respBody = []byte("{\"code\":\"success\",\"data\":null}")
  204. }
  205. c.Writer.Header().Set("Content-Type", "application/json")
  206. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  207. if err != nil {
  208. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  209. return
  210. }
  211. return
  212. }
  213. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  214. userId := c.GetInt("id")
  215. var condition = struct {
  216. IDs []any `json:"ids"`
  217. Action string `json:"action"`
  218. }{}
  219. err := c.BindJSON(&condition)
  220. if err != nil {
  221. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  222. return
  223. }
  224. var tasks []any
  225. if len(condition.IDs) > 0 {
  226. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  227. if err != nil {
  228. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  229. return
  230. }
  231. for _, task := range taskModels {
  232. tasks = append(tasks, TaskModel2Dto(task))
  233. }
  234. } else {
  235. tasks = make([]any, 0)
  236. }
  237. respBody, err = common.Marshal(dto.TaskResponse[[]any]{
  238. Code: "success",
  239. Data: tasks,
  240. })
  241. return
  242. }
  243. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  244. taskId := c.Param("id")
  245. userId := c.GetInt("id")
  246. originTask, exist, err := model.GetByTaskId(userId, taskId)
  247. if err != nil {
  248. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  249. return
  250. }
  251. if !exist {
  252. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  253. return
  254. }
  255. respBody, err = common.Marshal(dto.TaskResponse[any]{
  256. Code: "success",
  257. Data: TaskModel2Dto(originTask),
  258. })
  259. return
  260. }
  261. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  262. taskId := c.Param("task_id")
  263. if taskId == "" {
  264. taskId = c.GetString("task_id")
  265. }
  266. userId := c.GetInt("id")
  267. originTask, exist, err := model.GetByTaskId(userId, taskId)
  268. if err != nil {
  269. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  270. return
  271. }
  272. if !exist {
  273. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  274. return
  275. }
  276. isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
  277. // Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
  278. if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
  279. respBody = realtimeResp
  280. return
  281. }
  282. // OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
  283. if isOpenAIVideoAPI {
  284. adaptor := GetTaskAdaptor(originTask.Platform)
  285. if adaptor == nil {
  286. taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
  287. return
  288. }
  289. if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
  290. openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
  291. if err != nil {
  292. taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
  293. return
  294. }
  295. respBody = openAIVideoData
  296. return
  297. }
  298. taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
  299. return
  300. }
  301. // 通用 TaskDto 格式
  302. respBody, err = common.Marshal(dto.TaskResponse[any]{
  303. Code: "success",
  304. Data: TaskModel2Dto(originTask),
  305. })
  306. if err != nil {
  307. taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
  308. }
  309. return
  310. }
  311. // tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
  312. // 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
  313. // 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
  314. func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
  315. channelModel, err := model.GetChannelById(task.ChannelId, true)
  316. if err != nil {
  317. return nil
  318. }
  319. if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
  320. return nil
  321. }
  322. baseURL := constant.ChannelBaseURLs[channelModel.Type]
  323. if channelModel.GetBaseURL() != "" {
  324. baseURL = channelModel.GetBaseURL()
  325. }
  326. proxy := channelModel.GetSetting().Proxy
  327. adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
  328. if adaptor == nil {
  329. return nil
  330. }
  331. resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
  332. "task_id": task.GetUpstreamTaskID(),
  333. "action": task.Action,
  334. }, proxy)
  335. if err != nil || resp == nil {
  336. return nil
  337. }
  338. defer resp.Body.Close()
  339. body, err := io.ReadAll(resp.Body)
  340. if err != nil {
  341. return nil
  342. }
  343. ti, err := adaptor.ParseTaskResult(body)
  344. if err != nil || ti == nil {
  345. return nil
  346. }
  347. // 将上游最新状态更新到 task
  348. if ti.Status != "" {
  349. task.Status = model.TaskStatus(ti.Status)
  350. }
  351. if ti.Progress != "" {
  352. task.Progress = ti.Progress
  353. }
  354. if strings.HasPrefix(ti.Url, "data:") {
  355. // data: URI — kept in Data, not ResultURL
  356. } else if ti.Url != "" {
  357. task.PrivateData.ResultURL = ti.Url
  358. } else if task.Status == model.TaskStatusSuccess {
  359. // No URL from adaptor — construct proxy URL using public task ID
  360. task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
  361. }
  362. _ = task.Update()
  363. // OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
  364. if isOpenAIVideoAPI {
  365. return nil
  366. }
  367. // 非 OpenAI Video API: 构建自定义格式响应
  368. format := detectVideoFormat(body)
  369. out := map[string]any{
  370. "error": nil,
  371. "format": format,
  372. "metadata": nil,
  373. "status": mapTaskStatusToSimple(task.Status),
  374. "task_id": task.TaskID,
  375. "url": task.GetResultURL(),
  376. }
  377. respBody, _ := common.Marshal(dto.TaskResponse[any]{
  378. Code: "success",
  379. Data: out,
  380. })
  381. return respBody
  382. }
  383. // detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
  384. func detectVideoFormat(rawBody []byte) string {
  385. var raw map[string]any
  386. if err := common.Unmarshal(rawBody, &raw); err != nil {
  387. return "mp4"
  388. }
  389. respObj, ok := raw["response"].(map[string]any)
  390. if !ok {
  391. return "mp4"
  392. }
  393. vids, ok := respObj["videos"].([]any)
  394. if !ok || len(vids) == 0 {
  395. return "mp4"
  396. }
  397. v0, ok := vids[0].(map[string]any)
  398. if !ok {
  399. return "mp4"
  400. }
  401. mt, ok := v0["mimeType"].(string)
  402. if !ok || mt == "" || strings.Contains(mt, "mp4") {
  403. return "mp4"
  404. }
  405. return mt
  406. }
  407. // mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
  408. func mapTaskStatusToSimple(status model.TaskStatus) string {
  409. switch status {
  410. case model.TaskStatusSuccess:
  411. return "succeeded"
  412. case model.TaskStatusFailure:
  413. return "failed"
  414. case model.TaskStatusQueued, model.TaskStatusSubmitted:
  415. return "queued"
  416. default:
  417. return "processing"
  418. }
  419. }
  420. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  421. return &dto.TaskDto{
  422. ID: task.ID,
  423. CreatedAt: task.CreatedAt,
  424. UpdatedAt: task.UpdatedAt,
  425. TaskID: task.TaskID,
  426. Platform: string(task.Platform),
  427. UserId: task.UserId,
  428. Group: task.Group,
  429. ChannelId: task.ChannelId,
  430. Quota: task.Quota,
  431. Action: task.Action,
  432. Status: string(task.Status),
  433. FailReason: task.FailReason,
  434. ResultURL: task.GetResultURL(),
  435. SubmitTime: task.SubmitTime,
  436. StartTime: task.StartTime,
  437. FinishTime: task.FinishTime,
  438. Progress: task.Progress,
  439. Properties: task.Properties,
  440. Username: task.Username,
  441. Data: task.Data,
  442. }
  443. }