cache.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "math/rand"
  7. "one-api/common"
  8. "one-api/setting"
  9. "sort"
  10. "strings"
  11. "sync"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. )
  15. var group2model2channels map[string]map[string][]*Channel
  16. var channelsIDM map[int]*Channel
  17. var channelSyncLock sync.RWMutex
  18. func InitChannelCache() {
  19. if !common.MemoryCacheEnabled {
  20. return
  21. }
  22. newChannelId2channel := make(map[int]*Channel)
  23. var channels []*Channel
  24. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  25. for _, channel := range channels {
  26. newChannelId2channel[channel.Id] = channel
  27. }
  28. var abilities []*Ability
  29. DB.Find(&abilities)
  30. groups := make(map[string]bool)
  31. for _, ability := range abilities {
  32. groups[ability.Group] = true
  33. }
  34. newGroup2model2channels := make(map[string]map[string][]*Channel)
  35. newChannelsIDM := make(map[int]*Channel)
  36. for group := range groups {
  37. newGroup2model2channels[group] = make(map[string][]*Channel)
  38. }
  39. for _, channel := range channels {
  40. newChannelsIDM[channel.Id] = channel
  41. groups := strings.Split(channel.Group, ",")
  42. for _, group := range groups {
  43. models := strings.Split(channel.Models, ",")
  44. for _, model := range models {
  45. if _, ok := newGroup2model2channels[group][model]; !ok {
  46. newGroup2model2channels[group][model] = make([]*Channel, 0)
  47. }
  48. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  49. }
  50. }
  51. }
  52. // sort by priority
  53. for group, model2channels := range newGroup2model2channels {
  54. for model, channels := range model2channels {
  55. sort.Slice(channels, func(i, j int) bool {
  56. return channels[i].GetPriority() > channels[j].GetPriority()
  57. })
  58. newGroup2model2channels[group][model] = channels
  59. }
  60. }
  61. channelSyncLock.Lock()
  62. group2model2channels = newGroup2model2channels
  63. channelsIDM = newChannelsIDM
  64. channelSyncLock.Unlock()
  65. common.SysLog("channels synced from database")
  66. }
  67. func SyncChannelCache(frequency int) {
  68. for {
  69. time.Sleep(time.Duration(frequency) * time.Second)
  70. common.SysLog("syncing channels from database")
  71. InitChannelCache()
  72. }
  73. }
  74. func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
  75. var channel *Channel
  76. var err error
  77. selectGroup := group
  78. if group == "auto" {
  79. if len(setting.AutoGroups) == 0 {
  80. return nil, selectGroup, errors.New("auto groups is not enabled")
  81. }
  82. for _, autoGroup := range setting.AutoGroups {
  83. log.Printf("autoGroup: %s", autoGroup)
  84. channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
  85. if channel == nil {
  86. continue
  87. } else {
  88. c.Set("auto_group", autoGroup)
  89. selectGroup = autoGroup
  90. log.Printf("selectGroup: %s", selectGroup)
  91. break
  92. }
  93. }
  94. } else {
  95. channel, err = getRandomSatisfiedChannel(group, model, retry)
  96. if err != nil {
  97. return nil, group, err
  98. }
  99. }
  100. if channel == nil {
  101. return nil, group, errors.New("channel not found")
  102. }
  103. return channel, selectGroup, nil
  104. }
  105. func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
  106. if strings.HasPrefix(model, "gpt-4-gizmo") {
  107. model = "gpt-4-gizmo-*"
  108. }
  109. if strings.HasPrefix(model, "gpt-4o-gizmo") {
  110. model = "gpt-4o-gizmo-*"
  111. }
  112. // if memory cache is disabled, get channel directly from database
  113. if !common.MemoryCacheEnabled {
  114. return GetRandomSatisfiedChannel(group, model, retry)
  115. }
  116. channelSyncLock.RLock()
  117. channels := group2model2channels[group][model]
  118. channelSyncLock.RUnlock()
  119. if len(channels) == 0 {
  120. return nil, errors.New("channel not found")
  121. }
  122. uniquePriorities := make(map[int]bool)
  123. for _, channel := range channels {
  124. uniquePriorities[int(channel.GetPriority())] = true
  125. }
  126. var sortedUniquePriorities []int
  127. for priority := range uniquePriorities {
  128. sortedUniquePriorities = append(sortedUniquePriorities, priority)
  129. }
  130. sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
  131. if retry >= len(uniquePriorities) {
  132. retry = len(uniquePriorities) - 1
  133. }
  134. targetPriority := int64(sortedUniquePriorities[retry])
  135. // get the priority for the given retry number
  136. var targetChannels []*Channel
  137. for _, channel := range channels {
  138. if channel.GetPriority() == targetPriority {
  139. targetChannels = append(targetChannels, channel)
  140. }
  141. }
  142. // 平滑系数
  143. smoothingFactor := 10
  144. // Calculate the total weight of all channels up to endIdx
  145. totalWeight := 0
  146. for _, channel := range targetChannels {
  147. totalWeight += channel.GetWeight() + smoothingFactor
  148. }
  149. // Generate a random value in the range [0, totalWeight)
  150. randomWeight := rand.Intn(totalWeight)
  151. // Find a channel based on its weight
  152. for _, channel := range targetChannels {
  153. randomWeight -= channel.GetWeight() + smoothingFactor
  154. if randomWeight < 0 {
  155. return channel, nil
  156. }
  157. }
  158. // return null if no channel is not found
  159. return nil, errors.New("channel not found")
  160. }
  161. func CacheGetChannel(id int) (*Channel, error) {
  162. if !common.MemoryCacheEnabled {
  163. return GetChannelById(id, true)
  164. }
  165. channelSyncLock.RLock()
  166. defer channelSyncLock.RUnlock()
  167. c, ok := channelsIDM[id]
  168. if !ok {
  169. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  170. }
  171. return c, nil
  172. }
  173. func CacheUpdateChannelStatus(id int, status int) {
  174. if !common.MemoryCacheEnabled {
  175. return
  176. }
  177. channelSyncLock.Lock()
  178. defer channelSyncLock.Unlock()
  179. if channel, ok := channelsIDM[id]; ok {
  180. channel.Status = status
  181. }
  182. }