cache.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/rand"
  6. "one-api/common"
  7. "sort"
  8. "strings"
  9. "sync"
  10. "time"
  11. )
  12. //func CacheGetUserGroup(id int) (group string, err error) {
  13. // if !common.RedisEnabled {
  14. // return GetUserGroup(id)
  15. // }
  16. // group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
  17. // if err != nil {
  18. // group, err = GetUserGroup(id)
  19. // if err != nil {
  20. // return "", err
  21. // }
  22. // err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
  23. // if err != nil {
  24. // common.SysError("Redis set user group error: " + err.Error())
  25. // }
  26. // }
  27. // return group, err
  28. //}
  29. //
  30. //func CacheGetUsername(id int) (username string, err error) {
  31. // if !common.RedisEnabled {
  32. // return GetUsernameById(id)
  33. // }
  34. // username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
  35. // if err != nil {
  36. // username, err = GetUsernameById(id)
  37. // if err != nil {
  38. // return "", err
  39. // }
  40. // err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
  41. // if err != nil {
  42. // common.SysError("Redis set user group error: " + err.Error())
  43. // }
  44. // }
  45. // return username, err
  46. //}
  47. //
  48. //func CacheGetUserQuota(id int) (quota int, err error) {
  49. // if !common.RedisEnabled {
  50. // return GetUserQuota(id)
  51. // }
  52. // quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
  53. // if err != nil {
  54. // quota, err = GetUserQuota(id)
  55. // if err != nil {
  56. // return 0, err
  57. // }
  58. // return quota, nil
  59. // }
  60. // quota, err = strconv.Atoi(quotaString)
  61. // return quota, nil
  62. //}
  63. //
  64. //func CacheUpdateUserQuota(id int) error {
  65. // if !common.RedisEnabled {
  66. // return nil
  67. // }
  68. // quota, err := GetUserQuota(id)
  69. // if err != nil {
  70. // return err
  71. // }
  72. // return cacheSetUserQuota(id, quota)
  73. //}
  74. //
  75. //func cacheSetUserQuota(id int, quota int) error {
  76. // err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
  77. // return err
  78. //}
  79. //
  80. //func CacheDecreaseUserQuota(id int, quota int) error {
  81. // if !common.RedisEnabled {
  82. // return nil
  83. // }
  84. // err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
  85. // return err
  86. //}
  87. //
  88. //func CacheIsUserEnabled(userId int) (bool, error) {
  89. // if !common.RedisEnabled {
  90. // return IsUserEnabled(userId)
  91. // }
  92. // enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
  93. // if err == nil {
  94. // return enabled == "1", nil
  95. // }
  96. //
  97. // userEnabled, err := IsUserEnabled(userId)
  98. // if err != nil {
  99. // return false, err
  100. // }
  101. // enabled = "0"
  102. // if userEnabled {
  103. // enabled = "1"
  104. // }
  105. // err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
  106. // if err != nil {
  107. // common.SysError("Redis set user enabled error: " + err.Error())
  108. // }
  109. // return userEnabled, err
  110. //}
  111. var group2model2channels map[string]map[string][]*Channel
  112. var channelsIDM map[int]*Channel
  113. var channelSyncLock sync.RWMutex
  114. func InitChannelCache() {
  115. newChannelId2channel := make(map[int]*Channel)
  116. var channels []*Channel
  117. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  118. for _, channel := range channels {
  119. newChannelId2channel[channel.Id] = channel
  120. }
  121. var abilities []*Ability
  122. DB.Find(&abilities)
  123. groups := make(map[string]bool)
  124. for _, ability := range abilities {
  125. groups[ability.Group] = true
  126. }
  127. newGroup2model2channels := make(map[string]map[string][]*Channel)
  128. newChannelsIDM := make(map[int]*Channel)
  129. for group := range groups {
  130. newGroup2model2channels[group] = make(map[string][]*Channel)
  131. }
  132. for _, channel := range channels {
  133. newChannelsIDM[channel.Id] = channel
  134. groups := strings.Split(channel.Group, ",")
  135. for _, group := range groups {
  136. models := strings.Split(channel.Models, ",")
  137. for _, model := range models {
  138. if _, ok := newGroup2model2channels[group][model]; !ok {
  139. newGroup2model2channels[group][model] = make([]*Channel, 0)
  140. }
  141. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  142. }
  143. }
  144. }
  145. // sort by priority
  146. for group, model2channels := range newGroup2model2channels {
  147. for model, channels := range model2channels {
  148. sort.Slice(channels, func(i, j int) bool {
  149. return channels[i].GetPriority() > channels[j].GetPriority()
  150. })
  151. newGroup2model2channels[group][model] = channels
  152. }
  153. }
  154. channelSyncLock.Lock()
  155. group2model2channels = newGroup2model2channels
  156. channelsIDM = newChannelsIDM
  157. channelSyncLock.Unlock()
  158. common.SysLog("channels synced from database")
  159. }
  160. func SyncChannelCache(frequency int) {
  161. for {
  162. time.Sleep(time.Duration(frequency) * time.Second)
  163. common.SysLog("syncing channels from database")
  164. InitChannelCache()
  165. }
  166. }
  167. func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
  168. if strings.HasPrefix(model, "gpt-4-gizmo") {
  169. model = "gpt-4-gizmo-*"
  170. }
  171. if strings.HasPrefix(model, "gpt-4o-gizmo") {
  172. model = "gpt-4o-gizmo-*"
  173. }
  174. // if memory cache is disabled, get channel directly from database
  175. if (!common.MemoryCacheEnabled) {
  176. return GetRandomSatisfiedChannel(group, model, retry)
  177. }
  178. channelSyncLock.RLock()
  179. defer channelSyncLock.RUnlock()
  180. channels := group2model2channels[group][model]
  181. if len(channels) == 0 {
  182. return nil, errors.New("channel not found")
  183. }
  184. uniquePriorities := make(map[int]bool)
  185. for _, channel := range channels {
  186. uniquePriorities[int(channel.GetPriority())] = true
  187. }
  188. var sortedUniquePriorities []int
  189. for priority := range uniquePriorities {
  190. sortedUniquePriorities = append(sortedUniquePriorities, priority)
  191. }
  192. sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
  193. if retry >= len(uniquePriorities) {
  194. retry = len(uniquePriorities) - 1
  195. }
  196. targetPriority := int64(sortedUniquePriorities[retry])
  197. // get the priority for the given retry number
  198. var targetChannels []*Channel
  199. for _, channel := range channels {
  200. if channel.GetPriority() == targetPriority {
  201. targetChannels = append(targetChannels, channel)
  202. }
  203. }
  204. // 平滑系数
  205. smoothingFactor := 10
  206. // Calculate the total weight of all channels up to endIdx
  207. totalWeight := 0
  208. for _, channel := range targetChannels {
  209. totalWeight += channel.GetWeight() + smoothingFactor
  210. }
  211. // Generate a random value in the range [0, totalWeight)
  212. randomWeight := rand.Intn(totalWeight)
  213. // Find a channel based on its weight
  214. for _, channel := range targetChannels {
  215. randomWeight -= channel.GetWeight() + smoothingFactor
  216. if randomWeight < 0 {
  217. return channel, nil
  218. }
  219. }
  220. // return null if no channel is not found
  221. return nil, errors.New("channel not found")
  222. }
  223. func CacheGetChannel(id int) (*Channel, error) {
  224. if !common.MemoryCacheEnabled {
  225. return GetChannelById(id, true)
  226. }
  227. channelSyncLock.RLock()
  228. defer channelSyncLock.RUnlock()
  229. c, ok := channelsIDM[id]
  230. if !ok {
  231. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  232. }
  233. return c, nil
  234. }
  235. func CacheUpdateChannelStatus(id int, status int) {
  236. if !common.MemoryCacheEnabled {
  237. return
  238. }
  239. channelSyncLock.Lock()
  240. defer channelSyncLock.Unlock()
  241. if channel, ok := channelsIDM[id]; ok {
  242. channel.Status = status
  243. }
  244. }