relay_task.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. "one-api/logger"
  13. "one-api/model"
  14. relaycommon "one-api/relay/common"
  15. relayconstant "one-api/relay/constant"
  16. "one-api/service"
  17. "one-api/setting/ratio_setting"
  18. "github.com/gin-gonic/gin"
  19. )
  20. /*
  21. Task 任务通过平台、Action 区分任务
  22. */
  23. func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
  24. platform := constant.TaskPlatform(c.GetString("platform"))
  25. if platform == "" {
  26. platform = GetTaskPlatform(c)
  27. }
  28. relayInfo := relaycommon.GenTaskRelayInfo(c)
  29. adaptor := GetTaskAdaptor(platform)
  30. if adaptor == nil {
  31. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  32. }
  33. adaptor.Init(relayInfo)
  34. // get & validate taskRequest 获取并验证文本请求
  35. taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
  36. if taskErr != nil {
  37. return
  38. }
  39. modelName := relayInfo.OriginModelName
  40. if modelName == "" {
  41. modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
  42. }
  43. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  44. if !success {
  45. defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
  46. if !ok {
  47. modelPrice = 0.1
  48. } else {
  49. modelPrice = defaultPrice
  50. }
  51. }
  52. // 预扣
  53. groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
  54. var ratio float64
  55. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
  56. if hasUserGroupRatio {
  57. ratio = modelPrice * userGroupRatio
  58. } else {
  59. ratio = modelPrice * groupRatio
  60. }
  61. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  62. if err != nil {
  63. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  64. return
  65. }
  66. quota := int(ratio * common.QuotaPerUnit)
  67. if userQuota-quota < 0 {
  68. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  69. return
  70. }
  71. if relayInfo.OriginTaskID != "" {
  72. originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
  73. if err != nil {
  74. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  75. return
  76. }
  77. if !exist {
  78. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  79. return
  80. }
  81. if originTask.ChannelId != relayInfo.ChannelId {
  82. channel, err := model.GetChannelById(originTask.ChannelId, true)
  83. if err != nil {
  84. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  85. return
  86. }
  87. if channel.Status != common.ChannelStatusEnabled {
  88. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  89. }
  90. c.Set("base_url", channel.GetBaseURL())
  91. c.Set("channel_id", originTask.ChannelId)
  92. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  93. relayInfo.BaseUrl = channel.GetBaseURL()
  94. relayInfo.ChannelId = originTask.ChannelId
  95. }
  96. }
  97. // build body
  98. requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
  99. if err != nil {
  100. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  101. return
  102. }
  103. // do request
  104. resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
  105. if err != nil {
  106. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  107. return
  108. }
  109. // handle response
  110. if resp != nil && resp.StatusCode != http.StatusOK {
  111. responseBody, _ := io.ReadAll(resp.Body)
  112. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  113. return
  114. }
  115. defer func() {
  116. // release quota
  117. if relayInfo.ConsumeQuota && taskErr == nil {
  118. err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
  119. if err != nil {
  120. logger.SysError("error consuming token remain quota: " + err.Error())
  121. }
  122. if quota != 0 {
  123. tokenName := c.GetString("token_name")
  124. gRatio := groupRatio
  125. if hasUserGroupRatio {
  126. gRatio = userGroupRatio
  127. }
  128. logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
  129. other := make(map[string]interface{})
  130. other["model_price"] = modelPrice
  131. other["group_ratio"] = groupRatio
  132. if hasUserGroupRatio {
  133. other["user_group_ratio"] = userGroupRatio
  134. }
  135. model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
  136. ChannelId: relayInfo.ChannelId,
  137. ModelName: modelName,
  138. TokenName: tokenName,
  139. Quota: quota,
  140. Content: logContent,
  141. TokenId: relayInfo.TokenId,
  142. UserQuota: userQuota,
  143. Group: relayInfo.UsingGroup,
  144. Other: other,
  145. })
  146. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  147. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  148. }
  149. }
  150. }()
  151. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
  152. if taskErr != nil {
  153. return
  154. }
  155. relayInfo.ConsumeQuota = true
  156. // insert task
  157. task := model.InitTask(platform, relayInfo)
  158. task.TaskID = taskID
  159. task.Quota = quota
  160. task.Data = taskData
  161. task.Action = relayInfo.Action
  162. err = task.Insert()
  163. if err != nil {
  164. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  165. return
  166. }
  167. return nil
  168. }
  169. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  170. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  171. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  172. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  173. }
  174. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  175. respBuilder, ok := fetchRespBuilders[relayMode]
  176. if !ok {
  177. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  178. }
  179. respBody, taskErr := respBuilder(c)
  180. if taskErr != nil {
  181. return taskErr
  182. }
  183. c.Writer.Header().Set("Content-Type", "application/json")
  184. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  185. if err != nil {
  186. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  187. return
  188. }
  189. return
  190. }
  191. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  192. userId := c.GetInt("id")
  193. var condition = struct {
  194. IDs []any `json:"ids"`
  195. Action string `json:"action"`
  196. }{}
  197. err := c.BindJSON(&condition)
  198. if err != nil {
  199. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  200. return
  201. }
  202. var tasks []any
  203. if len(condition.IDs) > 0 {
  204. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  205. if err != nil {
  206. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  207. return
  208. }
  209. for _, task := range taskModels {
  210. tasks = append(tasks, TaskModel2Dto(task))
  211. }
  212. } else {
  213. tasks = make([]any, 0)
  214. }
  215. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  216. Code: "success",
  217. Data: tasks,
  218. })
  219. return
  220. }
  221. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  222. taskId := c.Param("id")
  223. userId := c.GetInt("id")
  224. originTask, exist, err := model.GetByTaskId(userId, taskId)
  225. if err != nil {
  226. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  227. return
  228. }
  229. if !exist {
  230. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  231. return
  232. }
  233. respBody, err = json.Marshal(dto.TaskResponse[any]{
  234. Code: "success",
  235. Data: TaskModel2Dto(originTask),
  236. })
  237. return
  238. }
  239. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  240. taskId := c.Param("task_id")
  241. if taskId == "" {
  242. taskId = c.GetString("task_id")
  243. }
  244. userId := c.GetInt("id")
  245. originTask, exist, err := model.GetByTaskId(userId, taskId)
  246. if err != nil {
  247. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  248. return
  249. }
  250. if !exist {
  251. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  252. return
  253. }
  254. respBody, err = json.Marshal(dto.TaskResponse[any]{
  255. Code: "success",
  256. Data: TaskModel2Dto(originTask),
  257. })
  258. return
  259. }
  260. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  261. return &dto.TaskDto{
  262. TaskID: task.TaskID,
  263. Action: task.Action,
  264. Status: string(task.Status),
  265. FailReason: task.FailReason,
  266. SubmitTime: task.SubmitTime,
  267. StartTime: task.StartTime,
  268. FinishTime: task.FinishTime,
  269. Progress: task.Progress,
  270. Data: task.Data,
  271. }
  272. }