redis.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. package common
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "reflect"
  8. "strconv"
  9. "time"
  10. "github.com/go-redis/redis/v8"
  11. "gorm.io/gorm"
  12. )
  13. var RDB *redis.Client
  14. var RedisEnabled = true
  15. // InitRedisClient This function is called after init()
  16. func InitRedisClient() (err error) {
  17. if os.Getenv("REDIS_CONN_STRING") == "" {
  18. RedisEnabled = false
  19. SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
  20. return nil
  21. }
  22. if os.Getenv("SYNC_FREQUENCY") == "" {
  23. SysLog("SYNC_FREQUENCY not set, use default value 60")
  24. SyncFrequency = 60
  25. }
  26. SysLog("Redis is enabled")
  27. opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
  28. if err != nil {
  29. FatalLog("failed to parse Redis connection string: " + err.Error())
  30. }
  31. RDB = redis.NewClient(opt)
  32. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  33. defer cancel()
  34. _, err = RDB.Ping(ctx).Result()
  35. if err != nil {
  36. FatalLog("Redis ping test failed: " + err.Error())
  37. }
  38. return err
  39. }
  40. func ParseRedisOption() *redis.Options {
  41. opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
  42. if err != nil {
  43. FatalLog("failed to parse Redis connection string: " + err.Error())
  44. }
  45. return opt
  46. }
  47. func RedisSet(key string, value string, expiration time.Duration) error {
  48. ctx := context.Background()
  49. return RDB.Set(ctx, key, value, expiration).Err()
  50. }
  51. func RedisGet(key string) (string, error) {
  52. ctx := context.Background()
  53. return RDB.Get(ctx, key).Result()
  54. }
  55. //func RedisExpire(key string, expiration time.Duration) error {
  56. // ctx := context.Background()
  57. // return RDB.Expire(ctx, key, expiration).Err()
  58. //}
  59. //
  60. //func RedisGetEx(key string, expiration time.Duration) (string, error) {
  61. // ctx := context.Background()
  62. // return RDB.GetSet(ctx, key, expiration).Result()
  63. //}
  64. func RedisDel(key string) error {
  65. ctx := context.Background()
  66. return RDB.Del(ctx, key).Err()
  67. }
  68. func RedisHDelObj(key string) error {
  69. ctx := context.Background()
  70. return RDB.HDel(ctx, key).Err()
  71. }
  72. func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
  73. ctx := context.Background()
  74. data := make(map[string]interface{})
  75. // 使用反射遍历结构体字段
  76. v := reflect.ValueOf(obj).Elem()
  77. t := v.Type()
  78. for i := 0; i < v.NumField(); i++ {
  79. field := t.Field(i)
  80. value := v.Field(i)
  81. // Skip DeletedAt field
  82. if field.Type.String() == "gorm.DeletedAt" {
  83. continue
  84. }
  85. // 处理指针类型
  86. if value.Kind() == reflect.Ptr {
  87. if value.IsNil() {
  88. data[field.Name] = ""
  89. continue
  90. }
  91. value = value.Elem()
  92. }
  93. // 处理布尔类型
  94. if value.Kind() == reflect.Bool {
  95. data[field.Name] = strconv.FormatBool(value.Bool())
  96. continue
  97. }
  98. // 其他类型直接转换为字符串
  99. data[field.Name] = fmt.Sprintf("%v", value.Interface())
  100. }
  101. txn := RDB.TxPipeline()
  102. txn.HSet(ctx, key, data)
  103. txn.Expire(ctx, key, expiration)
  104. _, err := txn.Exec(ctx)
  105. if err != nil {
  106. return fmt.Errorf("failed to execute transaction: %w", err)
  107. }
  108. return nil
  109. }
  110. func RedisHGetObj(key string, obj interface{}) error {
  111. ctx := context.Background()
  112. result, err := RDB.HGetAll(ctx, key).Result()
  113. if err != nil {
  114. return fmt.Errorf("failed to load hash from Redis: %w", err)
  115. }
  116. if len(result) == 0 {
  117. return fmt.Errorf("key %s not found in Redis", key)
  118. }
  119. // Handle both pointer and non-pointer values
  120. val := reflect.ValueOf(obj)
  121. if val.Kind() != reflect.Ptr {
  122. return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
  123. }
  124. v := val.Elem()
  125. if v.Kind() != reflect.Struct {
  126. return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
  127. }
  128. t := v.Type()
  129. for i := 0; i < v.NumField(); i++ {
  130. field := t.Field(i)
  131. fieldName := field.Name
  132. if value, ok := result[fieldName]; ok {
  133. fieldValue := v.Field(i)
  134. // Handle pointer types
  135. if fieldValue.Kind() == reflect.Ptr {
  136. if value == "" {
  137. continue
  138. }
  139. if fieldValue.IsNil() {
  140. fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
  141. }
  142. fieldValue = fieldValue.Elem()
  143. }
  144. // Enhanced type handling for Token struct
  145. switch fieldValue.Kind() {
  146. case reflect.String:
  147. fieldValue.SetString(value)
  148. case reflect.Int, reflect.Int64:
  149. intValue, err := strconv.ParseInt(value, 10, 64)
  150. if err != nil {
  151. return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
  152. }
  153. fieldValue.SetInt(intValue)
  154. case reflect.Bool:
  155. boolValue, err := strconv.ParseBool(value)
  156. if err != nil {
  157. return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
  158. }
  159. fieldValue.SetBool(boolValue)
  160. case reflect.Struct:
  161. // Special handling for gorm.DeletedAt
  162. if fieldValue.Type().String() == "gorm.DeletedAt" {
  163. if value != "" {
  164. timeValue, err := time.Parse(time.RFC3339, value)
  165. if err != nil {
  166. return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
  167. }
  168. fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
  169. }
  170. }
  171. default:
  172. return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
  173. }
  174. }
  175. }
  176. return nil
  177. }
  178. // RedisIncr Add this function to handle atomic increments
  179. func RedisIncr(key string, delta int64) error {
  180. // 检查键的剩余生存时间
  181. ttlCmd := RDB.TTL(context.Background(), key)
  182. ttl, err := ttlCmd.Result()
  183. if err != nil && !errors.Is(err, redis.Nil) {
  184. return fmt.Errorf("failed to get TTL: %w", err)
  185. }
  186. // 只有在 key 存在且有 TTL 时才需要特殊处理
  187. if ttl > 0 {
  188. ctx := context.Background()
  189. // 开始一个Redis事务
  190. txn := RDB.TxPipeline()
  191. // 减少余额
  192. decrCmd := txn.IncrBy(ctx, key, delta)
  193. if err := decrCmd.Err(); err != nil {
  194. return err // 如果减少失败,则直接返回错误
  195. }
  196. // 重新设置过期时间,使用原来的过期时间
  197. txn.Expire(ctx, key, ttl)
  198. // 执行事务
  199. _, err = txn.Exec(ctx)
  200. return err
  201. }
  202. return nil
  203. }
  204. func RedisHIncrBy(key, field string, delta int64) error {
  205. ttlCmd := RDB.TTL(context.Background(), key)
  206. ttl, err := ttlCmd.Result()
  207. if err != nil && !errors.Is(err, redis.Nil) {
  208. return fmt.Errorf("failed to get TTL: %w", err)
  209. }
  210. if ttl > 0 {
  211. ctx := context.Background()
  212. txn := RDB.TxPipeline()
  213. incrCmd := txn.HIncrBy(ctx, key, field, delta)
  214. if err := incrCmd.Err(); err != nil {
  215. return err
  216. }
  217. txn.Expire(ctx, key, ttl)
  218. _, err = txn.Exec(ctx)
  219. return err
  220. }
  221. return nil
  222. }
  223. func RedisHSetField(key, field string, value interface{}) error {
  224. ttlCmd := RDB.TTL(context.Background(), key)
  225. ttl, err := ttlCmd.Result()
  226. if err != nil && !errors.Is(err, redis.Nil) {
  227. return fmt.Errorf("failed to get TTL: %w", err)
  228. }
  229. if ttl > 0 {
  230. ctx := context.Background()
  231. txn := RDB.TxPipeline()
  232. hsetCmd := txn.HSet(ctx, key, field, value)
  233. if err := hsetCmd.Err(); err != nil {
  234. return err
  235. }
  236. txn.Expire(ctx, key, ttl)
  237. _, err = txn.Exec(ctx)
  238. return err
  239. }
  240. return nil
  241. }