quota.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/bytedance/gopkg/util/gopool"
  6. "math"
  7. "one-api/common"
  8. constant2 "one-api/constant"
  9. "one-api/dto"
  10. "one-api/model"
  11. relaycommon "one-api/relay/common"
  12. "one-api/setting"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. )
  17. type TokenDetails struct {
  18. TextTokens int
  19. AudioTokens int
  20. }
  21. type QuotaInfo struct {
  22. InputDetails TokenDetails
  23. OutputDetails TokenDetails
  24. ModelName string
  25. UsePrice bool
  26. ModelPrice float64
  27. ModelRatio float64
  28. GroupRatio float64
  29. }
  30. func calculateAudioQuota(info QuotaInfo) int {
  31. if info.UsePrice {
  32. return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
  33. }
  34. completionRatio := common.GetCompletionRatio(info.ModelName)
  35. audioRatio := common.GetAudioRatio(info.ModelName)
  36. audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName)
  37. ratio := info.GroupRatio * info.ModelRatio
  38. quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
  39. quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) +
  40. int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio))
  41. quota = int(math.Round(float64(quota) * ratio))
  42. if ratio != 0 && quota <= 0 {
  43. quota = 1
  44. }
  45. return quota
  46. }
  47. func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
  48. if relayInfo.UsePrice {
  49. return nil
  50. }
  51. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  52. if err != nil {
  53. return err
  54. }
  55. token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
  56. if err != nil {
  57. return err
  58. }
  59. modelName := relayInfo.UpstreamModelName
  60. textInputTokens := usage.InputTokenDetails.TextTokens
  61. textOutTokens := usage.OutputTokenDetails.TextTokens
  62. audioInputTokens := usage.InputTokenDetails.AudioTokens
  63. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  64. groupRatio := setting.GetGroupRatio(relayInfo.Group)
  65. modelRatio := common.GetModelRatio(modelName)
  66. quotaInfo := QuotaInfo{
  67. InputDetails: TokenDetails{
  68. TextTokens: textInputTokens,
  69. AudioTokens: audioInputTokens,
  70. },
  71. OutputDetails: TokenDetails{
  72. TextTokens: textOutTokens,
  73. AudioTokens: audioOutTokens,
  74. },
  75. ModelName: modelName,
  76. UsePrice: relayInfo.UsePrice,
  77. ModelRatio: modelRatio,
  78. GroupRatio: groupRatio,
  79. }
  80. quota := calculateAudioQuota(quotaInfo)
  81. if userQuota < quota {
  82. return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
  83. }
  84. if !token.UnlimitedQuota && token.RemainQuota < quota {
  85. return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
  86. }
  87. err = PostConsumeQuota(relayInfo, quota, 0, false)
  88. if err != nil {
  89. return err
  90. }
  91. common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
  92. return nil
  93. }
  94. func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
  95. usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
  96. modelPrice float64, usePrice bool, extraContent string) {
  97. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  98. textInputTokens := usage.InputTokenDetails.TextTokens
  99. textOutTokens := usage.OutputTokenDetails.TextTokens
  100. audioInputTokens := usage.InputTokenDetails.AudioTokens
  101. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  102. tokenName := ctx.GetString("token_name")
  103. completionRatio := common.GetCompletionRatio(modelName)
  104. audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
  105. audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
  106. quotaInfo := QuotaInfo{
  107. InputDetails: TokenDetails{
  108. TextTokens: textInputTokens,
  109. AudioTokens: audioInputTokens,
  110. },
  111. OutputDetails: TokenDetails{
  112. TextTokens: textOutTokens,
  113. AudioTokens: audioOutTokens,
  114. },
  115. ModelName: modelName,
  116. UsePrice: usePrice,
  117. ModelRatio: modelRatio,
  118. GroupRatio: groupRatio,
  119. }
  120. quota := calculateAudioQuota(quotaInfo)
  121. totalTokens := usage.TotalTokens
  122. var logContent string
  123. if !usePrice {
  124. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  125. } else {
  126. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  127. }
  128. // record all the consume log even if quota is 0
  129. if totalTokens == 0 {
  130. // in this case, must be some error happened
  131. // we cannot just return, because we may have to return the pre-consumed quota
  132. quota = 0
  133. logContent += fmt.Sprintf("(可能是上游超时)")
  134. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  135. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
  136. } else {
  137. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  138. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  139. }
  140. logModel := modelName
  141. if extraContent != "" {
  142. logContent += ", " + extraContent
  143. }
  144. other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  145. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
  146. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  147. }
  148. func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
  149. usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
  150. modelPrice float64, usePrice bool, extraContent string) {
  151. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  152. textInputTokens := usage.PromptTokensDetails.TextTokens
  153. textOutTokens := usage.CompletionTokenDetails.TextTokens
  154. audioInputTokens := usage.PromptTokensDetails.AudioTokens
  155. audioOutTokens := usage.CompletionTokenDetails.AudioTokens
  156. tokenName := ctx.GetString("token_name")
  157. completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName)
  158. audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName)
  159. audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName)
  160. quotaInfo := QuotaInfo{
  161. InputDetails: TokenDetails{
  162. TextTokens: textInputTokens,
  163. AudioTokens: audioInputTokens,
  164. },
  165. OutputDetails: TokenDetails{
  166. TextTokens: textOutTokens,
  167. AudioTokens: audioOutTokens,
  168. },
  169. ModelName: relayInfo.RecodeModelName,
  170. UsePrice: usePrice,
  171. ModelRatio: modelRatio,
  172. GroupRatio: groupRatio,
  173. }
  174. quota := calculateAudioQuota(quotaInfo)
  175. totalTokens := usage.TotalTokens
  176. var logContent string
  177. if !usePrice {
  178. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  179. } else {
  180. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  181. }
  182. // record all the consume log even if quota is 0
  183. if totalTokens == 0 {
  184. // in this case, must be some error happened
  185. // we cannot just return, because we may have to return the pre-consumed quota
  186. quota = 0
  187. logContent += fmt.Sprintf("(可能是上游超时)")
  188. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  189. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota))
  190. } else {
  191. quotaDelta := quota - preConsumedQuota
  192. if quotaDelta != 0 {
  193. err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
  194. if err != nil {
  195. common.LogError(ctx, "error consuming token remain quota: "+err.Error())
  196. }
  197. }
  198. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  199. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  200. }
  201. logModel := relayInfo.RecodeModelName
  202. if extraContent != "" {
  203. logContent += ", " + extraContent
  204. }
  205. other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  206. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
  207. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  208. }
  209. func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
  210. if quota < 0 {
  211. return errors.New("quota 不能为负数!")
  212. }
  213. if relayInfo.IsPlayground {
  214. return nil
  215. }
  216. //if relayInfo.TokenUnlimited {
  217. // return nil
  218. //}
  219. token, err := model.GetTokenById(relayInfo.TokenId)
  220. if err != nil {
  221. return err
  222. }
  223. if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
  224. return errors.New("令牌额度不足")
  225. }
  226. err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
  227. if err != nil {
  228. return err
  229. }
  230. return nil
  231. }
  232. func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
  233. if quota > 0 {
  234. err = model.DecreaseUserQuota(relayInfo.UserId, quota)
  235. } else {
  236. err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
  237. }
  238. if err != nil {
  239. return err
  240. }
  241. if !relayInfo.IsPlayground {
  242. if quota > 0 {
  243. err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
  244. } else {
  245. err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
  246. }
  247. if err != nil {
  248. return err
  249. }
  250. }
  251. if sendEmail {
  252. if (quota + preConsumedQuota) != 0 {
  253. checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
  254. }
  255. }
  256. return nil
  257. }
  258. func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
  259. gopool.Go(func() {
  260. userCache, err := model.GetUserCache(userId)
  261. if err != nil {
  262. common.SysError("failed to get user cache: " + err.Error())
  263. }
  264. userSetting := userCache.GetSetting()
  265. threshold := common.QuotaRemindThreshold
  266. if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
  267. threshold = int(userCustomThreshold.(float64))
  268. }
  269. //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
  270. quotaTooLow := false
  271. consumeQuota := quota + preConsumedQuota
  272. if userCache.Quota-consumeQuota < threshold {
  273. quotaTooLow = true
  274. }
  275. if quotaTooLow {
  276. prompt := "您的额度即将用尽"
  277. topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
  278. content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
  279. err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
  280. if err != nil {
  281. common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
  282. }
  283. }
  284. })
  285. }