codex_credential_refresh_task.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package service
  2. import (
  3. "context"
  4. "fmt"
  5. "strings"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/constant"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/QuantumNous/new-api/model"
  13. "github.com/bytedance/gopkg/util/gopool"
  14. )
  15. const (
  16. codexCredentialRefreshTickInterval = 10 * time.Minute
  17. codexCredentialRefreshThreshold = 24 * time.Hour
  18. codexCredentialRefreshBatchSize = 200
  19. codexCredentialRefreshTimeout = 15 * time.Second
  20. )
  21. var (
  22. codexCredentialRefreshOnce sync.Once
  23. codexCredentialRefreshRunning atomic.Bool
  24. )
  25. func shouldAutoRefreshCodexChannelStatus(status int) bool {
  26. return status == common.ChannelStatusEnabled || status == common.ChannelStatusAutoDisabled
  27. }
  28. func StartCodexCredentialAutoRefreshTask() {
  29. codexCredentialRefreshOnce.Do(func() {
  30. if !common.IsMasterNode {
  31. return
  32. }
  33. gopool.Go(func() {
  34. logger.LogInfo(context.Background(), fmt.Sprintf("codex credential auto-refresh task started: tick=%s threshold=%s", codexCredentialRefreshTickInterval, codexCredentialRefreshThreshold))
  35. ticker := time.NewTicker(codexCredentialRefreshTickInterval)
  36. defer ticker.Stop()
  37. runCodexCredentialAutoRefreshOnce()
  38. for range ticker.C {
  39. runCodexCredentialAutoRefreshOnce()
  40. }
  41. })
  42. })
  43. }
  44. func runCodexCredentialAutoRefreshOnce() {
  45. if !codexCredentialRefreshRunning.CompareAndSwap(false, true) {
  46. return
  47. }
  48. defer codexCredentialRefreshRunning.Store(false)
  49. ctx := context.Background()
  50. now := time.Now()
  51. var refreshed int
  52. var scanned int
  53. offset := 0
  54. for {
  55. var channels []*model.Channel
  56. err := model.DB.
  57. Select("id", "name", "key", "status", "channel_info").
  58. Where("type = ? AND (status = ? OR status = ?)",
  59. constant.ChannelTypeCodex,
  60. common.ChannelStatusEnabled,
  61. common.ChannelStatusAutoDisabled,
  62. ).
  63. Order("id asc").
  64. Limit(codexCredentialRefreshBatchSize).
  65. Offset(offset).
  66. Find(&channels).Error
  67. if err != nil {
  68. logger.LogError(ctx, fmt.Sprintf("codex credential auto-refresh: query channels failed: %v", err))
  69. return
  70. }
  71. if len(channels) == 0 {
  72. break
  73. }
  74. offset += codexCredentialRefreshBatchSize
  75. for _, ch := range channels {
  76. if ch == nil {
  77. continue
  78. }
  79. scanned++
  80. if ch.ChannelInfo.IsMultiKey {
  81. continue
  82. }
  83. rawKey := strings.TrimSpace(ch.Key)
  84. if rawKey == "" {
  85. continue
  86. }
  87. oauthKey, err := parseCodexOAuthKey(rawKey)
  88. if err != nil {
  89. continue
  90. }
  91. refreshToken := strings.TrimSpace(oauthKey.RefreshToken)
  92. if refreshToken == "" {
  93. continue
  94. }
  95. expiredAtRaw := strings.TrimSpace(oauthKey.Expired)
  96. expiredAt, err := time.Parse(time.RFC3339, expiredAtRaw)
  97. if err == nil && !expiredAt.IsZero() && expiredAt.Sub(now) > codexCredentialRefreshThreshold {
  98. continue
  99. }
  100. refreshCtx, cancel := context.WithTimeout(ctx, codexCredentialRefreshTimeout)
  101. newKey, _, err := RefreshCodexChannelCredential(refreshCtx, ch.Id, CodexCredentialRefreshOptions{ResetCaches: false})
  102. cancel()
  103. if err != nil {
  104. logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refresh failed: %v", ch.Id, ch.Name, err))
  105. continue
  106. }
  107. refreshed++
  108. logger.LogInfo(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refreshed, expires_at=%s", ch.Id, ch.Name, newKey.Expired))
  109. }
  110. }
  111. if refreshed > 0 {
  112. func() {
  113. defer func() {
  114. if r := recover(); r != nil {
  115. logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: InitChannelCache panic: %v", r))
  116. }
  117. }()
  118. model.InitChannelCache()
  119. }()
  120. ResetProxyClientCache()
  121. }
  122. if common.DebugEnabled {
  123. logger.LogDebug(ctx, "codex credential auto-refresh: scanned=%d refreshed=%d", scanned, refreshed)
  124. }
  125. }