cache.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. package model
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "math/rand"
  7. "one-api/common"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. TokenCacheSeconds = common.SyncFrequency
  16. UserId2GroupCacheSeconds = common.SyncFrequency
  17. UserId2QuotaCacheSeconds = common.SyncFrequency
  18. UserId2StatusCacheSeconds = common.SyncFrequency
  19. )
  20. // 仅用于定时同步缓存
  21. var token2UserId = make(map[string]int)
  22. var token2UserIdLock sync.RWMutex
  23. func cacheSetToken(token *Token) error {
  24. if !common.RedisEnabled {
  25. return token.SelectUpdate()
  26. }
  27. jsonBytes, err := json.Marshal(token)
  28. if err != nil {
  29. return err
  30. }
  31. err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
  32. if err != nil {
  33. common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
  34. return err
  35. }
  36. token2UserIdLock.Lock()
  37. defer token2UserIdLock.Unlock()
  38. token2UserId[token.Key] = token.UserId
  39. return nil
  40. }
  41. // CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
  42. func CacheGetTokenByKey(key string) (*Token, error) {
  43. if !common.RedisEnabled {
  44. return GetTokenByKey(key)
  45. }
  46. var token *Token
  47. tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
  48. if err != nil {
  49. // 如果缓存中不存在,则从数据库中获取
  50. token, err = GetTokenByKey(key)
  51. if err != nil {
  52. return nil, err
  53. }
  54. err = cacheSetToken(token)
  55. return token, nil
  56. }
  57. // 如果缓存中存在,则续期时间
  58. err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
  59. err = json.Unmarshal([]byte(tokenObjectString), &token)
  60. return token, err
  61. }
  62. func SyncTokenCache(frequency int) {
  63. for {
  64. time.Sleep(time.Duration(frequency) * time.Second)
  65. common.SysLog("syncing tokens from database")
  66. token2UserIdLock.Lock()
  67. // 从token2UserId中获取所有的key
  68. var copyToken2UserId = make(map[string]int)
  69. for s, i := range token2UserId {
  70. copyToken2UserId[s] = i
  71. }
  72. token2UserId = make(map[string]int)
  73. token2UserIdLock.Unlock()
  74. for key := range copyToken2UserId {
  75. token, err := GetTokenByKey(key)
  76. if err != nil {
  77. // 如果数据库中不存在,则删除缓存
  78. common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
  79. //delete redis
  80. err := common.RedisDel(fmt.Sprintf("token:%s", key))
  81. if err != nil {
  82. common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
  83. }
  84. } else {
  85. // 如果数据库中存在,则更新缓存
  86. err := cacheSetToken(token)
  87. if err != nil {
  88. common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
  89. }
  90. }
  91. }
  92. }
  93. }
  94. func CacheGetUserGroup(id int) (group string, err error) {
  95. if !common.RedisEnabled {
  96. return GetUserGroup(id)
  97. }
  98. group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
  99. if err != nil {
  100. group, err = GetUserGroup(id)
  101. if err != nil {
  102. return "", err
  103. }
  104. err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  105. if err != nil {
  106. common.SysError("Redis set user group error: " + err.Error())
  107. }
  108. }
  109. return group, err
  110. }
  111. func CacheGetUsername(id int) (username string, err error) {
  112. if !common.RedisEnabled {
  113. return GetUsernameById(id)
  114. }
  115. username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
  116. if err != nil {
  117. username, err = GetUsernameById(id)
  118. if err != nil {
  119. return "", err
  120. }
  121. err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  122. if err != nil {
  123. common.SysError("Redis set user group error: " + err.Error())
  124. }
  125. }
  126. return username, err
  127. }
  128. func CacheGetUserQuota(id int) (quota int, err error) {
  129. if !common.RedisEnabled {
  130. return GetUserQuota(id)
  131. }
  132. quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
  133. if err != nil {
  134. quota, err = GetUserQuota(id)
  135. if err != nil {
  136. return 0, err
  137. }
  138. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  139. if err != nil {
  140. common.SysError("Redis set user quota error: " + err.Error())
  141. }
  142. return quota, err
  143. }
  144. quota, err = strconv.Atoi(quotaString)
  145. return quota, err
  146. }
  147. func CacheUpdateUserQuota(id int) error {
  148. if !common.RedisEnabled {
  149. return nil
  150. }
  151. quota, err := GetUserQuota(id)
  152. if err != nil {
  153. return err
  154. }
  155. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  156. return err
  157. }
  158. func CacheDecreaseUserQuota(id int, quota int) error {
  159. if !common.RedisEnabled {
  160. return nil
  161. }
  162. err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
  163. return err
  164. }
  165. func CacheIsUserEnabled(userId int) (bool, error) {
  166. if !common.RedisEnabled {
  167. return IsUserEnabled(userId)
  168. }
  169. enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
  170. if err == nil {
  171. return enabled == "1", nil
  172. }
  173. userEnabled, err := IsUserEnabled(userId)
  174. if err != nil {
  175. return false, err
  176. }
  177. enabled = "0"
  178. if userEnabled {
  179. enabled = "1"
  180. }
  181. err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
  182. if err != nil {
  183. common.SysError("Redis set user enabled error: " + err.Error())
  184. }
  185. return userEnabled, err
  186. }
  187. var group2model2channels map[string]map[string][]*Channel
  188. var channelsIDM map[int]*Channel
  189. var channelSyncLock sync.RWMutex
  190. func InitChannelCache() {
  191. newChannelId2channel := make(map[int]*Channel)
  192. var channels []*Channel
  193. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  194. for _, channel := range channels {
  195. newChannelId2channel[channel.Id] = channel
  196. }
  197. var abilities []*Ability
  198. DB.Find(&abilities)
  199. groups := make(map[string]bool)
  200. for _, ability := range abilities {
  201. groups[ability.Group] = true
  202. }
  203. newGroup2model2channels := make(map[string]map[string][]*Channel)
  204. newChannelsIDM := make(map[int]*Channel)
  205. for group := range groups {
  206. newGroup2model2channels[group] = make(map[string][]*Channel)
  207. }
  208. for _, channel := range channels {
  209. newChannelsIDM[channel.Id] = channel
  210. groups := strings.Split(channel.Group, ",")
  211. for _, group := range groups {
  212. models := strings.Split(channel.Models, ",")
  213. for _, model := range models {
  214. if _, ok := newGroup2model2channels[group][model]; !ok {
  215. newGroup2model2channels[group][model] = make([]*Channel, 0)
  216. }
  217. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  218. }
  219. }
  220. }
  221. // sort by priority
  222. for group, model2channels := range newGroup2model2channels {
  223. for model, channels := range model2channels {
  224. sort.Slice(channels, func(i, j int) bool {
  225. return channels[i].GetPriority() > channels[j].GetPriority()
  226. })
  227. newGroup2model2channels[group][model] = channels
  228. }
  229. }
  230. channelSyncLock.Lock()
  231. group2model2channels = newGroup2model2channels
  232. channelsIDM = newChannelsIDM
  233. channelSyncLock.Unlock()
  234. common.SysLog("channels synced from database")
  235. }
  236. func SyncChannelCache(frequency int) {
  237. for {
  238. time.Sleep(time.Duration(frequency) * time.Second)
  239. common.SysLog("syncing channels from database")
  240. InitChannelCache()
  241. }
  242. }
  243. func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
  244. if strings.HasPrefix(model, "gpt-4-gizmo") {
  245. model = "gpt-4-gizmo-*"
  246. }
  247. // if memory cache is disabled, get channel directly from database
  248. if !common.MemoryCacheEnabled {
  249. return GetRandomSatisfiedChannel(group, model)
  250. }
  251. channelSyncLock.RLock()
  252. defer channelSyncLock.RUnlock()
  253. channels := group2model2channels[group][model]
  254. if len(channels) == 0 {
  255. return nil, errors.New("channel not found")
  256. }
  257. endIdx := len(channels)
  258. // choose by priority
  259. firstChannel := channels[0]
  260. if firstChannel.GetPriority() > 0 {
  261. for i := range channels {
  262. if channels[i].GetPriority() != firstChannel.GetPriority() {
  263. endIdx = i
  264. break
  265. }
  266. }
  267. }
  268. // Calculate the total weight of all channels up to endIdx
  269. totalWeight := 0
  270. for _, channel := range channels[:endIdx] {
  271. totalWeight += channel.GetWeight()
  272. }
  273. if totalWeight == 0 {
  274. // If all weights are 0, select a channel randomly
  275. return channels[rand.Intn(endIdx)], nil
  276. }
  277. // Generate a random value in the range [0, totalWeight)
  278. randomWeight := rand.Intn(totalWeight)
  279. // Find a channel based on its weight
  280. for _, channel := range channels[:endIdx] {
  281. randomWeight -= channel.GetWeight()
  282. if randomWeight <= 0 {
  283. return channel, nil
  284. }
  285. }
  286. // return null if no channel is not found
  287. return nil, errors.New("channel not found")
  288. }
  289. func CacheGetChannel(id int) (*Channel, error) {
  290. if !common.MemoryCacheEnabled {
  291. return GetChannelById(id, true)
  292. }
  293. channelSyncLock.RLock()
  294. defer channelSyncLock.RUnlock()
  295. c, ok := channelsIDM[id]
  296. if !ok {
  297. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  298. }
  299. return c, nil
  300. }