override.go 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552
  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/types"
  11. "github.com/samber/lo"
  12. "github.com/tidwall/gjson"
  13. "github.com/tidwall/sjson"
  14. )
  15. var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
  16. const (
  17. paramOverrideContextRequestHeaders = "request_headers"
  18. paramOverrideContextRequestHeadersRaw = "request_headers_raw"
  19. paramOverrideContextHeaderOverride = "header_override"
  20. paramOverrideContextHeaderOverrideNormalized = "header_override_normalized"
  21. )
  22. var errSourceHeaderNotFound = errors.New("source header does not exist")
  23. type ConditionOperation struct {
  24. Path string `json:"path"` // JSON路径
  25. Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
  26. Value interface{} `json:"value"` // 匹配的值
  27. Invert bool `json:"invert"` // 反选功能,true表示取反结果
  28. PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
  29. }
  30. type ParamOperation struct {
  31. Path string `json:"path"`
  32. Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects, set_header, delete_header, copy_header, move_header, pass_headers, sync_fields
  33. Value interface{} `json:"value"`
  34. KeepOrigin bool `json:"keep_origin"`
  35. From string `json:"from,omitempty"`
  36. To string `json:"to,omitempty"`
  37. Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
  38. Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
  39. }
  40. type ParamOverrideReturnError struct {
  41. Message string
  42. StatusCode int
  43. Code string
  44. Type string
  45. SkipRetry bool
  46. }
  47. func (e *ParamOverrideReturnError) Error() string {
  48. if e == nil {
  49. return "param override return error"
  50. }
  51. if e.Message == "" {
  52. return "param override return error"
  53. }
  54. return e.Message
  55. }
  56. func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) {
  57. if err == nil {
  58. return nil, false
  59. }
  60. var target *ParamOverrideReturnError
  61. if errors.As(err, &target) {
  62. return target, true
  63. }
  64. return nil, false
  65. }
  66. func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError {
  67. if err == nil {
  68. return types.NewError(
  69. errors.New("param override return error is nil"),
  70. types.ErrorCodeChannelParamOverrideInvalid,
  71. types.ErrOptionWithSkipRetry(),
  72. )
  73. }
  74. statusCode := err.StatusCode
  75. if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired {
  76. statusCode = http.StatusBadRequest
  77. }
  78. errorCode := err.Code
  79. if strings.TrimSpace(errorCode) == "" {
  80. errorCode = string(types.ErrorCodeInvalidRequest)
  81. }
  82. errorType := err.Type
  83. if strings.TrimSpace(errorType) == "" {
  84. errorType = "invalid_request_error"
  85. }
  86. message := strings.TrimSpace(err.Message)
  87. if message == "" {
  88. message = "request blocked by param override"
  89. }
  90. opts := make([]types.NewAPIErrorOptions, 0, 1)
  91. if err.SkipRetry {
  92. opts = append(opts, types.ErrOptionWithSkipRetry())
  93. }
  94. return types.WithOpenAIError(types.OpenAIError{
  95. Message: message,
  96. Type: errorType,
  97. Code: errorCode,
  98. }, statusCode, opts...)
  99. }
  100. func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
  101. if len(paramOverride) == 0 {
  102. return jsonData, nil
  103. }
  104. // 尝试断言为操作格式
  105. if operations, ok := tryParseOperations(paramOverride); ok {
  106. // 使用新方法
  107. result, err := applyOperations(string(jsonData), operations, conditionContext)
  108. return []byte(result), err
  109. }
  110. // 直接使用旧方法
  111. return applyOperationsLegacy(jsonData, paramOverride)
  112. }
  113. func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) {
  114. paramOverride := getParamOverrideMap(info)
  115. if len(paramOverride) == 0 {
  116. return jsonData, nil
  117. }
  118. overrideCtx := BuildParamOverrideContext(info)
  119. result, err := ApplyParamOverride(jsonData, paramOverride, overrideCtx)
  120. if err != nil {
  121. return nil, err
  122. }
  123. syncRuntimeHeaderOverrideFromContext(info, overrideCtx)
  124. return result, nil
  125. }
  126. func getParamOverrideMap(info *RelayInfo) map[string]interface{} {
  127. if info == nil || info.ChannelMeta == nil {
  128. return nil
  129. }
  130. return info.ChannelMeta.ParamOverride
  131. }
  132. func getHeaderOverrideMap(info *RelayInfo) map[string]interface{} {
  133. if info == nil || info.ChannelMeta == nil {
  134. return nil
  135. }
  136. return info.ChannelMeta.HeadersOverride
  137. }
  138. func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
  139. // 检查是否包含 "operations" 字段
  140. if opsValue, exists := paramOverride["operations"]; exists {
  141. if opsSlice, ok := opsValue.([]interface{}); ok {
  142. var operations []ParamOperation
  143. for _, op := range opsSlice {
  144. if opMap, ok := op.(map[string]interface{}); ok {
  145. operation := ParamOperation{}
  146. // 断言必要字段
  147. if path, ok := opMap["path"].(string); ok {
  148. operation.Path = path
  149. }
  150. if mode, ok := opMap["mode"].(string); ok {
  151. operation.Mode = mode
  152. } else {
  153. return nil, false // mode 是必需的
  154. }
  155. // 可选字段
  156. if value, exists := opMap["value"]; exists {
  157. operation.Value = value
  158. }
  159. if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
  160. operation.KeepOrigin = keepOrigin
  161. }
  162. if from, ok := opMap["from"].(string); ok {
  163. operation.From = from
  164. }
  165. if to, ok := opMap["to"].(string); ok {
  166. operation.To = to
  167. }
  168. if logic, ok := opMap["logic"].(string); ok {
  169. operation.Logic = logic
  170. } else {
  171. operation.Logic = "OR" // 默认为OR
  172. }
  173. // 解析条件
  174. if conditions, exists := opMap["conditions"]; exists {
  175. parsedConditions, err := parseConditionOperations(conditions)
  176. if err != nil {
  177. return nil, false
  178. }
  179. operation.Conditions = append(operation.Conditions, parsedConditions...)
  180. }
  181. operations = append(operations, operation)
  182. } else {
  183. return nil, false
  184. }
  185. }
  186. return operations, true
  187. }
  188. }
  189. return nil, false
  190. }
  191. func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
  192. if len(conditions) == 0 {
  193. return true, nil // 没有条件,直接通过
  194. }
  195. results := make([]bool, len(conditions))
  196. for i, condition := range conditions {
  197. result, err := checkSingleCondition(jsonStr, contextJSON, condition)
  198. if err != nil {
  199. return false, err
  200. }
  201. results[i] = result
  202. }
  203. if strings.ToUpper(logic) == "AND" {
  204. return lo.EveryBy(results, func(item bool) bool { return item }), nil
  205. }
  206. return lo.SomeBy(results, func(item bool) bool { return item }), nil
  207. }
  208. func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
  209. // 处理负数索引
  210. path := processNegativeIndex(jsonStr, condition.Path)
  211. value := gjson.Get(jsonStr, path)
  212. if !value.Exists() && contextJSON != "" {
  213. value = gjson.Get(contextJSON, condition.Path)
  214. }
  215. if !value.Exists() {
  216. if condition.PassMissingKey {
  217. return true, nil
  218. }
  219. return false, nil
  220. }
  221. // 利用gjson的类型解析
  222. targetBytes, err := common.Marshal(condition.Value)
  223. if err != nil {
  224. return false, fmt.Errorf("failed to marshal condition value: %v", err)
  225. }
  226. targetValue := gjson.ParseBytes(targetBytes)
  227. result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
  228. if err != nil {
  229. return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
  230. }
  231. if condition.Invert {
  232. result = !result
  233. }
  234. return result, nil
  235. }
  236. func processNegativeIndex(jsonStr string, path string) string {
  237. matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
  238. if len(matches) == 0 {
  239. return path
  240. }
  241. result := path
  242. for _, match := range matches {
  243. negIndex := match[1]
  244. index, _ := strconv.Atoi(negIndex)
  245. arrayPath := strings.Split(path, negIndex)[0]
  246. if strings.HasSuffix(arrayPath, ".") {
  247. arrayPath = arrayPath[:len(arrayPath)-1]
  248. }
  249. array := gjson.Get(jsonStr, arrayPath)
  250. if array.IsArray() {
  251. length := len(array.Array())
  252. actualIndex := length + index
  253. if actualIndex >= 0 && actualIndex < length {
  254. result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1)
  255. }
  256. }
  257. }
  258. return result
  259. }
  260. // compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
  261. func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
  262. switch mode {
  263. case "full":
  264. return compareEqual(jsonValue, targetValue)
  265. case "prefix":
  266. return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
  267. case "suffix":
  268. return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
  269. case "contains":
  270. return strings.Contains(jsonValue.String(), targetValue.String()), nil
  271. case "gt":
  272. return compareNumeric(jsonValue, targetValue, "gt")
  273. case "gte":
  274. return compareNumeric(jsonValue, targetValue, "gte")
  275. case "lt":
  276. return compareNumeric(jsonValue, targetValue, "lt")
  277. case "lte":
  278. return compareNumeric(jsonValue, targetValue, "lte")
  279. default:
  280. return false, fmt.Errorf("unsupported comparison mode: %s", mode)
  281. }
  282. }
  283. func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
  284. // 对null值特殊处理:两个都是null返回true,一个是null另一个不是返回false
  285. if jsonValue.Type == gjson.Null || targetValue.Type == gjson.Null {
  286. return jsonValue.Type == gjson.Null && targetValue.Type == gjson.Null, nil
  287. }
  288. // 对布尔值特殊处理
  289. if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
  290. (targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
  291. return jsonValue.Bool() == targetValue.Bool(), nil
  292. }
  293. // 如果类型不同,报错
  294. if jsonValue.Type != targetValue.Type {
  295. return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
  296. }
  297. switch jsonValue.Type {
  298. case gjson.True, gjson.False:
  299. return jsonValue.Bool() == targetValue.Bool(), nil
  300. case gjson.Number:
  301. return jsonValue.Num == targetValue.Num, nil
  302. case gjson.String:
  303. return jsonValue.String() == targetValue.String(), nil
  304. default:
  305. return jsonValue.String() == targetValue.String(), nil
  306. }
  307. }
  308. func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
  309. // 只有数字类型才支持数值比较
  310. if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
  311. return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
  312. }
  313. jsonNum := jsonValue.Num
  314. targetNum := targetValue.Num
  315. switch operator {
  316. case "gt":
  317. return jsonNum > targetNum, nil
  318. case "gte":
  319. return jsonNum >= targetNum, nil
  320. case "lt":
  321. return jsonNum < targetNum, nil
  322. case "lte":
  323. return jsonNum <= targetNum, nil
  324. default:
  325. return false, fmt.Errorf("unsupported numeric operator: %s", operator)
  326. }
  327. }
  328. // applyOperationsLegacy 原参数覆盖方法
  329. func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
  330. reqMap := make(map[string]interface{})
  331. err := common.Unmarshal(jsonData, &reqMap)
  332. if err != nil {
  333. return nil, err
  334. }
  335. for key, value := range paramOverride {
  336. reqMap[key] = value
  337. }
  338. return common.Marshal(reqMap)
  339. }
  340. func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
  341. context := ensureContextMap(conditionContext)
  342. contextJSON, err := marshalContextJSON(context)
  343. if err != nil {
  344. return "", fmt.Errorf("failed to marshal condition context: %v", err)
  345. }
  346. result := jsonStr
  347. for _, op := range operations {
  348. // 检查条件是否满足
  349. ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
  350. if err != nil {
  351. return "", err
  352. }
  353. if !ok {
  354. continue // 条件不满足,跳过当前操作
  355. }
  356. // 处理路径中的负数索引
  357. opPath := processNegativeIndex(result, op.Path)
  358. switch op.Mode {
  359. case "delete":
  360. result, err = sjson.Delete(result, opPath)
  361. case "set":
  362. if op.KeepOrigin && gjson.Get(result, opPath).Exists() {
  363. continue
  364. }
  365. result, err = sjson.Set(result, opPath, op.Value)
  366. case "move":
  367. opFrom := processNegativeIndex(result, op.From)
  368. opTo := processNegativeIndex(result, op.To)
  369. result, err = moveValue(result, opFrom, opTo)
  370. case "copy":
  371. if op.From == "" || op.To == "" {
  372. return "", fmt.Errorf("copy from/to is required")
  373. }
  374. opFrom := processNegativeIndex(result, op.From)
  375. opTo := processNegativeIndex(result, op.To)
  376. result, err = copyValue(result, opFrom, opTo)
  377. case "prepend":
  378. result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
  379. case "append":
  380. result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
  381. case "trim_prefix":
  382. result, err = trimStringValue(result, opPath, op.Value, true)
  383. case "trim_suffix":
  384. result, err = trimStringValue(result, opPath, op.Value, false)
  385. case "ensure_prefix":
  386. result, err = ensureStringAffix(result, opPath, op.Value, true)
  387. case "ensure_suffix":
  388. result, err = ensureStringAffix(result, opPath, op.Value, false)
  389. case "trim_space":
  390. result, err = transformStringValue(result, opPath, strings.TrimSpace)
  391. case "to_lower":
  392. result, err = transformStringValue(result, opPath, strings.ToLower)
  393. case "to_upper":
  394. result, err = transformStringValue(result, opPath, strings.ToUpper)
  395. case "replace":
  396. result, err = replaceStringValue(result, opPath, op.From, op.To)
  397. case "regex_replace":
  398. result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
  399. case "return_error":
  400. returnErr, parseErr := parseParamOverrideReturnError(op.Value)
  401. if parseErr != nil {
  402. return "", parseErr
  403. }
  404. return "", returnErr
  405. case "prune_objects":
  406. result, err = pruneObjects(result, opPath, contextJSON, op.Value)
  407. case "set_header":
  408. err = setHeaderOverrideInContext(context, op.Path, op.Value, op.KeepOrigin)
  409. if err == nil {
  410. contextJSON, err = marshalContextJSON(context)
  411. }
  412. case "delete_header":
  413. err = deleteHeaderOverrideInContext(context, op.Path)
  414. if err == nil {
  415. contextJSON, err = marshalContextJSON(context)
  416. }
  417. case "copy_header":
  418. sourceHeader := strings.TrimSpace(op.From)
  419. targetHeader := strings.TrimSpace(op.To)
  420. if sourceHeader == "" {
  421. sourceHeader = strings.TrimSpace(op.Path)
  422. }
  423. if targetHeader == "" {
  424. targetHeader = strings.TrimSpace(op.Path)
  425. }
  426. err = copyHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
  427. if err == nil {
  428. contextJSON, err = marshalContextJSON(context)
  429. }
  430. case "move_header":
  431. sourceHeader := strings.TrimSpace(op.From)
  432. targetHeader := strings.TrimSpace(op.To)
  433. if sourceHeader == "" {
  434. sourceHeader = strings.TrimSpace(op.Path)
  435. }
  436. if targetHeader == "" {
  437. targetHeader = strings.TrimSpace(op.Path)
  438. }
  439. err = moveHeaderInContext(context, sourceHeader, targetHeader, op.KeepOrigin)
  440. if err == nil {
  441. contextJSON, err = marshalContextJSON(context)
  442. }
  443. case "pass_headers":
  444. headerNames, parseErr := parseHeaderPassThroughNames(op.Value)
  445. if parseErr != nil {
  446. return "", parseErr
  447. }
  448. for _, headerName := range headerNames {
  449. if err = copyHeaderInContext(context, headerName, headerName, op.KeepOrigin); err != nil {
  450. if errors.Is(err, errSourceHeaderNotFound) {
  451. err = nil
  452. continue
  453. }
  454. break
  455. }
  456. }
  457. if err == nil {
  458. contextJSON, err = marshalContextJSON(context)
  459. }
  460. case "sync_fields":
  461. result, err = syncFieldsBetweenTargets(result, context, op.From, op.To)
  462. if err == nil {
  463. contextJSON, err = marshalContextJSON(context)
  464. }
  465. default:
  466. return "", fmt.Errorf("unknown operation: %s", op.Mode)
  467. }
  468. if err != nil {
  469. return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
  470. }
  471. }
  472. return result, nil
  473. }
  474. func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) {
  475. result := &ParamOverrideReturnError{
  476. StatusCode: http.StatusBadRequest,
  477. Code: string(types.ErrorCodeInvalidRequest),
  478. Type: "invalid_request_error",
  479. SkipRetry: true,
  480. }
  481. switch raw := value.(type) {
  482. case nil:
  483. return nil, fmt.Errorf("return_error value is required")
  484. case string:
  485. result.Message = strings.TrimSpace(raw)
  486. case map[string]interface{}:
  487. if message, ok := raw["message"].(string); ok {
  488. result.Message = strings.TrimSpace(message)
  489. }
  490. if result.Message == "" {
  491. if message, ok := raw["msg"].(string); ok {
  492. result.Message = strings.TrimSpace(message)
  493. }
  494. }
  495. if code, exists := raw["code"]; exists {
  496. codeStr := strings.TrimSpace(fmt.Sprintf("%v", code))
  497. if codeStr != "" {
  498. result.Code = codeStr
  499. }
  500. }
  501. if errType, ok := raw["type"].(string); ok {
  502. errType = strings.TrimSpace(errType)
  503. if errType != "" {
  504. result.Type = errType
  505. }
  506. }
  507. if skipRetry, ok := raw["skip_retry"].(bool); ok {
  508. result.SkipRetry = skipRetry
  509. }
  510. if statusCodeRaw, exists := raw["status_code"]; exists {
  511. statusCode, ok := parseOverrideInt(statusCodeRaw)
  512. if !ok {
  513. return nil, fmt.Errorf("return_error status_code must be an integer")
  514. }
  515. result.StatusCode = statusCode
  516. } else if statusRaw, exists := raw["status"]; exists {
  517. statusCode, ok := parseOverrideInt(statusRaw)
  518. if !ok {
  519. return nil, fmt.Errorf("return_error status must be an integer")
  520. }
  521. result.StatusCode = statusCode
  522. }
  523. default:
  524. return nil, fmt.Errorf("return_error value must be string or object")
  525. }
  526. if result.Message == "" {
  527. return nil, fmt.Errorf("return_error message is required")
  528. }
  529. if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired {
  530. return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode)
  531. }
  532. return result, nil
  533. }
  534. func parseOverrideInt(v interface{}) (int, bool) {
  535. switch value := v.(type) {
  536. case int:
  537. return value, true
  538. case float64:
  539. if value != float64(int(value)) {
  540. return 0, false
  541. }
  542. return int(value), true
  543. default:
  544. return 0, false
  545. }
  546. }
  547. func ensureContextMap(conditionContext map[string]interface{}) map[string]interface{} {
  548. if conditionContext != nil {
  549. return conditionContext
  550. }
  551. return make(map[string]interface{})
  552. }
  553. func marshalContextJSON(context map[string]interface{}) (string, error) {
  554. if context == nil || len(context) == 0 {
  555. return "", nil
  556. }
  557. ctxBytes, err := common.Marshal(context)
  558. if err != nil {
  559. return "", err
  560. }
  561. return string(ctxBytes), nil
  562. }
  563. func setHeaderOverrideInContext(context map[string]interface{}, headerName string, value interface{}, keepOrigin bool) error {
  564. headerName = strings.TrimSpace(headerName)
  565. if headerName == "" {
  566. return fmt.Errorf("header name is required")
  567. }
  568. if keepOrigin {
  569. if _, exists := getHeaderValueFromContext(context, headerName); exists {
  570. return nil
  571. }
  572. }
  573. if value == nil {
  574. return fmt.Errorf("header value is required")
  575. }
  576. headerValue := strings.TrimSpace(fmt.Sprintf("%v", value))
  577. if headerValue == "" {
  578. return fmt.Errorf("header value is required")
  579. }
  580. rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
  581. rawHeaders[headerName] = headerValue
  582. normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
  583. normalizedHeaders[normalizeHeaderContextKey(headerName)] = headerValue
  584. return nil
  585. }
  586. func copyHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
  587. fromHeader = strings.TrimSpace(fromHeader)
  588. toHeader = strings.TrimSpace(toHeader)
  589. if fromHeader == "" || toHeader == "" {
  590. return fmt.Errorf("copy_header from/to is required")
  591. }
  592. value, exists := getHeaderValueFromContext(context, fromHeader)
  593. if !exists {
  594. return fmt.Errorf("%w: %s", errSourceHeaderNotFound, fromHeader)
  595. }
  596. return setHeaderOverrideInContext(context, toHeader, value, keepOrigin)
  597. }
  598. func moveHeaderInContext(context map[string]interface{}, fromHeader, toHeader string, keepOrigin bool) error {
  599. fromHeader = strings.TrimSpace(fromHeader)
  600. toHeader = strings.TrimSpace(toHeader)
  601. if fromHeader == "" || toHeader == "" {
  602. return fmt.Errorf("move_header from/to is required")
  603. }
  604. if err := copyHeaderInContext(context, fromHeader, toHeader, keepOrigin); err != nil {
  605. return err
  606. }
  607. if strings.EqualFold(fromHeader, toHeader) {
  608. return nil
  609. }
  610. return deleteHeaderOverrideInContext(context, fromHeader)
  611. }
  612. func deleteHeaderOverrideInContext(context map[string]interface{}, headerName string) error {
  613. headerName = strings.TrimSpace(headerName)
  614. if headerName == "" {
  615. return fmt.Errorf("header name is required")
  616. }
  617. rawHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverride)
  618. for key := range rawHeaders {
  619. if strings.EqualFold(strings.TrimSpace(key), headerName) {
  620. delete(rawHeaders, key)
  621. }
  622. }
  623. normalizedHeaders := ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized)
  624. delete(normalizedHeaders, normalizeHeaderContextKey(headerName))
  625. return nil
  626. }
  627. func parseHeaderPassThroughNames(value interface{}) ([]string, error) {
  628. normalizeNames := func(values []string) []string {
  629. names := lo.FilterMap(values, func(item string, _ int) (string, bool) {
  630. headerName := strings.TrimSpace(item)
  631. if headerName == "" {
  632. return "", false
  633. }
  634. return headerName, true
  635. })
  636. return lo.Uniq(names)
  637. }
  638. switch raw := value.(type) {
  639. case nil:
  640. return nil, fmt.Errorf("pass_headers value is required")
  641. case string:
  642. trimmed := strings.TrimSpace(raw)
  643. if trimmed == "" {
  644. return nil, fmt.Errorf("pass_headers value is required")
  645. }
  646. if strings.HasPrefix(trimmed, "[") || strings.HasPrefix(trimmed, "{") {
  647. var parsed interface{}
  648. if err := common.UnmarshalJsonStr(trimmed, &parsed); err == nil {
  649. return parseHeaderPassThroughNames(parsed)
  650. }
  651. }
  652. names := normalizeNames(strings.Split(trimmed, ","))
  653. if len(names) == 0 {
  654. return nil, fmt.Errorf("pass_headers value is invalid")
  655. }
  656. return names, nil
  657. case []interface{}:
  658. names := lo.FilterMap(raw, func(item interface{}, _ int) (string, bool) {
  659. headerName := strings.TrimSpace(fmt.Sprintf("%v", item))
  660. if headerName == "" {
  661. return "", false
  662. }
  663. return headerName, true
  664. })
  665. names = lo.Uniq(names)
  666. if len(names) == 0 {
  667. return nil, fmt.Errorf("pass_headers value is invalid")
  668. }
  669. return names, nil
  670. case map[string]interface{}:
  671. candidates := make([]string, 0, 8)
  672. if headersRaw, ok := raw["headers"]; ok {
  673. names, err := parseHeaderPassThroughNames(headersRaw)
  674. if err == nil {
  675. candidates = append(candidates, names...)
  676. }
  677. }
  678. if namesRaw, ok := raw["names"]; ok {
  679. names, err := parseHeaderPassThroughNames(namesRaw)
  680. if err == nil {
  681. candidates = append(candidates, names...)
  682. }
  683. }
  684. if headerRaw, ok := raw["header"]; ok {
  685. names, err := parseHeaderPassThroughNames(headerRaw)
  686. if err == nil {
  687. candidates = append(candidates, names...)
  688. }
  689. }
  690. names := normalizeNames(candidates)
  691. if len(names) == 0 {
  692. return nil, fmt.Errorf("pass_headers value is invalid")
  693. }
  694. return names, nil
  695. default:
  696. return nil, fmt.Errorf("pass_headers value must be string, array or object")
  697. }
  698. }
  699. type syncTarget struct {
  700. kind string
  701. key string
  702. }
  703. func parseSyncTarget(spec string) (syncTarget, error) {
  704. raw := strings.TrimSpace(spec)
  705. if raw == "" {
  706. return syncTarget{}, fmt.Errorf("sync_fields target is required")
  707. }
  708. idx := strings.Index(raw, ":")
  709. if idx < 0 {
  710. // Backward compatibility: treat bare value as JSON path.
  711. return syncTarget{
  712. kind: "json",
  713. key: raw,
  714. }, nil
  715. }
  716. kind := strings.ToLower(strings.TrimSpace(raw[:idx]))
  717. key := strings.TrimSpace(raw[idx+1:])
  718. if key == "" {
  719. return syncTarget{}, fmt.Errorf("sync_fields target key is required: %s", raw)
  720. }
  721. switch kind {
  722. case "json", "body":
  723. return syncTarget{
  724. kind: "json",
  725. key: key,
  726. }, nil
  727. case "header":
  728. return syncTarget{
  729. kind: "header",
  730. key: key,
  731. }, nil
  732. default:
  733. return syncTarget{}, fmt.Errorf("sync_fields target prefix is invalid: %s", raw)
  734. }
  735. }
  736. func readSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget) (interface{}, bool, error) {
  737. switch target.kind {
  738. case "json":
  739. path := processNegativeIndex(jsonStr, target.key)
  740. value := gjson.Get(jsonStr, path)
  741. if !value.Exists() || value.Type == gjson.Null {
  742. return nil, false, nil
  743. }
  744. if value.Type == gjson.String && strings.TrimSpace(value.String()) == "" {
  745. return nil, false, nil
  746. }
  747. return value.Value(), true, nil
  748. case "header":
  749. value, ok := getHeaderValueFromContext(context, target.key)
  750. if !ok || strings.TrimSpace(value) == "" {
  751. return nil, false, nil
  752. }
  753. return value, true, nil
  754. default:
  755. return nil, false, fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
  756. }
  757. }
  758. func writeSyncTargetValue(jsonStr string, context map[string]interface{}, target syncTarget, value interface{}) (string, error) {
  759. switch target.kind {
  760. case "json":
  761. path := processNegativeIndex(jsonStr, target.key)
  762. nextJSON, err := sjson.Set(jsonStr, path, value)
  763. if err != nil {
  764. return "", err
  765. }
  766. return nextJSON, nil
  767. case "header":
  768. if err := setHeaderOverrideInContext(context, target.key, value, false); err != nil {
  769. return "", err
  770. }
  771. return jsonStr, nil
  772. default:
  773. return "", fmt.Errorf("unsupported sync_fields target kind: %s", target.kind)
  774. }
  775. }
  776. func syncFieldsBetweenTargets(jsonStr string, context map[string]interface{}, fromSpec string, toSpec string) (string, error) {
  777. fromTarget, err := parseSyncTarget(fromSpec)
  778. if err != nil {
  779. return "", err
  780. }
  781. toTarget, err := parseSyncTarget(toSpec)
  782. if err != nil {
  783. return "", err
  784. }
  785. fromValue, fromExists, err := readSyncTargetValue(jsonStr, context, fromTarget)
  786. if err != nil {
  787. return "", err
  788. }
  789. toValue, toExists, err := readSyncTargetValue(jsonStr, context, toTarget)
  790. if err != nil {
  791. return "", err
  792. }
  793. // If one side exists and the other side is missing, sync the missing side.
  794. if fromExists && !toExists {
  795. return writeSyncTargetValue(jsonStr, context, toTarget, fromValue)
  796. }
  797. if toExists && !fromExists {
  798. return writeSyncTargetValue(jsonStr, context, fromTarget, toValue)
  799. }
  800. return jsonStr, nil
  801. }
  802. func ensureMapKeyInContext(context map[string]interface{}, key string) map[string]interface{} {
  803. if context == nil {
  804. return map[string]interface{}{}
  805. }
  806. if existing, ok := context[key]; ok {
  807. if mapVal, ok := existing.(map[string]interface{}); ok {
  808. return mapVal
  809. }
  810. }
  811. result := make(map[string]interface{})
  812. context[key] = result
  813. return result
  814. }
  815. func getHeaderValueFromContext(context map[string]interface{}, headerName string) (string, bool) {
  816. headerName = strings.TrimSpace(headerName)
  817. if headerName == "" {
  818. return "", false
  819. }
  820. if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverride), headerName); ok {
  821. return value, true
  822. }
  823. if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeadersRaw), headerName); ok {
  824. return value, true
  825. }
  826. normalizedName := normalizeHeaderContextKey(headerName)
  827. if normalizedName == "" {
  828. return "", false
  829. }
  830. if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextHeaderOverrideNormalized), normalizedName); ok {
  831. return value, true
  832. }
  833. if value, ok := findHeaderValueInMap(ensureMapKeyInContext(context, paramOverrideContextRequestHeaders), normalizedName); ok {
  834. return value, true
  835. }
  836. return "", false
  837. }
  838. func findHeaderValueInMap(source map[string]interface{}, key string) (string, bool) {
  839. if len(source) == 0 {
  840. return "", false
  841. }
  842. entries := lo.Entries(source)
  843. entry, ok := lo.Find(entries, func(item lo.Entry[string, interface{}]) bool {
  844. return strings.EqualFold(strings.TrimSpace(item.Key), key)
  845. })
  846. if !ok {
  847. return "", false
  848. }
  849. value := strings.TrimSpace(fmt.Sprintf("%v", entry.Value))
  850. if value == "" {
  851. return "", false
  852. }
  853. return value, true
  854. }
  855. func normalizeHeaderContextKey(key string) string {
  856. key = strings.TrimSpace(strings.ToLower(key))
  857. if key == "" {
  858. return ""
  859. }
  860. var b strings.Builder
  861. b.Grow(len(key))
  862. previousUnderscore := false
  863. for _, r := range key {
  864. switch {
  865. case r >= 'a' && r <= 'z':
  866. b.WriteRune(r)
  867. previousUnderscore = false
  868. case r >= '0' && r <= '9':
  869. b.WriteRune(r)
  870. previousUnderscore = false
  871. default:
  872. if !previousUnderscore {
  873. b.WriteByte('_')
  874. previousUnderscore = true
  875. }
  876. }
  877. }
  878. result := strings.Trim(b.String(), "_")
  879. return result
  880. }
  881. func buildNormalizedHeaders(headers map[string]string) map[string]interface{} {
  882. if len(headers) == 0 {
  883. return map[string]interface{}{}
  884. }
  885. entries := lo.Entries(headers)
  886. normalizedEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
  887. normalized := normalizeHeaderContextKey(item.Key)
  888. value := strings.TrimSpace(item.Value)
  889. if normalized == "" || value == "" {
  890. return lo.Entry[string, string]{}, false
  891. }
  892. return lo.Entry[string, string]{Key: normalized, Value: value}, true
  893. })
  894. return lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) {
  895. return item.Key, item.Value
  896. })
  897. }
  898. func buildRawHeaders(headers map[string]string) map[string]interface{} {
  899. if len(headers) == 0 {
  900. return map[string]interface{}{}
  901. }
  902. entries := lo.Entries(headers)
  903. rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
  904. key := strings.TrimSpace(item.Key)
  905. value := strings.TrimSpace(item.Value)
  906. if key == "" || value == "" {
  907. return lo.Entry[string, string]{}, false
  908. }
  909. return lo.Entry[string, string]{Key: key, Value: value}, true
  910. })
  911. return lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
  912. return item.Key, item.Value
  913. })
  914. }
  915. func buildHeaderOverrideContext(headers map[string]interface{}) (map[string]interface{}, map[string]interface{}) {
  916. if len(headers) == 0 {
  917. return map[string]interface{}{}, map[string]interface{}{}
  918. }
  919. entries := lo.Entries(headers)
  920. rawEntries := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, string], bool) {
  921. key := strings.TrimSpace(item.Key)
  922. value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
  923. if key == "" || value == "" {
  924. return lo.Entry[string, string]{}, false
  925. }
  926. return lo.Entry[string, string]{Key: key, Value: value}, true
  927. })
  928. raw := lo.SliceToMap(rawEntries, func(item lo.Entry[string, string]) (string, interface{}) {
  929. return item.Key, item.Value
  930. })
  931. normalizedEntries := lo.FilterMap(rawEntries, func(item lo.Entry[string, string], _ int) (lo.Entry[string, string], bool) {
  932. normalized := normalizeHeaderContextKey(item.Key)
  933. if normalized == "" {
  934. return lo.Entry[string, string]{}, false
  935. }
  936. return lo.Entry[string, string]{Key: normalized, Value: item.Value}, true
  937. })
  938. normalized := lo.SliceToMap(normalizedEntries, func(item lo.Entry[string, string]) (string, interface{}) {
  939. return item.Key, item.Value
  940. })
  941. return raw, normalized
  942. }
  943. func syncRuntimeHeaderOverrideFromContext(info *RelayInfo, context map[string]interface{}) {
  944. if info == nil || context == nil {
  945. return
  946. }
  947. raw, exists := context[paramOverrideContextHeaderOverride]
  948. if !exists {
  949. return
  950. }
  951. rawMap, ok := raw.(map[string]interface{})
  952. if !ok {
  953. return
  954. }
  955. entries := lo.Entries(rawMap)
  956. sanitized := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (lo.Entry[string, interface{}], bool) {
  957. key := strings.TrimSpace(item.Key)
  958. if key == "" {
  959. return lo.Entry[string, interface{}]{}, false
  960. }
  961. value := strings.TrimSpace(fmt.Sprintf("%v", item.Value))
  962. if value == "" {
  963. return lo.Entry[string, interface{}]{}, false
  964. }
  965. return lo.Entry[string, interface{}]{Key: key, Value: value}, true
  966. })
  967. info.RuntimeHeadersOverride = lo.SliceToMap(sanitized, func(item lo.Entry[string, interface{}]) (string, interface{}) {
  968. return item.Key, item.Value
  969. })
  970. info.UseRuntimeHeadersOverride = true
  971. }
  972. func moveValue(jsonStr, fromPath, toPath string) (string, error) {
  973. sourceValue := gjson.Get(jsonStr, fromPath)
  974. if !sourceValue.Exists() {
  975. return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
  976. }
  977. result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
  978. if err != nil {
  979. return "", err
  980. }
  981. return sjson.Delete(result, fromPath)
  982. }
  983. func copyValue(jsonStr, fromPath, toPath string) (string, error) {
  984. sourceValue := gjson.Get(jsonStr, fromPath)
  985. if !sourceValue.Exists() {
  986. return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
  987. }
  988. return sjson.Set(jsonStr, toPath, sourceValue.Value())
  989. }
  990. func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
  991. current := gjson.Get(jsonStr, path)
  992. switch {
  993. case current.IsArray():
  994. return modifyArray(jsonStr, path, value, isPrepend)
  995. case current.Type == gjson.String:
  996. return modifyString(jsonStr, path, value, isPrepend)
  997. case current.Type == gjson.JSON:
  998. return mergeObjects(jsonStr, path, value, keepOrigin)
  999. }
  1000. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  1001. }
  1002. func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
  1003. current := gjson.Get(jsonStr, path)
  1004. var newArray []interface{}
  1005. // 添加新值
  1006. addValue := func() {
  1007. if arr, ok := value.([]interface{}); ok {
  1008. newArray = append(newArray, arr...)
  1009. } else {
  1010. newArray = append(newArray, value)
  1011. }
  1012. }
  1013. // 添加原值
  1014. addOriginal := func() {
  1015. current.ForEach(func(_, val gjson.Result) bool {
  1016. newArray = append(newArray, val.Value())
  1017. return true
  1018. })
  1019. }
  1020. if isPrepend {
  1021. addValue()
  1022. addOriginal()
  1023. } else {
  1024. addOriginal()
  1025. addValue()
  1026. }
  1027. return sjson.Set(jsonStr, path, newArray)
  1028. }
  1029. func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
  1030. current := gjson.Get(jsonStr, path)
  1031. valueStr := fmt.Sprintf("%v", value)
  1032. var newStr string
  1033. if isPrepend {
  1034. newStr = valueStr + current.String()
  1035. } else {
  1036. newStr = current.String() + valueStr
  1037. }
  1038. return sjson.Set(jsonStr, path, newStr)
  1039. }
  1040. func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
  1041. current := gjson.Get(jsonStr, path)
  1042. if current.Type != gjson.String {
  1043. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  1044. }
  1045. if value == nil {
  1046. return jsonStr, fmt.Errorf("trim value is required")
  1047. }
  1048. valueStr := fmt.Sprintf("%v", value)
  1049. var newStr string
  1050. if isPrefix {
  1051. newStr = strings.TrimPrefix(current.String(), valueStr)
  1052. } else {
  1053. newStr = strings.TrimSuffix(current.String(), valueStr)
  1054. }
  1055. return sjson.Set(jsonStr, path, newStr)
  1056. }
  1057. func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
  1058. current := gjson.Get(jsonStr, path)
  1059. if current.Type != gjson.String {
  1060. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  1061. }
  1062. if value == nil {
  1063. return jsonStr, fmt.Errorf("ensure value is required")
  1064. }
  1065. valueStr := fmt.Sprintf("%v", value)
  1066. if valueStr == "" {
  1067. return jsonStr, fmt.Errorf("ensure value is required")
  1068. }
  1069. currentStr := current.String()
  1070. if isPrefix {
  1071. if strings.HasPrefix(currentStr, valueStr) {
  1072. return jsonStr, nil
  1073. }
  1074. return sjson.Set(jsonStr, path, valueStr+currentStr)
  1075. }
  1076. if strings.HasSuffix(currentStr, valueStr) {
  1077. return jsonStr, nil
  1078. }
  1079. return sjson.Set(jsonStr, path, currentStr+valueStr)
  1080. }
  1081. func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) {
  1082. current := gjson.Get(jsonStr, path)
  1083. if current.Type != gjson.String {
  1084. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  1085. }
  1086. return sjson.Set(jsonStr, path, transform(current.String()))
  1087. }
  1088. func replaceStringValue(jsonStr, path, from, to string) (string, error) {
  1089. current := gjson.Get(jsonStr, path)
  1090. if current.Type != gjson.String {
  1091. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  1092. }
  1093. if from == "" {
  1094. return jsonStr, fmt.Errorf("replace from is required")
  1095. }
  1096. return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to))
  1097. }
  1098. func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) {
  1099. current := gjson.Get(jsonStr, path)
  1100. if current.Type != gjson.String {
  1101. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  1102. }
  1103. if pattern == "" {
  1104. return jsonStr, fmt.Errorf("regex pattern is required")
  1105. }
  1106. re, err := regexp.Compile(pattern)
  1107. if err != nil {
  1108. return jsonStr, err
  1109. }
  1110. return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
  1111. }
  1112. type pruneObjectsOptions struct {
  1113. conditions []ConditionOperation
  1114. logic string
  1115. recursive bool
  1116. }
  1117. func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
  1118. options, err := parsePruneObjectsOptions(value)
  1119. if err != nil {
  1120. return "", err
  1121. }
  1122. if path == "" {
  1123. var root interface{}
  1124. if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
  1125. return "", err
  1126. }
  1127. cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
  1128. if err != nil {
  1129. return "", err
  1130. }
  1131. cleanedBytes, err := common.Marshal(cleaned)
  1132. if err != nil {
  1133. return "", err
  1134. }
  1135. return string(cleanedBytes), nil
  1136. }
  1137. target := gjson.Get(jsonStr, path)
  1138. if !target.Exists() {
  1139. return jsonStr, nil
  1140. }
  1141. var targetNode interface{}
  1142. if target.Type == gjson.JSON {
  1143. if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
  1144. return "", err
  1145. }
  1146. } else {
  1147. targetNode = target.Value()
  1148. }
  1149. cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
  1150. if err != nil {
  1151. return "", err
  1152. }
  1153. cleanedBytes, err := common.Marshal(cleaned)
  1154. if err != nil {
  1155. return "", err
  1156. }
  1157. return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
  1158. }
  1159. func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
  1160. opts := pruneObjectsOptions{
  1161. logic: "AND",
  1162. recursive: true,
  1163. }
  1164. switch raw := value.(type) {
  1165. case nil:
  1166. return opts, fmt.Errorf("prune_objects value is required")
  1167. case string:
  1168. v := strings.TrimSpace(raw)
  1169. if v == "" {
  1170. return opts, fmt.Errorf("prune_objects value is required")
  1171. }
  1172. opts.conditions = []ConditionOperation{
  1173. {
  1174. Path: "type",
  1175. Mode: "full",
  1176. Value: v,
  1177. },
  1178. }
  1179. case map[string]interface{}:
  1180. if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" {
  1181. opts.logic = logic
  1182. }
  1183. if recursive, ok := raw["recursive"].(bool); ok {
  1184. opts.recursive = recursive
  1185. }
  1186. if condRaw, exists := raw["conditions"]; exists {
  1187. conditions, err := parseConditionOperations(condRaw)
  1188. if err != nil {
  1189. return opts, err
  1190. }
  1191. opts.conditions = append(opts.conditions, conditions...)
  1192. }
  1193. if whereRaw, exists := raw["where"]; exists {
  1194. whereMap, ok := whereRaw.(map[string]interface{})
  1195. if !ok {
  1196. return opts, fmt.Errorf("prune_objects where must be object")
  1197. }
  1198. for key, val := range whereMap {
  1199. key = strings.TrimSpace(key)
  1200. if key == "" {
  1201. continue
  1202. }
  1203. opts.conditions = append(opts.conditions, ConditionOperation{
  1204. Path: key,
  1205. Mode: "full",
  1206. Value: val,
  1207. })
  1208. }
  1209. }
  1210. if matchType, exists := raw["type"]; exists {
  1211. opts.conditions = append(opts.conditions, ConditionOperation{
  1212. Path: "type",
  1213. Mode: "full",
  1214. Value: matchType,
  1215. })
  1216. }
  1217. default:
  1218. return opts, fmt.Errorf("prune_objects value must be string or object")
  1219. }
  1220. if len(opts.conditions) == 0 {
  1221. return opts, fmt.Errorf("prune_objects conditions are required")
  1222. }
  1223. return opts, nil
  1224. }
  1225. func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
  1226. switch typed := raw.(type) {
  1227. case map[string]interface{}:
  1228. entries := lo.Entries(typed)
  1229. conditions := lo.FilterMap(entries, func(item lo.Entry[string, interface{}], _ int) (ConditionOperation, bool) {
  1230. path := strings.TrimSpace(item.Key)
  1231. if path == "" {
  1232. return ConditionOperation{}, false
  1233. }
  1234. return ConditionOperation{
  1235. Path: path,
  1236. Mode: "full",
  1237. Value: item.Value,
  1238. }, true
  1239. })
  1240. if len(conditions) == 0 {
  1241. return nil, fmt.Errorf("conditions object must contain at least one key")
  1242. }
  1243. return conditions, nil
  1244. case []interface{}:
  1245. items := typed
  1246. result := make([]ConditionOperation, 0, len(items))
  1247. for _, item := range items {
  1248. itemMap, ok := item.(map[string]interface{})
  1249. if !ok {
  1250. return nil, fmt.Errorf("condition must be object")
  1251. }
  1252. path, _ := itemMap["path"].(string)
  1253. mode, _ := itemMap["mode"].(string)
  1254. if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
  1255. return nil, fmt.Errorf("condition path/mode is required")
  1256. }
  1257. condition := ConditionOperation{
  1258. Path: path,
  1259. Mode: mode,
  1260. }
  1261. if value, exists := itemMap["value"]; exists {
  1262. condition.Value = value
  1263. }
  1264. if invert, ok := itemMap["invert"].(bool); ok {
  1265. condition.Invert = invert
  1266. }
  1267. if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
  1268. condition.PassMissingKey = passMissingKey
  1269. }
  1270. result = append(result, condition)
  1271. }
  1272. return result, nil
  1273. default:
  1274. return nil, fmt.Errorf("conditions must be an array or object")
  1275. }
  1276. }
  1277. func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
  1278. switch value := node.(type) {
  1279. case []interface{}:
  1280. result := make([]interface{}, 0, len(value))
  1281. for _, item := range value {
  1282. next, drop, err := pruneObjectsNode(item, options, contextJSON, false)
  1283. if err != nil {
  1284. return nil, false, err
  1285. }
  1286. if drop {
  1287. continue
  1288. }
  1289. result = append(result, next)
  1290. }
  1291. return result, false, nil
  1292. case map[string]interface{}:
  1293. shouldDrop, err := shouldPruneObject(value, options, contextJSON)
  1294. if err != nil {
  1295. return nil, false, err
  1296. }
  1297. if shouldDrop && !isRoot {
  1298. return nil, true, nil
  1299. }
  1300. if !options.recursive {
  1301. return value, false, nil
  1302. }
  1303. for key, child := range value {
  1304. next, drop, err := pruneObjectsNode(child, options, contextJSON, false)
  1305. if err != nil {
  1306. return nil, false, err
  1307. }
  1308. if drop {
  1309. delete(value, key)
  1310. continue
  1311. }
  1312. value[key] = next
  1313. }
  1314. return value, false, nil
  1315. default:
  1316. return node, false, nil
  1317. }
  1318. }
  1319. func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) {
  1320. nodeBytes, err := common.Marshal(node)
  1321. if err != nil {
  1322. return false, err
  1323. }
  1324. return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
  1325. }
  1326. func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
  1327. current := gjson.Get(jsonStr, path)
  1328. var currentMap, newMap map[string]interface{}
  1329. // 解析当前值
  1330. if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
  1331. return "", err
  1332. }
  1333. // 解析新值
  1334. switch v := value.(type) {
  1335. case map[string]interface{}:
  1336. newMap = v
  1337. default:
  1338. jsonBytes, _ := common.Marshal(v)
  1339. if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
  1340. return "", err
  1341. }
  1342. }
  1343. // 合并
  1344. result := make(map[string]interface{})
  1345. for k, v := range currentMap {
  1346. result[k] = v
  1347. }
  1348. for k, v := range newMap {
  1349. if !keepOrigin || result[k] == nil {
  1350. result[k] = v
  1351. }
  1352. }
  1353. return sjson.Set(jsonStr, path, result)
  1354. }
  1355. // BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
  1356. // 目前内置以下字段:
  1357. // - upstream_model/model:始终为通道映射后的上游模型名。
  1358. // - original_model:请求最初指定的模型名。
  1359. // - request_path:请求路径
  1360. // - is_channel_test:是否为渠道测试请求(同 is_test)。
  1361. func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
  1362. if info == nil {
  1363. return nil
  1364. }
  1365. ctx := make(map[string]interface{})
  1366. if info.ChannelMeta != nil && info.ChannelMeta.UpstreamModelName != "" {
  1367. ctx["model"] = info.ChannelMeta.UpstreamModelName
  1368. ctx["upstream_model"] = info.ChannelMeta.UpstreamModelName
  1369. }
  1370. if info.OriginModelName != "" {
  1371. ctx["original_model"] = info.OriginModelName
  1372. if _, exists := ctx["model"]; !exists {
  1373. ctx["model"] = info.OriginModelName
  1374. }
  1375. }
  1376. if info.RequestURLPath != "" {
  1377. requestPath := info.RequestURLPath
  1378. if requestPath != "" {
  1379. ctx["request_path"] = requestPath
  1380. }
  1381. }
  1382. ctx[paramOverrideContextRequestHeaders] = buildNormalizedHeaders(info.RequestHeaders)
  1383. ctx[paramOverrideContextRequestHeadersRaw] = buildRawHeaders(info.RequestHeaders)
  1384. headerOverrideSource := getHeaderOverrideMap(info)
  1385. if info.UseRuntimeHeadersOverride {
  1386. headerOverrideSource = info.RuntimeHeadersOverride
  1387. }
  1388. rawHeaderOverride, normalizedHeaderOverride := buildHeaderOverrideContext(headerOverrideSource)
  1389. ctx[paramOverrideContextHeaderOverride] = rawHeaderOverride
  1390. ctx[paramOverrideContextHeaderOverrideNormalized] = normalizedHeaderOverride
  1391. ctx["retry_index"] = info.RetryIndex
  1392. ctx["is_retry"] = info.RetryIndex > 0
  1393. ctx["retry"] = map[string]interface{}{
  1394. "index": info.RetryIndex,
  1395. "is_retry": info.RetryIndex > 0,
  1396. }
  1397. if info.LastError != nil {
  1398. code := string(info.LastError.GetErrorCode())
  1399. errorType := string(info.LastError.GetErrorType())
  1400. lastError := map[string]interface{}{
  1401. "status_code": info.LastError.StatusCode,
  1402. "message": info.LastError.Error(),
  1403. "code": code,
  1404. "error_code": code,
  1405. "type": errorType,
  1406. "error_type": errorType,
  1407. "skip_retry": types.IsSkipRetryError(info.LastError),
  1408. }
  1409. ctx["last_error"] = lastError
  1410. ctx["last_error_status_code"] = info.LastError.StatusCode
  1411. ctx["last_error_message"] = info.LastError.Error()
  1412. ctx["last_error_code"] = code
  1413. ctx["last_error_type"] = errorType
  1414. }
  1415. ctx["is_channel_test"] = info.IsChannelTest
  1416. return ctx
  1417. }