codex_usage.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package controller
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/constant"
  11. "github.com/QuantumNous/new-api/model"
  12. "github.com/QuantumNous/new-api/relay/channel/codex"
  13. "github.com/QuantumNous/new-api/service"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func GetCodexChannelUsage(c *gin.Context) {
  17. channelId, err := strconv.Atoi(c.Param("id"))
  18. if err != nil {
  19. common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
  20. return
  21. }
  22. ch, err := model.GetChannelById(channelId, true)
  23. if err != nil {
  24. common.ApiError(c, err)
  25. return
  26. }
  27. if ch == nil {
  28. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
  29. return
  30. }
  31. if ch.Type != constant.ChannelTypeCodex {
  32. c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
  33. return
  34. }
  35. if ch.ChannelInfo.IsMultiKey {
  36. c.JSON(http.StatusOK, gin.H{"success": false, "message": "multi-key channel is not supported"})
  37. return
  38. }
  39. oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
  40. if err != nil {
  41. common.SysError("failed to parse oauth key: " + err.Error())
  42. c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析凭证失败,请检查渠道配置"})
  43. return
  44. }
  45. accessToken := strings.TrimSpace(oauthKey.AccessToken)
  46. accountID := strings.TrimSpace(oauthKey.AccountID)
  47. if accessToken == "" {
  48. c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: access_token is required"})
  49. return
  50. }
  51. if accountID == "" {
  52. c.JSON(http.StatusOK, gin.H{"success": false, "message": "codex channel: account_id is required"})
  53. return
  54. }
  55. client, err := service.NewProxyHttpClient(ch.GetSetting().Proxy)
  56. if err != nil {
  57. common.ApiError(c, err)
  58. return
  59. }
  60. ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
  61. defer cancel()
  62. statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
  63. if err != nil {
  64. common.SysError("failed to fetch codex usage: " + err.Error())
  65. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"})
  66. return
  67. }
  68. if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) && strings.TrimSpace(oauthKey.RefreshToken) != "" {
  69. refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
  70. defer refreshCancel()
  71. res, refreshErr := service.RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
  72. if refreshErr == nil {
  73. oauthKey.AccessToken = res.AccessToken
  74. oauthKey.RefreshToken = res.RefreshToken
  75. oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
  76. oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
  77. if strings.TrimSpace(oauthKey.Type) == "" {
  78. oauthKey.Type = "codex"
  79. }
  80. encoded, encErr := common.Marshal(oauthKey)
  81. if encErr == nil {
  82. _ = model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error
  83. model.InitChannelCache()
  84. service.ResetProxyClientCache()
  85. }
  86. ctx2, cancel2 := context.WithTimeout(c.Request.Context(), 15*time.Second)
  87. defer cancel2()
  88. statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
  89. if err != nil {
  90. common.SysError("failed to fetch codex usage after refresh: " + err.Error())
  91. c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"})
  92. return
  93. }
  94. }
  95. }
  96. var payload any
  97. if common.Unmarshal(body, &payload) != nil {
  98. payload = string(body)
  99. }
  100. ok := statusCode >= 200 && statusCode < 300
  101. resp := gin.H{
  102. "success": ok,
  103. "message": "",
  104. "upstream_status": statusCode,
  105. "data": payload,
  106. }
  107. if !ok {
  108. resp["message"] = fmt.Sprintf("upstream status: %d", statusCode)
  109. }
  110. c.JSON(http.StatusOK, resp)
  111. }