oidc.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package controller
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "one-api/common"
  9. "one-api/logger"
  10. "one-api/model"
  11. "one-api/setting"
  12. "one-api/setting/system_setting"
  13. "strconv"
  14. "strings"
  15. "time"
  16. "github.com/gin-contrib/sessions"
  17. "github.com/gin-gonic/gin"
  18. )
  19. type OidcResponse struct {
  20. AccessToken string `json:"access_token"`
  21. IDToken string `json:"id_token"`
  22. RefreshToken string `json:"refresh_token"`
  23. TokenType string `json:"token_type"`
  24. ExpiresIn int `json:"expires_in"`
  25. Scope string `json:"scope"`
  26. }
  27. type OidcUser struct {
  28. OpenID string `json:"sub"`
  29. Email string `json:"email"`
  30. Name string `json:"name"`
  31. PreferredUsername string `json:"preferred_username"`
  32. Picture string `json:"picture"`
  33. }
  34. func getOidcUserInfoByCode(code string) (*OidcUser, error) {
  35. if code == "" {
  36. return nil, errors.New("无效的参数")
  37. }
  38. values := url.Values{}
  39. values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
  40. values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
  41. values.Set("code", code)
  42. values.Set("grant_type", "authorization_code")
  43. values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
  44. formData := values.Encode()
  45. req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
  46. if err != nil {
  47. return nil, err
  48. }
  49. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  50. req.Header.Set("Accept", "application/json")
  51. client := http.Client{
  52. Timeout: 5 * time.Second,
  53. }
  54. res, err := client.Do(req)
  55. if err != nil {
  56. logger.SysLog(err.Error())
  57. return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
  58. }
  59. defer res.Body.Close()
  60. var oidcResponse OidcResponse
  61. err = json.NewDecoder(res.Body).Decode(&oidcResponse)
  62. if err != nil {
  63. return nil, err
  64. }
  65. if oidcResponse.AccessToken == "" {
  66. logger.SysError("OIDC 获取 Token 失败,请检查设置!")
  67. return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
  68. }
  69. req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
  70. if err != nil {
  71. return nil, err
  72. }
  73. req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
  74. res2, err := client.Do(req)
  75. if err != nil {
  76. logger.SysLog(err.Error())
  77. return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
  78. }
  79. defer res2.Body.Close()
  80. if res2.StatusCode != http.StatusOK {
  81. logger.SysError("OIDC 获取用户信息失败!请检查设置!")
  82. return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
  83. }
  84. var oidcUser OidcUser
  85. err = json.NewDecoder(res2.Body).Decode(&oidcUser)
  86. if err != nil {
  87. return nil, err
  88. }
  89. if oidcUser.OpenID == "" || oidcUser.Email == "" {
  90. logger.SysError("OIDC 获取用户信息为空!请检查设置!")
  91. return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
  92. }
  93. return &oidcUser, nil
  94. }
  95. func OidcAuth(c *gin.Context) {
  96. session := sessions.Default(c)
  97. state := c.Query("state")
  98. if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
  99. c.JSON(http.StatusForbidden, gin.H{
  100. "success": false,
  101. "message": "state is empty or not same",
  102. })
  103. return
  104. }
  105. username := session.Get("username")
  106. if username != nil {
  107. OidcBind(c)
  108. return
  109. }
  110. if !system_setting.GetOIDCSettings().Enabled {
  111. c.JSON(http.StatusOK, gin.H{
  112. "success": false,
  113. "message": "管理员未开启通过 OIDC 登录以及注册",
  114. })
  115. return
  116. }
  117. code := c.Query("code")
  118. oidcUser, err := getOidcUserInfoByCode(code)
  119. if err != nil {
  120. common.ApiError(c, err)
  121. return
  122. }
  123. user := model.User{
  124. OidcId: oidcUser.OpenID,
  125. }
  126. if model.IsOidcIdAlreadyTaken(user.OidcId) {
  127. err := user.FillUserByOidcId()
  128. if err != nil {
  129. c.JSON(http.StatusOK, gin.H{
  130. "success": false,
  131. "message": err.Error(),
  132. })
  133. return
  134. }
  135. } else {
  136. if common.RegisterEnabled {
  137. user.Email = oidcUser.Email
  138. if oidcUser.PreferredUsername != "" {
  139. user.Username = oidcUser.PreferredUsername
  140. } else {
  141. user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
  142. }
  143. if oidcUser.Name != "" {
  144. user.DisplayName = oidcUser.Name
  145. } else {
  146. user.DisplayName = "OIDC User"
  147. }
  148. err := user.Insert(0)
  149. if err != nil {
  150. c.JSON(http.StatusOK, gin.H{
  151. "success": false,
  152. "message": err.Error(),
  153. })
  154. return
  155. }
  156. } else {
  157. c.JSON(http.StatusOK, gin.H{
  158. "success": false,
  159. "message": "管理员关闭了新用户注册",
  160. })
  161. return
  162. }
  163. }
  164. if user.Status != common.UserStatusEnabled {
  165. c.JSON(http.StatusOK, gin.H{
  166. "message": "用户已被封禁",
  167. "success": false,
  168. })
  169. return
  170. }
  171. setupLogin(&user, c)
  172. }
  173. func OidcBind(c *gin.Context) {
  174. if !system_setting.GetOIDCSettings().Enabled {
  175. c.JSON(http.StatusOK, gin.H{
  176. "success": false,
  177. "message": "管理员未开启通过 OIDC 登录以及注册",
  178. })
  179. return
  180. }
  181. code := c.Query("code")
  182. oidcUser, err := getOidcUserInfoByCode(code)
  183. if err != nil {
  184. common.ApiError(c, err)
  185. return
  186. }
  187. user := model.User{
  188. OidcId: oidcUser.OpenID,
  189. }
  190. if model.IsOidcIdAlreadyTaken(user.OidcId) {
  191. c.JSON(http.StatusOK, gin.H{
  192. "success": false,
  193. "message": "该 OIDC 账户已被绑定",
  194. })
  195. return
  196. }
  197. session := sessions.Default(c)
  198. id := session.Get("id")
  199. // id := c.GetInt("id") // critical bug!
  200. user.Id = id.(int)
  201. err = user.FillUserById()
  202. if err != nil {
  203. common.ApiError(c, err)
  204. return
  205. }
  206. user.OidcId = oidcUser.OpenID
  207. err = user.Update(false)
  208. if err != nil {
  209. common.ApiError(c, err)
  210. return
  211. }
  212. c.JSON(http.StatusOK, gin.H{
  213. "success": true,
  214. "message": "bind",
  215. })
  216. return
  217. }