relay_task.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. "one-api/model"
  13. relaycommon "one-api/relay/common"
  14. relayconstant "one-api/relay/constant"
  15. "one-api/service"
  16. "one-api/setting/ratio_setting"
  17. "strconv"
  18. "strings"
  19. "github.com/gin-gonic/gin"
  20. )
  21. /*
  22. Task 任务通过平台、Action 区分任务
  23. */
  24. func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  25. info.InitChannelMeta(c)
  26. // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
  27. if info.TaskRelayInfo == nil {
  28. info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
  29. }
  30. platform := constant.TaskPlatform(c.GetString("platform"))
  31. if platform == "" {
  32. platform = GetTaskPlatform(c)
  33. }
  34. info.InitChannelMeta(c)
  35. adaptor := GetTaskAdaptor(platform)
  36. if adaptor == nil {
  37. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  38. }
  39. adaptor.Init(info)
  40. // get & validate taskRequest 获取并验证文本请求
  41. taskErr = adaptor.ValidateRequestAndSetAction(c, info)
  42. if taskErr != nil {
  43. return
  44. }
  45. modelName := info.OriginModelName
  46. if modelName == "" {
  47. modelName = service.CoverTaskActionToModelName(platform, info.Action)
  48. }
  49. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  50. if !success {
  51. defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
  52. if !ok {
  53. modelPrice = 0.1
  54. } else {
  55. modelPrice = defaultPrice
  56. }
  57. }
  58. // 预扣
  59. groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
  60. var ratio float64
  61. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
  62. if hasUserGroupRatio {
  63. ratio = modelPrice * userGroupRatio
  64. } else {
  65. ratio = modelPrice * groupRatio
  66. }
  67. if len(info.PriceData.OtherRatios) > 0 {
  68. for _, ra := range info.PriceData.OtherRatios {
  69. if 1.0 != ra {
  70. ratio *= ra
  71. }
  72. }
  73. }
  74. userQuota, err := model.GetUserQuota(info.UserId, false)
  75. if err != nil {
  76. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  77. return
  78. }
  79. quota := int(ratio * common.QuotaPerUnit)
  80. if userQuota-quota < 0 {
  81. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  82. return
  83. }
  84. if info.OriginTaskID != "" {
  85. originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
  86. if err != nil {
  87. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  88. return
  89. }
  90. if !exist {
  91. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  92. return
  93. }
  94. if originTask.ChannelId != info.ChannelId {
  95. channel, err := model.GetChannelById(originTask.ChannelId, true)
  96. if err != nil {
  97. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  98. return
  99. }
  100. if channel.Status != common.ChannelStatusEnabled {
  101. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  102. }
  103. c.Set("base_url", channel.GetBaseURL())
  104. c.Set("channel_id", originTask.ChannelId)
  105. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  106. info.ChannelBaseUrl = channel.GetBaseURL()
  107. info.ChannelId = originTask.ChannelId
  108. }
  109. }
  110. // build body
  111. requestBody, err := adaptor.BuildRequestBody(c, info)
  112. if err != nil {
  113. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  114. return
  115. }
  116. // do request
  117. resp, err := adaptor.DoRequest(c, info, requestBody)
  118. if err != nil {
  119. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  120. return
  121. }
  122. // handle response
  123. if resp != nil && resp.StatusCode != http.StatusOK {
  124. responseBody, _ := io.ReadAll(resp.Body)
  125. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  126. return
  127. }
  128. defer func() {
  129. // release quota
  130. if info.ConsumeQuota && taskErr == nil {
  131. err := service.PostConsumeQuota(info, quota, 0, true)
  132. if err != nil {
  133. common.SysLog("error consuming token remain quota: " + err.Error())
  134. }
  135. if quota != 0 {
  136. tokenName := c.GetString("token_name")
  137. gRatio := groupRatio
  138. if hasUserGroupRatio {
  139. gRatio = userGroupRatio
  140. }
  141. logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action)
  142. if len(info.PriceData.OtherRatios) > 0 {
  143. var contents []string
  144. for key, ra := range info.PriceData.OtherRatios {
  145. if 1.0 != ra {
  146. contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
  147. }
  148. }
  149. if len(contents) > 0 {
  150. logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
  151. }
  152. }
  153. other := make(map[string]interface{})
  154. other["model_price"] = modelPrice
  155. other["group_ratio"] = groupRatio
  156. if hasUserGroupRatio {
  157. other["user_group_ratio"] = userGroupRatio
  158. }
  159. model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
  160. ChannelId: info.ChannelId,
  161. ModelName: modelName,
  162. TokenName: tokenName,
  163. Quota: quota,
  164. Content: logContent,
  165. TokenId: info.TokenId,
  166. Group: info.UsingGroup,
  167. Other: other,
  168. })
  169. model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
  170. model.UpdateChannelUsedQuota(info.ChannelId, quota)
  171. }
  172. }
  173. }()
  174. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
  175. if taskErr != nil {
  176. return
  177. }
  178. info.ConsumeQuota = true
  179. // insert task
  180. task := model.InitTask(platform, info)
  181. task.TaskID = taskID
  182. task.Quota = quota
  183. task.Data = taskData
  184. task.Action = info.Action
  185. err = task.Insert()
  186. if err != nil {
  187. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  188. return
  189. }
  190. return nil
  191. }
  192. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  193. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  194. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  195. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  196. }
  197. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  198. respBuilder, ok := fetchRespBuilders[relayMode]
  199. if !ok {
  200. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  201. }
  202. respBody, taskErr := respBuilder(c)
  203. if taskErr != nil {
  204. return taskErr
  205. }
  206. if len(respBody) == 0 {
  207. respBody = []byte("{\"code\":\"success\",\"data\":null}")
  208. }
  209. c.Writer.Header().Set("Content-Type", "application/json")
  210. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  211. if err != nil {
  212. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  213. return
  214. }
  215. return
  216. }
  217. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  218. userId := c.GetInt("id")
  219. var condition = struct {
  220. IDs []any `json:"ids"`
  221. Action string `json:"action"`
  222. }{}
  223. err := c.BindJSON(&condition)
  224. if err != nil {
  225. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  226. return
  227. }
  228. var tasks []any
  229. if len(condition.IDs) > 0 {
  230. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  231. if err != nil {
  232. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  233. return
  234. }
  235. for _, task := range taskModels {
  236. tasks = append(tasks, TaskModel2Dto(task))
  237. }
  238. } else {
  239. tasks = make([]any, 0)
  240. }
  241. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  242. Code: "success",
  243. Data: tasks,
  244. })
  245. return
  246. }
  247. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  248. taskId := c.Param("id")
  249. userId := c.GetInt("id")
  250. originTask, exist, err := model.GetByTaskId(userId, taskId)
  251. if err != nil {
  252. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  253. return
  254. }
  255. if !exist {
  256. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  257. return
  258. }
  259. respBody, err = json.Marshal(dto.TaskResponse[any]{
  260. Code: "success",
  261. Data: TaskModel2Dto(originTask),
  262. })
  263. return
  264. }
  265. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  266. taskId := c.Param("task_id")
  267. if taskId == "" {
  268. taskId = c.GetString("task_id")
  269. }
  270. userId := c.GetInt("id")
  271. originTask, exist, err := model.GetByTaskId(userId, taskId)
  272. if err != nil {
  273. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  274. return
  275. }
  276. if !exist {
  277. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  278. return
  279. }
  280. func() {
  281. channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
  282. if err2 != nil {
  283. return
  284. }
  285. if channelModel.Type != constant.ChannelTypeVertexAi {
  286. return
  287. }
  288. baseURL := constant.ChannelBaseURLs[channelModel.Type]
  289. if channelModel.GetBaseURL() != "" {
  290. baseURL = channelModel.GetBaseURL()
  291. }
  292. adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
  293. if adaptor == nil {
  294. return
  295. }
  296. resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
  297. "task_id": originTask.TaskID,
  298. "action": originTask.Action,
  299. })
  300. if err2 != nil || resp == nil {
  301. return
  302. }
  303. defer resp.Body.Close()
  304. body, err2 := io.ReadAll(resp.Body)
  305. if err2 != nil {
  306. return
  307. }
  308. ti, err2 := adaptor.ParseTaskResult(body)
  309. if err2 == nil && ti != nil {
  310. if ti.Status != "" {
  311. originTask.Status = model.TaskStatus(ti.Status)
  312. }
  313. if ti.Progress != "" {
  314. originTask.Progress = ti.Progress
  315. }
  316. if ti.Url != "" {
  317. originTask.FailReason = ti.Url
  318. }
  319. _ = originTask.Update()
  320. var raw map[string]any
  321. _ = json.Unmarshal(body, &raw)
  322. format := "mp4"
  323. if respObj, ok := raw["response"].(map[string]any); ok {
  324. if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
  325. if v0, ok := vids[0].(map[string]any); ok {
  326. if mt, ok := v0["mimeType"].(string); ok && mt != "" {
  327. if strings.Contains(mt, "mp4") {
  328. format = "mp4"
  329. } else {
  330. format = mt
  331. }
  332. }
  333. }
  334. }
  335. }
  336. status := "processing"
  337. switch originTask.Status {
  338. case model.TaskStatusSuccess:
  339. status = "succeeded"
  340. case model.TaskStatusFailure:
  341. status = "failed"
  342. case model.TaskStatusQueued, model.TaskStatusSubmitted:
  343. status = "queued"
  344. }
  345. out := map[string]any{
  346. "error": nil,
  347. "format": format,
  348. "metadata": nil,
  349. "status": status,
  350. "task_id": originTask.TaskID,
  351. "url": originTask.FailReason,
  352. }
  353. respBody, _ = json.Marshal(dto.TaskResponse[any]{
  354. Code: "success",
  355. Data: out,
  356. })
  357. }
  358. }()
  359. if len(respBody) != 0 {
  360. return
  361. }
  362. if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
  363. respBody = originTask.Data
  364. return
  365. }
  366. respBody, err = json.Marshal(dto.TaskResponse[any]{
  367. Code: "success",
  368. Data: TaskModel2Dto(originTask),
  369. })
  370. if err != nil {
  371. taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
  372. }
  373. return
  374. }
  375. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  376. return &dto.TaskDto{
  377. TaskID: task.TaskID,
  378. Action: task.Action,
  379. Status: string(task.Status),
  380. FailReason: task.FailReason,
  381. SubmitTime: task.SubmitTime,
  382. StartTime: task.StartTime,
  383. FinishTime: task.FinishTime,
  384. Progress: task.Progress,
  385. Data: task.Data,
  386. }
  387. }