github.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package controller
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/logger"
  10. "one-api/model"
  11. "strconv"
  12. "time"
  13. "github.com/gin-contrib/sessions"
  14. "github.com/gin-gonic/gin"
  15. )
  16. type GitHubOAuthResponse struct {
  17. AccessToken string `json:"access_token"`
  18. Scope string `json:"scope"`
  19. TokenType string `json:"token_type"`
  20. }
  21. type GitHubUser struct {
  22. Login string `json:"login"`
  23. Name string `json:"name"`
  24. Email string `json:"email"`
  25. }
  26. func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
  27. if code == "" {
  28. return nil, errors.New("无效的参数")
  29. }
  30. values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
  31. jsonData, err := json.Marshal(values)
  32. if err != nil {
  33. return nil, err
  34. }
  35. req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
  36. if err != nil {
  37. return nil, err
  38. }
  39. req.Header.Set("Content-Type", "application/json")
  40. req.Header.Set("Accept", "application/json")
  41. client := http.Client{
  42. Timeout: 5 * time.Second,
  43. }
  44. res, err := client.Do(req)
  45. if err != nil {
  46. logger.SysLog(err.Error())
  47. return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
  48. }
  49. defer res.Body.Close()
  50. var oAuthResponse GitHubOAuthResponse
  51. err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
  52. if err != nil {
  53. return nil, err
  54. }
  55. req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
  56. if err != nil {
  57. return nil, err
  58. }
  59. req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
  60. res2, err := client.Do(req)
  61. if err != nil {
  62. logger.SysLog(err.Error())
  63. return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
  64. }
  65. defer res2.Body.Close()
  66. var githubUser GitHubUser
  67. err = json.NewDecoder(res2.Body).Decode(&githubUser)
  68. if err != nil {
  69. return nil, err
  70. }
  71. if githubUser.Login == "" {
  72. return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
  73. }
  74. return &githubUser, nil
  75. }
  76. func GitHubOAuth(c *gin.Context) {
  77. session := sessions.Default(c)
  78. state := c.Query("state")
  79. if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
  80. c.JSON(http.StatusForbidden, gin.H{
  81. "success": false,
  82. "message": "state is empty or not same",
  83. })
  84. return
  85. }
  86. username := session.Get("username")
  87. if username != nil {
  88. GitHubBind(c)
  89. return
  90. }
  91. if !common.GitHubOAuthEnabled {
  92. c.JSON(http.StatusOK, gin.H{
  93. "success": false,
  94. "message": "管理员未开启通过 GitHub 登录以及注册",
  95. })
  96. return
  97. }
  98. code := c.Query("code")
  99. githubUser, err := getGitHubUserInfoByCode(code)
  100. if err != nil {
  101. common.ApiError(c, err)
  102. return
  103. }
  104. user := model.User{
  105. GitHubId: githubUser.Login,
  106. }
  107. // IsGitHubIdAlreadyTaken is unscoped
  108. if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
  109. // FillUserByGitHubId is scoped
  110. err := user.FillUserByGitHubId()
  111. if err != nil {
  112. c.JSON(http.StatusOK, gin.H{
  113. "success": false,
  114. "message": err.Error(),
  115. })
  116. return
  117. }
  118. // if user.Id == 0 , user has been deleted
  119. if user.Id == 0 {
  120. c.JSON(http.StatusOK, gin.H{
  121. "success": false,
  122. "message": "用户已注销",
  123. })
  124. return
  125. }
  126. } else {
  127. if common.RegisterEnabled {
  128. user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
  129. if githubUser.Name != "" {
  130. user.DisplayName = githubUser.Name
  131. } else {
  132. user.DisplayName = "GitHub User"
  133. }
  134. user.Email = githubUser.Email
  135. user.Role = common.RoleCommonUser
  136. user.Status = common.UserStatusEnabled
  137. affCode := session.Get("aff")
  138. inviterId := 0
  139. if affCode != nil {
  140. inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
  141. }
  142. if err := user.Insert(inviterId); err != nil {
  143. c.JSON(http.StatusOK, gin.H{
  144. "success": false,
  145. "message": err.Error(),
  146. })
  147. return
  148. }
  149. } else {
  150. c.JSON(http.StatusOK, gin.H{
  151. "success": false,
  152. "message": "管理员关闭了新用户注册",
  153. })
  154. return
  155. }
  156. }
  157. if user.Status != common.UserStatusEnabled {
  158. c.JSON(http.StatusOK, gin.H{
  159. "message": "用户已被封禁",
  160. "success": false,
  161. })
  162. return
  163. }
  164. setupLogin(&user, c)
  165. }
  166. func GitHubBind(c *gin.Context) {
  167. if !common.GitHubOAuthEnabled {
  168. c.JSON(http.StatusOK, gin.H{
  169. "success": false,
  170. "message": "管理员未开启通过 GitHub 登录以及注册",
  171. })
  172. return
  173. }
  174. code := c.Query("code")
  175. githubUser, err := getGitHubUserInfoByCode(code)
  176. if err != nil {
  177. common.ApiError(c, err)
  178. return
  179. }
  180. user := model.User{
  181. GitHubId: githubUser.Login,
  182. }
  183. if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
  184. c.JSON(http.StatusOK, gin.H{
  185. "success": false,
  186. "message": "该 GitHub 账户已被绑定",
  187. })
  188. return
  189. }
  190. session := sessions.Default(c)
  191. id := session.Get("id")
  192. // id := c.GetInt("id") // critical bug!
  193. user.Id = id.(int)
  194. err = user.FillUserById()
  195. if err != nil {
  196. common.ApiError(c, err)
  197. return
  198. }
  199. user.GitHubId = githubUser.Login
  200. err = user.Update(false)
  201. if err != nil {
  202. common.ApiError(c, err)
  203. return
  204. }
  205. c.JSON(http.StatusOK, gin.H{
  206. "success": true,
  207. "message": "bind",
  208. })
  209. return
  210. }
  211. func GenerateOAuthCode(c *gin.Context) {
  212. session := sessions.Default(c)
  213. state := common.GetRandomString(12)
  214. affCode := c.Query("aff")
  215. if affCode != "" {
  216. session.Set("aff", affCode)
  217. }
  218. session.Set("oauth_state", state)
  219. err := session.Save()
  220. if err != nil {
  221. common.ApiError(c, err)
  222. return
  223. }
  224. c.JSON(http.StatusOK, gin.H{
  225. "success": true,
  226. "message": "",
  227. "data": state,
  228. })
  229. }