quota.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/bytedance/gopkg/util/gopool"
  6. "one-api/common"
  7. constant2 "one-api/constant"
  8. "one-api/dto"
  9. "one-api/model"
  10. relaycommon "one-api/relay/common"
  11. "one-api/relay/helper"
  12. "one-api/setting"
  13. "one-api/setting/operation_setting"
  14. "strings"
  15. "time"
  16. "github.com/gin-gonic/gin"
  17. )
  18. type TokenDetails struct {
  19. TextTokens int
  20. AudioTokens int
  21. }
  22. type QuotaInfo struct {
  23. InputDetails TokenDetails
  24. OutputDetails TokenDetails
  25. ModelName string
  26. UsePrice bool
  27. ModelPrice float64
  28. ModelRatio float64
  29. GroupRatio float64
  30. }
  31. func calculateAudioQuota(info QuotaInfo) int {
  32. if info.UsePrice {
  33. return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
  34. }
  35. completionRatio := operation_setting.GetCompletionRatio(info.ModelName)
  36. audioRatio := operation_setting.GetAudioRatio(info.ModelName)
  37. audioCompletionRatio := operation_setting.GetAudioCompletionRatio(info.ModelName)
  38. ratio := info.GroupRatio * info.ModelRatio
  39. quota := 0.0
  40. quota += float64(info.InputDetails.TextTokens)
  41. quota += float64(info.OutputDetails.TextTokens) * completionRatio
  42. quota += float64(info.InputDetails.AudioTokens) * audioRatio
  43. quota += float64(info.OutputDetails.AudioTokens) * audioRatio * audioCompletionRatio
  44. quota = quota * ratio
  45. if ratio != 0 && quota <= 0 {
  46. quota = 1
  47. }
  48. return int(quota)
  49. }
  50. func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
  51. if relayInfo.UsePrice {
  52. return nil
  53. }
  54. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  55. if err != nil {
  56. return err
  57. }
  58. token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
  59. if err != nil {
  60. return err
  61. }
  62. modelName := relayInfo.OriginModelName
  63. textInputTokens := usage.InputTokenDetails.TextTokens
  64. textOutTokens := usage.OutputTokenDetails.TextTokens
  65. audioInputTokens := usage.InputTokenDetails.AudioTokens
  66. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  67. groupRatio := setting.GetGroupRatio(relayInfo.Group)
  68. modelRatio, _ := operation_setting.GetModelRatio(modelName)
  69. quotaInfo := QuotaInfo{
  70. InputDetails: TokenDetails{
  71. TextTokens: textInputTokens,
  72. AudioTokens: audioInputTokens,
  73. },
  74. OutputDetails: TokenDetails{
  75. TextTokens: textOutTokens,
  76. AudioTokens: audioOutTokens,
  77. },
  78. ModelName: modelName,
  79. UsePrice: relayInfo.UsePrice,
  80. ModelRatio: modelRatio,
  81. GroupRatio: groupRatio,
  82. }
  83. quota := calculateAudioQuota(quotaInfo)
  84. if userQuota < quota {
  85. return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
  86. }
  87. if !token.UnlimitedQuota && token.RemainQuota < quota {
  88. return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
  89. }
  90. err = PostConsumeQuota(relayInfo, quota, 0, false)
  91. if err != nil {
  92. return err
  93. }
  94. common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
  95. return nil
  96. }
  97. func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
  98. usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
  99. modelPrice float64, usePrice bool, extraContent string) {
  100. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  101. textInputTokens := usage.InputTokenDetails.TextTokens
  102. textOutTokens := usage.OutputTokenDetails.TextTokens
  103. audioInputTokens := usage.InputTokenDetails.AudioTokens
  104. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  105. tokenName := ctx.GetString("token_name")
  106. completionRatio := operation_setting.GetCompletionRatio(modelName)
  107. audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
  108. audioCompletionRatio := operation_setting.GetAudioCompletionRatio(modelName)
  109. quotaInfo := QuotaInfo{
  110. InputDetails: TokenDetails{
  111. TextTokens: textInputTokens,
  112. AudioTokens: audioInputTokens,
  113. },
  114. OutputDetails: TokenDetails{
  115. TextTokens: textOutTokens,
  116. AudioTokens: audioOutTokens,
  117. },
  118. ModelName: modelName,
  119. UsePrice: usePrice,
  120. ModelRatio: modelRatio,
  121. GroupRatio: groupRatio,
  122. }
  123. quota := calculateAudioQuota(quotaInfo)
  124. totalTokens := usage.TotalTokens
  125. var logContent string
  126. if !usePrice {
  127. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  128. } else {
  129. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  130. }
  131. // record all the consume log even if quota is 0
  132. if totalTokens == 0 {
  133. // in this case, must be some error happened
  134. // we cannot just return, because we may have to return the pre-consumed quota
  135. quota = 0
  136. logContent += fmt.Sprintf("(可能是上游超时)")
  137. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  138. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
  139. } else {
  140. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  141. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  142. }
  143. logModel := modelName
  144. if extraContent != "" {
  145. logContent += ", " + extraContent
  146. }
  147. other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  148. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
  149. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  150. }
  151. func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
  152. usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
  153. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  154. textInputTokens := usage.PromptTokensDetails.TextTokens
  155. textOutTokens := usage.CompletionTokenDetails.TextTokens
  156. audioInputTokens := usage.PromptTokensDetails.AudioTokens
  157. audioOutTokens := usage.CompletionTokenDetails.AudioTokens
  158. tokenName := ctx.GetString("token_name")
  159. completionRatio := operation_setting.GetCompletionRatio(relayInfo.OriginModelName)
  160. audioRatio := operation_setting.GetAudioRatio(relayInfo.OriginModelName)
  161. audioCompletionRatio := operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)
  162. modelRatio := priceData.ModelRatio
  163. groupRatio := priceData.GroupRatio
  164. modelPrice := priceData.ModelPrice
  165. usePrice := priceData.UsePrice
  166. quotaInfo := QuotaInfo{
  167. InputDetails: TokenDetails{
  168. TextTokens: textInputTokens,
  169. AudioTokens: audioInputTokens,
  170. },
  171. OutputDetails: TokenDetails{
  172. TextTokens: textOutTokens,
  173. AudioTokens: audioOutTokens,
  174. },
  175. ModelName: relayInfo.OriginModelName,
  176. UsePrice: usePrice,
  177. ModelRatio: modelRatio,
  178. GroupRatio: groupRatio,
  179. }
  180. quota := calculateAudioQuota(quotaInfo)
  181. totalTokens := usage.TotalTokens
  182. var logContent string
  183. if !usePrice {
  184. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  185. } else {
  186. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  187. }
  188. // record all the consume log even if quota is 0
  189. if totalTokens == 0 {
  190. // in this case, must be some error happened
  191. // we cannot just return, because we may have to return the pre-consumed quota
  192. quota = 0
  193. logContent += fmt.Sprintf("(可能是上游超时)")
  194. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  195. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
  196. } else {
  197. quotaDelta := quota - preConsumedQuota
  198. if quotaDelta != 0 {
  199. err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
  200. if err != nil {
  201. common.LogError(ctx, "error consuming token remain quota: "+err.Error())
  202. }
  203. }
  204. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  205. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  206. }
  207. logModel := relayInfo.OriginModelName
  208. if extraContent != "" {
  209. logContent += ", " + extraContent
  210. }
  211. other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  212. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
  213. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  214. }
  215. func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
  216. if quota < 0 {
  217. return errors.New("quota 不能为负数!")
  218. }
  219. if relayInfo.IsPlayground {
  220. return nil
  221. }
  222. //if relayInfo.TokenUnlimited {
  223. // return nil
  224. //}
  225. token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
  226. if err != nil {
  227. return err
  228. }
  229. if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
  230. return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
  231. }
  232. err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
  233. if err != nil {
  234. return err
  235. }
  236. return nil
  237. }
  238. func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
  239. if quota > 0 {
  240. err = model.DecreaseUserQuota(relayInfo.UserId, quota)
  241. } else {
  242. err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
  243. }
  244. if err != nil {
  245. return err
  246. }
  247. if !relayInfo.IsPlayground {
  248. if quota > 0 {
  249. err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
  250. } else {
  251. err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
  252. }
  253. if err != nil {
  254. return err
  255. }
  256. }
  257. if sendEmail {
  258. if (quota + preConsumedQuota) != 0 {
  259. checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
  260. }
  261. }
  262. return nil
  263. }
  264. func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
  265. gopool.Go(func() {
  266. userSetting := relayInfo.UserSetting
  267. threshold := common.QuotaRemindThreshold
  268. if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
  269. threshold = int(userCustomThreshold.(float64))
  270. }
  271. //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
  272. quotaTooLow := false
  273. consumeQuota := quota + preConsumedQuota
  274. if relayInfo.UserQuota-consumeQuota < threshold {
  275. quotaTooLow = true
  276. }
  277. if quotaTooLow {
  278. prompt := "您的额度即将用尽"
  279. topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
  280. content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
  281. err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
  282. if err != nil {
  283. common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
  284. }
  285. }
  286. })
  287. }