cache.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. package model
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "math/rand"
  7. "one-api/common"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. TokenCacheSeconds = common.SyncFrequency
  16. UserId2GroupCacheSeconds = common.SyncFrequency
  17. UserId2QuotaCacheSeconds = common.SyncFrequency
  18. UserId2StatusCacheSeconds = common.SyncFrequency
  19. )
  20. // 仅用于定时同步缓存
  21. var token2UserId = make(map[string]int)
  22. var token2UserIdLock sync.RWMutex
  23. func cacheSetToken(token *Token) error {
  24. if !common.RedisEnabled {
  25. return token.SelectUpdate()
  26. }
  27. jsonBytes, err := json.Marshal(token)
  28. if err != nil {
  29. return err
  30. }
  31. err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
  32. if err != nil {
  33. common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
  34. return err
  35. }
  36. token2UserIdLock.Lock()
  37. defer token2UserIdLock.Unlock()
  38. token2UserId[token.Key] = token.UserId
  39. return nil
  40. }
  41. // CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
  42. func CacheGetTokenByKey(key string) (*Token, error) {
  43. if !common.RedisEnabled {
  44. return GetTokenByKey(key)
  45. }
  46. var token *Token
  47. tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
  48. if err != nil {
  49. // 如果缓存中不存在,则从数据库中获取
  50. token, err = GetTokenByKey(key)
  51. if err != nil {
  52. return nil, err
  53. }
  54. err = cacheSetToken(token)
  55. return token, nil
  56. }
  57. // 如果缓存中存在,则续期时间
  58. err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
  59. err = json.Unmarshal([]byte(tokenObjectString), &token)
  60. return token, err
  61. }
  62. func SyncTokenCache(frequency int) {
  63. for {
  64. time.Sleep(time.Duration(frequency) * time.Second)
  65. common.SysLog("syncing tokens from database")
  66. token2UserIdLock.Lock()
  67. // 从token2UserId中获取所有的key
  68. var copyToken2UserId = make(map[string]int)
  69. for s, i := range token2UserId {
  70. copyToken2UserId[s] = i
  71. }
  72. token2UserId = make(map[string]int)
  73. token2UserIdLock.Unlock()
  74. for key := range copyToken2UserId {
  75. token, err := GetTokenByKey(key)
  76. if err != nil {
  77. // 如果数据库中不存在,则删除缓存
  78. common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
  79. //delete redis
  80. err := common.RedisDel(fmt.Sprintf("token:%s", key))
  81. if err != nil {
  82. common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
  83. }
  84. } else {
  85. // 如果数据库中存在,先检查redis
  86. _, err := common.RedisGet(fmt.Sprintf("token:%s", key))
  87. if err != nil {
  88. // 如果redis中不存在,则跳过
  89. continue
  90. }
  91. err = cacheSetToken(token)
  92. if err != nil {
  93. common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
  94. }
  95. }
  96. }
  97. }
  98. }
  99. func CacheGetUserGroup(id int) (group string, err error) {
  100. if !common.RedisEnabled {
  101. return GetUserGroup(id)
  102. }
  103. group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
  104. if err != nil {
  105. group, err = GetUserGroup(id)
  106. if err != nil {
  107. return "", err
  108. }
  109. err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  110. if err != nil {
  111. common.SysError("Redis set user group error: " + err.Error())
  112. }
  113. }
  114. return group, err
  115. }
  116. func CacheGetUsername(id int) (username string, err error) {
  117. if !common.RedisEnabled {
  118. return GetUsernameById(id)
  119. }
  120. username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
  121. if err != nil {
  122. username, err = GetUsernameById(id)
  123. if err != nil {
  124. return "", err
  125. }
  126. err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  127. if err != nil {
  128. common.SysError("Redis set user group error: " + err.Error())
  129. }
  130. }
  131. return username, err
  132. }
  133. func CacheGetUserQuota(id int) (quota int, err error) {
  134. if !common.RedisEnabled {
  135. return GetUserQuota(id)
  136. }
  137. quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
  138. if err != nil {
  139. quota, err = GetUserQuota(id)
  140. if err != nil {
  141. return 0, err
  142. }
  143. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  144. if err != nil {
  145. common.SysError("Redis set user quota error: " + err.Error())
  146. }
  147. return quota, err
  148. }
  149. quota, err = strconv.Atoi(quotaString)
  150. return quota, err
  151. }
  152. func CacheUpdateUserQuota(id int) error {
  153. if !common.RedisEnabled {
  154. return nil
  155. }
  156. quota, err := GetUserQuota(id)
  157. if err != nil {
  158. return err
  159. }
  160. return CacheSetUserQuota(id, quota)
  161. }
  162. func CacheSetUserQuota(id int, quota int) error {
  163. err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  164. return err
  165. }
  166. func CacheDecreaseUserQuota(id int, quota int) error {
  167. if !common.RedisEnabled {
  168. return nil
  169. }
  170. err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
  171. return err
  172. }
  173. func CacheIsUserEnabled(userId int) (bool, error) {
  174. if !common.RedisEnabled {
  175. return IsUserEnabled(userId)
  176. }
  177. enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
  178. if err == nil {
  179. return enabled == "1", nil
  180. }
  181. userEnabled, err := IsUserEnabled(userId)
  182. if err != nil {
  183. return false, err
  184. }
  185. enabled = "0"
  186. if userEnabled {
  187. enabled = "1"
  188. }
  189. err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
  190. if err != nil {
  191. common.SysError("Redis set user enabled error: " + err.Error())
  192. }
  193. return userEnabled, err
  194. }
  195. var group2model2channels map[string]map[string][]*Channel
  196. var channelsIDM map[int]*Channel
  197. var channelSyncLock sync.RWMutex
  198. func InitChannelCache() {
  199. newChannelId2channel := make(map[int]*Channel)
  200. var channels []*Channel
  201. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  202. for _, channel := range channels {
  203. newChannelId2channel[channel.Id] = channel
  204. }
  205. var abilities []*Ability
  206. DB.Find(&abilities)
  207. groups := make(map[string]bool)
  208. for _, ability := range abilities {
  209. groups[ability.Group] = true
  210. }
  211. newGroup2model2channels := make(map[string]map[string][]*Channel)
  212. newChannelsIDM := make(map[int]*Channel)
  213. for group := range groups {
  214. newGroup2model2channels[group] = make(map[string][]*Channel)
  215. }
  216. for _, channel := range channels {
  217. newChannelsIDM[channel.Id] = channel
  218. groups := strings.Split(channel.Group, ",")
  219. for _, group := range groups {
  220. models := strings.Split(channel.Models, ",")
  221. for _, model := range models {
  222. if _, ok := newGroup2model2channels[group][model]; !ok {
  223. newGroup2model2channels[group][model] = make([]*Channel, 0)
  224. }
  225. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  226. }
  227. }
  228. }
  229. // sort by priority
  230. for group, model2channels := range newGroup2model2channels {
  231. for model, channels := range model2channels {
  232. sort.Slice(channels, func(i, j int) bool {
  233. return channels[i].GetPriority() > channels[j].GetPriority()
  234. })
  235. newGroup2model2channels[group][model] = channels
  236. }
  237. }
  238. channelSyncLock.Lock()
  239. group2model2channels = newGroup2model2channels
  240. channelsIDM = newChannelsIDM
  241. channelSyncLock.Unlock()
  242. common.SysLog("channels synced from database")
  243. }
  244. func SyncChannelCache(frequency int) {
  245. for {
  246. time.Sleep(time.Duration(frequency) * time.Second)
  247. common.SysLog("syncing channels from database")
  248. InitChannelCache()
  249. }
  250. }
  251. func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
  252. if strings.HasPrefix(model, "gpt-4-gizmo") {
  253. model = "gpt-4-gizmo-*"
  254. }
  255. // if memory cache is disabled, get channel directly from database
  256. if !common.MemoryCacheEnabled {
  257. return GetRandomSatisfiedChannel(group, model, retry)
  258. }
  259. channelSyncLock.RLock()
  260. defer channelSyncLock.RUnlock()
  261. channels := group2model2channels[group][model]
  262. if len(channels) == 0 {
  263. return nil, errors.New("channel not found")
  264. }
  265. uniquePriorities := make(map[int]bool)
  266. for _, channel := range channels {
  267. uniquePriorities[int(channel.GetPriority())] = true
  268. }
  269. var sortedUniquePriorities []int
  270. for priority := range uniquePriorities {
  271. sortedUniquePriorities = append(sortedUniquePriorities, priority)
  272. }
  273. sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
  274. if retry >= len(uniquePriorities) {
  275. retry = len(uniquePriorities) - 1
  276. }
  277. targetPriority := int64(sortedUniquePriorities[retry])
  278. // get the priority for the given retry number
  279. var targetChannels []*Channel
  280. for _, channel := range channels {
  281. if channel.GetPriority() == targetPriority {
  282. targetChannels = append(targetChannels, channel)
  283. }
  284. }
  285. // 平滑系数
  286. smoothingFactor := 10
  287. // Calculate the total weight of all channels up to endIdx
  288. totalWeight := 0
  289. for _, channel := range targetChannels {
  290. totalWeight += channel.GetWeight() + smoothingFactor
  291. }
  292. // Generate a random value in the range [0, totalWeight)
  293. randomWeight := rand.Intn(totalWeight)
  294. // Find a channel based on its weight
  295. for _, channel := range targetChannels {
  296. randomWeight -= channel.GetWeight() + smoothingFactor
  297. if randomWeight < 0 {
  298. return channel, nil
  299. }
  300. }
  301. // return null if no channel is not found
  302. return nil, errors.New("channel not found")
  303. }
  304. func CacheGetChannel(id int) (*Channel, error) {
  305. if !common.MemoryCacheEnabled {
  306. return GetChannelById(id, true)
  307. }
  308. channelSyncLock.RLock()
  309. defer channelSyncLock.RUnlock()
  310. c, ok := channelsIDM[id]
  311. if !ok {
  312. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  313. }
  314. return c, nil
  315. }