user.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050
  1. package model
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "strconv"
  8. "strings"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/bytedance/gopkg/util/gopool"
  13. "gorm.io/gorm"
  14. )
  15. const UserNameMaxLength = 20
  16. var (
  17. ErrDatabase = errors.New("database error")
  18. ErrInvalidCredentials = errors.New("invalid credentials")
  19. ErrUserEmptyCredentials = errors.New("empty credentials")
  20. )
  21. // User if you add sensitive fields, don't forget to clean them in setupLogin function.
  22. // Otherwise, the sensitive information will be saved on local storage in plain text!
  23. type User struct {
  24. Id int `json:"id"`
  25. Username string `json:"username" gorm:"unique;index" validate:"max=20"`
  26. Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
  27. OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database!
  28. DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
  29. Role int `json:"role" gorm:"type:int;default:1"` // admin, common
  30. Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
  31. Email string `json:"email" gorm:"index" validate:"max=50"`
  32. GitHubId string `json:"github_id" gorm:"column:github_id;index"`
  33. DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
  34. OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
  35. WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
  36. TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
  37. VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
  38. AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
  39. Quota int `json:"quota" gorm:"type:int;default:0"`
  40. UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
  41. RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
  42. Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
  43. AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
  44. AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
  45. AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
  46. AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
  47. InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
  48. DeletedAt gorm.DeletedAt `gorm:"index"`
  49. LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
  50. Setting string `json:"setting" gorm:"type:text;column:setting"`
  51. Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
  52. StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
  53. }
  54. func (user *User) ToBaseUser() *UserBase {
  55. cache := &UserBase{
  56. Id: user.Id,
  57. Group: user.Group,
  58. Quota: user.Quota,
  59. Status: user.Status,
  60. Username: user.Username,
  61. Setting: user.Setting,
  62. Email: user.Email,
  63. }
  64. return cache
  65. }
  66. func (user *User) GetAccessToken() string {
  67. if user.AccessToken == nil {
  68. return ""
  69. }
  70. return *user.AccessToken
  71. }
  72. func (user *User) SetAccessToken(token string) {
  73. user.AccessToken = &token
  74. }
  75. func (user *User) GetSetting() dto.UserSetting {
  76. setting := dto.UserSetting{}
  77. if user.Setting != "" {
  78. err := json.Unmarshal([]byte(user.Setting), &setting)
  79. if err != nil {
  80. common.SysLog("failed to unmarshal setting: " + err.Error())
  81. }
  82. }
  83. return setting
  84. }
  85. func (user *User) SetSetting(setting dto.UserSetting) {
  86. settingBytes, err := json.Marshal(setting)
  87. if err != nil {
  88. common.SysLog("failed to marshal setting: " + err.Error())
  89. return
  90. }
  91. user.Setting = string(settingBytes)
  92. }
  93. // 根据用户角色生成默认的边栏配置
  94. func generateDefaultSidebarConfigForRole(userRole int) string {
  95. defaultConfig := map[string]interface{}{}
  96. // 聊天区域 - 所有用户都可以访问
  97. defaultConfig["chat"] = map[string]interface{}{
  98. "enabled": true,
  99. "playground": true,
  100. "chat": true,
  101. }
  102. // 控制台区域 - 所有用户都可以访问
  103. defaultConfig["console"] = map[string]interface{}{
  104. "enabled": true,
  105. "detail": true,
  106. "token": true,
  107. "log": true,
  108. "midjourney": true,
  109. "task": true,
  110. }
  111. // 个人中心区域 - 所有用户都可以访问
  112. defaultConfig["personal"] = map[string]interface{}{
  113. "enabled": true,
  114. "topup": true,
  115. "personal": true,
  116. }
  117. // 管理员区域 - 根据角色决定
  118. if userRole == common.RoleAdminUser {
  119. // 管理员可以访问管理员区域,但不能访问系统设置
  120. defaultConfig["admin"] = map[string]interface{}{
  121. "enabled": true,
  122. "channel": true,
  123. "models": true,
  124. "redemption": true,
  125. "user": true,
  126. "setting": false, // 管理员不能访问系统设置
  127. }
  128. } else if userRole == common.RoleRootUser {
  129. // 超级管理员可以访问所有功能
  130. defaultConfig["admin"] = map[string]interface{}{
  131. "enabled": true,
  132. "channel": true,
  133. "models": true,
  134. "redemption": true,
  135. "user": true,
  136. "setting": true,
  137. }
  138. }
  139. // 普通用户不包含admin区域
  140. // 转换为JSON字符串
  141. configBytes, err := json.Marshal(defaultConfig)
  142. if err != nil {
  143. common.SysLog("生成默认边栏配置失败: " + err.Error())
  144. return ""
  145. }
  146. return string(configBytes)
  147. }
  148. // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
  149. func CheckUserExistOrDeleted(username string, email string) (bool, error) {
  150. var user User
  151. // err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
  152. // check email if empty
  153. var err error
  154. if email == "" {
  155. err = DB.Unscoped().First(&user, "username = ?", username).Error
  156. } else {
  157. err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
  158. }
  159. if err != nil {
  160. if errors.Is(err, gorm.ErrRecordNotFound) {
  161. // not exist, return false, nil
  162. return false, nil
  163. }
  164. // other error, return false, err
  165. return false, err
  166. }
  167. // exist, return true, nil
  168. return true, nil
  169. }
  170. func GetMaxUserId() int {
  171. var user User
  172. DB.Unscoped().Last(&user)
  173. return user.Id
  174. }
  175. func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
  176. // Start transaction
  177. tx := DB.Begin()
  178. if tx.Error != nil {
  179. return nil, 0, tx.Error
  180. }
  181. defer func() {
  182. if r := recover(); r != nil {
  183. tx.Rollback()
  184. }
  185. }()
  186. // Get total count within transaction
  187. err = tx.Unscoped().Model(&User{}).Count(&total).Error
  188. if err != nil {
  189. tx.Rollback()
  190. return nil, 0, err
  191. }
  192. // Get paginated users within same transaction
  193. err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
  194. if err != nil {
  195. tx.Rollback()
  196. return nil, 0, err
  197. }
  198. // Commit transaction
  199. if err = tx.Commit().Error; err != nil {
  200. return nil, 0, err
  201. }
  202. return users, total, nil
  203. }
  204. func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
  205. var users []*User
  206. var total int64
  207. var err error
  208. // 开始事务
  209. tx := DB.Begin()
  210. if tx.Error != nil {
  211. return nil, 0, tx.Error
  212. }
  213. defer func() {
  214. if r := recover(); r != nil {
  215. tx.Rollback()
  216. }
  217. }()
  218. // 构建基础查询
  219. query := tx.Unscoped().Model(&User{})
  220. // 构建搜索条件
  221. likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
  222. // 尝试将关键字转换为整数ID
  223. keywordInt, err := strconv.Atoi(keyword)
  224. if err == nil {
  225. // 如果是数字,同时搜索ID和其他字段
  226. likeCondition = "id = ? OR " + likeCondition
  227. if group != "" {
  228. query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
  229. keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
  230. } else {
  231. query = query.Where(likeCondition,
  232. keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
  233. }
  234. } else {
  235. // 非数字关键字,只搜索字符串字段
  236. if group != "" {
  237. query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
  238. "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
  239. } else {
  240. query = query.Where(likeCondition,
  241. "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
  242. }
  243. }
  244. // 获取总数
  245. err = query.Count(&total).Error
  246. if err != nil {
  247. tx.Rollback()
  248. return nil, 0, err
  249. }
  250. // 获取分页数据
  251. err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error
  252. if err != nil {
  253. tx.Rollback()
  254. return nil, 0, err
  255. }
  256. // 提交事务
  257. if err = tx.Commit().Error; err != nil {
  258. return nil, 0, err
  259. }
  260. return users, total, nil
  261. }
  262. func GetUserById(id int, selectAll bool) (*User, error) {
  263. if id == 0 {
  264. return nil, errors.New("id 为空!")
  265. }
  266. user := User{Id: id}
  267. var err error = nil
  268. if selectAll {
  269. err = DB.First(&user, "id = ?", id).Error
  270. } else {
  271. err = DB.Omit("password").First(&user, "id = ?", id).Error
  272. }
  273. return &user, err
  274. }
  275. func GetUserIdByAffCode(affCode string) (int, error) {
  276. if affCode == "" {
  277. return 0, errors.New("affCode 为空!")
  278. }
  279. var user User
  280. err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
  281. return user.Id, err
  282. }
  283. func DeleteUserById(id int) (err error) {
  284. if id == 0 {
  285. return errors.New("id 为空!")
  286. }
  287. user := User{Id: id}
  288. return user.Delete()
  289. }
  290. func HardDeleteUserById(id int) error {
  291. if id == 0 {
  292. return errors.New("id 为空!")
  293. }
  294. err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
  295. return err
  296. }
  297. func inviteUser(inviterId int) (err error) {
  298. user, err := GetUserById(inviterId, true)
  299. if err != nil {
  300. return err
  301. }
  302. user.AffCount++
  303. user.AffQuota += common.QuotaForInviter
  304. user.AffHistoryQuota += common.QuotaForInviter
  305. return DB.Save(user).Error
  306. }
  307. func (user *User) TransferAffQuotaToQuota(quota int) error {
  308. // 检查quota是否小于最小额度
  309. if float64(quota) < common.QuotaPerUnit {
  310. return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit)))
  311. }
  312. // 开始数据库事务
  313. tx := DB.Begin()
  314. if tx.Error != nil {
  315. return tx.Error
  316. }
  317. defer tx.Rollback() // 确保在函数退出时事务能回滚
  318. // 加锁查询用户以确保数据一致性
  319. err := tx.Set("gorm:query_option", "FOR UPDATE").First(&user, user.Id).Error
  320. if err != nil {
  321. return err
  322. }
  323. // 再次检查用户的AffQuota是否足够
  324. if user.AffQuota < quota {
  325. return errors.New("邀请额度不足!")
  326. }
  327. // 更新用户额度
  328. user.AffQuota -= quota
  329. user.Quota += quota
  330. // 保存用户状态
  331. if err := tx.Save(user).Error; err != nil {
  332. return err
  333. }
  334. // 提交事务
  335. return tx.Commit().Error
  336. }
  337. func (user *User) Insert(inviterId int) error {
  338. var err error
  339. if user.Password != "" {
  340. user.Password, err = common.Password2Hash(user.Password)
  341. if err != nil {
  342. return err
  343. }
  344. }
  345. user.Quota = common.QuotaForNewUser
  346. //user.SetAccessToken(common.GetUUID())
  347. user.AffCode = common.GetRandomString(4)
  348. // 初始化用户设置,包括默认的边栏配置
  349. if user.Setting == "" {
  350. defaultSetting := dto.UserSetting{}
  351. // 这里暂时不设置SidebarModules,因为需要在用户创建后根据角色设置
  352. user.SetSetting(defaultSetting)
  353. }
  354. result := DB.Create(user)
  355. if result.Error != nil {
  356. return result.Error
  357. }
  358. // 用户创建成功后,根据角色初始化边栏配置
  359. // 需要重新获取用户以确保有正确的ID和Role
  360. var createdUser User
  361. if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil {
  362. // 生成基于角色的默认边栏配置
  363. defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
  364. if defaultSidebarConfig != "" {
  365. currentSetting := createdUser.GetSetting()
  366. currentSetting.SidebarModules = defaultSidebarConfig
  367. createdUser.SetSetting(currentSetting)
  368. createdUser.Update(false)
  369. common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
  370. }
  371. }
  372. if common.QuotaForNewUser > 0 {
  373. RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
  374. }
  375. if inviterId != 0 {
  376. if common.QuotaForInvitee > 0 {
  377. _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
  378. RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
  379. }
  380. if common.QuotaForInviter > 0 {
  381. //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
  382. RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
  383. _ = inviteUser(inviterId)
  384. }
  385. }
  386. return nil
  387. }
  388. // InsertWithTx inserts a new user within an existing transaction.
  389. // This is used for OAuth registration where user creation and binding need to be atomic.
  390. // Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits.
  391. func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error {
  392. var err error
  393. if user.Password != "" {
  394. user.Password, err = common.Password2Hash(user.Password)
  395. if err != nil {
  396. return err
  397. }
  398. }
  399. user.Quota = common.QuotaForNewUser
  400. user.AffCode = common.GetRandomString(4)
  401. // 初始化用户设置
  402. if user.Setting == "" {
  403. defaultSetting := dto.UserSetting{}
  404. user.SetSetting(defaultSetting)
  405. }
  406. result := tx.Create(user)
  407. if result.Error != nil {
  408. return result.Error
  409. }
  410. return nil
  411. }
  412. // FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation.
  413. // This should be called after the transaction commits successfully.
  414. func (user *User) FinalizeOAuthUserCreation(inviterId int) {
  415. // 用户创建成功后,根据角色初始化边栏配置
  416. var createdUser User
  417. if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil {
  418. defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
  419. if defaultSidebarConfig != "" {
  420. currentSetting := createdUser.GetSetting()
  421. currentSetting.SidebarModules = defaultSidebarConfig
  422. createdUser.SetSetting(currentSetting)
  423. createdUser.Update(false)
  424. common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
  425. }
  426. }
  427. if common.QuotaForNewUser > 0 {
  428. RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
  429. }
  430. if inviterId != 0 {
  431. if common.QuotaForInvitee > 0 {
  432. _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
  433. RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
  434. }
  435. if common.QuotaForInviter > 0 {
  436. RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
  437. _ = inviteUser(inviterId)
  438. }
  439. }
  440. }
  441. func (user *User) Update(updatePassword bool) error {
  442. var err error
  443. if updatePassword {
  444. user.Password, err = common.Password2Hash(user.Password)
  445. if err != nil {
  446. return err
  447. }
  448. }
  449. newUser := *user
  450. DB.First(&user, user.Id)
  451. if err = DB.Model(user).Updates(newUser).Error; err != nil {
  452. return err
  453. }
  454. // Update cache
  455. return updateUserCache(*user)
  456. }
  457. func (user *User) Edit(updatePassword bool) error {
  458. var err error
  459. if updatePassword {
  460. user.Password, err = common.Password2Hash(user.Password)
  461. if err != nil {
  462. return err
  463. }
  464. }
  465. newUser := *user
  466. updates := map[string]interface{}{
  467. "username": newUser.Username,
  468. "display_name": newUser.DisplayName,
  469. "group": newUser.Group,
  470. "remark": newUser.Remark,
  471. }
  472. if updatePassword {
  473. updates["password"] = newUser.Password
  474. }
  475. DB.First(&user, user.Id)
  476. if err = DB.Model(user).Updates(updates).Error; err != nil {
  477. return err
  478. }
  479. // Update cache
  480. return updateUserCache(*user)
  481. }
  482. func (user *User) ClearBinding(bindingType string) error {
  483. if user.Id == 0 {
  484. return errors.New("user id is empty")
  485. }
  486. bindingColumnMap := map[string]string{
  487. "email": "email",
  488. "github": "github_id",
  489. "discord": "discord_id",
  490. "oidc": "oidc_id",
  491. "wechat": "wechat_id",
  492. "telegram": "telegram_id",
  493. "linuxdo": "linux_do_id",
  494. }
  495. column, ok := bindingColumnMap[bindingType]
  496. if !ok {
  497. return errors.New("invalid binding type")
  498. }
  499. if err := DB.Model(&User{}).Where("id = ?", user.Id).Update(column, "").Error; err != nil {
  500. return err
  501. }
  502. if err := DB.Where("id = ?", user.Id).First(user).Error; err != nil {
  503. return err
  504. }
  505. return updateUserCache(*user)
  506. }
  507. func (user *User) Delete() error {
  508. if user.Id == 0 {
  509. return errors.New("id 为空!")
  510. }
  511. if err := DB.Delete(user).Error; err != nil {
  512. return err
  513. }
  514. // 清除缓存
  515. return invalidateUserCache(user.Id)
  516. }
  517. func (user *User) HardDelete() error {
  518. if user.Id == 0 {
  519. return errors.New("id 为空!")
  520. }
  521. err := DB.Unscoped().Delete(user).Error
  522. return err
  523. }
  524. // ValidateAndFill check password & user status
  525. func (user *User) ValidateAndFill() (err error) {
  526. // When querying with struct, GORM will only query with non-zero fields,
  527. // that means if your field's value is 0, '', false or other zero values,
  528. // it won't be used to build query conditions
  529. password := user.Password
  530. username := strings.TrimSpace(user.Username)
  531. if username == "" || password == "" {
  532. return ErrUserEmptyCredentials
  533. }
  534. // find by username or email
  535. err = DB.Where("username = ? OR email = ?", username, username).First(user).Error
  536. if err != nil {
  537. if errors.Is(err, gorm.ErrRecordNotFound) {
  538. return ErrInvalidCredentials
  539. }
  540. return fmt.Errorf("%w: %v", ErrDatabase, err)
  541. }
  542. okay := common.ValidatePasswordAndHash(password, user.Password)
  543. if !okay || user.Status != common.UserStatusEnabled {
  544. return ErrInvalidCredentials
  545. }
  546. return nil
  547. }
  548. func (user *User) FillUserById() error {
  549. if user.Id == 0 {
  550. return errors.New("id 为空!")
  551. }
  552. DB.Where(User{Id: user.Id}).First(user)
  553. return nil
  554. }
  555. func (user *User) FillUserByEmail() error {
  556. if user.Email == "" {
  557. return errors.New("email 为空!")
  558. }
  559. DB.Where(User{Email: user.Email}).First(user)
  560. return nil
  561. }
  562. func (user *User) FillUserByGitHubId() error {
  563. if user.GitHubId == "" {
  564. return errors.New("GitHub id 为空!")
  565. }
  566. DB.Where(User{GitHubId: user.GitHubId}).First(user)
  567. return nil
  568. }
  569. // UpdateGitHubId updates the user's GitHub ID (used for migration from login to numeric ID)
  570. func (user *User) UpdateGitHubId(newGitHubId string) error {
  571. if user.Id == 0 {
  572. return errors.New("user id is empty")
  573. }
  574. return DB.Model(user).Update("github_id", newGitHubId).Error
  575. }
  576. func (user *User) FillUserByDiscordId() error {
  577. if user.DiscordId == "" {
  578. return errors.New("discord id 为空!")
  579. }
  580. DB.Where(User{DiscordId: user.DiscordId}).First(user)
  581. return nil
  582. }
  583. func (user *User) FillUserByOidcId() error {
  584. if user.OidcId == "" {
  585. return errors.New("oidc id 为空!")
  586. }
  587. DB.Where(User{OidcId: user.OidcId}).First(user)
  588. return nil
  589. }
  590. func (user *User) FillUserByWeChatId() error {
  591. if user.WeChatId == "" {
  592. return errors.New("WeChat id 为空!")
  593. }
  594. DB.Where(User{WeChatId: user.WeChatId}).First(user)
  595. return nil
  596. }
  597. func (user *User) FillUserByTelegramId() error {
  598. if user.TelegramId == "" {
  599. return errors.New("Telegram id 为空!")
  600. }
  601. err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error
  602. if errors.Is(err, gorm.ErrRecordNotFound) {
  603. return errors.New("该 Telegram 账户未绑定")
  604. }
  605. return nil
  606. }
  607. func IsEmailAlreadyTaken(email string) bool {
  608. return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1
  609. }
  610. func IsWeChatIdAlreadyTaken(wechatId string) bool {
  611. return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
  612. }
  613. func IsGitHubIdAlreadyTaken(githubId string) bool {
  614. return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
  615. }
  616. func IsDiscordIdAlreadyTaken(discordId string) bool {
  617. return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
  618. }
  619. func IsOidcIdAlreadyTaken(oidcId string) bool {
  620. return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
  621. }
  622. func IsTelegramIdAlreadyTaken(telegramId string) bool {
  623. return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
  624. }
  625. func ResetUserPasswordByEmail(email string, password string) error {
  626. if email == "" || password == "" {
  627. return errors.New("邮箱地址或密码为空!")
  628. }
  629. hashedPassword, err := common.Password2Hash(password)
  630. if err != nil {
  631. return err
  632. }
  633. err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error
  634. return err
  635. }
  636. func IsAdmin(userId int) bool {
  637. if userId == 0 {
  638. return false
  639. }
  640. var user User
  641. err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
  642. if err != nil {
  643. common.SysLog("no such user " + err.Error())
  644. return false
  645. }
  646. return user.Role >= common.RoleAdminUser
  647. }
  648. //// IsUserEnabled checks user status from Redis first, falls back to DB if needed
  649. //func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
  650. // defer func() {
  651. // // Update Redis cache asynchronously on successful DB read
  652. // if shouldUpdateRedis(fromDB, err) {
  653. // gopool.Go(func() {
  654. // if err := updateUserStatusCache(id, status); err != nil {
  655. // common.SysError("failed to update user status cache: " + err.Error())
  656. // }
  657. // })
  658. // }
  659. // }()
  660. // if !fromDB && common.RedisEnabled {
  661. // // Try Redis first
  662. // status, err := getUserStatusCache(id)
  663. // if err == nil {
  664. // return status == common.UserStatusEnabled, nil
  665. // }
  666. // // Don't return error - fall through to DB
  667. // }
  668. // fromDB = true
  669. // var user User
  670. // err = DB.Where("id = ?", id).Select("status").Find(&user).Error
  671. // if err != nil {
  672. // return false, err
  673. // }
  674. //
  675. // return user.Status == common.UserStatusEnabled, nil
  676. //}
  677. func ValidateAccessToken(token string) (user *User) {
  678. if token == "" {
  679. return nil
  680. }
  681. token = strings.Replace(token, "Bearer ", "", 1)
  682. user = &User{}
  683. if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 {
  684. return user
  685. }
  686. return nil
  687. }
  688. // GetUserQuota gets quota from Redis first, falls back to DB if needed
  689. func GetUserQuota(id int, fromDB bool) (quota int, err error) {
  690. defer func() {
  691. // Update Redis cache asynchronously on successful DB read
  692. if shouldUpdateRedis(fromDB, err) {
  693. gopool.Go(func() {
  694. if err := updateUserQuotaCache(id, quota); err != nil {
  695. common.SysLog("failed to update user quota cache: " + err.Error())
  696. }
  697. })
  698. }
  699. }()
  700. if !fromDB && common.RedisEnabled {
  701. quota, err := getUserQuotaCache(id)
  702. if err == nil {
  703. return quota, nil
  704. }
  705. // Don't return error - fall through to DB
  706. }
  707. fromDB = true
  708. err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
  709. if err != nil {
  710. return 0, err
  711. }
  712. return quota, nil
  713. }
  714. func GetUserUsedQuota(id int) (quota int, err error) {
  715. err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error
  716. return quota, err
  717. }
  718. func GetUserEmail(id int) (email string, err error) {
  719. err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
  720. return email, err
  721. }
  722. // GetUserGroup gets group from Redis first, falls back to DB if needed
  723. func GetUserGroup(id int, fromDB bool) (group string, err error) {
  724. defer func() {
  725. // Update Redis cache asynchronously on successful DB read
  726. if shouldUpdateRedis(fromDB, err) {
  727. gopool.Go(func() {
  728. if err := updateUserGroupCache(id, group); err != nil {
  729. common.SysLog("failed to update user group cache: " + err.Error())
  730. }
  731. })
  732. }
  733. }()
  734. if !fromDB && common.RedisEnabled {
  735. group, err := getUserGroupCache(id)
  736. if err == nil {
  737. return group, nil
  738. }
  739. // Don't return error - fall through to DB
  740. }
  741. fromDB = true
  742. err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
  743. if err != nil {
  744. return "", err
  745. }
  746. return group, nil
  747. }
  748. // GetUserSetting gets setting from Redis first, falls back to DB if needed
  749. func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
  750. var setting string
  751. defer func() {
  752. // Update Redis cache asynchronously on successful DB read
  753. if shouldUpdateRedis(fromDB, err) {
  754. gopool.Go(func() {
  755. if err := updateUserSettingCache(id, setting); err != nil {
  756. common.SysLog("failed to update user setting cache: " + err.Error())
  757. }
  758. })
  759. }
  760. }()
  761. if !fromDB && common.RedisEnabled {
  762. setting, err := getUserSettingCache(id)
  763. if err == nil {
  764. return setting, nil
  765. }
  766. // Don't return error - fall through to DB
  767. }
  768. fromDB = true
  769. // can be nil setting
  770. var safeSetting sql.NullString
  771. err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&safeSetting).Error
  772. if err != nil {
  773. return settingMap, err
  774. }
  775. if safeSetting.Valid {
  776. setting = safeSetting.String
  777. } else {
  778. setting = ""
  779. }
  780. userBase := &UserBase{
  781. Setting: setting,
  782. }
  783. return userBase.GetSetting(), nil
  784. }
  785. func IncreaseUserQuota(id int, quota int, db bool) (err error) {
  786. if quota < 0 {
  787. return errors.New("quota 不能为负数!")
  788. }
  789. gopool.Go(func() {
  790. err := cacheIncrUserQuota(id, int64(quota))
  791. if err != nil {
  792. common.SysLog("failed to increase user quota: " + err.Error())
  793. }
  794. })
  795. if !db && common.BatchUpdateEnabled {
  796. addNewRecord(BatchUpdateTypeUserQuota, id, quota)
  797. return nil
  798. }
  799. return increaseUserQuota(id, quota)
  800. }
  801. func increaseUserQuota(id int, quota int) (err error) {
  802. err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
  803. if err != nil {
  804. return err
  805. }
  806. return err
  807. }
  808. func DecreaseUserQuota(id int, quota int, db bool) (err error) {
  809. if quota < 0 {
  810. return errors.New("quota 不能为负数!")
  811. }
  812. gopool.Go(func() {
  813. err := cacheDecrUserQuota(id, int64(quota))
  814. if err != nil {
  815. common.SysLog("failed to decrease user quota: " + err.Error())
  816. }
  817. })
  818. if !db && common.BatchUpdateEnabled {
  819. addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
  820. return nil
  821. }
  822. return decreaseUserQuota(id, quota)
  823. }
  824. func decreaseUserQuota(id int, quota int) (err error) {
  825. err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
  826. if err != nil {
  827. return err
  828. }
  829. return err
  830. }
  831. func DeltaUpdateUserQuota(id int, delta int) (err error) {
  832. if delta == 0 {
  833. return nil
  834. }
  835. if delta > 0 {
  836. return IncreaseUserQuota(id, delta, false)
  837. } else {
  838. return DecreaseUserQuota(id, -delta, false)
  839. }
  840. }
  841. //func GetRootUserEmail() (email string) {
  842. // DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
  843. // return email
  844. //}
  845. func GetRootUser() (user *User) {
  846. DB.Where("role = ?", common.RoleRootUser).First(&user)
  847. return user
  848. }
  849. func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
  850. if common.BatchUpdateEnabled {
  851. addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
  852. addNewRecord(BatchUpdateTypeRequestCount, id, 1)
  853. return
  854. }
  855. updateUserUsedQuotaAndRequestCount(id, quota, 1)
  856. }
  857. func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
  858. err := DB.Model(&User{}).Where("id = ?", id).Updates(
  859. map[string]interface{}{
  860. "used_quota": gorm.Expr("used_quota + ?", quota),
  861. "request_count": gorm.Expr("request_count + ?", count),
  862. },
  863. ).Error
  864. if err != nil {
  865. common.SysLog("failed to update user used quota and request count: " + err.Error())
  866. return
  867. }
  868. //// 更新缓存
  869. //if err := invalidateUserCache(id); err != nil {
  870. // common.SysError("failed to invalidate user cache: " + err.Error())
  871. //}
  872. }
  873. func updateUserUsedQuota(id int, quota int) {
  874. err := DB.Model(&User{}).Where("id = ?", id).Updates(
  875. map[string]interface{}{
  876. "used_quota": gorm.Expr("used_quota + ?", quota),
  877. },
  878. ).Error
  879. if err != nil {
  880. common.SysLog("failed to update user used quota: " + err.Error())
  881. }
  882. }
  883. func updateUserRequestCount(id int, count int) {
  884. err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
  885. if err != nil {
  886. common.SysLog("failed to update user request count: " + err.Error())
  887. }
  888. }
  889. // GetUsernameById gets username from Redis first, falls back to DB if needed
  890. func GetUsernameById(id int, fromDB bool) (username string, err error) {
  891. defer func() {
  892. // Update Redis cache asynchronously on successful DB read
  893. if shouldUpdateRedis(fromDB, err) {
  894. gopool.Go(func() {
  895. if err := updateUserNameCache(id, username); err != nil {
  896. common.SysLog("failed to update user name cache: " + err.Error())
  897. }
  898. })
  899. }
  900. }()
  901. if !fromDB && common.RedisEnabled {
  902. username, err := getUserNameCache(id)
  903. if err == nil {
  904. return username, nil
  905. }
  906. // Don't return error - fall through to DB
  907. }
  908. fromDB = true
  909. err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
  910. if err != nil {
  911. return "", err
  912. }
  913. return username, nil
  914. }
  915. func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
  916. var user User
  917. err := DB.Unscoped().Where("linux_do_id = ?", linuxDOId).First(&user).Error
  918. return !errors.Is(err, gorm.ErrRecordNotFound)
  919. }
  920. func (user *User) FillUserByLinuxDOId() error {
  921. if user.LinuxDOId == "" {
  922. return errors.New("linux do id is empty")
  923. }
  924. err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
  925. return err
  926. }
  927. func RootUserExists() bool {
  928. var user User
  929. err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
  930. if err != nil {
  931. return false
  932. }
  933. return true
  934. }