relay_task.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package relay
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/gin-gonic/gin"
  9. "io"
  10. "net/http"
  11. "one-api/common"
  12. "one-api/constant"
  13. "one-api/dto"
  14. "one-api/model"
  15. relaycommon "one-api/relay/common"
  16. relayconstant "one-api/relay/constant"
  17. "one-api/service"
  18. "one-api/setting"
  19. )
  20. /*
  21. Task 任务通过平台、Action 区分任务
  22. */
  23. func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
  24. platform := constant.TaskPlatform(c.GetString("platform"))
  25. relayInfo := relaycommon.GenTaskRelayInfo(c)
  26. adaptor := GetTaskAdaptor(platform)
  27. if adaptor == nil {
  28. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  29. }
  30. adaptor.Init(relayInfo)
  31. // get & validate taskRequest 获取并验证文本请求
  32. taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
  33. if taskErr != nil {
  34. return
  35. }
  36. modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
  37. modelPrice, success := common.GetModelPrice(modelName, true)
  38. if !success {
  39. defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
  40. if !ok {
  41. modelPrice = 0.1
  42. } else {
  43. modelPrice = defaultPrice
  44. }
  45. }
  46. // 预扣
  47. groupRatio := setting.GetGroupRatio(relayInfo.Group)
  48. ratio := modelPrice * groupRatio
  49. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  50. if err != nil {
  51. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  52. return
  53. }
  54. quota := int(ratio * common.QuotaPerUnit)
  55. if userQuota-quota < 0 {
  56. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  57. return
  58. }
  59. if relayInfo.OriginTaskID != "" {
  60. originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
  61. if err != nil {
  62. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  63. return
  64. }
  65. if !exist {
  66. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  67. return
  68. }
  69. if originTask.ChannelId != relayInfo.ChannelId {
  70. channel, err := model.GetChannelById(originTask.ChannelId, true)
  71. if err != nil {
  72. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  73. return
  74. }
  75. if channel.Status != common.ChannelStatusEnabled {
  76. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  77. }
  78. c.Set("base_url", channel.GetBaseURL())
  79. c.Set("channel_id", originTask.ChannelId)
  80. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  81. relayInfo.BaseUrl = channel.GetBaseURL()
  82. relayInfo.ChannelId = originTask.ChannelId
  83. }
  84. }
  85. // build body
  86. requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
  87. if err != nil {
  88. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  89. return
  90. }
  91. // do request
  92. resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
  93. if err != nil {
  94. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  95. return
  96. }
  97. // handle response
  98. if resp != nil && resp.StatusCode != http.StatusOK {
  99. responseBody, _ := io.ReadAll(resp.Body)
  100. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  101. return
  102. }
  103. defer func(ctx context.Context) {
  104. // release quota
  105. if relayInfo.ConsumeQuota && taskErr == nil {
  106. err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
  107. if err != nil {
  108. common.SysError("error consuming token remain quota: " + err.Error())
  109. }
  110. if quota != 0 {
  111. tokenName := c.GetString("token_name")
  112. logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
  113. other := make(map[string]interface{})
  114. other["model_price"] = modelPrice
  115. other["group_ratio"] = groupRatio
  116. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
  117. modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
  118. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  119. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  120. }
  121. }
  122. }(c.Request.Context())
  123. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
  124. if taskErr != nil {
  125. return
  126. }
  127. relayInfo.ConsumeQuota = true
  128. // insert task
  129. task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
  130. task.TaskID = taskID
  131. task.Quota = quota
  132. task.Data = taskData
  133. err = task.Insert()
  134. if err != nil {
  135. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  136. return
  137. }
  138. return nil
  139. }
  140. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  141. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  142. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  143. }
  144. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  145. respBuilder, ok := fetchRespBuilders[relayMode]
  146. if !ok {
  147. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  148. }
  149. respBody, taskErr := respBuilder(c)
  150. if taskErr != nil {
  151. return taskErr
  152. }
  153. c.Writer.Header().Set("Content-Type", "application/json")
  154. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  155. if err != nil {
  156. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  157. return
  158. }
  159. return
  160. }
  161. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  162. userId := c.GetInt("id")
  163. var condition = struct {
  164. IDs []any `json:"ids"`
  165. Action string `json:"action"`
  166. }{}
  167. err := c.BindJSON(&condition)
  168. if err != nil {
  169. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  170. return
  171. }
  172. var tasks []any
  173. if len(condition.IDs) > 0 {
  174. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  175. if err != nil {
  176. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  177. return
  178. }
  179. for _, task := range taskModels {
  180. tasks = append(tasks, TaskModel2Dto(task))
  181. }
  182. } else {
  183. tasks = make([]any, 0)
  184. }
  185. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  186. Code: "success",
  187. Data: tasks,
  188. })
  189. return
  190. }
  191. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  192. taskId := c.Param("id")
  193. userId := c.GetInt("id")
  194. originTask, exist, err := model.GetByTaskId(userId, taskId)
  195. if err != nil {
  196. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  197. return
  198. }
  199. if !exist {
  200. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  201. return
  202. }
  203. respBody, err = json.Marshal(dto.TaskResponse[any]{
  204. Code: "success",
  205. Data: TaskModel2Dto(originTask),
  206. })
  207. return
  208. }
  209. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  210. return &dto.TaskDto{
  211. TaskID: task.TaskID,
  212. Action: task.Action,
  213. Status: string(task.Status),
  214. FailReason: task.FailReason,
  215. SubmitTime: task.SubmitTime,
  216. StartTime: task.StartTime,
  217. FinishTime: task.FinishTime,
  218. Progress: task.Progress,
  219. Data: task.Data,
  220. }
  221. }