cache.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. package model
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "math/rand"
  7. "one-api/common"
  8. "one-api/constant"
  9. "sort"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. // 仅用于定时同步缓存
  15. var token2UserId = make(map[string]int)
  16. var token2UserIdLock sync.RWMutex
  17. func cacheSetToken(token *Token) error {
  18. jsonBytes, err := json.Marshal(token)
  19. if err != nil {
  20. return err
  21. }
  22. err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(constant.TokenCacheSeconds)*time.Second)
  23. if err != nil {
  24. common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
  25. return err
  26. }
  27. token2UserIdLock.Lock()
  28. defer token2UserIdLock.Unlock()
  29. token2UserId[token.Key] = token.UserId
  30. return nil
  31. }
  32. // CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
  33. func CacheGetTokenByKey(key string) (*Token, error) {
  34. if !common.RedisEnabled {
  35. return GetTokenByKey(key)
  36. }
  37. var token *Token
  38. tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
  39. if err != nil {
  40. // 如果缓存中不存在,则从数据库中获取
  41. token, err = GetTokenByKey(key)
  42. if err != nil {
  43. return nil, err
  44. }
  45. err = cacheSetToken(token)
  46. return token, nil
  47. }
  48. // 如果缓存中存在,则续期时间
  49. err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(constant.TokenCacheSeconds)*time.Second)
  50. err = json.Unmarshal([]byte(tokenObjectString), &token)
  51. return token, err
  52. }
  53. func SyncTokenCache(frequency int) {
  54. for {
  55. time.Sleep(time.Duration(frequency) * time.Second)
  56. common.SysLog("syncing tokens from database")
  57. token2UserIdLock.Lock()
  58. // 从token2UserId中获取所有的key
  59. var copyToken2UserId = make(map[string]int)
  60. for s, i := range token2UserId {
  61. copyToken2UserId[s] = i
  62. }
  63. token2UserId = make(map[string]int)
  64. token2UserIdLock.Unlock()
  65. for key := range copyToken2UserId {
  66. token, err := GetTokenByKey(key)
  67. if err != nil {
  68. // 如果数据库中不存在,则删除缓存
  69. common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
  70. //delete redis
  71. err := common.RedisDel(fmt.Sprintf("token:%s", key))
  72. if err != nil {
  73. common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
  74. }
  75. } else {
  76. // 如果数据库中存在,先检查redis
  77. _, err = common.RedisGet(fmt.Sprintf("token:%s", key))
  78. if err != nil {
  79. // 如果redis中不存在,则跳过
  80. continue
  81. }
  82. err = cacheSetToken(token)
  83. if err != nil {
  84. common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
  85. }
  86. }
  87. }
  88. }
  89. }
  90. //func CacheGetUserGroup(id int) (group string, err error) {
  91. // if !common.RedisEnabled {
  92. // return GetUserGroup(id)
  93. // }
  94. // group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
  95. // if err != nil {
  96. // group, err = GetUserGroup(id)
  97. // if err != nil {
  98. // return "", err
  99. // }
  100. // err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
  101. // if err != nil {
  102. // common.SysError("Redis set user group error: " + err.Error())
  103. // }
  104. // }
  105. // return group, err
  106. //}
  107. //
  108. //func CacheGetUsername(id int) (username string, err error) {
  109. // if !common.RedisEnabled {
  110. // return GetUsernameById(id)
  111. // }
  112. // username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
  113. // if err != nil {
  114. // username, err = GetUsernameById(id)
  115. // if err != nil {
  116. // return "", err
  117. // }
  118. // err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
  119. // if err != nil {
  120. // common.SysError("Redis set user group error: " + err.Error())
  121. // }
  122. // }
  123. // return username, err
  124. //}
  125. //
  126. //func CacheGetUserQuota(id int) (quota int, err error) {
  127. // if !common.RedisEnabled {
  128. // return GetUserQuota(id)
  129. // }
  130. // quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
  131. // if err != nil {
  132. // quota, err = GetUserQuota(id)
  133. // if err != nil {
  134. // return 0, err
  135. // }
  136. // return quota, nil
  137. // }
  138. // quota, err = strconv.Atoi(quotaString)
  139. // return quota, nil
  140. //}
  141. //
  142. //func CacheUpdateUserQuota(id int) error {
  143. // if !common.RedisEnabled {
  144. // return nil
  145. // }
  146. // quota, err := GetUserQuota(id)
  147. // if err != nil {
  148. // return err
  149. // }
  150. // return cacheSetUserQuota(id, quota)
  151. //}
  152. //
  153. //func cacheSetUserQuota(id int, quota int) error {
  154. // err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
  155. // return err
  156. //}
  157. //
  158. //func CacheDecreaseUserQuota(id int, quota int) error {
  159. // if !common.RedisEnabled {
  160. // return nil
  161. // }
  162. // err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
  163. // return err
  164. //}
  165. //
  166. //func CacheIsUserEnabled(userId int) (bool, error) {
  167. // if !common.RedisEnabled {
  168. // return IsUserEnabled(userId)
  169. // }
  170. // enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
  171. // if err == nil {
  172. // return enabled == "1", nil
  173. // }
  174. //
  175. // userEnabled, err := IsUserEnabled(userId)
  176. // if err != nil {
  177. // return false, err
  178. // }
  179. // enabled = "0"
  180. // if userEnabled {
  181. // enabled = "1"
  182. // }
  183. // err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
  184. // if err != nil {
  185. // common.SysError("Redis set user enabled error: " + err.Error())
  186. // }
  187. // return userEnabled, err
  188. //}
  189. var group2model2channels map[string]map[string][]*Channel
  190. var channelsIDM map[int]*Channel
  191. var channelSyncLock sync.RWMutex
  192. func InitChannelCache() {
  193. newChannelId2channel := make(map[int]*Channel)
  194. var channels []*Channel
  195. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  196. for _, channel := range channels {
  197. newChannelId2channel[channel.Id] = channel
  198. }
  199. var abilities []*Ability
  200. DB.Find(&abilities)
  201. groups := make(map[string]bool)
  202. for _, ability := range abilities {
  203. groups[ability.Group] = true
  204. }
  205. newGroup2model2channels := make(map[string]map[string][]*Channel)
  206. newChannelsIDM := make(map[int]*Channel)
  207. for group := range groups {
  208. newGroup2model2channels[group] = make(map[string][]*Channel)
  209. }
  210. for _, channel := range channels {
  211. newChannelsIDM[channel.Id] = channel
  212. groups := strings.Split(channel.Group, ",")
  213. for _, group := range groups {
  214. models := strings.Split(channel.Models, ",")
  215. for _, model := range models {
  216. if _, ok := newGroup2model2channels[group][model]; !ok {
  217. newGroup2model2channels[group][model] = make([]*Channel, 0)
  218. }
  219. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  220. }
  221. }
  222. }
  223. // sort by priority
  224. for group, model2channels := range newGroup2model2channels {
  225. for model, channels := range model2channels {
  226. sort.Slice(channels, func(i, j int) bool {
  227. return channels[i].GetPriority() > channels[j].GetPriority()
  228. })
  229. newGroup2model2channels[group][model] = channels
  230. }
  231. }
  232. channelSyncLock.Lock()
  233. group2model2channels = newGroup2model2channels
  234. channelsIDM = newChannelsIDM
  235. channelSyncLock.Unlock()
  236. common.SysLog("channels synced from database")
  237. }
  238. func SyncChannelCache(frequency int) {
  239. for {
  240. time.Sleep(time.Duration(frequency) * time.Second)
  241. common.SysLog("syncing channels from database")
  242. InitChannelCache()
  243. }
  244. }
  245. func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
  246. if strings.HasPrefix(model, "gpt-4-gizmo") {
  247. model = "gpt-4-gizmo-*"
  248. }
  249. if strings.HasPrefix(model, "gpt-4o-gizmo") {
  250. model = "gpt-4o-gizmo-*"
  251. }
  252. // if memory cache is disabled, get channel directly from database
  253. if !common.MemoryCacheEnabled {
  254. return GetRandomSatisfiedChannel(group, model, retry)
  255. }
  256. channelSyncLock.RLock()
  257. defer channelSyncLock.RUnlock()
  258. channels := group2model2channels[group][model]
  259. if len(channels) == 0 {
  260. return nil, errors.New("channel not found")
  261. }
  262. uniquePriorities := make(map[int]bool)
  263. for _, channel := range channels {
  264. uniquePriorities[int(channel.GetPriority())] = true
  265. }
  266. var sortedUniquePriorities []int
  267. for priority := range uniquePriorities {
  268. sortedUniquePriorities = append(sortedUniquePriorities, priority)
  269. }
  270. sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
  271. if retry >= len(uniquePriorities) {
  272. retry = len(uniquePriorities) - 1
  273. }
  274. targetPriority := int64(sortedUniquePriorities[retry])
  275. // get the priority for the given retry number
  276. var targetChannels []*Channel
  277. for _, channel := range channels {
  278. if channel.GetPriority() == targetPriority {
  279. targetChannels = append(targetChannels, channel)
  280. }
  281. }
  282. // 平滑系数
  283. smoothingFactor := 10
  284. // Calculate the total weight of all channels up to endIdx
  285. totalWeight := 0
  286. for _, channel := range targetChannels {
  287. totalWeight += channel.GetWeight() + smoothingFactor
  288. }
  289. // Generate a random value in the range [0, totalWeight)
  290. randomWeight := rand.Intn(totalWeight)
  291. // Find a channel based on its weight
  292. for _, channel := range targetChannels {
  293. randomWeight -= channel.GetWeight() + smoothingFactor
  294. if randomWeight < 0 {
  295. return channel, nil
  296. }
  297. }
  298. // return null if no channel is not found
  299. return nil, errors.New("channel not found")
  300. }
  301. func CacheGetChannel(id int) (*Channel, error) {
  302. if !common.MemoryCacheEnabled {
  303. return GetChannelById(id, true)
  304. }
  305. channelSyncLock.RLock()
  306. defer channelSyncLock.RUnlock()
  307. c, ok := channelsIDM[id]
  308. if !ok {
  309. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  310. }
  311. return c, nil
  312. }
  313. func CacheUpdateChannelStatus(id int, status int) {
  314. if !common.MemoryCacheEnabled {
  315. return
  316. }
  317. channelSyncLock.Lock()
  318. defer channelSyncLock.Unlock()
  319. if channel, ok := channelsIDM[id]; ok {
  320. channel.Status = status
  321. }
  322. }