billing_session.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. package service
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strings"
  6. "sync"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/QuantumNous/new-api/logger"
  9. "github.com/QuantumNous/new-api/model"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/types"
  12. "github.com/bytedance/gopkg/util/gopool"
  13. "github.com/gin-gonic/gin"
  14. )
  15. // ---------------------------------------------------------------------------
  16. // BillingSession — 统一计费会话
  17. // ---------------------------------------------------------------------------
  18. // BillingSession 封装单次请求的预扣费/结算/退款生命周期。
  19. // 实现 relaycommon.BillingSettler 接口。
  20. type BillingSession struct {
  21. relayInfo *relaycommon.RelayInfo
  22. funding FundingSource
  23. preConsumedQuota int // 实际预扣额度(信任用户可能为 0)
  24. tokenConsumed int // 令牌额度实际扣减量
  25. settled bool // Settle 已调用
  26. refunded bool // Refund 已调用
  27. mu sync.Mutex
  28. }
  29. // Settle 根据实际消耗额度进行结算。
  30. func (s *BillingSession) Settle(actualQuota int) error {
  31. s.mu.Lock()
  32. defer s.mu.Unlock()
  33. if s.settled {
  34. return nil
  35. }
  36. delta := actualQuota - s.preConsumedQuota
  37. if delta == 0 {
  38. s.settled = true
  39. return nil
  40. }
  41. // 1) 调整资金来源
  42. if err := s.funding.Settle(delta); err != nil {
  43. return err
  44. }
  45. // 2) 调整令牌额度
  46. if !s.relayInfo.IsPlayground {
  47. if delta > 0 {
  48. if err := model.DecreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, delta); err != nil {
  49. return err
  50. }
  51. } else {
  52. if err := model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, -delta); err != nil {
  53. return err
  54. }
  55. }
  56. }
  57. // 3) 更新 relayInfo 上的订阅 PostDelta(用于日志)
  58. if s.funding.Source() == BillingSourceSubscription {
  59. s.relayInfo.SubscriptionPostDelta += int64(delta)
  60. }
  61. s.settled = true
  62. return nil
  63. }
  64. // Refund 退还所有预扣费,幂等安全,异步执行。
  65. func (s *BillingSession) Refund(c *gin.Context) {
  66. s.mu.Lock()
  67. if s.settled || s.refunded || !s.needsRefundLocked() {
  68. s.mu.Unlock()
  69. return
  70. }
  71. s.refunded = true
  72. s.mu.Unlock()
  73. logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, funding=%s)",
  74. s.relayInfo.UserId,
  75. logger.FormatQuota(s.tokenConsumed),
  76. s.funding.Source(),
  77. ))
  78. // 复制需要的值到闭包中
  79. tokenId := s.relayInfo.TokenId
  80. tokenKey := s.relayInfo.TokenKey
  81. isPlayground := s.relayInfo.IsPlayground
  82. tokenConsumed := s.tokenConsumed
  83. funding := s.funding
  84. gopool.Go(func() {
  85. // 1) 退还资金来源
  86. if err := funding.Refund(); err != nil {
  87. common.SysLog("error refunding billing source: " + err.Error())
  88. }
  89. // 2) 退还令牌额度
  90. if tokenConsumed > 0 && !isPlayground {
  91. if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil {
  92. common.SysLog("error refunding token quota: " + err.Error())
  93. }
  94. }
  95. })
  96. }
  97. // NeedsRefund 返回是否存在需要退还的预扣状态。
  98. func (s *BillingSession) NeedsRefund() bool {
  99. s.mu.Lock()
  100. defer s.mu.Unlock()
  101. return s.needsRefundLocked()
  102. }
  103. func (s *BillingSession) needsRefundLocked() bool {
  104. if s.settled || s.refunded {
  105. return false
  106. }
  107. if s.tokenConsumed > 0 {
  108. return true
  109. }
  110. // 订阅可能在 tokenConsumed=0 时仍预扣了额度
  111. if sub, ok := s.funding.(*SubscriptionFunding); ok && sub.preConsumed > 0 {
  112. return true
  113. }
  114. return false
  115. }
  116. // GetPreConsumedQuota 返回实际预扣的额度。
  117. func (s *BillingSession) GetPreConsumedQuota() int {
  118. return s.preConsumedQuota
  119. }
  120. // ---------------------------------------------------------------------------
  121. // PreConsume — 统一预扣费入口(含信任额度旁路)
  122. // ---------------------------------------------------------------------------
  123. // preConsume 执行预扣费:信任检查 -> 令牌预扣 -> 资金来源预扣。
  124. // 任一步骤失败时原子回滚已完成的步骤。
  125. func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIError {
  126. effectiveQuota := quota
  127. // ---- 信任额度旁路 ----
  128. if s.shouldTrust(c) {
  129. effectiveQuota = 0
  130. logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source()))
  131. } else if effectiveQuota > 0 {
  132. logger.LogInfo(c, fmt.Sprintf("用户 %d 需要预扣费 %s (funding=%s)", s.relayInfo.UserId, logger.FormatQuota(effectiveQuota), s.funding.Source()))
  133. }
  134. // ---- 1) 预扣令牌额度 ----
  135. if effectiveQuota > 0 {
  136. if err := PreConsumeTokenQuota(s.relayInfo, effectiveQuota); err != nil {
  137. return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  138. }
  139. s.tokenConsumed = effectiveQuota
  140. }
  141. // ---- 2) 预扣资金来源 ----
  142. if err := s.funding.PreConsume(effectiveQuota); err != nil {
  143. // 回滚令牌额度
  144. if s.tokenConsumed > 0 && !s.relayInfo.IsPlayground {
  145. _ = model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, s.tokenConsumed)
  146. s.tokenConsumed = 0
  147. }
  148. errMsg := err.Error()
  149. if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") {
  150. return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  151. }
  152. if strings.Contains(errMsg, "用户额度不足") || strings.Contains(errMsg, "预扣费额度失败") {
  153. return types.NewErrorWithStatusCode(err, types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  154. }
  155. return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
  156. }
  157. s.preConsumedQuota = effectiveQuota
  158. // ---- 同步 RelayInfo 兼容字段 ----
  159. s.syncRelayInfo()
  160. return nil
  161. }
  162. // shouldTrust 统一信任额度检查,适用于钱包和订阅。
  163. func (s *BillingSession) shouldTrust(c *gin.Context) bool {
  164. trustQuota := common.GetTrustQuota()
  165. if trustQuota <= 0 {
  166. return false
  167. }
  168. // 检查令牌是否充足
  169. tokenTrusted := s.relayInfo.TokenUnlimited
  170. if !tokenTrusted {
  171. tokenQuota := c.GetInt("token_quota")
  172. tokenTrusted = tokenQuota > trustQuota
  173. }
  174. if !tokenTrusted {
  175. return false
  176. }
  177. switch s.funding.Source() {
  178. case BillingSourceWallet:
  179. return s.relayInfo.UserQuota > trustQuota
  180. case BillingSourceSubscription:
  181. // 订阅暂不支持信任旁路(订阅剩余额度需要额外查询,且预扣开销小)
  182. // 后续可以在此处添加订阅信任逻辑
  183. return false
  184. default:
  185. return false
  186. }
  187. }
  188. // syncRelayInfo 将 BillingSession 的状态同步到 RelayInfo 的兼容字段上。
  189. func (s *BillingSession) syncRelayInfo() {
  190. info := s.relayInfo
  191. info.FinalPreConsumedQuota = s.preConsumedQuota
  192. info.BillingSource = s.funding.Source()
  193. if sub, ok := s.funding.(*SubscriptionFunding); ok {
  194. info.SubscriptionId = sub.subscriptionId
  195. info.SubscriptionPreConsumed = sub.preConsumed
  196. info.SubscriptionPostDelta = 0
  197. info.SubscriptionAmountTotal = sub.AmountTotal
  198. info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter
  199. info.SubscriptionPlanId = sub.PlanId
  200. info.SubscriptionPlanTitle = sub.PlanTitle
  201. } else {
  202. info.SubscriptionId = 0
  203. info.SubscriptionPreConsumed = 0
  204. }
  205. }
  206. // ---------------------------------------------------------------------------
  207. // NewBillingSession 工厂 — 根据计费偏好创建会话并处理回退
  208. // ---------------------------------------------------------------------------
  209. // NewBillingSession 根据用户计费偏好创建 BillingSession,处理 subscription_first / wallet_first 的回退。
  210. func NewBillingSession(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) (*BillingSession, *types.NewAPIError) {
  211. if relayInfo == nil {
  212. return nil, types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
  213. }
  214. pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference)
  215. // 钱包路径需要先检查用户额度
  216. tryWallet := func() (*BillingSession, *types.NewAPIError) {
  217. userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
  218. if err != nil {
  219. return nil, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
  220. }
  221. if userQuota <= 0 {
  222. return nil, types.NewErrorWithStatusCode(
  223. fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)),
  224. types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
  225. types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  226. }
  227. if userQuota-preConsumedQuota < 0 {
  228. return nil, types.NewErrorWithStatusCode(
  229. fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)),
  230. types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
  231. types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
  232. }
  233. relayInfo.UserQuota = userQuota
  234. session := &BillingSession{
  235. relayInfo: relayInfo,
  236. funding: &WalletFunding{userId: relayInfo.UserId},
  237. }
  238. if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil {
  239. return nil, apiErr
  240. }
  241. return session, nil
  242. }
  243. trySubscription := func() (*BillingSession, *types.NewAPIError) {
  244. subConsume := int64(preConsumedQuota)
  245. if subConsume <= 0 {
  246. subConsume = 1
  247. }
  248. session := &BillingSession{
  249. relayInfo: relayInfo,
  250. funding: &SubscriptionFunding{
  251. requestId: relayInfo.RequestId,
  252. userId: relayInfo.UserId,
  253. modelName: relayInfo.OriginModelName,
  254. amount: subConsume,
  255. },
  256. }
  257. if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil {
  258. return nil, apiErr
  259. }
  260. return session, nil
  261. }
  262. switch pref {
  263. case "subscription_only":
  264. return trySubscription()
  265. case "wallet_only":
  266. return tryWallet()
  267. case "wallet_first":
  268. session, err := tryWallet()
  269. if err != nil {
  270. if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
  271. return trySubscription()
  272. }
  273. return nil, err
  274. }
  275. return session, nil
  276. case "subscription_first":
  277. fallthrough
  278. default:
  279. session, err := trySubscription()
  280. if err != nil {
  281. if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
  282. return tryWallet()
  283. }
  284. return nil, err
  285. }
  286. return session, nil
  287. }
  288. }