custom_oauth.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. package controller
  2. import (
  3. "context"
  4. "io"
  5. "net/http"
  6. "net/url"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/model"
  12. "github.com/QuantumNous/new-api/oauth"
  13. "github.com/gin-gonic/gin"
  14. )
  15. // CustomOAuthProviderResponse is the response structure for custom OAuth providers
  16. // It excludes sensitive fields like client_secret
  17. type CustomOAuthProviderResponse struct {
  18. Id int `json:"id"`
  19. Name string `json:"name"`
  20. Slug string `json:"slug"`
  21. Icon string `json:"icon"`
  22. Enabled bool `json:"enabled"`
  23. ClientId string `json:"client_id"`
  24. AuthorizationEndpoint string `json:"authorization_endpoint"`
  25. TokenEndpoint string `json:"token_endpoint"`
  26. UserInfoEndpoint string `json:"user_info_endpoint"`
  27. Scopes string `json:"scopes"`
  28. UserIdField string `json:"user_id_field"`
  29. UsernameField string `json:"username_field"`
  30. DisplayNameField string `json:"display_name_field"`
  31. EmailField string `json:"email_field"`
  32. WellKnown string `json:"well_known"`
  33. AuthStyle int `json:"auth_style"`
  34. AccessPolicy string `json:"access_policy"`
  35. AccessDeniedMessage string `json:"access_denied_message"`
  36. }
  37. type UserOAuthBindingResponse struct {
  38. ProviderId int `json:"provider_id"`
  39. ProviderName string `json:"provider_name"`
  40. ProviderSlug string `json:"provider_slug"`
  41. ProviderIcon string `json:"provider_icon"`
  42. ProviderUserId string `json:"provider_user_id"`
  43. }
  44. func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
  45. return &CustomOAuthProviderResponse{
  46. Id: p.Id,
  47. Name: p.Name,
  48. Slug: p.Slug,
  49. Icon: p.Icon,
  50. Enabled: p.Enabled,
  51. ClientId: p.ClientId,
  52. AuthorizationEndpoint: p.AuthorizationEndpoint,
  53. TokenEndpoint: p.TokenEndpoint,
  54. UserInfoEndpoint: p.UserInfoEndpoint,
  55. Scopes: p.Scopes,
  56. UserIdField: p.UserIdField,
  57. UsernameField: p.UsernameField,
  58. DisplayNameField: p.DisplayNameField,
  59. EmailField: p.EmailField,
  60. WellKnown: p.WellKnown,
  61. AuthStyle: p.AuthStyle,
  62. AccessPolicy: p.AccessPolicy,
  63. AccessDeniedMessage: p.AccessDeniedMessage,
  64. }
  65. }
  66. // GetCustomOAuthProviders returns all custom OAuth providers
  67. func GetCustomOAuthProviders(c *gin.Context) {
  68. providers, err := model.GetAllCustomOAuthProviders()
  69. if err != nil {
  70. common.ApiError(c, err)
  71. return
  72. }
  73. response := make([]*CustomOAuthProviderResponse, len(providers))
  74. for i, p := range providers {
  75. response[i] = toCustomOAuthProviderResponse(p)
  76. }
  77. c.JSON(http.StatusOK, gin.H{
  78. "success": true,
  79. "message": "",
  80. "data": response,
  81. })
  82. }
  83. // GetCustomOAuthProvider returns a single custom OAuth provider by ID
  84. func GetCustomOAuthProvider(c *gin.Context) {
  85. idStr := c.Param("id")
  86. id, err := strconv.Atoi(idStr)
  87. if err != nil {
  88. common.ApiErrorMsg(c, "无效的 ID")
  89. return
  90. }
  91. provider, err := model.GetCustomOAuthProviderById(id)
  92. if err != nil {
  93. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  94. return
  95. }
  96. c.JSON(http.StatusOK, gin.H{
  97. "success": true,
  98. "message": "",
  99. "data": toCustomOAuthProviderResponse(provider),
  100. })
  101. }
  102. // CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
  103. type CreateCustomOAuthProviderRequest struct {
  104. Name string `json:"name" binding:"required"`
  105. Slug string `json:"slug" binding:"required"`
  106. Icon string `json:"icon"`
  107. Enabled bool `json:"enabled"`
  108. ClientId string `json:"client_id" binding:"required"`
  109. ClientSecret string `json:"client_secret" binding:"required"`
  110. AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
  111. TokenEndpoint string `json:"token_endpoint" binding:"required"`
  112. UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
  113. Scopes string `json:"scopes"`
  114. UserIdField string `json:"user_id_field"`
  115. UsernameField string `json:"username_field"`
  116. DisplayNameField string `json:"display_name_field"`
  117. EmailField string `json:"email_field"`
  118. WellKnown string `json:"well_known"`
  119. AuthStyle int `json:"auth_style"`
  120. AccessPolicy string `json:"access_policy"`
  121. AccessDeniedMessage string `json:"access_denied_message"`
  122. }
  123. type FetchCustomOAuthDiscoveryRequest struct {
  124. WellKnownURL string `json:"well_known_url"`
  125. IssuerURL string `json:"issuer_url"`
  126. }
  127. // FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route)
  128. func FetchCustomOAuthDiscovery(c *gin.Context) {
  129. var req FetchCustomOAuthDiscoveryRequest
  130. if err := c.ShouldBindJSON(&req); err != nil {
  131. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  132. return
  133. }
  134. wellKnownURL := strings.TrimSpace(req.WellKnownURL)
  135. issuerURL := strings.TrimSpace(req.IssuerURL)
  136. if wellKnownURL == "" && issuerURL == "" {
  137. common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL")
  138. return
  139. }
  140. targetURL := wellKnownURL
  141. if targetURL == "" {
  142. targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
  143. }
  144. targetURL = strings.TrimSpace(targetURL)
  145. parsedURL, err := url.Parse(targetURL)
  146. if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
  147. common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https")
  148. return
  149. }
  150. ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second)
  151. defer cancel()
  152. httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
  153. if err != nil {
  154. common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error())
  155. return
  156. }
  157. httpReq.Header.Set("Accept", "application/json")
  158. client := &http.Client{Timeout: 20 * time.Second}
  159. resp, err := client.Do(httpReq)
  160. if err != nil {
  161. common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error())
  162. return
  163. }
  164. defer resp.Body.Close()
  165. if resp.StatusCode != http.StatusOK {
  166. body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
  167. message := strings.TrimSpace(string(body))
  168. if message == "" {
  169. message = resp.Status
  170. }
  171. common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message)
  172. return
  173. }
  174. var discovery map[string]any
  175. if err = common.DecodeJson(resp.Body, &discovery); err != nil {
  176. common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error())
  177. return
  178. }
  179. c.JSON(http.StatusOK, gin.H{
  180. "success": true,
  181. "message": "",
  182. "data": gin.H{
  183. "well_known_url": targetURL,
  184. "discovery": discovery,
  185. },
  186. })
  187. }
  188. // CreateCustomOAuthProvider creates a new custom OAuth provider
  189. func CreateCustomOAuthProvider(c *gin.Context) {
  190. var req CreateCustomOAuthProviderRequest
  191. if err := c.ShouldBindJSON(&req); err != nil {
  192. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  193. return
  194. }
  195. // Check if slug is already taken
  196. if model.IsSlugTaken(req.Slug, 0) {
  197. common.ApiErrorMsg(c, "该 Slug 已被使用")
  198. return
  199. }
  200. // Check if slug conflicts with built-in providers
  201. if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
  202. common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
  203. return
  204. }
  205. provider := &model.CustomOAuthProvider{
  206. Name: req.Name,
  207. Slug: req.Slug,
  208. Icon: req.Icon,
  209. Enabled: req.Enabled,
  210. ClientId: req.ClientId,
  211. ClientSecret: req.ClientSecret,
  212. AuthorizationEndpoint: req.AuthorizationEndpoint,
  213. TokenEndpoint: req.TokenEndpoint,
  214. UserInfoEndpoint: req.UserInfoEndpoint,
  215. Scopes: req.Scopes,
  216. UserIdField: req.UserIdField,
  217. UsernameField: req.UsernameField,
  218. DisplayNameField: req.DisplayNameField,
  219. EmailField: req.EmailField,
  220. WellKnown: req.WellKnown,
  221. AuthStyle: req.AuthStyle,
  222. AccessPolicy: req.AccessPolicy,
  223. AccessDeniedMessage: req.AccessDeniedMessage,
  224. }
  225. if err := model.CreateCustomOAuthProvider(provider); err != nil {
  226. common.ApiError(c, err)
  227. return
  228. }
  229. // Register the provider in the OAuth registry
  230. oauth.RegisterOrUpdateCustomProvider(provider)
  231. c.JSON(http.StatusOK, gin.H{
  232. "success": true,
  233. "message": "创建成功",
  234. "data": toCustomOAuthProviderResponse(provider),
  235. })
  236. }
  237. // UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
  238. type UpdateCustomOAuthProviderRequest struct {
  239. Name string `json:"name"`
  240. Slug string `json:"slug"`
  241. Icon *string `json:"icon"` // Optional: if nil, keep existing
  242. Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
  243. ClientId string `json:"client_id"`
  244. ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
  245. AuthorizationEndpoint string `json:"authorization_endpoint"`
  246. TokenEndpoint string `json:"token_endpoint"`
  247. UserInfoEndpoint string `json:"user_info_endpoint"`
  248. Scopes string `json:"scopes"`
  249. UserIdField string `json:"user_id_field"`
  250. UsernameField string `json:"username_field"`
  251. DisplayNameField string `json:"display_name_field"`
  252. EmailField string `json:"email_field"`
  253. WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
  254. AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
  255. AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing
  256. AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing
  257. }
  258. // UpdateCustomOAuthProvider updates an existing custom OAuth provider
  259. func UpdateCustomOAuthProvider(c *gin.Context) {
  260. idStr := c.Param("id")
  261. id, err := strconv.Atoi(idStr)
  262. if err != nil {
  263. common.ApiErrorMsg(c, "无效的 ID")
  264. return
  265. }
  266. var req UpdateCustomOAuthProviderRequest
  267. if err := c.ShouldBindJSON(&req); err != nil {
  268. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  269. return
  270. }
  271. // Get existing provider
  272. provider, err := model.GetCustomOAuthProviderById(id)
  273. if err != nil {
  274. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  275. return
  276. }
  277. oldSlug := provider.Slug
  278. // Check if new slug is taken by another provider
  279. if req.Slug != "" && req.Slug != provider.Slug {
  280. if model.IsSlugTaken(req.Slug, id) {
  281. common.ApiErrorMsg(c, "该 Slug 已被使用")
  282. return
  283. }
  284. // Check if slug conflicts with built-in providers
  285. if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
  286. common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
  287. return
  288. }
  289. }
  290. // Update fields
  291. if req.Name != "" {
  292. provider.Name = req.Name
  293. }
  294. if req.Slug != "" {
  295. provider.Slug = req.Slug
  296. }
  297. if req.Icon != nil {
  298. provider.Icon = *req.Icon
  299. }
  300. if req.Enabled != nil {
  301. provider.Enabled = *req.Enabled
  302. }
  303. if req.ClientId != "" {
  304. provider.ClientId = req.ClientId
  305. }
  306. if req.ClientSecret != "" {
  307. provider.ClientSecret = req.ClientSecret
  308. }
  309. if req.AuthorizationEndpoint != "" {
  310. provider.AuthorizationEndpoint = req.AuthorizationEndpoint
  311. }
  312. if req.TokenEndpoint != "" {
  313. provider.TokenEndpoint = req.TokenEndpoint
  314. }
  315. if req.UserInfoEndpoint != "" {
  316. provider.UserInfoEndpoint = req.UserInfoEndpoint
  317. }
  318. if req.Scopes != "" {
  319. provider.Scopes = req.Scopes
  320. }
  321. if req.UserIdField != "" {
  322. provider.UserIdField = req.UserIdField
  323. }
  324. if req.UsernameField != "" {
  325. provider.UsernameField = req.UsernameField
  326. }
  327. if req.DisplayNameField != "" {
  328. provider.DisplayNameField = req.DisplayNameField
  329. }
  330. if req.EmailField != "" {
  331. provider.EmailField = req.EmailField
  332. }
  333. if req.WellKnown != nil {
  334. provider.WellKnown = *req.WellKnown
  335. }
  336. if req.AuthStyle != nil {
  337. provider.AuthStyle = *req.AuthStyle
  338. }
  339. if req.AccessPolicy != nil {
  340. provider.AccessPolicy = *req.AccessPolicy
  341. }
  342. if req.AccessDeniedMessage != nil {
  343. provider.AccessDeniedMessage = *req.AccessDeniedMessage
  344. }
  345. if err := model.UpdateCustomOAuthProvider(provider); err != nil {
  346. common.ApiError(c, err)
  347. return
  348. }
  349. // Update the provider in the OAuth registry
  350. if oldSlug != provider.Slug {
  351. oauth.UnregisterCustomProvider(oldSlug)
  352. }
  353. oauth.RegisterOrUpdateCustomProvider(provider)
  354. c.JSON(http.StatusOK, gin.H{
  355. "success": true,
  356. "message": "更新成功",
  357. "data": toCustomOAuthProviderResponse(provider),
  358. })
  359. }
  360. // DeleteCustomOAuthProvider deletes a custom OAuth provider
  361. func DeleteCustomOAuthProvider(c *gin.Context) {
  362. idStr := c.Param("id")
  363. id, err := strconv.Atoi(idStr)
  364. if err != nil {
  365. common.ApiErrorMsg(c, "无效的 ID")
  366. return
  367. }
  368. // Get existing provider to get slug
  369. provider, err := model.GetCustomOAuthProviderById(id)
  370. if err != nil {
  371. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  372. return
  373. }
  374. // Check if there are any user bindings
  375. count, err := model.GetBindingCountByProviderId(id)
  376. if err != nil {
  377. common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
  378. common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
  379. return
  380. }
  381. if count > 0 {
  382. common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
  383. return
  384. }
  385. if err := model.DeleteCustomOAuthProvider(id); err != nil {
  386. common.ApiError(c, err)
  387. return
  388. }
  389. // Unregister the provider from the OAuth registry
  390. oauth.UnregisterCustomProvider(provider.Slug)
  391. c.JSON(http.StatusOK, gin.H{
  392. "success": true,
  393. "message": "删除成功",
  394. })
  395. }
  396. func buildUserOAuthBindingsResponse(userId int) ([]UserOAuthBindingResponse, error) {
  397. bindings, err := model.GetUserOAuthBindingsByUserId(userId)
  398. if err != nil {
  399. return nil, err
  400. }
  401. response := make([]UserOAuthBindingResponse, 0, len(bindings))
  402. for _, binding := range bindings {
  403. provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
  404. if err != nil {
  405. continue
  406. }
  407. response = append(response, UserOAuthBindingResponse{
  408. ProviderId: binding.ProviderId,
  409. ProviderName: provider.Name,
  410. ProviderSlug: provider.Slug,
  411. ProviderIcon: provider.Icon,
  412. ProviderUserId: binding.ProviderUserId,
  413. })
  414. }
  415. return response, nil
  416. }
  417. // GetUserOAuthBindings returns all OAuth bindings for the current user
  418. func GetUserOAuthBindings(c *gin.Context) {
  419. userId := c.GetInt("id")
  420. if userId == 0 {
  421. common.ApiErrorMsg(c, "未登录")
  422. return
  423. }
  424. response, err := buildUserOAuthBindingsResponse(userId)
  425. if err != nil {
  426. common.ApiError(c, err)
  427. return
  428. }
  429. c.JSON(http.StatusOK, gin.H{
  430. "success": true,
  431. "message": "",
  432. "data": response,
  433. })
  434. }
  435. func GetUserOAuthBindingsByAdmin(c *gin.Context) {
  436. userIdStr := c.Param("id")
  437. userId, err := strconv.Atoi(userIdStr)
  438. if err != nil {
  439. common.ApiErrorMsg(c, "invalid user id")
  440. return
  441. }
  442. targetUser, err := model.GetUserById(userId, false)
  443. if err != nil {
  444. common.ApiError(c, err)
  445. return
  446. }
  447. myRole := c.GetInt("role")
  448. if myRole <= targetUser.Role && myRole != common.RoleRootUser {
  449. common.ApiErrorMsg(c, "no permission")
  450. return
  451. }
  452. response, err := buildUserOAuthBindingsResponse(userId)
  453. if err != nil {
  454. common.ApiError(c, err)
  455. return
  456. }
  457. c.JSON(http.StatusOK, gin.H{
  458. "success": true,
  459. "message": "",
  460. "data": response,
  461. })
  462. }
  463. // UnbindCustomOAuth unbinds a custom OAuth provider from the current user
  464. func UnbindCustomOAuth(c *gin.Context) {
  465. userId := c.GetInt("id")
  466. if userId == 0 {
  467. common.ApiErrorMsg(c, "未登录")
  468. return
  469. }
  470. providerIdStr := c.Param("provider_id")
  471. providerId, err := strconv.Atoi(providerIdStr)
  472. if err != nil {
  473. common.ApiErrorMsg(c, "无效的提供商 ID")
  474. return
  475. }
  476. if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
  477. common.ApiError(c, err)
  478. return
  479. }
  480. c.JSON(http.StatusOK, gin.H{
  481. "success": true,
  482. "message": "解绑成功",
  483. })
  484. }
  485. func UnbindCustomOAuthByAdmin(c *gin.Context) {
  486. userIdStr := c.Param("id")
  487. userId, err := strconv.Atoi(userIdStr)
  488. if err != nil {
  489. common.ApiErrorMsg(c, "invalid user id")
  490. return
  491. }
  492. targetUser, err := model.GetUserById(userId, false)
  493. if err != nil {
  494. common.ApiError(c, err)
  495. return
  496. }
  497. myRole := c.GetInt("role")
  498. if myRole <= targetUser.Role && myRole != common.RoleRootUser {
  499. common.ApiErrorMsg(c, "no permission")
  500. return
  501. }
  502. providerIdStr := c.Param("provider_id")
  503. providerId, err := strconv.Atoi(providerIdStr)
  504. if err != nil {
  505. common.ApiErrorMsg(c, "invalid provider id")
  506. return
  507. }
  508. if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
  509. common.ApiError(c, err)
  510. return
  511. }
  512. c.JSON(http.StatusOK, gin.H{
  513. "success": true,
  514. "message": "success",
  515. })
  516. }