quota.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/bytedance/gopkg/util/gopool"
  6. "math"
  7. "one-api/common"
  8. constant2 "one-api/constant"
  9. "one-api/dto"
  10. "one-api/model"
  11. relaycommon "one-api/relay/common"
  12. "one-api/relay/helper"
  13. "one-api/setting"
  14. "strings"
  15. "time"
  16. "github.com/gin-gonic/gin"
  17. )
  18. type TokenDetails struct {
  19. TextTokens int
  20. AudioTokens int
  21. }
  22. type QuotaInfo struct {
  23. InputDetails TokenDetails
  24. OutputDetails TokenDetails
  25. ModelName string
  26. UsePrice bool
  27. ModelPrice float64
  28. ModelRatio float64
  29. GroupRatio float64
  30. }
  31. func calculateAudioQuota(info QuotaInfo) int {
  32. if info.UsePrice {
  33. return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
  34. }
  35. completionRatio := common.GetCompletionRatio(info.ModelName)
  36. audioRatio := common.GetAudioRatio(info.ModelName)
  37. audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName)
  38. ratio := info.GroupRatio * info.ModelRatio
  39. quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
  40. quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) +
  41. int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio))
  42. quota = int(math.Round(float64(quota) * ratio))
  43. if ratio != 0 && quota <= 0 {
  44. quota = 1
  45. }
  46. return quota
  47. }
  48. func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
  49. if relayInfo.UsePrice {
  50. return nil
  51. }
  52. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  53. if err != nil {
  54. return err
  55. }
  56. token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
  57. if err != nil {
  58. return err
  59. }
  60. modelName := relayInfo.OriginModelName
  61. textInputTokens := usage.InputTokenDetails.TextTokens
  62. textOutTokens := usage.OutputTokenDetails.TextTokens
  63. audioInputTokens := usage.InputTokenDetails.AudioTokens
  64. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  65. groupRatio := setting.GetGroupRatio(relayInfo.Group)
  66. modelRatio := common.GetModelRatio(modelName)
  67. quotaInfo := QuotaInfo{
  68. InputDetails: TokenDetails{
  69. TextTokens: textInputTokens,
  70. AudioTokens: audioInputTokens,
  71. },
  72. OutputDetails: TokenDetails{
  73. TextTokens: textOutTokens,
  74. AudioTokens: audioOutTokens,
  75. },
  76. ModelName: modelName,
  77. UsePrice: relayInfo.UsePrice,
  78. ModelRatio: modelRatio,
  79. GroupRatio: groupRatio,
  80. }
  81. quota := calculateAudioQuota(quotaInfo)
  82. if userQuota < quota {
  83. return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
  84. }
  85. if !token.UnlimitedQuota && token.RemainQuota < quota {
  86. return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
  87. }
  88. err = PostConsumeQuota(relayInfo, quota, 0, false)
  89. if err != nil {
  90. return err
  91. }
  92. common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
  93. return nil
  94. }
  95. func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
  96. usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
  97. modelPrice float64, usePrice bool, extraContent string) {
  98. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  99. textInputTokens := usage.InputTokenDetails.TextTokens
  100. textOutTokens := usage.OutputTokenDetails.TextTokens
  101. audioInputTokens := usage.InputTokenDetails.AudioTokens
  102. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  103. tokenName := ctx.GetString("token_name")
  104. completionRatio := common.GetCompletionRatio(modelName)
  105. audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
  106. audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
  107. quotaInfo := QuotaInfo{
  108. InputDetails: TokenDetails{
  109. TextTokens: textInputTokens,
  110. AudioTokens: audioInputTokens,
  111. },
  112. OutputDetails: TokenDetails{
  113. TextTokens: textOutTokens,
  114. AudioTokens: audioOutTokens,
  115. },
  116. ModelName: modelName,
  117. UsePrice: usePrice,
  118. ModelRatio: modelRatio,
  119. GroupRatio: groupRatio,
  120. }
  121. quota := calculateAudioQuota(quotaInfo)
  122. totalTokens := usage.TotalTokens
  123. var logContent string
  124. if !usePrice {
  125. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  126. } else {
  127. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  128. }
  129. // record all the consume log even if quota is 0
  130. if totalTokens == 0 {
  131. // in this case, must be some error happened
  132. // we cannot just return, because we may have to return the pre-consumed quota
  133. quota = 0
  134. logContent += fmt.Sprintf("(可能是上游超时)")
  135. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  136. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
  137. } else {
  138. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  139. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  140. }
  141. logModel := modelName
  142. if extraContent != "" {
  143. logContent += ", " + extraContent
  144. }
  145. other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  146. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
  147. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  148. }
  149. func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
  150. usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
  151. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  152. textInputTokens := usage.PromptTokensDetails.TextTokens
  153. textOutTokens := usage.CompletionTokenDetails.TextTokens
  154. audioInputTokens := usage.PromptTokensDetails.AudioTokens
  155. audioOutTokens := usage.CompletionTokenDetails.AudioTokens
  156. tokenName := ctx.GetString("token_name")
  157. completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
  158. audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
  159. audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
  160. modelRatio := priceData.ModelRatio
  161. groupRatio := priceData.GroupRatio
  162. modelPrice := priceData.ModelPrice
  163. usePrice := priceData.UsePrice
  164. quotaInfo := QuotaInfo{
  165. InputDetails: TokenDetails{
  166. TextTokens: textInputTokens,
  167. AudioTokens: audioInputTokens,
  168. },
  169. OutputDetails: TokenDetails{
  170. TextTokens: textOutTokens,
  171. AudioTokens: audioOutTokens,
  172. },
  173. ModelName: relayInfo.OriginModelName,
  174. UsePrice: usePrice,
  175. ModelRatio: modelRatio,
  176. GroupRatio: groupRatio,
  177. }
  178. quota := calculateAudioQuota(quotaInfo)
  179. totalTokens := usage.TotalTokens
  180. var logContent string
  181. if !usePrice {
  182. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  183. } else {
  184. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  185. }
  186. // record all the consume log even if quota is 0
  187. if totalTokens == 0 {
  188. // in this case, must be some error happened
  189. // we cannot just return, because we may have to return the pre-consumed quota
  190. quota = 0
  191. logContent += fmt.Sprintf("(可能是上游超时)")
  192. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  193. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
  194. } else {
  195. quotaDelta := quota - preConsumedQuota
  196. if quotaDelta != 0 {
  197. err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
  198. if err != nil {
  199. common.LogError(ctx, "error consuming token remain quota: "+err.Error())
  200. }
  201. }
  202. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  203. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  204. }
  205. logModel := relayInfo.OriginModelName
  206. if extraContent != "" {
  207. logContent += ", " + extraContent
  208. }
  209. other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  210. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
  211. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  212. }
  213. func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
  214. if quota < 0 {
  215. return errors.New("quota 不能为负数!")
  216. }
  217. if relayInfo.IsPlayground {
  218. return nil
  219. }
  220. //if relayInfo.TokenUnlimited {
  221. // return nil
  222. //}
  223. token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
  224. if err != nil {
  225. return err
  226. }
  227. if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
  228. return errors.New("令牌额度不足")
  229. }
  230. err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
  231. if err != nil {
  232. return err
  233. }
  234. return nil
  235. }
  236. func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
  237. if quota > 0 {
  238. err = model.DecreaseUserQuota(relayInfo.UserId, quota)
  239. } else {
  240. err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
  241. }
  242. if err != nil {
  243. return err
  244. }
  245. if !relayInfo.IsPlayground {
  246. if quota > 0 {
  247. err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
  248. } else {
  249. err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
  250. }
  251. if err != nil {
  252. return err
  253. }
  254. }
  255. if sendEmail {
  256. if (quota + preConsumedQuota) != 0 {
  257. checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
  258. }
  259. }
  260. return nil
  261. }
  262. func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
  263. gopool.Go(func() {
  264. userCache, err := model.GetUserCache(userId)
  265. if err != nil {
  266. common.SysError("failed to get user cache: " + err.Error())
  267. }
  268. userSetting := userCache.GetSetting()
  269. threshold := common.QuotaRemindThreshold
  270. if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
  271. threshold = int(userCustomThreshold.(float64))
  272. }
  273. //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
  274. quotaTooLow := false
  275. consumeQuota := quota + preConsumedQuota
  276. if userCache.Quota-consumeQuota < threshold {
  277. quotaTooLow = true
  278. }
  279. if quotaTooLow {
  280. prompt := "您的额度即将用尽"
  281. topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
  282. content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
  283. err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
  284. if err != nil {
  285. common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
  286. }
  287. }
  288. })
  289. }