cache.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. package model
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "math/rand"
  7. "one-api/common"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. TokenCacheSeconds = common.SyncFrequency
  16. UserId2GroupCacheSeconds = common.SyncFrequency
  17. UserId2QuotaCacheSeconds = common.SyncFrequency
  18. UserId2StatusCacheSeconds = common.SyncFrequency
  19. )
  20. // 仅用于定时同步缓存
  21. var token2UserId = make(map[string]int)
  22. var token2UserIdLock sync.RWMutex
  23. func cacheSetToken(token *Token) error {
  24. if !common.RedisEnabled {
  25. return token.SelectUpdate()
  26. }
  27. jsonBytes, err := json.Marshal(token)
  28. if err != nil {
  29. return err
  30. }
  31. err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
  32. if err != nil {
  33. common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
  34. return err
  35. }
  36. token2UserIdLock.Lock()
  37. defer token2UserIdLock.Unlock()
  38. token2UserId[token.Key] = token.UserId
  39. return nil
  40. }
  41. // CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
  42. func CacheGetTokenByKey(key string) (*Token, error) {
  43. if !common.RedisEnabled {
  44. return GetTokenByKey(key)
  45. }
  46. var token *Token
  47. tokenObjectString, err := common.RedisGetEx(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
  48. if err != nil {
  49. // 如果缓存中不存在,则从数据库中获取
  50. token, err = GetTokenByKey(key)
  51. if err != nil {
  52. return nil, err
  53. }
  54. err = cacheSetToken(token)
  55. return token, nil
  56. }
  57. err = json.Unmarshal([]byte(tokenObjectString), &token)
  58. return token, err
  59. }
  60. func SyncTokenCache(frequency int) {
  61. for {
  62. time.Sleep(time.Duration(frequency) * time.Second)
  63. common.SysLog("syncing tokens from database")
  64. token2UserIdLock.Lock()
  65. // 从token2UserId中获取所有的key
  66. var copyToken2UserId = make(map[string]int)
  67. for s, i := range token2UserId {
  68. copyToken2UserId[s] = i
  69. }
  70. token2UserId = make(map[string]int)
  71. token2UserIdLock.Unlock()
  72. for key := range copyToken2UserId {
  73. token, err := GetTokenByKey(key)
  74. if err != nil {
  75. // 如果数据库中不存在,则删除缓存
  76. common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
  77. //delete redis
  78. err := common.RedisDel(fmt.Sprintf("token:%s", key))
  79. if err != nil {
  80. common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
  81. }
  82. } else {
  83. // 如果数据库中存在,则更新缓存
  84. err := cacheSetToken(token)
  85. if err != nil {
  86. common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
  87. }
  88. }
  89. }
  90. }
  91. }
  92. func CacheGetUserGroup(id int) (group string, err error) {
  93. if !common.RedisEnabled {
  94. return GetUserGroup(id)
  95. }
  96. group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
  97. if err != nil {
  98. group, err = GetUserGroup(id)
  99. if err != nil {
  100. return "", err
  101. }
  102. err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  103. if err != nil {
  104. common.SysError("Redis set user group error: " + err.Error())
  105. }
  106. }
  107. return group, err
  108. }
  109. func CacheGetUsername(id int) (username string, err error) {
  110. if !common.RedisEnabled {
  111. return GetUsernameById(id)
  112. }
  113. username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
  114. if err != nil {
  115. username, err = GetUsernameById(id)
  116. if err != nil {
  117. return "", err
  118. }
  119. err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  120. if err != nil {
  121. common.SysError("Redis set user group error: " + err.Error())
  122. }
  123. }
  124. return username, err
  125. }
  126. func CacheGetUserQuota(id int) (quota int, err error) {
  127. if !common.RedisEnabled {
  128. return GetUserQuota(id)
  129. }
  130. quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
  131. if err != nil {
  132. quota, err = GetUserQuota(id)
  133. if err != nil {
  134. return 0, err
  135. }
  136. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  137. if err != nil {
  138. common.SysError("Redis set user quota error: " + err.Error())
  139. }
  140. return quota, err
  141. }
  142. quota, err = strconv.Atoi(quotaString)
  143. return quota, err
  144. }
  145. func CacheUpdateUserQuota(id int) error {
  146. if !common.RedisEnabled {
  147. return nil
  148. }
  149. quota, err := GetUserQuota(id)
  150. if err != nil {
  151. return err
  152. }
  153. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  154. return err
  155. }
  156. func CacheDecreaseUserQuota(id int, quota int) error {
  157. if !common.RedisEnabled {
  158. return nil
  159. }
  160. err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
  161. return err
  162. }
  163. func CacheIsUserEnabled(userId int) (bool, error) {
  164. if !common.RedisEnabled {
  165. return IsUserEnabled(userId)
  166. }
  167. enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
  168. if err == nil {
  169. return enabled == "1", nil
  170. }
  171. userEnabled, err := IsUserEnabled(userId)
  172. if err != nil {
  173. return false, err
  174. }
  175. enabled = "0"
  176. if userEnabled {
  177. enabled = "1"
  178. }
  179. err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
  180. if err != nil {
  181. common.SysError("Redis set user enabled error: " + err.Error())
  182. }
  183. return userEnabled, err
  184. }
  185. var group2model2channels map[string]map[string][]*Channel
  186. var channelsIDM map[int]*Channel
  187. var channelSyncLock sync.RWMutex
  188. func InitChannelCache() {
  189. newChannelId2channel := make(map[int]*Channel)
  190. var channels []*Channel
  191. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  192. for _, channel := range channels {
  193. newChannelId2channel[channel.Id] = channel
  194. }
  195. var abilities []*Ability
  196. DB.Find(&abilities)
  197. groups := make(map[string]bool)
  198. for _, ability := range abilities {
  199. groups[ability.Group] = true
  200. }
  201. newGroup2model2channels := make(map[string]map[string][]*Channel)
  202. newChannelsIDM := make(map[int]*Channel)
  203. for group := range groups {
  204. newGroup2model2channels[group] = make(map[string][]*Channel)
  205. }
  206. for _, channel := range channels {
  207. newChannelsIDM[channel.Id] = channel
  208. groups := strings.Split(channel.Group, ",")
  209. for _, group := range groups {
  210. models := strings.Split(channel.Models, ",")
  211. for _, model := range models {
  212. if _, ok := newGroup2model2channels[group][model]; !ok {
  213. newGroup2model2channels[group][model] = make([]*Channel, 0)
  214. }
  215. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  216. }
  217. }
  218. }
  219. // sort by priority
  220. for group, model2channels := range newGroup2model2channels {
  221. for model, channels := range model2channels {
  222. sort.Slice(channels, func(i, j int) bool {
  223. return channels[i].GetPriority() > channels[j].GetPriority()
  224. })
  225. newGroup2model2channels[group][model] = channels
  226. }
  227. }
  228. channelSyncLock.Lock()
  229. group2model2channels = newGroup2model2channels
  230. channelsIDM = newChannelsIDM
  231. channelSyncLock.Unlock()
  232. common.SysLog("channels synced from database")
  233. }
  234. func SyncChannelCache(frequency int) {
  235. for {
  236. time.Sleep(time.Duration(frequency) * time.Second)
  237. common.SysLog("syncing channels from database")
  238. InitChannelCache()
  239. }
  240. }
  241. func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
  242. if strings.HasPrefix(model, "gpt-4-gizmo") {
  243. model = "gpt-4-gizmo-*"
  244. }
  245. // if memory cache is disabled, get channel directly from database
  246. if !common.MemoryCacheEnabled {
  247. return GetRandomSatisfiedChannel(group, model)
  248. }
  249. channelSyncLock.RLock()
  250. defer channelSyncLock.RUnlock()
  251. channels := group2model2channels[group][model]
  252. if len(channels) == 0 {
  253. return nil, errors.New("channel not found")
  254. }
  255. endIdx := len(channels)
  256. // choose by priority
  257. firstChannel := channels[0]
  258. if firstChannel.GetPriority() > 0 {
  259. for i := range channels {
  260. if channels[i].GetPriority() != firstChannel.GetPriority() {
  261. endIdx = i
  262. break
  263. }
  264. }
  265. }
  266. // Calculate the total weight of all channels up to endIdx
  267. totalWeight := 0
  268. for _, channel := range channels[:endIdx] {
  269. totalWeight += channel.GetWeight()
  270. }
  271. if totalWeight == 0 {
  272. // If all weights are 0, select a channel randomly
  273. return channels[rand.Intn(endIdx)], nil
  274. }
  275. // Generate a random value in the range [0, totalWeight)
  276. randomWeight := rand.Intn(totalWeight)
  277. // Find a channel based on its weight
  278. for _, channel := range channels[:endIdx] {
  279. randomWeight -= channel.GetWeight()
  280. if randomWeight <= 0 {
  281. return channel, nil
  282. }
  283. }
  284. // return null if no channel is not found
  285. return nil, errors.New("channel not found")
  286. }
  287. func CacheGetChannel(id int) (*Channel, error) {
  288. if !common.MemoryCacheEnabled {
  289. return GetChannelById(id, true)
  290. }
  291. channelSyncLock.RLock()
  292. defer channelSyncLock.RUnlock()
  293. c, ok := channelsIDM[id]
  294. if !ok {
  295. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  296. }
  297. return c, nil
  298. }