user.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056
  1. package model
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "strconv"
  8. "strings"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/bytedance/gopkg/util/gopool"
  13. "gorm.io/gorm"
  14. )
  15. const UserNameMaxLength = 20
  16. // User if you add sensitive fields, don't forget to clean them in setupLogin function.
  17. // Otherwise, the sensitive information will be saved on local storage in plain text!
  18. type User struct {
  19. Id int `json:"id"`
  20. Username string `json:"username" gorm:"unique;index" validate:"max=20"`
  21. Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
  22. OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database!
  23. DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
  24. Role int `json:"role" gorm:"type:int;default:1"` // admin, common
  25. Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
  26. Email string `json:"email" gorm:"index" validate:"max=50"`
  27. GitHubId string `json:"github_id" gorm:"column:github_id;index"`
  28. DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
  29. OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
  30. WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
  31. TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
  32. VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
  33. AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
  34. Quota int `json:"quota" gorm:"type:int;default:0"`
  35. UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
  36. RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
  37. Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
  38. AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
  39. AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
  40. AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
  41. AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
  42. InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
  43. DeletedAt gorm.DeletedAt `gorm:"index"`
  44. LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
  45. Setting string `json:"setting" gorm:"type:text;column:setting"`
  46. Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
  47. StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
  48. CreatedAt int64 `json:"created_at" gorm:"autoCreateTime;column:created_at"`
  49. LastLoginAt int64 `json:"last_login_at" gorm:"default:0;column:last_login_at"`
  50. }
  51. func (user *User) ToBaseUser() *UserBase {
  52. cache := &UserBase{
  53. Id: user.Id,
  54. Group: user.Group,
  55. Quota: user.Quota,
  56. Status: user.Status,
  57. Username: user.Username,
  58. Setting: user.Setting,
  59. Email: user.Email,
  60. }
  61. return cache
  62. }
  63. func (user *User) GetAccessToken() string {
  64. if user.AccessToken == nil {
  65. return ""
  66. }
  67. return *user.AccessToken
  68. }
  69. func (user *User) SetAccessToken(token string) {
  70. user.AccessToken = &token
  71. }
  72. func (user *User) GetSetting() dto.UserSetting {
  73. setting := dto.UserSetting{}
  74. if user.Setting != "" {
  75. err := json.Unmarshal([]byte(user.Setting), &setting)
  76. if err != nil {
  77. common.SysLog("failed to unmarshal setting: " + err.Error())
  78. }
  79. }
  80. return setting
  81. }
  82. func (user *User) SetSetting(setting dto.UserSetting) {
  83. settingBytes, err := json.Marshal(setting)
  84. if err != nil {
  85. common.SysLog("failed to marshal setting: " + err.Error())
  86. return
  87. }
  88. user.Setting = string(settingBytes)
  89. }
  90. // 根据用户角色生成默认的边栏配置
  91. func generateDefaultSidebarConfigForRole(userRole int) string {
  92. defaultConfig := map[string]interface{}{}
  93. // 聊天区域 - 所有用户都可以访问
  94. defaultConfig["chat"] = map[string]interface{}{
  95. "enabled": true,
  96. "playground": true,
  97. "chat": true,
  98. }
  99. // 控制台区域 - 所有用户都可以访问
  100. defaultConfig["console"] = map[string]interface{}{
  101. "enabled": true,
  102. "detail": true,
  103. "token": true,
  104. "log": true,
  105. "midjourney": true,
  106. "task": true,
  107. }
  108. // 个人中心区域 - 所有用户都可以访问
  109. defaultConfig["personal"] = map[string]interface{}{
  110. "enabled": true,
  111. "topup": true,
  112. "personal": true,
  113. }
  114. // 管理员区域 - 根据角色决定
  115. if userRole == common.RoleAdminUser {
  116. // 管理员可以访问管理员区域,但不能访问系统设置
  117. defaultConfig["admin"] = map[string]interface{}{
  118. "enabled": true,
  119. "channel": true,
  120. "models": true,
  121. "redemption": true,
  122. "user": true,
  123. "setting": false, // 管理员不能访问系统设置
  124. }
  125. } else if userRole == common.RoleRootUser {
  126. // 超级管理员可以访问所有功能
  127. defaultConfig["admin"] = map[string]interface{}{
  128. "enabled": true,
  129. "channel": true,
  130. "models": true,
  131. "redemption": true,
  132. "user": true,
  133. "setting": true,
  134. }
  135. }
  136. // 普通用户不包含admin区域
  137. // 转换为JSON字符串
  138. configBytes, err := json.Marshal(defaultConfig)
  139. if err != nil {
  140. common.SysLog("生成默认边栏配置失败: " + err.Error())
  141. return ""
  142. }
  143. return string(configBytes)
  144. }
  145. // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
  146. func CheckUserExistOrDeleted(username string, email string) (bool, error) {
  147. var user User
  148. // err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
  149. // check email if empty
  150. var err error
  151. if email == "" {
  152. err = DB.Unscoped().First(&user, "username = ?", username).Error
  153. } else {
  154. err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
  155. }
  156. if err != nil {
  157. if errors.Is(err, gorm.ErrRecordNotFound) {
  158. // not exist, return false, nil
  159. return false, nil
  160. }
  161. // other error, return false, err
  162. return false, err
  163. }
  164. // exist, return true, nil
  165. return true, nil
  166. }
  167. func GetMaxUserId() int {
  168. var user User
  169. DB.Unscoped().Last(&user)
  170. return user.Id
  171. }
  172. func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
  173. // Start transaction
  174. tx := DB.Begin()
  175. if tx.Error != nil {
  176. return nil, 0, tx.Error
  177. }
  178. defer func() {
  179. if r := recover(); r != nil {
  180. tx.Rollback()
  181. }
  182. }()
  183. // Get total count within transaction
  184. err = tx.Unscoped().Model(&User{}).Count(&total).Error
  185. if err != nil {
  186. tx.Rollback()
  187. return nil, 0, err
  188. }
  189. // Get paginated users within same transaction
  190. err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
  191. if err != nil {
  192. tx.Rollback()
  193. return nil, 0, err
  194. }
  195. // Commit transaction
  196. if err = tx.Commit().Error; err != nil {
  197. return nil, 0, err
  198. }
  199. return users, total, nil
  200. }
  201. func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
  202. var users []*User
  203. var total int64
  204. var err error
  205. // 开始事务
  206. tx := DB.Begin()
  207. if tx.Error != nil {
  208. return nil, 0, tx.Error
  209. }
  210. defer func() {
  211. if r := recover(); r != nil {
  212. tx.Rollback()
  213. }
  214. }()
  215. // 构建基础查询
  216. query := tx.Unscoped().Model(&User{})
  217. // 构建搜索条件
  218. likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
  219. // 尝试将关键字转换为整数ID
  220. keywordInt, err := strconv.Atoi(keyword)
  221. if err == nil {
  222. // 如果是数字,同时搜索ID和其他字段
  223. likeCondition = "id = ? OR " + likeCondition
  224. if group != "" {
  225. query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
  226. keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
  227. } else {
  228. query = query.Where(likeCondition,
  229. keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
  230. }
  231. } else {
  232. // 非数字关键字,只搜索字符串字段
  233. if group != "" {
  234. query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
  235. "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
  236. } else {
  237. query = query.Where(likeCondition,
  238. "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
  239. }
  240. }
  241. // 获取总数
  242. err = query.Count(&total).Error
  243. if err != nil {
  244. tx.Rollback()
  245. return nil, 0, err
  246. }
  247. // 获取分页数据
  248. err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error
  249. if err != nil {
  250. tx.Rollback()
  251. return nil, 0, err
  252. }
  253. // 提交事务
  254. if err = tx.Commit().Error; err != nil {
  255. return nil, 0, err
  256. }
  257. return users, total, nil
  258. }
  259. func GetUserById(id int, selectAll bool) (*User, error) {
  260. if id == 0 {
  261. return nil, errors.New("id 为空!")
  262. }
  263. user := User{Id: id}
  264. var err error = nil
  265. if selectAll {
  266. err = DB.First(&user, "id = ?", id).Error
  267. } else {
  268. err = DB.Omit("password").First(&user, "id = ?", id).Error
  269. }
  270. return &user, err
  271. }
  272. func GetUserIdByAffCode(affCode string) (int, error) {
  273. if affCode == "" {
  274. return 0, errors.New("affCode 为空!")
  275. }
  276. var user User
  277. err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
  278. return user.Id, err
  279. }
  280. func DeleteUserById(id int) (err error) {
  281. if id == 0 {
  282. return errors.New("id 为空!")
  283. }
  284. user := User{Id: id}
  285. return user.Delete()
  286. }
  287. func HardDeleteUserById(id int) error {
  288. if id == 0 {
  289. return errors.New("id 为空!")
  290. }
  291. err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
  292. return err
  293. }
  294. func inviteUser(inviterId int) (err error) {
  295. user, err := GetUserById(inviterId, true)
  296. if err != nil {
  297. return err
  298. }
  299. user.AffCount++
  300. user.AffQuota += common.QuotaForInviter
  301. user.AffHistoryQuota += common.QuotaForInviter
  302. return DB.Save(user).Error
  303. }
  304. func (user *User) TransferAffQuotaToQuota(quota int) error {
  305. // 检查quota是否小于最小额度
  306. if float64(quota) < common.QuotaPerUnit {
  307. return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit)))
  308. }
  309. // 开始数据库事务
  310. tx := DB.Begin()
  311. if tx.Error != nil {
  312. return tx.Error
  313. }
  314. defer tx.Rollback() // 确保在函数退出时事务能回滚
  315. // 加锁查询用户以确保数据一致性
  316. err := tx.Set("gorm:query_option", "FOR UPDATE").First(&user, user.Id).Error
  317. if err != nil {
  318. return err
  319. }
  320. // 再次检查用户的AffQuota是否足够
  321. if user.AffQuota < quota {
  322. return errors.New("邀请额度不足!")
  323. }
  324. // 更新用户额度
  325. user.AffQuota -= quota
  326. user.Quota += quota
  327. // 保存用户状态
  328. if err := tx.Save(user).Error; err != nil {
  329. return err
  330. }
  331. // 提交事务
  332. return tx.Commit().Error
  333. }
  334. func (user *User) Insert(inviterId int) error {
  335. var err error
  336. if user.Password != "" {
  337. user.Password, err = common.Password2Hash(user.Password)
  338. if err != nil {
  339. return err
  340. }
  341. }
  342. user.Quota = common.QuotaForNewUser
  343. //user.SetAccessToken(common.GetUUID())
  344. user.AffCode = common.GetRandomString(4)
  345. // 初始化用户设置,包括默认的边栏配置
  346. if user.Setting == "" {
  347. defaultSetting := dto.UserSetting{}
  348. // 这里暂时不设置SidebarModules,因为需要在用户创建后根据角色设置
  349. user.SetSetting(defaultSetting)
  350. }
  351. result := DB.Create(user)
  352. if result.Error != nil {
  353. return result.Error
  354. }
  355. // 用户创建成功后,根据角色初始化边栏配置
  356. // 需要重新获取用户以确保有正确的ID和Role
  357. var createdUser User
  358. if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil {
  359. // 生成基于角色的默认边栏配置
  360. defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
  361. if defaultSidebarConfig != "" {
  362. currentSetting := createdUser.GetSetting()
  363. currentSetting.SidebarModules = defaultSidebarConfig
  364. createdUser.SetSetting(currentSetting)
  365. createdUser.Update(false)
  366. common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
  367. }
  368. }
  369. if common.QuotaForNewUser > 0 {
  370. RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
  371. }
  372. if inviterId != 0 {
  373. if common.QuotaForInvitee > 0 {
  374. _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
  375. RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
  376. }
  377. if common.QuotaForInviter > 0 {
  378. //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
  379. RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
  380. _ = inviteUser(inviterId)
  381. }
  382. }
  383. return nil
  384. }
  385. // InsertWithTx inserts a new user within an existing transaction.
  386. // This is used for OAuth registration where user creation and binding need to be atomic.
  387. // Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits.
  388. func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error {
  389. var err error
  390. if user.Password != "" {
  391. user.Password, err = common.Password2Hash(user.Password)
  392. if err != nil {
  393. return err
  394. }
  395. }
  396. user.Quota = common.QuotaForNewUser
  397. user.AffCode = common.GetRandomString(4)
  398. // 初始化用户设置
  399. if user.Setting == "" {
  400. defaultSetting := dto.UserSetting{}
  401. user.SetSetting(defaultSetting)
  402. }
  403. result := tx.Create(user)
  404. if result.Error != nil {
  405. return result.Error
  406. }
  407. return nil
  408. }
  409. // FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation.
  410. // This should be called after the transaction commits successfully.
  411. func (user *User) FinalizeOAuthUserCreation(inviterId int) {
  412. // 用户创建成功后,根据角色初始化边栏配置
  413. var createdUser User
  414. if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil {
  415. defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
  416. if defaultSidebarConfig != "" {
  417. currentSetting := createdUser.GetSetting()
  418. currentSetting.SidebarModules = defaultSidebarConfig
  419. createdUser.SetSetting(currentSetting)
  420. createdUser.Update(false)
  421. common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
  422. }
  423. }
  424. if common.QuotaForNewUser > 0 {
  425. RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
  426. }
  427. if inviterId != 0 {
  428. if common.QuotaForInvitee > 0 {
  429. _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
  430. RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
  431. }
  432. if common.QuotaForInviter > 0 {
  433. RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
  434. _ = inviteUser(inviterId)
  435. }
  436. }
  437. }
  438. func (user *User) Update(updatePassword bool) error {
  439. var err error
  440. if updatePassword {
  441. user.Password, err = common.Password2Hash(user.Password)
  442. if err != nil {
  443. return err
  444. }
  445. }
  446. newUser := *user
  447. DB.First(&user, user.Id)
  448. if err = DB.Model(user).Updates(newUser).Error; err != nil {
  449. return err
  450. }
  451. // Update cache
  452. return updateUserCache(*user)
  453. }
  454. func (user *User) Edit(updatePassword bool) error {
  455. var err error
  456. if updatePassword {
  457. user.Password, err = common.Password2Hash(user.Password)
  458. if err != nil {
  459. return err
  460. }
  461. }
  462. newUser := *user
  463. updates := map[string]interface{}{
  464. "username": newUser.Username,
  465. "display_name": newUser.DisplayName,
  466. "group": newUser.Group,
  467. "remark": newUser.Remark,
  468. }
  469. if updatePassword {
  470. updates["password"] = newUser.Password
  471. }
  472. DB.First(&user, user.Id)
  473. if err = DB.Model(user).Updates(updates).Error; err != nil {
  474. return err
  475. }
  476. // Update cache
  477. return updateUserCache(*user)
  478. }
  479. func (user *User) ClearBinding(bindingType string) error {
  480. if user.Id == 0 {
  481. return errors.New("user id is empty")
  482. }
  483. bindingColumnMap := map[string]string{
  484. "email": "email",
  485. "github": "github_id",
  486. "discord": "discord_id",
  487. "oidc": "oidc_id",
  488. "wechat": "wechat_id",
  489. "telegram": "telegram_id",
  490. "linuxdo": "linux_do_id",
  491. }
  492. column, ok := bindingColumnMap[bindingType]
  493. if !ok {
  494. return errors.New("invalid binding type")
  495. }
  496. if err := DB.Model(&User{}).Where("id = ?", user.Id).Update(column, "").Error; err != nil {
  497. return err
  498. }
  499. if err := DB.Where("id = ?", user.Id).First(user).Error; err != nil {
  500. return err
  501. }
  502. return updateUserCache(*user)
  503. }
  504. func (user *User) Delete() error {
  505. if user.Id == 0 {
  506. return errors.New("id 为空!")
  507. }
  508. if err := DB.Delete(user).Error; err != nil {
  509. return err
  510. }
  511. // 清除缓存
  512. return invalidateUserCache(user.Id)
  513. }
  514. func (user *User) HardDelete() error {
  515. if user.Id == 0 {
  516. return errors.New("id 为空!")
  517. }
  518. err := DB.Unscoped().Delete(user).Error
  519. return err
  520. }
  521. // ValidateAndFill check password & user status
  522. func (user *User) ValidateAndFill() (err error) {
  523. // When querying with struct, GORM will only query with non-zero fields,
  524. // that means if your field's value is 0, '', false or other zero values,
  525. // it won't be used to build query conditions
  526. password := user.Password
  527. username := strings.TrimSpace(user.Username)
  528. if username == "" || password == "" {
  529. return ErrUserEmptyCredentials
  530. }
  531. // find by username or email
  532. err = DB.Where("username = ? OR email = ?", username, username).First(user).Error
  533. if err != nil {
  534. if errors.Is(err, gorm.ErrRecordNotFound) {
  535. return ErrInvalidCredentials
  536. }
  537. return fmt.Errorf("%w: %v", ErrDatabase, err)
  538. }
  539. okay := common.ValidatePasswordAndHash(password, user.Password)
  540. if !okay || user.Status != common.UserStatusEnabled {
  541. return ErrInvalidCredentials
  542. }
  543. return nil
  544. }
  545. func (user *User) FillUserById() error {
  546. if user.Id == 0 {
  547. return errors.New("id 为空!")
  548. }
  549. DB.Where(User{Id: user.Id}).First(user)
  550. return nil
  551. }
  552. func (user *User) FillUserByEmail() error {
  553. if user.Email == "" {
  554. return errors.New("email 为空!")
  555. }
  556. DB.Where(User{Email: user.Email}).First(user)
  557. return nil
  558. }
  559. func (user *User) FillUserByGitHubId() error {
  560. if user.GitHubId == "" {
  561. return errors.New("GitHub id 为空!")
  562. }
  563. DB.Where(User{GitHubId: user.GitHubId}).First(user)
  564. return nil
  565. }
  566. // UpdateGitHubId updates the user's GitHub ID (used for migration from login to numeric ID)
  567. func (user *User) UpdateGitHubId(newGitHubId string) error {
  568. if user.Id == 0 {
  569. return errors.New("user id is empty")
  570. }
  571. return DB.Model(user).Update("github_id", newGitHubId).Error
  572. }
  573. func (user *User) FillUserByDiscordId() error {
  574. if user.DiscordId == "" {
  575. return errors.New("discord id 为空!")
  576. }
  577. DB.Where(User{DiscordId: user.DiscordId}).First(user)
  578. return nil
  579. }
  580. func (user *User) FillUserByOidcId() error {
  581. if user.OidcId == "" {
  582. return errors.New("oidc id 为空!")
  583. }
  584. DB.Where(User{OidcId: user.OidcId}).First(user)
  585. return nil
  586. }
  587. func (user *User) FillUserByWeChatId() error {
  588. if user.WeChatId == "" {
  589. return errors.New("WeChat id 为空!")
  590. }
  591. DB.Where(User{WeChatId: user.WeChatId}).First(user)
  592. return nil
  593. }
  594. func (user *User) FillUserByTelegramId() error {
  595. if user.TelegramId == "" {
  596. return errors.New("Telegram id 为空!")
  597. }
  598. err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error
  599. if errors.Is(err, gorm.ErrRecordNotFound) {
  600. return errors.New("该 Telegram 账户未绑定")
  601. }
  602. return nil
  603. }
  604. func IsEmailAlreadyTaken(email string) bool {
  605. return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1
  606. }
  607. func IsWeChatIdAlreadyTaken(wechatId string) bool {
  608. return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
  609. }
  610. func IsGitHubIdAlreadyTaken(githubId string) bool {
  611. return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
  612. }
  613. func IsDiscordIdAlreadyTaken(discordId string) bool {
  614. return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
  615. }
  616. func IsOidcIdAlreadyTaken(oidcId string) bool {
  617. return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
  618. }
  619. func IsTelegramIdAlreadyTaken(telegramId string) bool {
  620. return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
  621. }
  622. func ResetUserPasswordByEmail(email string, password string) error {
  623. if email == "" || password == "" {
  624. return errors.New("邮箱地址或密码为空!")
  625. }
  626. hashedPassword, err := common.Password2Hash(password)
  627. if err != nil {
  628. return err
  629. }
  630. err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error
  631. return err
  632. }
  633. func IsAdmin(userId int) bool {
  634. if userId == 0 {
  635. return false
  636. }
  637. var user User
  638. err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
  639. if err != nil {
  640. common.SysLog("no such user " + err.Error())
  641. return false
  642. }
  643. return user.Role >= common.RoleAdminUser
  644. }
  645. //// IsUserEnabled checks user status from Redis first, falls back to DB if needed
  646. //func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
  647. // defer func() {
  648. // // Update Redis cache asynchronously on successful DB read
  649. // if shouldUpdateRedis(fromDB, err) {
  650. // gopool.Go(func() {
  651. // if err := updateUserStatusCache(id, status); err != nil {
  652. // common.SysError("failed to update user status cache: " + err.Error())
  653. // }
  654. // })
  655. // }
  656. // }()
  657. // if !fromDB && common.RedisEnabled {
  658. // // Try Redis first
  659. // status, err := getUserStatusCache(id)
  660. // if err == nil {
  661. // return status == common.UserStatusEnabled, nil
  662. // }
  663. // // Don't return error - fall through to DB
  664. // }
  665. // fromDB = true
  666. // var user User
  667. // err = DB.Where("id = ?", id).Select("status").Find(&user).Error
  668. // if err != nil {
  669. // return false, err
  670. // }
  671. //
  672. // return user.Status == common.UserStatusEnabled, nil
  673. //}
  674. func ValidateAccessToken(token string) (*User, error) {
  675. if token == "" {
  676. return nil, nil
  677. }
  678. token = strings.Replace(token, "Bearer ", "", 1)
  679. user := &User{}
  680. err := DB.Where("access_token = ?", token).First(user).Error
  681. if err != nil {
  682. if errors.Is(err, gorm.ErrRecordNotFound) {
  683. return nil, nil
  684. }
  685. return nil, fmt.Errorf("%w: %v", ErrDatabase, err)
  686. }
  687. return user, nil
  688. }
  689. // GetUserQuota gets quota from Redis first, falls back to DB if needed
  690. func GetUserQuota(id int, fromDB bool) (quota int, err error) {
  691. defer func() {
  692. // Update Redis cache asynchronously on successful DB read
  693. if shouldUpdateRedis(fromDB, err) {
  694. gopool.Go(func() {
  695. if err := updateUserQuotaCache(id, quota); err != nil {
  696. common.SysLog("failed to update user quota cache: " + err.Error())
  697. }
  698. })
  699. }
  700. }()
  701. if !fromDB && common.RedisEnabled {
  702. quota, err := getUserQuotaCache(id)
  703. if err == nil {
  704. return quota, nil
  705. }
  706. // Don't return error - fall through to DB
  707. }
  708. fromDB = true
  709. err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
  710. if err != nil {
  711. return 0, err
  712. }
  713. return quota, nil
  714. }
  715. func GetUserUsedQuota(id int) (quota int, err error) {
  716. err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error
  717. return quota, err
  718. }
  719. func GetUserEmail(id int) (email string, err error) {
  720. err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
  721. return email, err
  722. }
  723. // GetUserGroup gets group from Redis first, falls back to DB if needed
  724. func GetUserGroup(id int, fromDB bool) (group string, err error) {
  725. defer func() {
  726. // Update Redis cache asynchronously on successful DB read
  727. if shouldUpdateRedis(fromDB, err) {
  728. gopool.Go(func() {
  729. if err := updateUserGroupCache(id, group); err != nil {
  730. common.SysLog("failed to update user group cache: " + err.Error())
  731. }
  732. })
  733. }
  734. }()
  735. if !fromDB && common.RedisEnabled {
  736. group, err := getUserGroupCache(id)
  737. if err == nil {
  738. return group, nil
  739. }
  740. // Don't return error - fall through to DB
  741. }
  742. fromDB = true
  743. err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
  744. if err != nil {
  745. return "", err
  746. }
  747. return group, nil
  748. }
  749. // GetUserSetting gets setting from Redis first, falls back to DB if needed
  750. func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
  751. var setting string
  752. defer func() {
  753. // Update Redis cache asynchronously on successful DB read
  754. if shouldUpdateRedis(fromDB, err) {
  755. gopool.Go(func() {
  756. if err := updateUserSettingCache(id, setting); err != nil {
  757. common.SysLog("failed to update user setting cache: " + err.Error())
  758. }
  759. })
  760. }
  761. }()
  762. if !fromDB && common.RedisEnabled {
  763. setting, err := getUserSettingCache(id)
  764. if err == nil {
  765. return setting, nil
  766. }
  767. // Don't return error - fall through to DB
  768. }
  769. fromDB = true
  770. // can be nil setting
  771. var safeSetting sql.NullString
  772. err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&safeSetting).Error
  773. if err != nil {
  774. return settingMap, err
  775. }
  776. if safeSetting.Valid {
  777. setting = safeSetting.String
  778. } else {
  779. setting = ""
  780. }
  781. userBase := &UserBase{
  782. Setting: setting,
  783. }
  784. return userBase.GetSetting(), nil
  785. }
  786. func IncreaseUserQuota(id int, quota int, db bool) (err error) {
  787. if quota < 0 {
  788. return errors.New("quota 不能为负数!")
  789. }
  790. gopool.Go(func() {
  791. err := cacheIncrUserQuota(id, int64(quota))
  792. if err != nil {
  793. common.SysLog("failed to increase user quota: " + err.Error())
  794. }
  795. })
  796. if !db && common.BatchUpdateEnabled {
  797. addNewRecord(BatchUpdateTypeUserQuota, id, quota)
  798. return nil
  799. }
  800. return increaseUserQuota(id, quota)
  801. }
  802. func increaseUserQuota(id int, quota int) (err error) {
  803. err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
  804. if err != nil {
  805. return err
  806. }
  807. return err
  808. }
  809. func DecreaseUserQuota(id int, quota int, db bool) (err error) {
  810. if quota < 0 {
  811. return errors.New("quota 不能为负数!")
  812. }
  813. gopool.Go(func() {
  814. err := cacheDecrUserQuota(id, int64(quota))
  815. if err != nil {
  816. common.SysLog("failed to decrease user quota: " + err.Error())
  817. }
  818. })
  819. if !db && common.BatchUpdateEnabled {
  820. addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
  821. return nil
  822. }
  823. return decreaseUserQuota(id, quota)
  824. }
  825. func decreaseUserQuota(id int, quota int) (err error) {
  826. err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
  827. if err != nil {
  828. return err
  829. }
  830. return err
  831. }
  832. func DeltaUpdateUserQuota(id int, delta int) (err error) {
  833. if delta == 0 {
  834. return nil
  835. }
  836. if delta > 0 {
  837. return IncreaseUserQuota(id, delta, false)
  838. } else {
  839. return DecreaseUserQuota(id, -delta, false)
  840. }
  841. }
  842. //func GetRootUserEmail() (email string) {
  843. // DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
  844. // return email
  845. //}
  846. func GetRootUser() (user *User) {
  847. DB.Where("role = ?", common.RoleRootUser).First(&user)
  848. return user
  849. }
  850. func UpdateUserLastLoginAt(id int) {
  851. if err := DB.Model(&User{}).Where("id = ?", id).Update("last_login_at", common.GetTimestamp()).Error; err != nil {
  852. common.SysLog("failed to update user last_login_at: " + err.Error())
  853. }
  854. }
  855. func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
  856. if common.BatchUpdateEnabled {
  857. addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
  858. addNewRecord(BatchUpdateTypeRequestCount, id, 1)
  859. return
  860. }
  861. updateUserUsedQuotaAndRequestCount(id, quota, 1)
  862. }
  863. func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
  864. err := DB.Model(&User{}).Where("id = ?", id).Updates(
  865. map[string]interface{}{
  866. "used_quota": gorm.Expr("used_quota + ?", quota),
  867. "request_count": gorm.Expr("request_count + ?", count),
  868. },
  869. ).Error
  870. if err != nil {
  871. common.SysLog("failed to update user used quota and request count: " + err.Error())
  872. return
  873. }
  874. //// 更新缓存
  875. //if err := invalidateUserCache(id); err != nil {
  876. // common.SysError("failed to invalidate user cache: " + err.Error())
  877. //}
  878. }
  879. func updateUserUsedQuota(id int, quota int) {
  880. err := DB.Model(&User{}).Where("id = ?", id).Updates(
  881. map[string]interface{}{
  882. "used_quota": gorm.Expr("used_quota + ?", quota),
  883. },
  884. ).Error
  885. if err != nil {
  886. common.SysLog("failed to update user used quota: " + err.Error())
  887. }
  888. }
  889. func updateUserRequestCount(id int, count int) {
  890. err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
  891. if err != nil {
  892. common.SysLog("failed to update user request count: " + err.Error())
  893. }
  894. }
  895. // GetUsernameById gets username from Redis first, falls back to DB if needed
  896. func GetUsernameById(id int, fromDB bool) (username string, err error) {
  897. defer func() {
  898. // Update Redis cache asynchronously on successful DB read
  899. if shouldUpdateRedis(fromDB, err) {
  900. gopool.Go(func() {
  901. if err := updateUserNameCache(id, username); err != nil {
  902. common.SysLog("failed to update user name cache: " + err.Error())
  903. }
  904. })
  905. }
  906. }()
  907. if !fromDB && common.RedisEnabled {
  908. username, err := getUserNameCache(id)
  909. if err == nil {
  910. return username, nil
  911. }
  912. // Don't return error - fall through to DB
  913. }
  914. fromDB = true
  915. err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
  916. if err != nil {
  917. return "", err
  918. }
  919. return username, nil
  920. }
  921. func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
  922. var user User
  923. err := DB.Unscoped().Where("linux_do_id = ?", linuxDOId).First(&user).Error
  924. return !errors.Is(err, gorm.ErrRecordNotFound)
  925. }
  926. func (user *User) FillUserByLinuxDOId() error {
  927. if user.LinuxDOId == "" {
  928. return errors.New("linux do id is empty")
  929. }
  930. err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
  931. return err
  932. }
  933. func RootUserExists() bool {
  934. var user User
  935. err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
  936. if err != nil {
  937. return false
  938. }
  939. return true
  940. }