relay_task.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "io"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/constant"
  12. "one-api/dto"
  13. "one-api/model"
  14. relaycommon "one-api/relay/common"
  15. relayconstant "one-api/relay/constant"
  16. "one-api/service"
  17. "one-api/setting/ratio_setting"
  18. )
  19. /*
  20. Task 任务通过平台、Action 区分任务
  21. */
  22. func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
  23. platform := constant.TaskPlatform(c.GetString("platform"))
  24. relayInfo := relaycommon.GenTaskRelayInfo(c)
  25. adaptor := GetTaskAdaptor(platform)
  26. if adaptor == nil {
  27. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  28. }
  29. adaptor.Init(relayInfo)
  30. // get & validate taskRequest 获取并验证文本请求
  31. taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
  32. if taskErr != nil {
  33. return
  34. }
  35. modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
  36. if platform == constant.TaskPlatformKling {
  37. modelName = relayInfo.OriginModelName
  38. }
  39. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  40. if !success {
  41. defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
  42. if !ok {
  43. modelPrice = 0.1
  44. } else {
  45. modelPrice = defaultPrice
  46. }
  47. }
  48. // 预扣
  49. groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
  50. ratio := modelPrice * groupRatio
  51. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  52. if err != nil {
  53. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  54. return
  55. }
  56. quota := int(ratio * common.QuotaPerUnit)
  57. if userQuota-quota < 0 {
  58. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  59. return
  60. }
  61. if relayInfo.OriginTaskID != "" {
  62. originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
  63. if err != nil {
  64. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  65. return
  66. }
  67. if !exist {
  68. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  69. return
  70. }
  71. if originTask.ChannelId != relayInfo.ChannelId {
  72. channel, err := model.GetChannelById(originTask.ChannelId, true)
  73. if err != nil {
  74. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  75. return
  76. }
  77. if channel.Status != common.ChannelStatusEnabled {
  78. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  79. }
  80. c.Set("base_url", channel.GetBaseURL())
  81. c.Set("channel_id", originTask.ChannelId)
  82. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  83. relayInfo.BaseUrl = channel.GetBaseURL()
  84. relayInfo.ChannelId = originTask.ChannelId
  85. }
  86. }
  87. // build body
  88. requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
  89. if err != nil {
  90. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  91. return
  92. }
  93. // do request
  94. resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
  95. if err != nil {
  96. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  97. return
  98. }
  99. // handle response
  100. if resp != nil && resp.StatusCode != http.StatusOK {
  101. responseBody, _ := io.ReadAll(resp.Body)
  102. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  103. return
  104. }
  105. defer func() {
  106. // release quota
  107. if relayInfo.ConsumeQuota && taskErr == nil {
  108. err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
  109. if err != nil {
  110. common.SysError("error consuming token remain quota: " + err.Error())
  111. }
  112. if quota != 0 {
  113. tokenName := c.GetString("token_name")
  114. logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
  115. other := make(map[string]interface{})
  116. other["model_price"] = modelPrice
  117. other["group_ratio"] = groupRatio
  118. model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
  119. modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
  120. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  121. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  122. }
  123. }
  124. }()
  125. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
  126. if taskErr != nil {
  127. return
  128. }
  129. relayInfo.ConsumeQuota = true
  130. // insert task
  131. task := model.InitTask(platform, relayInfo)
  132. task.TaskID = taskID
  133. task.Quota = quota
  134. task.Data = taskData
  135. task.Action = relayInfo.Action
  136. err = task.Insert()
  137. if err != nil {
  138. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  139. return
  140. }
  141. return nil
  142. }
  143. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  144. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  145. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  146. relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
  147. }
  148. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  149. respBuilder, ok := fetchRespBuilders[relayMode]
  150. if !ok {
  151. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  152. }
  153. respBody, taskErr := respBuilder(c)
  154. if taskErr != nil {
  155. return taskErr
  156. }
  157. c.Writer.Header().Set("Content-Type", "application/json")
  158. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  159. if err != nil {
  160. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  161. return
  162. }
  163. return
  164. }
  165. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  166. userId := c.GetInt("id")
  167. var condition = struct {
  168. IDs []any `json:"ids"`
  169. Action string `json:"action"`
  170. }{}
  171. err := c.BindJSON(&condition)
  172. if err != nil {
  173. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  174. return
  175. }
  176. var tasks []any
  177. if len(condition.IDs) > 0 {
  178. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  179. if err != nil {
  180. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  181. return
  182. }
  183. for _, task := range taskModels {
  184. tasks = append(tasks, TaskModel2Dto(task))
  185. }
  186. } else {
  187. tasks = make([]any, 0)
  188. }
  189. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  190. Code: "success",
  191. Data: tasks,
  192. })
  193. return
  194. }
  195. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  196. taskId := c.Param("id")
  197. userId := c.GetInt("id")
  198. originTask, exist, err := model.GetByTaskId(userId, taskId)
  199. if err != nil {
  200. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  201. return
  202. }
  203. if !exist {
  204. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  205. return
  206. }
  207. respBody, err = json.Marshal(dto.TaskResponse[any]{
  208. Code: "success",
  209. Data: TaskModel2Dto(originTask),
  210. })
  211. return
  212. }
  213. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  214. taskId := c.Param("id")
  215. userId := c.GetInt("id")
  216. originTask, exist, err := model.GetByTaskId(userId, taskId)
  217. if err != nil {
  218. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  219. return
  220. }
  221. if !exist {
  222. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  223. return
  224. }
  225. respBody, err = json.Marshal(dto.TaskResponse[any]{
  226. Code: "success",
  227. Data: TaskModel2Dto(originTask),
  228. })
  229. return
  230. }
  231. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  232. return &dto.TaskDto{
  233. TaskID: task.TaskID,
  234. Action: task.Action,
  235. Status: string(task.Status),
  236. FailReason: task.FailReason,
  237. SubmitTime: task.SubmitTime,
  238. StartTime: task.StartTime,
  239. FinishTime: task.FinishTime,
  240. Progress: task.Progress,
  241. Data: task.Data,
  242. }
  243. }