relay_task.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. "one-api/model"
  13. relaycommon "one-api/relay/common"
  14. relayconstant "one-api/relay/constant"
  15. "one-api/service"
  16. "one-api/setting/ratio_setting"
  17. "github.com/gin-gonic/gin"
  18. )
  19. /*
  20. Task 任务通过平台、Action 区分任务
  21. */
  22. func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  23. info.InitChannelMeta(c)
  24. // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
  25. if info.TaskRelayInfo == nil {
  26. info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
  27. }
  28. platform := constant.TaskPlatform(c.GetString("platform"))
  29. if platform == "" {
  30. platform = GetTaskPlatform(c)
  31. }
  32. adaptor := GetTaskAdaptor(platform)
  33. if adaptor == nil {
  34. return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
  35. }
  36. adaptor.Init(info)
  37. // get & validate taskRequest 获取并验证文本请求
  38. taskErr = adaptor.ValidateRequestAndSetAction(c, info)
  39. if taskErr != nil {
  40. return
  41. }
  42. modelName := info.OriginModelName
  43. if modelName == "" {
  44. modelName = service.CoverTaskActionToModelName(platform, info.Action)
  45. }
  46. modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
  47. if !success {
  48. defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
  49. if !ok {
  50. modelPrice = 0.1
  51. } else {
  52. modelPrice = defaultPrice
  53. }
  54. }
  55. // 预扣
  56. groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
  57. var ratio float64
  58. userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
  59. if hasUserGroupRatio {
  60. ratio = modelPrice * userGroupRatio
  61. } else {
  62. ratio = modelPrice * groupRatio
  63. }
  64. userQuota, err := model.GetUserQuota(info.UserId, false)
  65. if err != nil {
  66. taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
  67. return
  68. }
  69. quota := int(ratio * common.QuotaPerUnit)
  70. if userQuota-quota < 0 {
  71. taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
  72. return
  73. }
  74. if info.OriginTaskID != "" {
  75. originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
  76. if err != nil {
  77. taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
  78. return
  79. }
  80. if !exist {
  81. taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
  82. return
  83. }
  84. if originTask.ChannelId != info.ChannelId {
  85. channel, err := model.GetChannelById(originTask.ChannelId, true)
  86. if err != nil {
  87. taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
  88. return
  89. }
  90. if channel.Status != common.ChannelStatusEnabled {
  91. return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
  92. }
  93. c.Set("base_url", channel.GetBaseURL())
  94. c.Set("channel_id", originTask.ChannelId)
  95. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  96. info.ChannelBaseUrl = channel.GetBaseURL()
  97. info.ChannelId = originTask.ChannelId
  98. }
  99. }
  100. // build body
  101. requestBody, err := adaptor.BuildRequestBody(c, info)
  102. if err != nil {
  103. taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
  104. return
  105. }
  106. // do request
  107. resp, err := adaptor.DoRequest(c, info, requestBody)
  108. if err != nil {
  109. taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
  110. return
  111. }
  112. // handle response
  113. if resp != nil && resp.StatusCode != http.StatusOK {
  114. responseBody, _ := io.ReadAll(resp.Body)
  115. taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
  116. return
  117. }
  118. defer func() {
  119. // release quota
  120. if info.ConsumeQuota && taskErr == nil {
  121. err := service.PostConsumeQuota(info, quota, 0, true)
  122. if err != nil {
  123. common.SysLog("error consuming token remain quota: " + err.Error())
  124. }
  125. if quota != 0 {
  126. tokenName := c.GetString("token_name")
  127. gRatio := groupRatio
  128. if hasUserGroupRatio {
  129. gRatio = userGroupRatio
  130. }
  131. logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action)
  132. other := make(map[string]interface{})
  133. other["model_price"] = modelPrice
  134. other["group_ratio"] = groupRatio
  135. if hasUserGroupRatio {
  136. other["user_group_ratio"] = userGroupRatio
  137. }
  138. model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
  139. ChannelId: info.ChannelId,
  140. ModelName: modelName,
  141. TokenName: tokenName,
  142. Quota: quota,
  143. Content: logContent,
  144. TokenId: info.TokenId,
  145. Group: info.UsingGroup,
  146. Other: other,
  147. })
  148. model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
  149. model.UpdateChannelUsedQuota(info.ChannelId, quota)
  150. }
  151. }
  152. }()
  153. taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
  154. if taskErr != nil {
  155. return
  156. }
  157. info.ConsumeQuota = true
  158. // insert task
  159. task := model.InitTask(platform, info)
  160. task.TaskID = taskID
  161. task.Quota = quota
  162. task.Data = taskData
  163. task.Action = info.Action
  164. err = task.Insert()
  165. if err != nil {
  166. taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
  167. return
  168. }
  169. return nil
  170. }
  171. var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
  172. relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
  173. relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
  174. relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
  175. }
  176. func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
  177. respBuilder, ok := fetchRespBuilders[relayMode]
  178. if !ok {
  179. taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
  180. }
  181. respBody, taskErr := respBuilder(c)
  182. if taskErr != nil {
  183. return taskErr
  184. }
  185. c.Writer.Header().Set("Content-Type", "application/json")
  186. _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
  187. if err != nil {
  188. taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
  189. return
  190. }
  191. return
  192. }
  193. func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  194. userId := c.GetInt("id")
  195. var condition = struct {
  196. IDs []any `json:"ids"`
  197. Action string `json:"action"`
  198. }{}
  199. err := c.BindJSON(&condition)
  200. if err != nil {
  201. taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
  202. return
  203. }
  204. var tasks []any
  205. if len(condition.IDs) > 0 {
  206. taskModels, err := model.GetByTaskIds(userId, condition.IDs)
  207. if err != nil {
  208. taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
  209. return
  210. }
  211. for _, task := range taskModels {
  212. tasks = append(tasks, TaskModel2Dto(task))
  213. }
  214. } else {
  215. tasks = make([]any, 0)
  216. }
  217. respBody, err = json.Marshal(dto.TaskResponse[[]any]{
  218. Code: "success",
  219. Data: tasks,
  220. })
  221. return
  222. }
  223. func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  224. taskId := c.Param("id")
  225. userId := c.GetInt("id")
  226. originTask, exist, err := model.GetByTaskId(userId, taskId)
  227. if err != nil {
  228. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  229. return
  230. }
  231. if !exist {
  232. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  233. return
  234. }
  235. respBody, err = json.Marshal(dto.TaskResponse[any]{
  236. Code: "success",
  237. Data: TaskModel2Dto(originTask),
  238. })
  239. return
  240. }
  241. func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
  242. taskId := c.Param("task_id")
  243. if taskId == "" {
  244. taskId = c.GetString("task_id")
  245. }
  246. userId := c.GetInt("id")
  247. originTask, exist, err := model.GetByTaskId(userId, taskId)
  248. if err != nil {
  249. taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
  250. return
  251. }
  252. if !exist {
  253. taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
  254. return
  255. }
  256. respBody, err = json.Marshal(dto.TaskResponse[any]{
  257. Code: "success",
  258. Data: TaskModel2Dto(originTask),
  259. })
  260. return
  261. }
  262. func TaskModel2Dto(task *model.Task) *dto.TaskDto {
  263. return &dto.TaskDto{
  264. TaskID: task.TaskID,
  265. Action: task.Action,
  266. Status: string(task.Status),
  267. FailReason: task.FailReason,
  268. SubmitTime: task.SubmitTime,
  269. StartTime: task.StartTime,
  270. FinishTime: task.FinishTime,
  271. Progress: task.Progress,
  272. Data: task.Data,
  273. }
  274. }