relay_task.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. "one-api/model"
  13. relaycommon "one-api/relay/common"
  14. relayconstant "one-api/relay/constant"
  15. "one-api/service"
  16. "one-api/setting/ratio_setting"
  17. "github.com/gin-gonic/gin"
  18. )
  19. /*
  20. Task 任务通过平台、Action 区分任务
  21. */
  22. func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
  23. platform := constant.TaskPlatform(c.GetString("platform"))
  24. if platform == "" {
  25. platform = GetTaskPlatform(c)
  26. }
  27. relayInfo, err := relaycommon.GenTaskRelayInfo(c)
  28. if err != nil {
  29. return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError)
  30. }
  31. adaptor := GetTaskAdaptor(platform)
  32. if adaptor == nil {
  33. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  34. }
  35. adaptor.Init(relayInfo)
  36. // get & validate taskRequest 获取并验证文本请求
  37. taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
  38. if taskErr != nil {
  39. return
  40. }
  41. modelName := relayInfo.OriginModelName
  42. if modelName == "" {
  43. modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
  44. }
  45. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  46. if !success {
  47. defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
  48. if !ok {
  49. modelPrice = 0.1
  50. } else {
  51. modelPrice = defaultPrice
  52. }
  53. }
  54. // 预扣
  55. groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
  56. var ratio float64
  57. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
  58. if hasUserGroupRatio {
  59. ratio = modelPrice * userGroupRatio
  60. } else {
  61. ratio = modelPrice * groupRatio
  62. }
  63. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  64. if err != nil {
  65. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  66. return
  67. }
  68. quota := int(ratio * common.QuotaPerUnit)
  69. if userQuota-quota < 0 {
  70. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  71. return
  72. }
  73. if relayInfo.OriginTaskID != "" {
  74. originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
  75. if err != nil {
  76. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  77. return
  78. }
  79. if !exist {
  80. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  81. return
  82. }
  83. if originTask.ChannelId != relayInfo.ChannelId {
  84. channel, err := model.GetChannelById(originTask.ChannelId, true)
  85. if err != nil {
  86. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  87. return
  88. }
  89. if channel.Status != common.ChannelStatusEnabled {
  90. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  91. }
  92. c.Set("base_url", channel.GetBaseURL())
  93. c.Set("channel_id", originTask.ChannelId)
  94. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  95. relayInfo.ChannelBaseUrl = channel.GetBaseURL()
  96. relayInfo.ChannelId = originTask.ChannelId
  97. }
  98. }
  99. // build body
  100. requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
  101. if err != nil {
  102. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  103. return
  104. }
  105. // do request
  106. resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
  107. if err != nil {
  108. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  109. return
  110. }
  111. // handle response
  112. if resp != nil && resp.StatusCode != http.StatusOK {
  113. responseBody, _ := io.ReadAll(resp.Body)
  114. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  115. return
  116. }
  117. defer func() {
  118. // release quota
  119. if relayInfo.ConsumeQuota && taskErr == nil {
  120. err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
  121. if err != nil {
  122. common.SysLog("error consuming token remain quota: " + err.Error())
  123. }
  124. if quota != 0 {
  125. tokenName := c.GetString("token_name")
  126. gRatio := groupRatio
  127. if hasUserGroupRatio {
  128. gRatio = userGroupRatio
  129. }
  130. logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
  131. other := make(map[string]interface{})
  132. other["model_price"] = modelPrice
  133. other["group_ratio"] = groupRatio
  134. if hasUserGroupRatio {
  135. other["user_group_ratio"] = userGroupRatio
  136. }
  137. model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
  138. ChannelId: relayInfo.ChannelId,
  139. ModelName: modelName,
  140. TokenName: tokenName,
  141. Quota: quota,
  142. Content: logContent,
  143. TokenId: relayInfo.TokenId,
  144. Group: relayInfo.UsingGroup,
  145. Other: other,
  146. })
  147. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  148. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  149. }
  150. }
  151. }()
  152. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
  153. if taskErr != nil {
  154. return
  155. }
  156. relayInfo.ConsumeQuota = true
  157. // insert task
  158. task := model.InitTask(platform, relayInfo)
  159. task.TaskID = taskID
  160. task.Quota = quota
  161. task.Data = taskData
  162. task.Action = relayInfo.Action
  163. err = task.Insert()
  164. if err != nil {
  165. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  166. return
  167. }
  168. return nil
  169. }
  170. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  171. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  172. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  173. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  174. }
  175. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  176. respBuilder, ok := fetchRespBuilders[relayMode]
  177. if !ok {
  178. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  179. }
  180. respBody, taskErr := respBuilder(c)
  181. if taskErr != nil {
  182. return taskErr
  183. }
  184. c.Writer.Header().Set("Content-Type", "application/json")
  185. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  186. if err != nil {
  187. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  188. return
  189. }
  190. return
  191. }
  192. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  193. userId := c.GetInt("id")
  194. var condition = struct {
  195. IDs []any `json:"ids"`
  196. Action string `json:"action"`
  197. }{}
  198. err := c.BindJSON(&condition)
  199. if err != nil {
  200. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  201. return
  202. }
  203. var tasks []any
  204. if len(condition.IDs) > 0 {
  205. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  206. if err != nil {
  207. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  208. return
  209. }
  210. for _, task := range taskModels {
  211. tasks = append(tasks, TaskModel2Dto(task))
  212. }
  213. } else {
  214. tasks = make([]any, 0)
  215. }
  216. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  217. Code: "success",
  218. Data: tasks,
  219. })
  220. return
  221. }
  222. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  223. taskId := c.Param("id")
  224. userId := c.GetInt("id")
  225. originTask, exist, err := model.GetByTaskId(userId, taskId)
  226. if err != nil {
  227. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  228. return
  229. }
  230. if !exist {
  231. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  232. return
  233. }
  234. respBody, err = json.Marshal(dto.TaskResponse[any]{
  235. Code: "success",
  236. Data: TaskModel2Dto(originTask),
  237. })
  238. return
  239. }
  240. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  241. taskId := c.Param("task_id")
  242. if taskId == "" {
  243. taskId = c.GetString("task_id")
  244. }
  245. userId := c.GetInt("id")
  246. originTask, exist, err := model.GetByTaskId(userId, taskId)
  247. if err != nil {
  248. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  249. return
  250. }
  251. if !exist {
  252. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  253. return
  254. }
  255. respBody, err = json.Marshal(dto.TaskResponse[any]{
  256. Code: "success",
  257. Data: TaskModel2Dto(originTask),
  258. })
  259. return
  260. }
  261. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  262. return &dto.TaskDto{
  263. TaskID: task.TaskID,
  264. Action: task.Action,
  265. Status: string(task.Status),
  266. FailReason: task.FailReason,
  267. SubmitTime: task.SubmitTime,
  268. StartTime: task.StartTime,
  269. FinishTime: task.FinishTime,
  270. Progress: task.Progress,
  271. Data: task.Data,
  272. }
  273. }