quota.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/bytedance/gopkg/util/gopool"
  6. "math"
  7. "one-api/common"
  8. constant2 "one-api/constant"
  9. "one-api/dto"
  10. "one-api/model"
  11. relaycommon "one-api/relay/common"
  12. "one-api/relay/helper"
  13. "one-api/setting"
  14. "one-api/setting/operation_setting"
  15. "strings"
  16. "time"
  17. "github.com/gin-gonic/gin"
  18. )
  19. type TokenDetails struct {
  20. TextTokens int
  21. AudioTokens int
  22. }
  23. type QuotaInfo struct {
  24. InputDetails TokenDetails
  25. OutputDetails TokenDetails
  26. ModelName string
  27. UsePrice bool
  28. ModelPrice float64
  29. ModelRatio float64
  30. GroupRatio float64
  31. }
  32. func calculateAudioQuota(info QuotaInfo) int {
  33. if info.UsePrice {
  34. return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
  35. }
  36. completionRatio := operation_setting.GetCompletionRatio(info.ModelName)
  37. audioRatio := operation_setting.GetAudioRatio(info.ModelName)
  38. audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName)
  39. ratio := info.GroupRatio * info.ModelRatio
  40. quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
  41. quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) +
  42. int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio))
  43. quota = int(math.Round(float64(quota) * ratio))
  44. if ratio != 0 && quota <= 0 {
  45. quota = 1
  46. }
  47. return quota
  48. }
  49. func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
  50. if relayInfo.UsePrice {
  51. return nil
  52. }
  53. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  54. if err != nil {
  55. return err
  56. }
  57. token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
  58. if err != nil {
  59. return err
  60. }
  61. modelName := relayInfo.OriginModelName
  62. textInputTokens := usage.InputTokenDetails.TextTokens
  63. textOutTokens := usage.OutputTokenDetails.TextTokens
  64. audioInputTokens := usage.InputTokenDetails.AudioTokens
  65. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  66. groupRatio := setting.GetGroupRatio(relayInfo.Group)
  67. modelRatio, _ := operation_setting.GetModelRatio(modelName)
  68. quotaInfo := QuotaInfo{
  69. InputDetails: TokenDetails{
  70. TextTokens: textInputTokens,
  71. AudioTokens: audioInputTokens,
  72. },
  73. OutputDetails: TokenDetails{
  74. TextTokens: textOutTokens,
  75. AudioTokens: audioOutTokens,
  76. },
  77. ModelName: modelName,
  78. UsePrice: relayInfo.UsePrice,
  79. ModelRatio: modelRatio,
  80. GroupRatio: groupRatio,
  81. }
  82. quota := calculateAudioQuota(quotaInfo)
  83. if userQuota < quota {
  84. return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
  85. }
  86. if !token.UnlimitedQuota && token.RemainQuota < quota {
  87. return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
  88. }
  89. err = PostConsumeQuota(relayInfo, quota, 0, false)
  90. if err != nil {
  91. return err
  92. }
  93. common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
  94. return nil
  95. }
  96. func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
  97. usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
  98. modelPrice float64, usePrice bool, extraContent string) {
  99. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  100. textInputTokens := usage.InputTokenDetails.TextTokens
  101. textOutTokens := usage.OutputTokenDetails.TextTokens
  102. audioInputTokens := usage.InputTokenDetails.AudioTokens
  103. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  104. tokenName := ctx.GetString("token_name")
  105. completionRatio := operation_setting.GetCompletionRatio(modelName)
  106. audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
  107. audioCompletionRatio := operation_setting.GetAudioCompletionRatio(modelName)
  108. quotaInfo := QuotaInfo{
  109. InputDetails: TokenDetails{
  110. TextTokens: textInputTokens,
  111. AudioTokens: audioInputTokens,
  112. },
  113. OutputDetails: TokenDetails{
  114. TextTokens: textOutTokens,
  115. AudioTokens: audioOutTokens,
  116. },
  117. ModelName: modelName,
  118. UsePrice: usePrice,
  119. ModelRatio: modelRatio,
  120. GroupRatio: groupRatio,
  121. }
  122. quota := calculateAudioQuota(quotaInfo)
  123. totalTokens := usage.TotalTokens
  124. var logContent string
  125. if !usePrice {
  126. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  127. } else {
  128. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  129. }
  130. // record all the consume log even if quota is 0
  131. if totalTokens == 0 {
  132. // in this case, must be some error happened
  133. // we cannot just return, because we may have to return the pre-consumed quota
  134. quota = 0
  135. logContent += fmt.Sprintf("(可能是上游超时)")
  136. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  137. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
  138. } else {
  139. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  140. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  141. }
  142. logModel := modelName
  143. if extraContent != "" {
  144. logContent += ", " + extraContent
  145. }
  146. other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  147. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
  148. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  149. }
  150. func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
  151. usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
  152. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  153. textInputTokens := usage.PromptTokensDetails.TextTokens
  154. textOutTokens := usage.CompletionTokenDetails.TextTokens
  155. audioInputTokens := usage.PromptTokensDetails.AudioTokens
  156. audioOutTokens := usage.CompletionTokenDetails.AudioTokens
  157. tokenName := ctx.GetString("token_name")
  158. completionRatio := operation_setting.GetCompletionRatio(relayInfo.OriginModelName)
  159. audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
  160. audioCompletionRatio := operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)
  161. modelRatio := priceData.ModelRatio
  162. groupRatio := priceData.GroupRatio
  163. modelPrice := priceData.ModelPrice
  164. usePrice := priceData.UsePrice
  165. quotaInfo := QuotaInfo{
  166. InputDetails: TokenDetails{
  167. TextTokens: textInputTokens,
  168. AudioTokens: audioInputTokens,
  169. },
  170. OutputDetails: TokenDetails{
  171. TextTokens: textOutTokens,
  172. AudioTokens: audioOutTokens,
  173. },
  174. ModelName: relayInfo.OriginModelName,
  175. UsePrice: usePrice,
  176. ModelRatio: modelRatio,
  177. GroupRatio: groupRatio,
  178. }
  179. quota := calculateAudioQuota(quotaInfo)
  180. totalTokens := usage.TotalTokens
  181. var logContent string
  182. if !usePrice {
  183. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  184. } else {
  185. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  186. }
  187. // record all the consume log even if quota is 0
  188. if totalTokens == 0 {
  189. // in this case, must be some error happened
  190. // we cannot just return, because we may have to return the pre-consumed quota
  191. quota = 0
  192. logContent += fmt.Sprintf("(可能是上游超时)")
  193. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  194. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
  195. } else {
  196. quotaDelta := quota - preConsumedQuota
  197. if quotaDelta != 0 {
  198. err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
  199. if err != nil {
  200. common.LogError(ctx, "error consuming token remain quota: "+err.Error())
  201. }
  202. }
  203. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  204. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  205. }
  206. logModel := relayInfo.OriginModelName
  207. if extraContent != "" {
  208. logContent += ", " + extraContent
  209. }
  210. other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  211. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
  212. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  213. }
  214. func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
  215. if quota < 0 {
  216. return errors.New("quota 不能为负数!")
  217. }
  218. if relayInfo.IsPlayground {
  219. return nil
  220. }
  221. //if relayInfo.TokenUnlimited {
  222. // return nil
  223. //}
  224. token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
  225. if err != nil {
  226. return err
  227. }
  228. if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
  229. return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
  230. }
  231. err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
  232. if err != nil {
  233. return err
  234. }
  235. return nil
  236. }
  237. func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
  238. if quota > 0 {
  239. err = model.DecreaseUserQuota(relayInfo.UserId, quota)
  240. } else {
  241. err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
  242. }
  243. if err != nil {
  244. return err
  245. }
  246. if !relayInfo.IsPlayground {
  247. if quota > 0 {
  248. err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
  249. } else {
  250. err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
  251. }
  252. if err != nil {
  253. return err
  254. }
  255. }
  256. if sendEmail {
  257. if (quota + preConsumedQuota) != 0 {
  258. checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
  259. }
  260. }
  261. return nil
  262. }
  263. func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
  264. gopool.Go(func() {
  265. userSetting := relayInfo.UserSetting
  266. threshold := common.QuotaRemindThreshold
  267. if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
  268. threshold = int(userCustomThreshold.(float64))
  269. }
  270. //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
  271. quotaTooLow := false
  272. consumeQuota := quota + preConsumedQuota
  273. if relayInfo.UserQuota-consumeQuota < threshold {
  274. quotaTooLow = true
  275. }
  276. if quotaTooLow {
  277. prompt := "您的额度即将用尽"
  278. topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
  279. content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
  280. err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
  281. if err != nil {
  282. common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
  283. }
  284. }
  285. })
  286. }