generic.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. package oauth
  2. import (
  3. "context"
  4. "encoding/base64"
  5. stdjson "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "net/url"
  11. "regexp"
  12. "strconv"
  13. "strings"
  14. "time"
  15. "github.com/QuantumNous/new-api/common"
  16. "github.com/QuantumNous/new-api/i18n"
  17. "github.com/QuantumNous/new-api/logger"
  18. "github.com/QuantumNous/new-api/model"
  19. "github.com/QuantumNous/new-api/setting/system_setting"
  20. "github.com/gin-gonic/gin"
  21. "github.com/samber/lo"
  22. "github.com/tidwall/gjson"
  23. )
  24. // AuthStyle defines how to send client credentials
  25. const (
  26. AuthStyleAutoDetect = 0 // Auto-detect based on server response
  27. AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
  28. AuthStyleInHeader = 2 // Send as Basic Auth header
  29. )
  30. // GenericOAuthProvider implements OAuth for custom/generic OAuth providers
  31. type GenericOAuthProvider struct {
  32. config *model.CustomOAuthProvider
  33. }
  34. type accessPolicy struct {
  35. Logic string `json:"logic"`
  36. Conditions []accessCondition `json:"conditions"`
  37. Groups []accessPolicy `json:"groups"`
  38. }
  39. type accessCondition struct {
  40. Field string `json:"field"`
  41. Op string `json:"op"`
  42. Value any `json:"value"`
  43. }
  44. type accessPolicyFailure struct {
  45. Field string
  46. Op string
  47. Expected any
  48. Current any
  49. }
  50. var supportedAccessPolicyOps = []string{
  51. "eq",
  52. "ne",
  53. "gt",
  54. "gte",
  55. "lt",
  56. "lte",
  57. "in",
  58. "not_in",
  59. "contains",
  60. "not_contains",
  61. "exists",
  62. "not_exists",
  63. }
  64. // NewGenericOAuthProvider creates a new generic OAuth provider from config
  65. func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
  66. return &GenericOAuthProvider{config: config}
  67. }
  68. func (p *GenericOAuthProvider) GetName() string {
  69. return p.config.Name
  70. }
  71. func (p *GenericOAuthProvider) IsEnabled() bool {
  72. return p.config.Enabled
  73. }
  74. func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
  75. return p.config
  76. }
  77. func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
  78. if code == "" {
  79. return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
  80. }
  81. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
  82. redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
  83. values := url.Values{}
  84. values.Set("grant_type", "authorization_code")
  85. values.Set("code", code)
  86. values.Set("redirect_uri", redirectUri)
  87. // Determine auth style
  88. authStyle := p.config.AuthStyle
  89. if authStyle == AuthStyleAutoDetect {
  90. // Default to params style for most OAuth servers
  91. authStyle = AuthStyleInParams
  92. }
  93. var req *http.Request
  94. var err error
  95. if authStyle == AuthStyleInParams {
  96. values.Set("client_id", p.config.ClientId)
  97. values.Set("client_secret", p.config.ClientSecret)
  98. }
  99. req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
  100. if err != nil {
  101. return nil, err
  102. }
  103. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  104. req.Header.Set("Accept", "application/json")
  105. if authStyle == AuthStyleInHeader {
  106. // Basic Auth
  107. credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
  108. req.Header.Set("Authorization", "Basic "+credentials)
  109. }
  110. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
  111. p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
  112. client := http.Client{
  113. Timeout: 20 * time.Second,
  114. }
  115. res, err := client.Do(req)
  116. if err != nil {
  117. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
  118. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
  119. }
  120. defer res.Body.Close()
  121. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
  122. body, err := io.ReadAll(res.Body)
  123. if err != nil {
  124. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
  125. return nil, err
  126. }
  127. bodyStr := string(body)
  128. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
  129. // Try to parse as JSON first
  130. var tokenResponse struct {
  131. AccessToken string `json:"access_token"`
  132. TokenType string `json:"token_type"`
  133. RefreshToken string `json:"refresh_token"`
  134. ExpiresIn int `json:"expires_in"`
  135. Scope string `json:"scope"`
  136. IDToken string `json:"id_token"`
  137. Error string `json:"error"`
  138. ErrorDesc string `json:"error_description"`
  139. }
  140. if err := common.Unmarshal(body, &tokenResponse); err != nil {
  141. // Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
  142. parsedValues, parseErr := url.ParseQuery(bodyStr)
  143. if parseErr != nil {
  144. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
  145. return nil, err
  146. }
  147. tokenResponse.AccessToken = parsedValues.Get("access_token")
  148. tokenResponse.TokenType = parsedValues.Get("token_type")
  149. tokenResponse.Scope = parsedValues.Get("scope")
  150. }
  151. if tokenResponse.Error != "" {
  152. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
  153. p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
  154. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
  155. }
  156. if tokenResponse.AccessToken == "" {
  157. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
  158. return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
  159. }
  160. logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
  161. return &OAuthToken{
  162. AccessToken: tokenResponse.AccessToken,
  163. TokenType: tokenResponse.TokenType,
  164. RefreshToken: tokenResponse.RefreshToken,
  165. ExpiresIn: tokenResponse.ExpiresIn,
  166. Scope: tokenResponse.Scope,
  167. IDToken: tokenResponse.IDToken,
  168. }, nil
  169. }
  170. func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
  171. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
  172. req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
  173. if err != nil {
  174. return nil, err
  175. }
  176. // Set authorization header
  177. tokenType := token.TokenType
  178. if tokenType == "" {
  179. tokenType = "Bearer"
  180. }
  181. req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
  182. req.Header.Set("Accept", "application/json")
  183. client := http.Client{
  184. Timeout: 20 * time.Second,
  185. }
  186. res, err := client.Do(req)
  187. if err != nil {
  188. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
  189. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
  190. }
  191. defer res.Body.Close()
  192. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
  193. if res.StatusCode != http.StatusOK {
  194. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
  195. return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
  196. }
  197. body, err := io.ReadAll(res.Body)
  198. if err != nil {
  199. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
  200. return nil, err
  201. }
  202. bodyStr := string(body)
  203. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
  204. // Extract fields using gjson (supports JSONPath-like syntax)
  205. userId := gjson.Get(bodyStr, p.config.UserIdField).String()
  206. username := gjson.Get(bodyStr, p.config.UsernameField).String()
  207. displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
  208. email := gjson.Get(bodyStr, p.config.EmailField).String()
  209. // If user ID field returns a number, convert it
  210. if userId == "" {
  211. // Try to get as number
  212. userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
  213. if userIdNum.Exists() {
  214. userId = userIdNum.Raw
  215. // Remove quotes if present
  216. userId = strings.Trim(userId, "\"")
  217. }
  218. }
  219. if userId == "" {
  220. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
  221. return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
  222. }
  223. logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
  224. p.config.Slug, userId, username, displayName, email)
  225. policyRaw := strings.TrimSpace(p.config.AccessPolicy)
  226. if policyRaw != "" {
  227. policy, err := parseAccessPolicy(policyRaw)
  228. if err != nil {
  229. logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error()))
  230. return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration")
  231. }
  232. allowed, failure := evaluateAccessPolicy(bodyStr, policy)
  233. if !allowed {
  234. message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure)
  235. logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v",
  236. p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current))
  237. return nil, &AccessDeniedError{Message: message}
  238. }
  239. }
  240. return &OAuthUser{
  241. ProviderUserID: userId,
  242. Username: username,
  243. DisplayName: displayName,
  244. Email: email,
  245. Extra: map[string]any{
  246. "provider": p.config.Slug,
  247. },
  248. }, nil
  249. }
  250. func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
  251. return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
  252. }
  253. func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
  254. foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
  255. if err != nil {
  256. return err
  257. }
  258. *user = *foundUser
  259. return nil
  260. }
  261. func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
  262. // For generic providers, we store the binding in user_oauth_bindings table
  263. // This is handled separately in the OAuth controller
  264. }
  265. func (p *GenericOAuthProvider) GetProviderPrefix() string {
  266. return p.config.Slug + "_"
  267. }
  268. // GetProviderId returns the provider ID for binding purposes
  269. func (p *GenericOAuthProvider) GetProviderId() int {
  270. return p.config.Id
  271. }
  272. // IsGenericProvider returns true for generic providers
  273. func (p *GenericOAuthProvider) IsGenericProvider() bool {
  274. return true
  275. }
  276. func parseAccessPolicy(raw string) (*accessPolicy, error) {
  277. var policy accessPolicy
  278. if err := common.UnmarshalJsonStr(raw, &policy); err != nil {
  279. return nil, err
  280. }
  281. if err := validateAccessPolicy(&policy); err != nil {
  282. return nil, err
  283. }
  284. return &policy, nil
  285. }
  286. func validateAccessPolicy(policy *accessPolicy) error {
  287. if policy == nil {
  288. return errors.New("policy is nil")
  289. }
  290. logic := strings.ToLower(strings.TrimSpace(policy.Logic))
  291. if logic == "" {
  292. logic = "and"
  293. }
  294. if !lo.Contains([]string{"and", "or"}, logic) {
  295. return fmt.Errorf("unsupported policy logic: %s", logic)
  296. }
  297. policy.Logic = logic
  298. if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
  299. return errors.New("policy requires at least one condition or group")
  300. }
  301. for index := range policy.Conditions {
  302. if err := validateAccessCondition(&policy.Conditions[index], index); err != nil {
  303. return err
  304. }
  305. }
  306. for index := range policy.Groups {
  307. if err := validateAccessPolicy(&policy.Groups[index]); err != nil {
  308. return fmt.Errorf("invalid policy group[%d]: %w", index, err)
  309. }
  310. }
  311. return nil
  312. }
  313. func validateAccessCondition(condition *accessCondition, index int) error {
  314. if condition == nil {
  315. return fmt.Errorf("condition[%d] is nil", index)
  316. }
  317. condition.Field = strings.TrimSpace(condition.Field)
  318. if condition.Field == "" {
  319. return fmt.Errorf("condition[%d].field is required", index)
  320. }
  321. condition.Op = normalizePolicyOp(condition.Op)
  322. if !lo.Contains(supportedAccessPolicyOps, condition.Op) {
  323. return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op)
  324. }
  325. if lo.Contains([]string{"in", "not_in"}, condition.Op) {
  326. if _, ok := condition.Value.([]any); !ok {
  327. return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op)
  328. }
  329. }
  330. return nil
  331. }
  332. func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) {
  333. if policy == nil {
  334. return true, nil
  335. }
  336. logic := strings.ToLower(strings.TrimSpace(policy.Logic))
  337. if logic == "" {
  338. logic = "and"
  339. }
  340. hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0
  341. if !hasAny {
  342. return true, nil
  343. }
  344. if logic == "or" {
  345. var firstFailure *accessPolicyFailure
  346. for _, cond := range policy.Conditions {
  347. ok, failure := evaluateAccessCondition(body, cond)
  348. if ok {
  349. return true, nil
  350. }
  351. if firstFailure == nil {
  352. firstFailure = failure
  353. }
  354. }
  355. for _, group := range policy.Groups {
  356. ok, failure := evaluateAccessPolicy(body, &group)
  357. if ok {
  358. return true, nil
  359. }
  360. if firstFailure == nil {
  361. firstFailure = failure
  362. }
  363. }
  364. return false, firstFailure
  365. }
  366. for _, cond := range policy.Conditions {
  367. ok, failure := evaluateAccessCondition(body, cond)
  368. if !ok {
  369. return false, failure
  370. }
  371. }
  372. for _, group := range policy.Groups {
  373. ok, failure := evaluateAccessPolicy(body, &group)
  374. if !ok {
  375. return false, failure
  376. }
  377. }
  378. return true, nil
  379. }
  380. func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) {
  381. path := cond.Field
  382. op := cond.Op
  383. result := gjson.Get(body, path)
  384. current := gjsonResultToValue(result)
  385. failure := &accessPolicyFailure{
  386. Field: path,
  387. Op: op,
  388. Expected: cond.Value,
  389. Current: current,
  390. }
  391. switch op {
  392. case "exists":
  393. return result.Exists(), failure
  394. case "not_exists":
  395. return !result.Exists(), failure
  396. case "eq":
  397. return compareAny(current, cond.Value) == 0, failure
  398. case "ne":
  399. return compareAny(current, cond.Value) != 0, failure
  400. case "gt":
  401. return compareAny(current, cond.Value) > 0, failure
  402. case "gte":
  403. return compareAny(current, cond.Value) >= 0, failure
  404. case "lt":
  405. return compareAny(current, cond.Value) < 0, failure
  406. case "lte":
  407. return compareAny(current, cond.Value) <= 0, failure
  408. case "in":
  409. return valueInSlice(current, cond.Value), failure
  410. case "not_in":
  411. return !valueInSlice(current, cond.Value), failure
  412. case "contains":
  413. return containsValue(current, cond.Value), failure
  414. case "not_contains":
  415. return !containsValue(current, cond.Value), failure
  416. default:
  417. return false, failure
  418. }
  419. }
  420. func normalizePolicyOp(op string) string {
  421. return strings.ToLower(strings.TrimSpace(op))
  422. }
  423. func gjsonResultToValue(result gjson.Result) any {
  424. if !result.Exists() {
  425. return nil
  426. }
  427. if result.IsArray() {
  428. arr := result.Array()
  429. values := make([]any, 0, len(arr))
  430. for _, item := range arr {
  431. values = append(values, gjsonResultToValue(item))
  432. }
  433. return values
  434. }
  435. switch result.Type {
  436. case gjson.Null:
  437. return nil
  438. case gjson.True:
  439. return true
  440. case gjson.False:
  441. return false
  442. case gjson.Number:
  443. return result.Num
  444. case gjson.String:
  445. return result.String()
  446. case gjson.JSON:
  447. var data any
  448. if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil {
  449. return data
  450. }
  451. return result.Raw
  452. default:
  453. return result.Value()
  454. }
  455. }
  456. func compareAny(left any, right any) int {
  457. if lf, ok := toFloat(left); ok {
  458. if rf, ok2 := toFloat(right); ok2 {
  459. switch {
  460. case lf < rf:
  461. return -1
  462. case lf > rf:
  463. return 1
  464. default:
  465. return 0
  466. }
  467. }
  468. }
  469. ls := strings.TrimSpace(fmt.Sprint(left))
  470. rs := strings.TrimSpace(fmt.Sprint(right))
  471. switch {
  472. case ls < rs:
  473. return -1
  474. case ls > rs:
  475. return 1
  476. default:
  477. return 0
  478. }
  479. }
  480. func toFloat(v any) (float64, bool) {
  481. switch value := v.(type) {
  482. case float64:
  483. return value, true
  484. case float32:
  485. return float64(value), true
  486. case int:
  487. return float64(value), true
  488. case int8:
  489. return float64(value), true
  490. case int16:
  491. return float64(value), true
  492. case int32:
  493. return float64(value), true
  494. case int64:
  495. return float64(value), true
  496. case uint:
  497. return float64(value), true
  498. case uint8:
  499. return float64(value), true
  500. case uint16:
  501. return float64(value), true
  502. case uint32:
  503. return float64(value), true
  504. case uint64:
  505. return float64(value), true
  506. case stdjson.Number:
  507. n, err := value.Float64()
  508. if err == nil {
  509. return n, true
  510. }
  511. case string:
  512. n, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
  513. if err == nil {
  514. return n, true
  515. }
  516. }
  517. return 0, false
  518. }
  519. func valueInSlice(current any, expected any) bool {
  520. list, ok := expected.([]any)
  521. if !ok {
  522. return false
  523. }
  524. return lo.ContainsBy(list, func(item any) bool {
  525. return compareAny(current, item) == 0
  526. })
  527. }
  528. func containsValue(current any, expected any) bool {
  529. switch value := current.(type) {
  530. case string:
  531. target := strings.TrimSpace(fmt.Sprint(expected))
  532. return strings.Contains(value, target)
  533. case []any:
  534. return lo.ContainsBy(value, func(item any) bool {
  535. return compareAny(item, expected) == 0
  536. })
  537. }
  538. return false
  539. }
  540. func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string {
  541. defaultMessage := "Access denied: your account does not meet this provider's access requirements."
  542. message := strings.TrimSpace(template)
  543. if message == "" {
  544. return defaultMessage
  545. }
  546. if failure == nil {
  547. failure = &accessPolicyFailure{}
  548. }
  549. replacements := map[string]string{
  550. "{{provider}}": providerName,
  551. "{{field}}": failure.Field,
  552. "{{op}}": failure.Op,
  553. "{{required}}": fmt.Sprint(failure.Expected),
  554. "{{current}}": fmt.Sprint(failure.Current),
  555. }
  556. for key, value := range replacements {
  557. message = strings.ReplaceAll(message, key, value)
  558. }
  559. currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`)
  560. message = currentPattern.ReplaceAllStringFunc(message, func(token string) string {
  561. match := currentPattern.FindStringSubmatch(token)
  562. if len(match) != 2 {
  563. return ""
  564. }
  565. path := strings.TrimSpace(match[1])
  566. if path == "" {
  567. return ""
  568. }
  569. return strings.TrimSpace(gjson.Get(body, path).String())
  570. })
  571. requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`)
  572. message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string {
  573. match := requiredPattern.FindStringSubmatch(token)
  574. if len(match) != 2 {
  575. return ""
  576. }
  577. path := strings.TrimSpace(match[1])
  578. if failure.Field == path {
  579. return fmt.Sprint(failure.Expected)
  580. }
  581. return ""
  582. })
  583. return strings.TrimSpace(message)
  584. }