custom_oauth_provider.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "github.com/QuantumNous/new-api/common"
  8. )
  9. type accessPolicyPayload struct {
  10. Logic string `json:"logic"`
  11. Conditions []accessConditionItem `json:"conditions"`
  12. Groups []accessPolicyPayload `json:"groups"`
  13. }
  14. type accessConditionItem struct {
  15. Field string `json:"field"`
  16. Op string `json:"op"`
  17. Value any `json:"value"`
  18. }
  19. var supportedAccessPolicyOps = map[string]struct{}{
  20. "eq": {},
  21. "ne": {},
  22. "gt": {},
  23. "gte": {},
  24. "lt": {},
  25. "lte": {},
  26. "in": {},
  27. "not_in": {},
  28. "contains": {},
  29. "not_contains": {},
  30. "exists": {},
  31. "not_exists": {},
  32. }
  33. // CustomOAuthProvider stores configuration for custom OAuth providers
  34. type CustomOAuthProvider struct {
  35. Id int `json:"id" gorm:"primaryKey"`
  36. Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
  37. Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
  38. Icon string `json:"icon" gorm:"type:varchar(128);default:''"` // Icon name from @lobehub/icons
  39. Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
  40. ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
  41. ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
  42. AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
  43. TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
  44. UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
  45. Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
  46. // Field mapping configuration (supports JSONPath via gjson)
  47. UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
  48. UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
  49. DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
  50. EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
  51. // Advanced options
  52. WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
  53. AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
  54. AccessPolicy string `json:"access_policy" gorm:"type:text"` // JSON policy for access control based on user info
  55. AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied
  56. CreatedAt time.Time `json:"created_at"`
  57. UpdatedAt time.Time `json:"updated_at"`
  58. }
  59. func (CustomOAuthProvider) TableName() string {
  60. return "custom_oauth_providers"
  61. }
  62. // GetAllCustomOAuthProviders returns all custom OAuth providers
  63. func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
  64. var providers []*CustomOAuthProvider
  65. err := DB.Order("id asc").Find(&providers).Error
  66. return providers, err
  67. }
  68. // GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
  69. func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
  70. var providers []*CustomOAuthProvider
  71. err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
  72. return providers, err
  73. }
  74. // GetCustomOAuthProviderById returns a custom OAuth provider by ID
  75. func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
  76. var provider CustomOAuthProvider
  77. err := DB.First(&provider, id).Error
  78. if err != nil {
  79. return nil, err
  80. }
  81. return &provider, nil
  82. }
  83. // GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
  84. func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
  85. var provider CustomOAuthProvider
  86. err := DB.Where("slug = ?", slug).First(&provider).Error
  87. if err != nil {
  88. return nil, err
  89. }
  90. return &provider, nil
  91. }
  92. // CreateCustomOAuthProvider creates a new custom OAuth provider
  93. func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
  94. if err := validateCustomOAuthProvider(provider); err != nil {
  95. return err
  96. }
  97. return DB.Create(provider).Error
  98. }
  99. // UpdateCustomOAuthProvider updates an existing custom OAuth provider
  100. func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
  101. if err := validateCustomOAuthProvider(provider); err != nil {
  102. return err
  103. }
  104. return DB.Save(provider).Error
  105. }
  106. // DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
  107. func DeleteCustomOAuthProvider(id int) error {
  108. // First, delete all user bindings for this provider
  109. if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
  110. return err
  111. }
  112. return DB.Delete(&CustomOAuthProvider{}, id).Error
  113. }
  114. // IsSlugTaken checks if a slug is already taken by another provider
  115. // Returns true on DB errors (fail-closed) to prevent slug conflicts
  116. func IsSlugTaken(slug string, excludeId int) bool {
  117. var count int64
  118. query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
  119. if excludeId > 0 {
  120. query = query.Where("id != ?", excludeId)
  121. }
  122. res := query.Count(&count)
  123. if res.Error != nil {
  124. // Fail-closed: treat DB errors as slug being taken to prevent conflicts
  125. return true
  126. }
  127. return count > 0
  128. }
  129. // validateCustomOAuthProvider validates a custom OAuth provider configuration
  130. func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
  131. if provider.Name == "" {
  132. return errors.New("provider name is required")
  133. }
  134. if provider.Slug == "" {
  135. return errors.New("provider slug is required")
  136. }
  137. // Slug must be lowercase and contain only alphanumeric characters and hyphens
  138. slug := strings.ToLower(provider.Slug)
  139. for _, c := range slug {
  140. if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
  141. return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
  142. }
  143. }
  144. provider.Slug = slug
  145. if provider.ClientId == "" {
  146. return errors.New("client ID is required")
  147. }
  148. if provider.AuthorizationEndpoint == "" {
  149. return errors.New("authorization endpoint is required")
  150. }
  151. if provider.TokenEndpoint == "" {
  152. return errors.New("token endpoint is required")
  153. }
  154. if provider.UserInfoEndpoint == "" {
  155. return errors.New("user info endpoint is required")
  156. }
  157. // Set defaults for field mappings if empty
  158. if provider.UserIdField == "" {
  159. provider.UserIdField = "sub"
  160. }
  161. if provider.UsernameField == "" {
  162. provider.UsernameField = "preferred_username"
  163. }
  164. if provider.DisplayNameField == "" {
  165. provider.DisplayNameField = "name"
  166. }
  167. if provider.EmailField == "" {
  168. provider.EmailField = "email"
  169. }
  170. if provider.Scopes == "" {
  171. provider.Scopes = "openid profile email"
  172. }
  173. if strings.TrimSpace(provider.AccessPolicy) != "" {
  174. var policy accessPolicyPayload
  175. if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil {
  176. return errors.New("access_policy must be valid JSON")
  177. }
  178. if err := validateAccessPolicyPayload(&policy); err != nil {
  179. return fmt.Errorf("access_policy is invalid: %w", err)
  180. }
  181. }
  182. return nil
  183. }
  184. func validateAccessPolicyPayload(policy *accessPolicyPayload) error {
  185. if policy == nil {
  186. return errors.New("policy is nil")
  187. }
  188. logic := strings.ToLower(strings.TrimSpace(policy.Logic))
  189. if logic == "" {
  190. logic = "and"
  191. }
  192. if logic != "and" && logic != "or" {
  193. return fmt.Errorf("unsupported logic: %s", logic)
  194. }
  195. if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
  196. return errors.New("policy requires at least one condition or group")
  197. }
  198. for index, condition := range policy.Conditions {
  199. field := strings.TrimSpace(condition.Field)
  200. if field == "" {
  201. return fmt.Errorf("condition[%d].field is required", index)
  202. }
  203. op := strings.ToLower(strings.TrimSpace(condition.Op))
  204. if _, ok := supportedAccessPolicyOps[op]; !ok {
  205. return fmt.Errorf("condition[%d].op is unsupported: %s", index, op)
  206. }
  207. if op == "in" || op == "not_in" {
  208. if _, ok := condition.Value.([]any); !ok {
  209. return fmt.Errorf("condition[%d].value must be an array for op %s", index, op)
  210. }
  211. }
  212. }
  213. for index := range policy.Groups {
  214. if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil {
  215. return fmt.Errorf("group[%d]: %w", index, err)
  216. }
  217. }
  218. return nil
  219. }