channel-billing.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. package controller
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/model"
  11. "strconv"
  12. "time"
  13. )
  14. type OpenAISubscriptionResponse struct {
  15. HasPaymentMethod bool `json:"has_payment_method"`
  16. HardLimitUSD float64 `json:"hard_limit_usd"`
  17. }
  18. type OpenAIUsageResponse struct {
  19. TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
  20. }
  21. func updateChannelBalance(channel *model.Channel) (float64, error) {
  22. baseURL := common.ChannelBaseURLs[channel.Type]
  23. switch channel.Type {
  24. case common.ChannelTypeAzure:
  25. return 0, errors.New("尚未实现")
  26. }
  27. url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
  28. client := &http.Client{}
  29. req, err := http.NewRequest("GET", url, nil)
  30. if err != nil {
  31. return 0, err
  32. }
  33. auth := fmt.Sprintf("Bearer %s", channel.Key)
  34. req.Header.Add("Authorization", auth)
  35. res, err := client.Do(req)
  36. if err != nil {
  37. return 0, err
  38. }
  39. body, err := io.ReadAll(res.Body)
  40. if err != nil {
  41. return 0, err
  42. }
  43. err = res.Body.Close()
  44. if err != nil {
  45. return 0, err
  46. }
  47. subscription := OpenAISubscriptionResponse{}
  48. err = json.Unmarshal(body, &subscription)
  49. if err != nil {
  50. return 0, err
  51. }
  52. now := time.Now()
  53. startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
  54. //endDate := now.Format("2006-01-02")
  55. url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, "2023-06-01")
  56. req, err = http.NewRequest("GET", url, nil)
  57. if err != nil {
  58. return 0, err
  59. }
  60. req.Header.Add("Authorization", auth)
  61. res, err = client.Do(req)
  62. if err != nil {
  63. return 0, err
  64. }
  65. body, err = io.ReadAll(res.Body)
  66. if err != nil {
  67. return 0, err
  68. }
  69. err = res.Body.Close()
  70. if err != nil {
  71. return 0, err
  72. }
  73. usage := OpenAIUsageResponse{}
  74. err = json.Unmarshal(body, &usage)
  75. if err != nil {
  76. return 0, err
  77. }
  78. balance := subscription.HardLimitUSD - usage.TotalUsage/100
  79. channel.UpdateBalance(balance)
  80. return balance, nil
  81. }
  82. func UpdateChannelBalance(c *gin.Context) {
  83. id, err := strconv.Atoi(c.Param("id"))
  84. if err != nil {
  85. c.JSON(http.StatusOK, gin.H{
  86. "success": false,
  87. "message": err.Error(),
  88. })
  89. return
  90. }
  91. channel, err := model.GetChannelById(id, true)
  92. if err != nil {
  93. c.JSON(http.StatusOK, gin.H{
  94. "success": false,
  95. "message": err.Error(),
  96. })
  97. return
  98. }
  99. balance, err := updateChannelBalance(channel)
  100. if err != nil {
  101. c.JSON(http.StatusOK, gin.H{
  102. "success": false,
  103. "message": err.Error(),
  104. })
  105. return
  106. }
  107. c.JSON(http.StatusOK, gin.H{
  108. "success": true,
  109. "message": "",
  110. "balance": balance,
  111. })
  112. return
  113. }
  114. func updateAllChannelsBalance() error {
  115. channels, err := model.GetAllChannels(0, 0, true)
  116. if err != nil {
  117. return err
  118. }
  119. for _, channel := range channels {
  120. if channel.Status != common.ChannelStatusEnabled {
  121. continue
  122. }
  123. balance, err := updateChannelBalance(channel)
  124. if err != nil {
  125. continue
  126. } else {
  127. // err is nil & balance <= 0 means quota is used up
  128. if balance <= 0 {
  129. disableChannel(channel.Id, channel.Name, "余额不足")
  130. }
  131. }
  132. }
  133. return nil
  134. }
  135. func UpdateAllChannelsBalance(c *gin.Context) {
  136. // TODO: make it async
  137. err := updateAllChannelsBalance()
  138. if err != nil {
  139. c.JSON(http.StatusOK, gin.H{
  140. "success": false,
  141. "message": err.Error(),
  142. })
  143. return
  144. }
  145. c.JSON(http.StatusOK, gin.H{
  146. "success": true,
  147. "message": "",
  148. })
  149. return
  150. }