auth.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. package middleware
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/constant"
  11. "github.com/QuantumNous/new-api/i18n"
  12. "github.com/QuantumNous/new-api/logger"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/service"
  15. "github.com/QuantumNous/new-api/setting/ratio_setting"
  16. "github.com/QuantumNous/new-api/types"
  17. "github.com/gin-contrib/sessions"
  18. "github.com/gin-gonic/gin"
  19. "gorm.io/gorm"
  20. )
  21. func validUserInfo(username string, role int) bool {
  22. // check username is empty
  23. if strings.TrimSpace(username) == "" {
  24. return false
  25. }
  26. if !common.IsValidateRole(role) {
  27. return false
  28. }
  29. return true
  30. }
  31. func authHelper(c *gin.Context, minRole int) {
  32. session := sessions.Default(c)
  33. username := session.Get("username")
  34. role := session.Get("role")
  35. id := session.Get("id")
  36. status := session.Get("status")
  37. useAccessToken := false
  38. if username == nil {
  39. // Check access token
  40. accessToken := c.Request.Header.Get("Authorization")
  41. if accessToken == "" {
  42. c.JSON(http.StatusUnauthorized, gin.H{
  43. "success": false,
  44. "message": common.TranslateMessage(c, i18n.MsgAuthNotLoggedIn),
  45. })
  46. c.Abort()
  47. return
  48. }
  49. user, authErr := model.ValidateAccessToken(accessToken)
  50. if authErr != nil {
  51. if errors.Is(authErr, model.ErrDatabase) {
  52. common.SysLog("ValidateAccessToken database error: " + authErr.Error())
  53. c.JSON(http.StatusInternalServerError, gin.H{
  54. "success": false,
  55. "message": common.TranslateMessage(c, i18n.MsgDatabaseError),
  56. })
  57. } else {
  58. c.JSON(http.StatusOK, gin.H{
  59. "success": false,
  60. "message": common.TranslateMessage(c, i18n.MsgAuthAccessTokenInvalid),
  61. })
  62. }
  63. c.Abort()
  64. return
  65. }
  66. if user != nil && user.Username != "" {
  67. if !validUserInfo(user.Username, user.Role) {
  68. c.JSON(http.StatusOK, gin.H{
  69. "success": false,
  70. "message": common.TranslateMessage(c, i18n.MsgAuthUserInfoInvalid),
  71. })
  72. c.Abort()
  73. return
  74. }
  75. // Token is valid
  76. username = user.Username
  77. role = user.Role
  78. id = user.Id
  79. status = user.Status
  80. useAccessToken = true
  81. } else {
  82. c.JSON(http.StatusOK, gin.H{
  83. "success": false,
  84. "message": common.TranslateMessage(c, i18n.MsgAuthAccessTokenInvalid),
  85. })
  86. c.Abort()
  87. return
  88. }
  89. }
  90. // get header New-Api-User
  91. apiUserIdStr := c.Request.Header.Get("New-Api-User")
  92. if apiUserIdStr == "" {
  93. c.JSON(http.StatusUnauthorized, gin.H{
  94. "success": false,
  95. "message": common.TranslateMessage(c, i18n.MsgAuthUserIdNotProvided),
  96. })
  97. c.Abort()
  98. return
  99. }
  100. apiUserId, err := strconv.Atoi(apiUserIdStr)
  101. if err != nil {
  102. c.JSON(http.StatusUnauthorized, gin.H{
  103. "success": false,
  104. "message": common.TranslateMessage(c, i18n.MsgAuthUserIdFormatError),
  105. })
  106. c.Abort()
  107. return
  108. }
  109. if id != apiUserId {
  110. c.JSON(http.StatusUnauthorized, gin.H{
  111. "success": false,
  112. "message": common.TranslateMessage(c, i18n.MsgAuthUserIdMismatch),
  113. })
  114. c.Abort()
  115. return
  116. }
  117. if status.(int) == common.UserStatusDisabled {
  118. c.JSON(http.StatusOK, gin.H{
  119. "success": false,
  120. "message": common.TranslateMessage(c, i18n.MsgAuthUserBanned),
  121. })
  122. c.Abort()
  123. return
  124. }
  125. if role.(int) < minRole {
  126. c.JSON(http.StatusOK, gin.H{
  127. "success": false,
  128. "message": common.TranslateMessage(c, i18n.MsgAuthInsufficientPrivilege),
  129. })
  130. c.Abort()
  131. return
  132. }
  133. if !validUserInfo(username.(string), role.(int)) {
  134. c.JSON(http.StatusOK, gin.H{
  135. "success": false,
  136. "message": common.TranslateMessage(c, i18n.MsgAuthUserInfoInvalid),
  137. })
  138. c.Abort()
  139. return
  140. }
  141. // 防止不同newapi版本冲突,导致数据不通用
  142. c.Header("Auth-Version", "864b7076dbcd0a3c01b5520316720ebf")
  143. c.Set("username", username)
  144. c.Set("role", role)
  145. c.Set("id", id)
  146. c.Set("group", session.Get("group"))
  147. c.Set("user_group", session.Get("group"))
  148. c.Set("use_access_token", useAccessToken)
  149. c.Next()
  150. }
  151. func TryUserAuth() func(c *gin.Context) {
  152. return func(c *gin.Context) {
  153. session := sessions.Default(c)
  154. id := session.Get("id")
  155. if id != nil {
  156. c.Set("id", id)
  157. }
  158. c.Next()
  159. }
  160. }
  161. func UserAuth() func(c *gin.Context) {
  162. return func(c *gin.Context) {
  163. authHelper(c, common.RoleCommonUser)
  164. }
  165. }
  166. func AdminAuth() func(c *gin.Context) {
  167. return func(c *gin.Context) {
  168. authHelper(c, common.RoleAdminUser)
  169. }
  170. }
  171. func RootAuth() func(c *gin.Context) {
  172. return func(c *gin.Context) {
  173. authHelper(c, common.RoleRootUser)
  174. }
  175. }
  176. func WssAuth(c *gin.Context) {
  177. }
  178. // TokenOrUserAuth allows either session-based user auth or API token auth.
  179. // Used for endpoints that need to be accessible from both the dashboard and API clients.
  180. func TokenOrUserAuth() func(c *gin.Context) {
  181. return func(c *gin.Context) {
  182. // Try session auth first (dashboard users)
  183. session := sessions.Default(c)
  184. if id := session.Get("id"); id != nil {
  185. if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled {
  186. c.Set("id", id)
  187. c.Next()
  188. return
  189. }
  190. }
  191. // Fall back to token auth (API clients)
  192. TokenAuth()(c)
  193. }
  194. }
  195. // TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
  196. // 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
  197. // 即使令牌已过期、已耗尽或已禁用,也允许访问。
  198. // 仍然检查用户是否被封禁。
  199. func TokenAuthReadOnly() func(c *gin.Context) {
  200. return func(c *gin.Context) {
  201. key := c.Request.Header.Get("Authorization")
  202. if key == "" {
  203. c.JSON(http.StatusUnauthorized, gin.H{
  204. "success": false,
  205. "message": common.TranslateMessage(c, i18n.MsgTokenNotProvided),
  206. })
  207. c.Abort()
  208. return
  209. }
  210. if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
  211. key = strings.TrimSpace(key[7:])
  212. }
  213. key = strings.TrimPrefix(key, "sk-")
  214. parts := strings.Split(key, "-")
  215. key = parts[0]
  216. token, err := model.GetTokenByKey(key, false)
  217. if err != nil {
  218. if errors.Is(err, gorm.ErrRecordNotFound) {
  219. c.JSON(http.StatusUnauthorized, gin.H{
  220. "success": false,
  221. "message": common.TranslateMessage(c, i18n.MsgTokenInvalid),
  222. })
  223. } else {
  224. common.SysLog("TokenAuthReadOnly GetTokenByKey database error: " + err.Error())
  225. c.JSON(http.StatusInternalServerError, gin.H{
  226. "success": false,
  227. "message": common.TranslateMessage(c, i18n.MsgDatabaseError),
  228. })
  229. }
  230. c.Abort()
  231. return
  232. }
  233. userCache, err := model.GetUserCache(token.UserId)
  234. if err != nil {
  235. common.SysLog(fmt.Sprintf("TokenAuthReadOnly GetUserCache error for user %d: %v", token.UserId, err))
  236. c.JSON(http.StatusInternalServerError, gin.H{
  237. "success": false,
  238. "message": common.TranslateMessage(c, i18n.MsgDatabaseError),
  239. })
  240. c.Abort()
  241. return
  242. }
  243. if userCache.Status != common.UserStatusEnabled {
  244. c.JSON(http.StatusForbidden, gin.H{
  245. "success": false,
  246. "message": common.TranslateMessage(c, i18n.MsgAuthUserBanned),
  247. })
  248. c.Abort()
  249. return
  250. }
  251. c.Set("id", token.UserId)
  252. c.Set("token_id", token.Id)
  253. c.Set("token_key", token.Key)
  254. c.Next()
  255. }
  256. }
  257. func TokenAuth() func(c *gin.Context) {
  258. return func(c *gin.Context) {
  259. // 先检测是否为ws
  260. if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
  261. // Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
  262. // read sk from Sec-WebSocket-Protocol
  263. key := c.Request.Header.Get("Sec-WebSocket-Protocol")
  264. parts := strings.Split(key, ",")
  265. for _, part := range parts {
  266. part = strings.TrimSpace(part)
  267. if strings.HasPrefix(part, "openai-insecure-api-key") {
  268. key = strings.TrimPrefix(part, "openai-insecure-api-key.")
  269. break
  270. }
  271. }
  272. c.Request.Header.Set("Authorization", "Bearer "+key)
  273. }
  274. // 检查path包含/v1/messages 或 /v1/models
  275. if strings.Contains(c.Request.URL.Path, "/v1/messages") || strings.Contains(c.Request.URL.Path, "/v1/models") {
  276. anthropicKey := c.Request.Header.Get("x-api-key")
  277. if anthropicKey != "" {
  278. c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
  279. }
  280. }
  281. // gemini api 从query中获取key
  282. if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
  283. strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
  284. strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
  285. skKey := c.Query("key")
  286. if skKey != "" {
  287. c.Request.Header.Set("Authorization", "Bearer "+skKey)
  288. }
  289. // 从x-goog-api-key header中获取key
  290. xGoogKey := c.Request.Header.Get("x-goog-api-key")
  291. if xGoogKey != "" {
  292. c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
  293. }
  294. }
  295. key := c.Request.Header.Get("Authorization")
  296. parts := make([]string, 0)
  297. if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
  298. key = strings.TrimSpace(key[7:])
  299. }
  300. if key == "" || key == "midjourney-proxy" {
  301. key = c.Request.Header.Get("mj-api-secret")
  302. if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
  303. key = strings.TrimSpace(key[7:])
  304. }
  305. key = strings.TrimPrefix(key, "sk-")
  306. parts = strings.Split(key, "-")
  307. key = parts[0]
  308. } else {
  309. key = strings.TrimPrefix(key, "sk-")
  310. parts = strings.Split(key, "-")
  311. key = parts[0]
  312. }
  313. token, err := model.ValidateUserToken(key)
  314. if token != nil {
  315. id := c.GetInt("id")
  316. if id == 0 {
  317. c.Set("id", token.UserId)
  318. }
  319. }
  320. if err != nil {
  321. if errors.Is(err, model.ErrDatabase) {
  322. common.SysLog("TokenAuth ValidateUserToken database error: " + err.Error())
  323. abortWithOpenAiMessage(c, http.StatusInternalServerError,
  324. common.TranslateMessage(c, i18n.MsgDatabaseError))
  325. } else {
  326. abortWithOpenAiMessage(c, http.StatusUnauthorized,
  327. common.TranslateMessage(c, i18n.MsgTokenInvalid))
  328. }
  329. return
  330. }
  331. allowIps := token.GetIpLimits()
  332. if len(allowIps) > 0 {
  333. clientIp := c.ClientIP()
  334. logger.LogDebug(c, "Token has IP restrictions, checking client IP %s", clientIp)
  335. ip := net.ParseIP(clientIp)
  336. if ip == nil {
  337. abortWithOpenAiMessage(c, http.StatusForbidden, "无法解析客户端 IP 地址")
  338. return
  339. }
  340. if common.IsIpInCIDRList(ip, allowIps) == false {
  341. abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中", types.ErrorCodeAccessDenied)
  342. return
  343. }
  344. logger.LogDebug(c, "Client IP %s passed the token IP restrictions check", clientIp)
  345. }
  346. userCache, err := model.GetUserCache(token.UserId)
  347. if err != nil {
  348. common.SysLog(fmt.Sprintf("TokenAuth GetUserCache error for user %d: %v", token.UserId, err))
  349. abortWithOpenAiMessage(c, http.StatusInternalServerError,
  350. common.TranslateMessage(c, i18n.MsgDatabaseError))
  351. return
  352. }
  353. userEnabled := userCache.Status == common.UserStatusEnabled
  354. if !userEnabled {
  355. abortWithOpenAiMessage(c, http.StatusForbidden, common.TranslateMessage(c, i18n.MsgAuthUserBanned))
  356. return
  357. }
  358. userCache.WriteContext(c)
  359. userGroup := userCache.Group
  360. tokenGroup := token.Group
  361. if tokenGroup != "" {
  362. // check common.UserUsableGroups[userGroup]
  363. if _, ok := service.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
  364. abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("无权访问 %s 分组", tokenGroup))
  365. return
  366. }
  367. // check group in common.GroupRatio
  368. if !ratio_setting.ContainsGroupRatio(tokenGroup) {
  369. if tokenGroup != "auto" {
  370. abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
  371. return
  372. }
  373. }
  374. userGroup = tokenGroup
  375. }
  376. common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
  377. err = SetupContextForToken(c, token, parts...)
  378. if err != nil {
  379. return
  380. }
  381. c.Next()
  382. }
  383. }
  384. func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
  385. if token == nil {
  386. return fmt.Errorf("token is nil")
  387. }
  388. c.Set("id", token.UserId)
  389. c.Set("token_id", token.Id)
  390. c.Set("token_key", token.Key)
  391. c.Set("token_name", token.Name)
  392. c.Set("token_unlimited_quota", token.UnlimitedQuota)
  393. if !token.UnlimitedQuota {
  394. c.Set("token_quota", token.RemainQuota)
  395. }
  396. if token.ModelLimitsEnabled {
  397. c.Set("token_model_limit_enabled", true)
  398. c.Set("token_model_limit", token.GetModelLimitsMap())
  399. } else {
  400. c.Set("token_model_limit_enabled", false)
  401. }
  402. common.SetContextKey(c, constant.ContextKeyTokenGroup, token.Group)
  403. common.SetContextKey(c, constant.ContextKeyTokenCrossGroupRetry, token.CrossGroupRetry)
  404. if len(parts) > 1 {
  405. if model.IsAdmin(token.UserId) {
  406. c.Set("specific_channel_id", parts[1])
  407. } else {
  408. c.Header("specific_channel_version", "701e3ae1dc3f7975556d354e0675168d004891c8")
  409. abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
  410. return fmt.Errorf("普通用户不支持指定渠道")
  411. }
  412. }
  413. return nil
  414. }