quota.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "math"
  6. "one-api/common"
  7. "one-api/dto"
  8. "one-api/model"
  9. relaycommon "one-api/relay/common"
  10. "one-api/setting"
  11. "strings"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. )
  15. type TokenDetails struct {
  16. TextTokens int
  17. AudioTokens int
  18. }
  19. type QuotaInfo struct {
  20. InputDetails TokenDetails
  21. OutputDetails TokenDetails
  22. ModelName string
  23. UsePrice bool
  24. ModelPrice float64
  25. ModelRatio float64
  26. GroupRatio float64
  27. }
  28. func calculateAudioQuota(info QuotaInfo) int {
  29. if info.UsePrice {
  30. return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio)
  31. }
  32. completionRatio := common.GetCompletionRatio(info.ModelName)
  33. audioRatio := common.GetAudioRatio(info.ModelName)
  34. audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName)
  35. ratio := info.GroupRatio * info.ModelRatio
  36. quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio))
  37. quota += int(math.Round(float64(info.InputDetails.AudioTokens)*audioRatio)) +
  38. int(math.Round(float64(info.OutputDetails.AudioTokens)*audioRatio*audioCompletionRatio))
  39. quota = int(math.Round(float64(quota) * ratio))
  40. if ratio != 0 && quota <= 0 {
  41. quota = 1
  42. }
  43. return quota
  44. }
  45. func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
  46. if relayInfo.UsePrice {
  47. return nil
  48. }
  49. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  50. if err != nil {
  51. return err
  52. }
  53. token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
  54. if err != nil {
  55. return err
  56. }
  57. modelName := relayInfo.UpstreamModelName
  58. textInputTokens := usage.InputTokenDetails.TextTokens
  59. textOutTokens := usage.OutputTokenDetails.TextTokens
  60. audioInputTokens := usage.InputTokenDetails.AudioTokens
  61. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  62. groupRatio := setting.GetGroupRatio(relayInfo.Group)
  63. modelRatio := common.GetModelRatio(modelName)
  64. quotaInfo := QuotaInfo{
  65. InputDetails: TokenDetails{
  66. TextTokens: textInputTokens,
  67. AudioTokens: audioInputTokens,
  68. },
  69. OutputDetails: TokenDetails{
  70. TextTokens: textOutTokens,
  71. AudioTokens: audioOutTokens,
  72. },
  73. ModelName: modelName,
  74. UsePrice: relayInfo.UsePrice,
  75. ModelRatio: modelRatio,
  76. GroupRatio: groupRatio,
  77. }
  78. quota := calculateAudioQuota(quotaInfo)
  79. if userQuota < quota {
  80. return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
  81. }
  82. if !token.UnlimitedQuota && token.RemainQuota < quota {
  83. return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
  84. }
  85. err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
  86. if err != nil {
  87. return err
  88. }
  89. common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
  90. return nil
  91. }
  92. func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
  93. usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
  94. modelPrice float64, usePrice bool, extraContent string) {
  95. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  96. textInputTokens := usage.InputTokenDetails.TextTokens
  97. textOutTokens := usage.OutputTokenDetails.TextTokens
  98. audioInputTokens := usage.InputTokenDetails.AudioTokens
  99. audioOutTokens := usage.OutputTokenDetails.AudioTokens
  100. tokenName := ctx.GetString("token_name")
  101. completionRatio := common.GetCompletionRatio(modelName)
  102. audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
  103. audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
  104. quotaInfo := QuotaInfo{
  105. InputDetails: TokenDetails{
  106. TextTokens: textInputTokens,
  107. AudioTokens: audioInputTokens,
  108. },
  109. OutputDetails: TokenDetails{
  110. TextTokens: textOutTokens,
  111. AudioTokens: audioOutTokens,
  112. },
  113. ModelName: modelName,
  114. UsePrice: usePrice,
  115. ModelRatio: modelRatio,
  116. GroupRatio: groupRatio,
  117. }
  118. quota := calculateAudioQuota(quotaInfo)
  119. totalTokens := usage.TotalTokens
  120. var logContent string
  121. if !usePrice {
  122. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  123. } else {
  124. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  125. }
  126. // record all the consume log even if quota is 0
  127. if totalTokens == 0 {
  128. // in this case, must be some error happened
  129. // we cannot just return, because we may have to return the pre-consumed quota
  130. quota = 0
  131. logContent += fmt.Sprintf("(可能是上游超时)")
  132. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  133. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
  134. } else {
  135. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  136. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  137. }
  138. logModel := modelName
  139. if extraContent != "" {
  140. logContent += ", " + extraContent
  141. }
  142. other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  143. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
  144. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  145. }
  146. func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
  147. usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
  148. modelPrice float64, usePrice bool, extraContent string) {
  149. useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
  150. textInputTokens := usage.PromptTokensDetails.TextTokens
  151. textOutTokens := usage.CompletionTokenDetails.TextTokens
  152. audioInputTokens := usage.PromptTokensDetails.AudioTokens
  153. audioOutTokens := usage.CompletionTokenDetails.AudioTokens
  154. tokenName := ctx.GetString("token_name")
  155. completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName)
  156. audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName)
  157. audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName)
  158. quotaInfo := QuotaInfo{
  159. InputDetails: TokenDetails{
  160. TextTokens: textInputTokens,
  161. AudioTokens: audioInputTokens,
  162. },
  163. OutputDetails: TokenDetails{
  164. TextTokens: textOutTokens,
  165. AudioTokens: audioOutTokens,
  166. },
  167. ModelName: relayInfo.RecodeModelName,
  168. UsePrice: usePrice,
  169. ModelRatio: modelRatio,
  170. GroupRatio: groupRatio,
  171. }
  172. quota := calculateAudioQuota(quotaInfo)
  173. totalTokens := usage.TotalTokens
  174. var logContent string
  175. if !usePrice {
  176. logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
  177. } else {
  178. logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
  179. }
  180. // record all the consume log even if quota is 0
  181. if totalTokens == 0 {
  182. // in this case, must be some error happened
  183. // we cannot just return, because we may have to return the pre-consumed quota
  184. quota = 0
  185. logContent += fmt.Sprintf("(可能是上游超时)")
  186. common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
  187. "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota))
  188. } else {
  189. quotaDelta := quota - preConsumedQuota
  190. if quotaDelta != 0 {
  191. err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
  192. if err != nil {
  193. common.LogError(ctx, "error consuming token remain quota: "+err.Error())
  194. }
  195. }
  196. model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
  197. model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
  198. }
  199. logModel := relayInfo.RecodeModelName
  200. if extraContent != "" {
  201. logContent += ", " + extraContent
  202. }
  203. other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
  204. model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
  205. tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
  206. }