auth.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package middleware
  2. import (
  3. "github.com/gin-contrib/sessions"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strings"
  9. )
  10. func authHelper(c *gin.Context, minRole int) {
  11. session := sessions.Default(c)
  12. username := session.Get("username")
  13. role := session.Get("role")
  14. id := session.Get("id")
  15. status := session.Get("status")
  16. if username == nil {
  17. c.JSON(http.StatusUnauthorized, gin.H{
  18. "success": false,
  19. "message": "无权进行此操作,未登录",
  20. })
  21. c.Abort()
  22. return
  23. }
  24. if status.(int) == common.UserStatusDisabled {
  25. c.JSON(http.StatusOK, gin.H{
  26. "success": false,
  27. "message": "用户已被封禁",
  28. })
  29. c.Abort()
  30. return
  31. }
  32. if role.(int) < minRole {
  33. c.JSON(http.StatusOK, gin.H{
  34. "success": false,
  35. "message": "无权进行此操作,权限不足",
  36. })
  37. c.Abort()
  38. return
  39. }
  40. c.Set("username", username)
  41. c.Set("role", role)
  42. c.Set("id", id)
  43. c.Next()
  44. }
  45. func UserAuth() func(c *gin.Context) {
  46. return func(c *gin.Context) {
  47. authHelper(c, common.RoleCommonUser)
  48. }
  49. }
  50. func AdminAuth() func(c *gin.Context) {
  51. return func(c *gin.Context) {
  52. authHelper(c, common.RoleAdminUser)
  53. }
  54. }
  55. func RootAuth() func(c *gin.Context) {
  56. return func(c *gin.Context) {
  57. authHelper(c, common.RoleRootUser)
  58. }
  59. }
  60. func TokenAuth() func(c *gin.Context) {
  61. return func(c *gin.Context) {
  62. key := c.Request.Header.Get("Authorization")
  63. parts := strings.Split(key, "-")
  64. key = parts[0]
  65. token, err := model.ValidateUserToken(key)
  66. if err != nil {
  67. c.JSON(http.StatusOK, gin.H{
  68. "error": gin.H{
  69. "message": err.Error(),
  70. "type": "one_api_error",
  71. },
  72. })
  73. c.Abort()
  74. return
  75. }
  76. c.Set("id", token.UserId)
  77. c.Set("token_id", token.Id)
  78. c.Set("unlimited_times", token.UnlimitedTimes)
  79. if len(parts) > 1 {
  80. c.Set("channelId", parts[1])
  81. }
  82. c.Next()
  83. }
  84. }