auth.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. package middleware
  2. import (
  3. "github.com/gin-contrib/sessions"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strings"
  9. )
  10. func authHelper(c *gin.Context, minRole int) {
  11. session := sessions.Default(c)
  12. username := session.Get("username")
  13. role := session.Get("role")
  14. id := session.Get("id")
  15. status := session.Get("status")
  16. if username == nil {
  17. // Check access token
  18. accessToken := c.Request.Header.Get("Authorization")
  19. if accessToken == "" {
  20. c.JSON(http.StatusUnauthorized, gin.H{
  21. "success": false,
  22. "message": "无权进行此操作,未登录且未提供 access token",
  23. })
  24. c.Abort()
  25. return
  26. }
  27. user := model.ValidateAccessToken(accessToken)
  28. if user != nil && user.Username != "" {
  29. // Token is valid
  30. username = user.Username
  31. role = user.Role
  32. id = user.Id
  33. status = user.Status
  34. } else {
  35. c.JSON(http.StatusOK, gin.H{
  36. "success": false,
  37. "message": "无权进行此操作,access token 无效",
  38. })
  39. c.Abort()
  40. return
  41. }
  42. }
  43. if status.(int) == common.UserStatusDisabled {
  44. c.JSON(http.StatusOK, gin.H{
  45. "success": false,
  46. "message": "用户已被封禁",
  47. })
  48. c.Abort()
  49. return
  50. }
  51. if role.(int) < minRole {
  52. c.JSON(http.StatusOK, gin.H{
  53. "success": false,
  54. "message": "无权进行此操作,权限不足",
  55. })
  56. c.Abort()
  57. return
  58. }
  59. c.Set("username", username)
  60. c.Set("role", role)
  61. c.Set("id", id)
  62. c.Next()
  63. }
  64. func UserAuth() func(c *gin.Context) {
  65. return func(c *gin.Context) {
  66. authHelper(c, common.RoleCommonUser)
  67. }
  68. }
  69. func AdminAuth() func(c *gin.Context) {
  70. return func(c *gin.Context) {
  71. authHelper(c, common.RoleAdminUser)
  72. }
  73. }
  74. func RootAuth() func(c *gin.Context) {
  75. return func(c *gin.Context) {
  76. authHelper(c, common.RoleRootUser)
  77. }
  78. }
  79. func TokenAuth() func(c *gin.Context) {
  80. return func(c *gin.Context) {
  81. key := c.Request.Header.Get("Authorization")
  82. parts := make([]string, 0)
  83. if key == "" {
  84. key = c.Request.Header.Get("mj-api-secret")
  85. key = strings.TrimPrefix(key, "Bearer ")
  86. key = strings.TrimPrefix(key, "sk-")
  87. parts := strings.Split(key, "-")
  88. key = parts[0]
  89. } else {
  90. key = strings.TrimPrefix(key, "Bearer ")
  91. key = strings.TrimPrefix(key, "sk-")
  92. parts := strings.Split(key, "-")
  93. key = parts[0]
  94. }
  95. token, err := model.ValidateUserToken(key)
  96. if err != nil {
  97. c.JSON(http.StatusUnauthorized, gin.H{
  98. "error": gin.H{
  99. "message": err.Error(),
  100. "type": "one_api_error",
  101. },
  102. })
  103. c.Abort()
  104. return
  105. }
  106. userEnabled, err := model.IsUserEnabled(token.UserId)
  107. if err != nil {
  108. c.JSON(http.StatusInternalServerError, gin.H{
  109. "error": gin.H{
  110. "message": err.Error(),
  111. "type": "one_api_error",
  112. },
  113. })
  114. c.Abort()
  115. return
  116. }
  117. if !userEnabled {
  118. c.JSON(http.StatusForbidden, gin.H{
  119. "error": gin.H{
  120. "message": "用户已被封禁",
  121. "type": "one_api_error",
  122. },
  123. })
  124. c.Abort()
  125. return
  126. }
  127. c.Set("id", token.UserId)
  128. c.Set("token_id", token.Id)
  129. c.Set("token_name", token.Name)
  130. requestURL := c.Request.URL.String()
  131. consumeQuota := true
  132. if strings.HasPrefix(requestURL, "/v1/models") {
  133. consumeQuota = false
  134. }
  135. c.Set("consume_quota", consumeQuota)
  136. if len(parts) > 1 {
  137. if model.IsAdmin(token.UserId) {
  138. c.Set("channelId", parts[1])
  139. } else {
  140. c.JSON(http.StatusForbidden, gin.H{
  141. "error": gin.H{
  142. "message": "普通用户不支持指定渠道",
  143. "type": "one_api_error",
  144. },
  145. })
  146. c.Abort()
  147. return
  148. }
  149. }
  150. c.Next()
  151. }
  152. }