claude_token_refresh.go 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package service
  2. import (
  3. "fmt"
  4. "one-api/common"
  5. "one-api/constant"
  6. "one-api/model"
  7. "strings"
  8. "time"
  9. "github.com/bytedance/gopkg/util/gopool"
  10. )
  11. // StartClaudeTokenRefreshScheduler starts the scheduled token refresh for Claude Code channels
  12. func StartClaudeTokenRefreshScheduler() {
  13. ticker := time.NewTicker(5 * time.Minute)
  14. gopool.Go(func() {
  15. defer ticker.Stop()
  16. for range ticker.C {
  17. RefreshClaudeCodeTokens()
  18. }
  19. })
  20. common.SysLog("Claude Code token refresh scheduler started (5 minute interval)")
  21. }
  22. // RefreshClaudeCodeTokens refreshes tokens for all active Claude Code channels
  23. func RefreshClaudeCodeTokens() {
  24. var channels []model.Channel
  25. // Get all active Claude Code channels
  26. err := model.DB.Where("type = ? AND status = ?", constant.ChannelTypeClaudeCode, common.ChannelStatusEnabled).Find(&channels).Error
  27. if err != nil {
  28. common.SysError("Failed to get Claude Code channels: " + err.Error())
  29. return
  30. }
  31. refreshCount := 0
  32. for _, channel := range channels {
  33. if refreshTokenForChannel(&channel) {
  34. refreshCount++
  35. }
  36. }
  37. if refreshCount > 0 {
  38. common.SysLog(fmt.Sprintf("Successfully refreshed %d Claude Code channel tokens", refreshCount))
  39. }
  40. }
  41. // refreshTokenForChannel attempts to refresh token for a single channel
  42. func refreshTokenForChannel(channel *model.Channel) bool {
  43. // Parse key in format: accesstoken|refreshtoken
  44. if channel.Key == "" || !strings.Contains(channel.Key, "|") {
  45. common.SysError(fmt.Sprintf("Channel %d has invalid key format, expected accesstoken|refreshtoken", channel.Id))
  46. return false
  47. }
  48. parts := strings.Split(channel.Key, "|")
  49. if len(parts) < 2 {
  50. common.SysError(fmt.Sprintf("Channel %d has invalid key format, expected accesstoken|refreshtoken", channel.Id))
  51. return false
  52. }
  53. accessToken := parts[0]
  54. refreshToken := parts[1]
  55. if refreshToken == "" {
  56. common.SysError(fmt.Sprintf("Channel %d has empty refresh token", channel.Id))
  57. return false
  58. }
  59. // Check if token needs refresh (refresh 30 minutes before expiry)
  60. // if !shouldRefreshToken(accessToken) {
  61. // return false
  62. // }
  63. // Use shared refresh function
  64. newToken, err := RefreshClaudeToken(accessToken, refreshToken)
  65. if err != nil {
  66. common.SysError(fmt.Sprintf("Failed to refresh token for channel %d: %s", channel.Id, err.Error()))
  67. return false
  68. }
  69. // Update channel with new tokens
  70. newKey := fmt.Sprintf("%s|%s", newToken.AccessToken, newToken.RefreshToken)
  71. err = model.DB.Model(channel).Update("key", newKey).Error
  72. if err != nil {
  73. common.SysError(fmt.Sprintf("Failed to update channel %d with new token: %s", channel.Id, err.Error()))
  74. return false
  75. }
  76. common.SysLog(fmt.Sprintf("Successfully refreshed token for Claude Code channel %d (%s)", channel.Id, channel.Name))
  77. return true
  78. }