relay_task.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. "one-api/model"
  13. relaycommon "one-api/relay/common"
  14. relayconstant "one-api/relay/constant"
  15. "one-api/service"
  16. "one-api/setting/ratio_setting"
  17. "strconv"
  18. "strings"
  19. "github.com/gin-gonic/gin"
  20. )
  21. /*
  22. Task 任务通过平台、Action 区分任务
  23. */
  24. func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
  25. platform := constant.TaskPlatform(c.GetString("platform"))
  26. if platform == "" {
  27. platform = GetTaskPlatform(c)
  28. }
  29. relayInfo, err := relaycommon.GenTaskRelayInfo(c)
  30. if err != nil {
  31. return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError)
  32. }
  33. relayInfo.InitChannelMeta(c)
  34. adaptor := GetTaskAdaptor(platform)
  35. if adaptor == nil {
  36. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  37. }
  38. adaptor.Init(relayInfo)
  39. // get & validate taskRequest 获取并验证文本请求
  40. taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
  41. if taskErr != nil {
  42. return
  43. }
  44. modelName := relayInfo.OriginModelName
  45. if modelName == "" {
  46. modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
  47. }
  48. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  49. if !success {
  50. defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
  51. if !ok {
  52. modelPrice = 0.1
  53. } else {
  54. modelPrice = defaultPrice
  55. }
  56. }
  57. // 预扣
  58. groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
  59. var ratio float64
  60. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
  61. if hasUserGroupRatio {
  62. ratio = modelPrice * userGroupRatio
  63. } else {
  64. ratio = modelPrice * groupRatio
  65. }
  66. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  67. if err != nil {
  68. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  69. return
  70. }
  71. quota := int(ratio * common.QuotaPerUnit)
  72. if userQuota-quota < 0 {
  73. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  74. return
  75. }
  76. if relayInfo.OriginTaskID != "" {
  77. originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
  78. if err != nil {
  79. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  80. return
  81. }
  82. if !exist {
  83. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  84. return
  85. }
  86. if originTask.ChannelId != relayInfo.ChannelId {
  87. channel, err := model.GetChannelById(originTask.ChannelId, true)
  88. if err != nil {
  89. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  90. return
  91. }
  92. if channel.Status != common.ChannelStatusEnabled {
  93. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  94. }
  95. c.Set("base_url", channel.GetBaseURL())
  96. c.Set("channel_id", originTask.ChannelId)
  97. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  98. relayInfo.ChannelBaseUrl = channel.GetBaseURL()
  99. relayInfo.ChannelId = originTask.ChannelId
  100. }
  101. }
  102. // build body
  103. requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
  104. if err != nil {
  105. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  106. return
  107. }
  108. // do request
  109. resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
  110. if err != nil {
  111. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  112. return
  113. }
  114. // handle response
  115. if resp != nil && resp.StatusCode != http.StatusOK {
  116. responseBody, _ := io.ReadAll(resp.Body)
  117. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  118. return
  119. }
  120. defer func() {
  121. // release quota
  122. if relayInfo.ConsumeQuota && taskErr == nil {
  123. err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
  124. if err != nil {
  125. common.SysLog("error consuming token remain quota: " + err.Error())
  126. }
  127. if quota != 0 {
  128. tokenName := c.GetString("token_name")
  129. gRatio := groupRatio
  130. if hasUserGroupRatio {
  131. gRatio = userGroupRatio
  132. }
  133. logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
  134. other := make(map[string]interface{})
  135. other["model_price"] = modelPrice
  136. other["group_ratio"] = groupRatio
  137. if hasUserGroupRatio {
  138. other["user_group_ratio"] = userGroupRatio
  139. }
  140. model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
  141. ChannelId: relayInfo.ChannelId,
  142. ModelName: modelName,
  143. TokenName: tokenName,
  144. Quota: quota,
  145. Content: logContent,
  146. TokenId: relayInfo.TokenId,
  147. Group: relayInfo.UsingGroup,
  148. Other: other,
  149. })
  150. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  151. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  152. }
  153. }
  154. }()
  155. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
  156. if taskErr != nil {
  157. return
  158. }
  159. relayInfo.ConsumeQuota = true
  160. // insert task
  161. task := model.InitTask(platform, relayInfo)
  162. task.TaskID = taskID
  163. task.Quota = quota
  164. task.Data = taskData
  165. task.Action = relayInfo.Action
  166. err = task.Insert()
  167. if err != nil {
  168. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  169. return
  170. }
  171. return nil
  172. }
  173. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  174. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  175. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  176. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  177. }
  178. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  179. respBuilder, ok := fetchRespBuilders[relayMode]
  180. if !ok {
  181. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  182. }
  183. respBody, taskErr := respBuilder(c)
  184. if taskErr != nil {
  185. return taskErr
  186. }
  187. if len(respBody) == 0 {
  188. respBody = []byte("{\"code\":\"success\",\"data\":null}")
  189. }
  190. c.Writer.Header().Set("Content-Type", "application/json")
  191. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  192. if err != nil {
  193. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  194. return
  195. }
  196. return
  197. }
  198. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  199. userId := c.GetInt("id")
  200. var condition = struct {
  201. IDs []any `json:"ids"`
  202. Action string `json:"action"`
  203. }{}
  204. err := c.BindJSON(&condition)
  205. if err != nil {
  206. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  207. return
  208. }
  209. var tasks []any
  210. if len(condition.IDs) > 0 {
  211. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  212. if err != nil {
  213. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  214. return
  215. }
  216. for _, task := range taskModels {
  217. tasks = append(tasks, TaskModel2Dto(task))
  218. }
  219. } else {
  220. tasks = make([]any, 0)
  221. }
  222. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  223. Code: "success",
  224. Data: tasks,
  225. })
  226. return
  227. }
  228. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  229. taskId := c.Param("id")
  230. userId := c.GetInt("id")
  231. originTask, exist, err := model.GetByTaskId(userId, taskId)
  232. if err != nil {
  233. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  234. return
  235. }
  236. if !exist {
  237. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  238. return
  239. }
  240. respBody, err = json.Marshal(dto.TaskResponse[any]{
  241. Code: "success",
  242. Data: TaskModel2Dto(originTask),
  243. })
  244. return
  245. }
  246. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  247. taskId := c.Param("task_id")
  248. if taskId == "" {
  249. taskId = c.GetString("task_id")
  250. }
  251. userId := c.GetInt("id")
  252. originTask, exist, err := model.GetByTaskId(userId, taskId)
  253. if err != nil {
  254. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  255. return
  256. }
  257. if !exist {
  258. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  259. return
  260. }
  261. func() {
  262. channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
  263. if err2 != nil {
  264. return
  265. }
  266. if channelModel.Type != constant.ChannelTypeVertexAi {
  267. return
  268. }
  269. baseURL := constant.ChannelBaseURLs[channelModel.Type]
  270. if channelModel.GetBaseURL() != "" {
  271. baseURL = channelModel.GetBaseURL()
  272. }
  273. adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
  274. if adaptor == nil {
  275. return
  276. }
  277. resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
  278. "task_id": originTask.TaskID,
  279. "action": originTask.Action,
  280. })
  281. if err2 != nil || resp == nil {
  282. return
  283. }
  284. defer resp.Body.Close()
  285. body, err2 := io.ReadAll(resp.Body)
  286. if err2 != nil {
  287. return
  288. }
  289. ti, err2 := adaptor.ParseTaskResult(body)
  290. if err2 == nil && ti != nil {
  291. if ti.Status != "" {
  292. originTask.Status = model.TaskStatus(ti.Status)
  293. }
  294. if ti.Progress != "" {
  295. originTask.Progress = ti.Progress
  296. }
  297. if ti.Url != "" {
  298. originTask.FailReason = ti.Url
  299. }
  300. _ = originTask.Update()
  301. var raw map[string]any
  302. _ = json.Unmarshal(body, &raw)
  303. format := "mp4"
  304. if respObj, ok := raw["response"].(map[string]any); ok {
  305. if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
  306. if v0, ok := vids[0].(map[string]any); ok {
  307. if mt, ok := v0["mimeType"].(string); ok && mt != "" {
  308. if strings.Contains(mt, "mp4") {
  309. format = "mp4"
  310. } else {
  311. format = mt
  312. }
  313. }
  314. }
  315. }
  316. }
  317. status := "processing"
  318. switch originTask.Status {
  319. case model.TaskStatusSuccess:
  320. status = "succeeded"
  321. case model.TaskStatusFailure:
  322. status = "failed"
  323. case model.TaskStatusQueued, model.TaskStatusSubmitted:
  324. status = "queued"
  325. }
  326. out := map[string]any{
  327. "error": nil,
  328. "format": format,
  329. "metadata": nil,
  330. "status": status,
  331. "task_id": originTask.TaskID,
  332. "url": originTask.FailReason,
  333. }
  334. respBody, _ = json.Marshal(dto.TaskResponse[any]{
  335. Code: "success",
  336. Data: out,
  337. })
  338. }
  339. }()
  340. if len(respBody) == 0 {
  341. respBody, err = json.Marshal(dto.TaskResponse[any]{
  342. Code: "success",
  343. Data: TaskModel2Dto(originTask),
  344. })
  345. }
  346. return
  347. }
  348. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  349. return &dto.TaskDto{
  350. TaskID: task.TaskID,
  351. Action: task.Action,
  352. Status: string(task.Status),
  353. FailReason: task.FailReason,
  354. SubmitTime: task.SubmitTime,
  355. StartTime: task.StartTime,
  356. FinishTime: task.FinishTime,
  357. Progress: task.Progress,
  358. Data: task.Data,
  359. }
  360. }