auth.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. package middleware
  2. import (
  3. "github.com/gin-contrib/sessions"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strings"
  9. )
  10. func authHelper(c *gin.Context, minRole int) {
  11. session := sessions.Default(c)
  12. username := session.Get("username")
  13. role := session.Get("role")
  14. id := session.Get("id")
  15. status := session.Get("status")
  16. if username == nil {
  17. c.JSON(http.StatusOK, gin.H{
  18. "success": false,
  19. "message": "无权进行此操作,未登录",
  20. })
  21. c.Abort()
  22. return
  23. }
  24. if status.(int) == common.UserStatusDisabled {
  25. c.JSON(http.StatusOK, gin.H{
  26. "success": false,
  27. "message": "用户已被封禁",
  28. })
  29. c.Abort()
  30. return
  31. }
  32. if role.(int) < minRole {
  33. c.JSON(http.StatusOK, gin.H{
  34. "success": false,
  35. "message": "无权进行此操作,权限不足",
  36. })
  37. c.Abort()
  38. return
  39. }
  40. c.Set("username", username)
  41. c.Set("role", role)
  42. c.Set("id", id)
  43. c.Next()
  44. }
  45. func UserAuth() func(c *gin.Context) {
  46. return func(c *gin.Context) {
  47. authHelper(c, common.RoleCommonUser)
  48. }
  49. }
  50. func AdminAuth() func(c *gin.Context) {
  51. return func(c *gin.Context) {
  52. authHelper(c, common.RoleAdminUser)
  53. }
  54. }
  55. func RootAuth() func(c *gin.Context) {
  56. return func(c *gin.Context) {
  57. authHelper(c, common.RoleRootUser)
  58. }
  59. }
  60. func TokenAuth() func(c *gin.Context) {
  61. return func(c *gin.Context) {
  62. key := c.Request.Header.Get("Authorization")
  63. parts := strings.Split(key, "-")
  64. key = parts[0]
  65. token, err := model.ValidateUserToken(key)
  66. if err != nil {
  67. c.JSON(http.StatusOK, gin.H{
  68. "error": gin.H{
  69. "message": err.Error(),
  70. "type": "one_api_error",
  71. },
  72. })
  73. c.Abort()
  74. return
  75. }
  76. c.Set("id", token.UserId)
  77. if len(parts) > 1 {
  78. c.Set("channelId", parts[1])
  79. }
  80. c.Next()
  81. }
  82. }