custom_oauth.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. package controller
  2. import (
  3. "context"
  4. "io"
  5. "net/http"
  6. "net/url"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/model"
  12. "github.com/QuantumNous/new-api/oauth"
  13. "github.com/gin-gonic/gin"
  14. )
  15. // CustomOAuthProviderResponse is the response structure for custom OAuth providers
  16. // It excludes sensitive fields like client_secret
  17. type CustomOAuthProviderResponse struct {
  18. Id int `json:"id"`
  19. Name string `json:"name"`
  20. Slug string `json:"slug"`
  21. Icon string `json:"icon"`
  22. Enabled bool `json:"enabled"`
  23. ClientId string `json:"client_id"`
  24. AuthorizationEndpoint string `json:"authorization_endpoint"`
  25. TokenEndpoint string `json:"token_endpoint"`
  26. UserInfoEndpoint string `json:"user_info_endpoint"`
  27. Scopes string `json:"scopes"`
  28. UserIdField string `json:"user_id_field"`
  29. UsernameField string `json:"username_field"`
  30. DisplayNameField string `json:"display_name_field"`
  31. EmailField string `json:"email_field"`
  32. WellKnown string `json:"well_known"`
  33. AuthStyle int `json:"auth_style"`
  34. AccessPolicy string `json:"access_policy"`
  35. AccessDeniedMessage string `json:"access_denied_message"`
  36. }
  37. func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
  38. return &CustomOAuthProviderResponse{
  39. Id: p.Id,
  40. Name: p.Name,
  41. Slug: p.Slug,
  42. Icon: p.Icon,
  43. Enabled: p.Enabled,
  44. ClientId: p.ClientId,
  45. AuthorizationEndpoint: p.AuthorizationEndpoint,
  46. TokenEndpoint: p.TokenEndpoint,
  47. UserInfoEndpoint: p.UserInfoEndpoint,
  48. Scopes: p.Scopes,
  49. UserIdField: p.UserIdField,
  50. UsernameField: p.UsernameField,
  51. DisplayNameField: p.DisplayNameField,
  52. EmailField: p.EmailField,
  53. WellKnown: p.WellKnown,
  54. AuthStyle: p.AuthStyle,
  55. AccessPolicy: p.AccessPolicy,
  56. AccessDeniedMessage: p.AccessDeniedMessage,
  57. }
  58. }
  59. // GetCustomOAuthProviders returns all custom OAuth providers
  60. func GetCustomOAuthProviders(c *gin.Context) {
  61. providers, err := model.GetAllCustomOAuthProviders()
  62. if err != nil {
  63. common.ApiError(c, err)
  64. return
  65. }
  66. response := make([]*CustomOAuthProviderResponse, len(providers))
  67. for i, p := range providers {
  68. response[i] = toCustomOAuthProviderResponse(p)
  69. }
  70. c.JSON(http.StatusOK, gin.H{
  71. "success": true,
  72. "message": "",
  73. "data": response,
  74. })
  75. }
  76. // GetCustomOAuthProvider returns a single custom OAuth provider by ID
  77. func GetCustomOAuthProvider(c *gin.Context) {
  78. idStr := c.Param("id")
  79. id, err := strconv.Atoi(idStr)
  80. if err != nil {
  81. common.ApiErrorMsg(c, "无效的 ID")
  82. return
  83. }
  84. provider, err := model.GetCustomOAuthProviderById(id)
  85. if err != nil {
  86. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  87. return
  88. }
  89. c.JSON(http.StatusOK, gin.H{
  90. "success": true,
  91. "message": "",
  92. "data": toCustomOAuthProviderResponse(provider),
  93. })
  94. }
  95. // CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
  96. type CreateCustomOAuthProviderRequest struct {
  97. Name string `json:"name" binding:"required"`
  98. Slug string `json:"slug" binding:"required"`
  99. Icon string `json:"icon"`
  100. Enabled bool `json:"enabled"`
  101. ClientId string `json:"client_id" binding:"required"`
  102. ClientSecret string `json:"client_secret" binding:"required"`
  103. AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
  104. TokenEndpoint string `json:"token_endpoint" binding:"required"`
  105. UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
  106. Scopes string `json:"scopes"`
  107. UserIdField string `json:"user_id_field"`
  108. UsernameField string `json:"username_field"`
  109. DisplayNameField string `json:"display_name_field"`
  110. EmailField string `json:"email_field"`
  111. WellKnown string `json:"well_known"`
  112. AuthStyle int `json:"auth_style"`
  113. AccessPolicy string `json:"access_policy"`
  114. AccessDeniedMessage string `json:"access_denied_message"`
  115. }
  116. type FetchCustomOAuthDiscoveryRequest struct {
  117. WellKnownURL string `json:"well_known_url"`
  118. IssuerURL string `json:"issuer_url"`
  119. }
  120. // FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route)
  121. func FetchCustomOAuthDiscovery(c *gin.Context) {
  122. var req FetchCustomOAuthDiscoveryRequest
  123. if err := c.ShouldBindJSON(&req); err != nil {
  124. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  125. return
  126. }
  127. wellKnownURL := strings.TrimSpace(req.WellKnownURL)
  128. issuerURL := strings.TrimSpace(req.IssuerURL)
  129. if wellKnownURL == "" && issuerURL == "" {
  130. common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL")
  131. return
  132. }
  133. targetURL := wellKnownURL
  134. if targetURL == "" {
  135. targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
  136. }
  137. targetURL = strings.TrimSpace(targetURL)
  138. parsedURL, err := url.Parse(targetURL)
  139. if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
  140. common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https")
  141. return
  142. }
  143. ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second)
  144. defer cancel()
  145. httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
  146. if err != nil {
  147. common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error())
  148. return
  149. }
  150. httpReq.Header.Set("Accept", "application/json")
  151. client := &http.Client{Timeout: 20 * time.Second}
  152. resp, err := client.Do(httpReq)
  153. if err != nil {
  154. common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error())
  155. return
  156. }
  157. defer resp.Body.Close()
  158. if resp.StatusCode != http.StatusOK {
  159. body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
  160. message := strings.TrimSpace(string(body))
  161. if message == "" {
  162. message = resp.Status
  163. }
  164. common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message)
  165. return
  166. }
  167. var discovery map[string]any
  168. if err = common.DecodeJson(resp.Body, &discovery); err != nil {
  169. common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error())
  170. return
  171. }
  172. c.JSON(http.StatusOK, gin.H{
  173. "success": true,
  174. "message": "",
  175. "data": gin.H{
  176. "well_known_url": targetURL,
  177. "discovery": discovery,
  178. },
  179. })
  180. }
  181. // CreateCustomOAuthProvider creates a new custom OAuth provider
  182. func CreateCustomOAuthProvider(c *gin.Context) {
  183. var req CreateCustomOAuthProviderRequest
  184. if err := c.ShouldBindJSON(&req); err != nil {
  185. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  186. return
  187. }
  188. // Check if slug is already taken
  189. if model.IsSlugTaken(req.Slug, 0) {
  190. common.ApiErrorMsg(c, "该 Slug 已被使用")
  191. return
  192. }
  193. // Check if slug conflicts with built-in providers
  194. if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
  195. common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
  196. return
  197. }
  198. provider := &model.CustomOAuthProvider{
  199. Name: req.Name,
  200. Slug: req.Slug,
  201. Icon: req.Icon,
  202. Enabled: req.Enabled,
  203. ClientId: req.ClientId,
  204. ClientSecret: req.ClientSecret,
  205. AuthorizationEndpoint: req.AuthorizationEndpoint,
  206. TokenEndpoint: req.TokenEndpoint,
  207. UserInfoEndpoint: req.UserInfoEndpoint,
  208. Scopes: req.Scopes,
  209. UserIdField: req.UserIdField,
  210. UsernameField: req.UsernameField,
  211. DisplayNameField: req.DisplayNameField,
  212. EmailField: req.EmailField,
  213. WellKnown: req.WellKnown,
  214. AuthStyle: req.AuthStyle,
  215. AccessPolicy: req.AccessPolicy,
  216. AccessDeniedMessage: req.AccessDeniedMessage,
  217. }
  218. if err := model.CreateCustomOAuthProvider(provider); err != nil {
  219. common.ApiError(c, err)
  220. return
  221. }
  222. // Register the provider in the OAuth registry
  223. oauth.RegisterOrUpdateCustomProvider(provider)
  224. c.JSON(http.StatusOK, gin.H{
  225. "success": true,
  226. "message": "创建成功",
  227. "data": toCustomOAuthProviderResponse(provider),
  228. })
  229. }
  230. // UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
  231. type UpdateCustomOAuthProviderRequest struct {
  232. Name string `json:"name"`
  233. Slug string `json:"slug"`
  234. Icon *string `json:"icon"` // Optional: if nil, keep existing
  235. Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
  236. ClientId string `json:"client_id"`
  237. ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
  238. AuthorizationEndpoint string `json:"authorization_endpoint"`
  239. TokenEndpoint string `json:"token_endpoint"`
  240. UserInfoEndpoint string `json:"user_info_endpoint"`
  241. Scopes string `json:"scopes"`
  242. UserIdField string `json:"user_id_field"`
  243. UsernameField string `json:"username_field"`
  244. DisplayNameField string `json:"display_name_field"`
  245. EmailField string `json:"email_field"`
  246. WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
  247. AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
  248. AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing
  249. AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing
  250. }
  251. // UpdateCustomOAuthProvider updates an existing custom OAuth provider
  252. func UpdateCustomOAuthProvider(c *gin.Context) {
  253. idStr := c.Param("id")
  254. id, err := strconv.Atoi(idStr)
  255. if err != nil {
  256. common.ApiErrorMsg(c, "无效的 ID")
  257. return
  258. }
  259. var req UpdateCustomOAuthProviderRequest
  260. if err := c.ShouldBindJSON(&req); err != nil {
  261. common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
  262. return
  263. }
  264. // Get existing provider
  265. provider, err := model.GetCustomOAuthProviderById(id)
  266. if err != nil {
  267. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  268. return
  269. }
  270. oldSlug := provider.Slug
  271. // Check if new slug is taken by another provider
  272. if req.Slug != "" && req.Slug != provider.Slug {
  273. if model.IsSlugTaken(req.Slug, id) {
  274. common.ApiErrorMsg(c, "该 Slug 已被使用")
  275. return
  276. }
  277. // Check if slug conflicts with built-in providers
  278. if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
  279. common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
  280. return
  281. }
  282. }
  283. // Update fields
  284. if req.Name != "" {
  285. provider.Name = req.Name
  286. }
  287. if req.Slug != "" {
  288. provider.Slug = req.Slug
  289. }
  290. if req.Icon != nil {
  291. provider.Icon = *req.Icon
  292. }
  293. if req.Enabled != nil {
  294. provider.Enabled = *req.Enabled
  295. }
  296. if req.ClientId != "" {
  297. provider.ClientId = req.ClientId
  298. }
  299. if req.ClientSecret != "" {
  300. provider.ClientSecret = req.ClientSecret
  301. }
  302. if req.AuthorizationEndpoint != "" {
  303. provider.AuthorizationEndpoint = req.AuthorizationEndpoint
  304. }
  305. if req.TokenEndpoint != "" {
  306. provider.TokenEndpoint = req.TokenEndpoint
  307. }
  308. if req.UserInfoEndpoint != "" {
  309. provider.UserInfoEndpoint = req.UserInfoEndpoint
  310. }
  311. if req.Scopes != "" {
  312. provider.Scopes = req.Scopes
  313. }
  314. if req.UserIdField != "" {
  315. provider.UserIdField = req.UserIdField
  316. }
  317. if req.UsernameField != "" {
  318. provider.UsernameField = req.UsernameField
  319. }
  320. if req.DisplayNameField != "" {
  321. provider.DisplayNameField = req.DisplayNameField
  322. }
  323. if req.EmailField != "" {
  324. provider.EmailField = req.EmailField
  325. }
  326. if req.WellKnown != nil {
  327. provider.WellKnown = *req.WellKnown
  328. }
  329. if req.AuthStyle != nil {
  330. provider.AuthStyle = *req.AuthStyle
  331. }
  332. if req.AccessPolicy != nil {
  333. provider.AccessPolicy = *req.AccessPolicy
  334. }
  335. if req.AccessDeniedMessage != nil {
  336. provider.AccessDeniedMessage = *req.AccessDeniedMessage
  337. }
  338. if err := model.UpdateCustomOAuthProvider(provider); err != nil {
  339. common.ApiError(c, err)
  340. return
  341. }
  342. // Update the provider in the OAuth registry
  343. if oldSlug != provider.Slug {
  344. oauth.UnregisterCustomProvider(oldSlug)
  345. }
  346. oauth.RegisterOrUpdateCustomProvider(provider)
  347. c.JSON(http.StatusOK, gin.H{
  348. "success": true,
  349. "message": "更新成功",
  350. "data": toCustomOAuthProviderResponse(provider),
  351. })
  352. }
  353. // DeleteCustomOAuthProvider deletes a custom OAuth provider
  354. func DeleteCustomOAuthProvider(c *gin.Context) {
  355. idStr := c.Param("id")
  356. id, err := strconv.Atoi(idStr)
  357. if err != nil {
  358. common.ApiErrorMsg(c, "无效的 ID")
  359. return
  360. }
  361. // Get existing provider to get slug
  362. provider, err := model.GetCustomOAuthProviderById(id)
  363. if err != nil {
  364. common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
  365. return
  366. }
  367. // Check if there are any user bindings
  368. count, err := model.GetBindingCountByProviderId(id)
  369. if err != nil {
  370. common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
  371. common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
  372. return
  373. }
  374. if count > 0 {
  375. common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
  376. return
  377. }
  378. if err := model.DeleteCustomOAuthProvider(id); err != nil {
  379. common.ApiError(c, err)
  380. return
  381. }
  382. // Unregister the provider from the OAuth registry
  383. oauth.UnregisterCustomProvider(provider.Slug)
  384. c.JSON(http.StatusOK, gin.H{
  385. "success": true,
  386. "message": "删除成功",
  387. })
  388. }
  389. // GetUserOAuthBindings returns all OAuth bindings for the current user
  390. func GetUserOAuthBindings(c *gin.Context) {
  391. userId := c.GetInt("id")
  392. if userId == 0 {
  393. common.ApiErrorMsg(c, "未登录")
  394. return
  395. }
  396. bindings, err := model.GetUserOAuthBindingsByUserId(userId)
  397. if err != nil {
  398. common.ApiError(c, err)
  399. return
  400. }
  401. // Build response with provider info
  402. type BindingResponse struct {
  403. ProviderId int `json:"provider_id"`
  404. ProviderName string `json:"provider_name"`
  405. ProviderSlug string `json:"provider_slug"`
  406. ProviderIcon string `json:"provider_icon"`
  407. ProviderUserId string `json:"provider_user_id"`
  408. }
  409. response := make([]BindingResponse, 0)
  410. for _, binding := range bindings {
  411. provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
  412. if err != nil {
  413. continue // Skip if provider not found
  414. }
  415. response = append(response, BindingResponse{
  416. ProviderId: binding.ProviderId,
  417. ProviderName: provider.Name,
  418. ProviderSlug: provider.Slug,
  419. ProviderIcon: provider.Icon,
  420. ProviderUserId: binding.ProviderUserId,
  421. })
  422. }
  423. c.JSON(http.StatusOK, gin.H{
  424. "success": true,
  425. "message": "",
  426. "data": response,
  427. })
  428. }
  429. // UnbindCustomOAuth unbinds a custom OAuth provider from the current user
  430. func UnbindCustomOAuth(c *gin.Context) {
  431. userId := c.GetInt("id")
  432. if userId == 0 {
  433. common.ApiErrorMsg(c, "未登录")
  434. return
  435. }
  436. providerIdStr := c.Param("provider_id")
  437. providerId, err := strconv.Atoi(providerIdStr)
  438. if err != nil {
  439. common.ApiErrorMsg(c, "无效的提供商 ID")
  440. return
  441. }
  442. if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
  443. common.ApiError(c, err)
  444. return
  445. }
  446. c.JSON(http.StatusOK, gin.H{
  447. "success": true,
  448. "message": "解绑成功",
  449. })
  450. }