|
|
@@ -1,12 +1,15 @@
|
|
|
package common
|
|
|
|
|
|
import (
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
+ "net/http"
|
|
|
"regexp"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
+ "github.com/QuantumNous/new-api/types"
|
|
|
"github.com/tidwall/gjson"
|
|
|
"github.com/tidwall/sjson"
|
|
|
)
|
|
|
@@ -23,7 +26,7 @@ type ConditionOperation struct {
|
|
|
|
|
|
type ParamOperation struct {
|
|
|
Path string `json:"path"`
|
|
|
- Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace
|
|
|
+ Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects
|
|
|
Value interface{} `json:"value"`
|
|
|
KeepOrigin bool `json:"keep_origin"`
|
|
|
From string `json:"from,omitempty"`
|
|
|
@@ -32,6 +35,76 @@ type ParamOperation struct {
|
|
|
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
|
|
}
|
|
|
|
|
|
+type ParamOverrideReturnError struct {
|
|
|
+ Message string
|
|
|
+ StatusCode int
|
|
|
+ Code string
|
|
|
+ Type string
|
|
|
+ SkipRetry bool
|
|
|
+}
|
|
|
+
|
|
|
+func (e *ParamOverrideReturnError) Error() string {
|
|
|
+ if e == nil {
|
|
|
+ return "param override return error"
|
|
|
+ }
|
|
|
+ if e.Message == "" {
|
|
|
+ return "param override return error"
|
|
|
+ }
|
|
|
+ return e.Message
|
|
|
+}
|
|
|
+
|
|
|
+func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) {
|
|
|
+ if err == nil {
|
|
|
+ return nil, false
|
|
|
+ }
|
|
|
+ var target *ParamOverrideReturnError
|
|
|
+ if errors.As(err, &target) {
|
|
|
+ return target, true
|
|
|
+ }
|
|
|
+ return nil, false
|
|
|
+}
|
|
|
+
|
|
|
+func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError {
|
|
|
+ if err == nil {
|
|
|
+ return types.NewError(
|
|
|
+ errors.New("param override return error is nil"),
|
|
|
+ types.ErrorCodeChannelParamOverrideInvalid,
|
|
|
+ types.ErrOptionWithSkipRetry(),
|
|
|
+ )
|
|
|
+ }
|
|
|
+
|
|
|
+ statusCode := err.StatusCode
|
|
|
+ if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired {
|
|
|
+ statusCode = http.StatusBadRequest
|
|
|
+ }
|
|
|
+
|
|
|
+ errorCode := err.Code
|
|
|
+ if strings.TrimSpace(errorCode) == "" {
|
|
|
+ errorCode = string(types.ErrorCodeInvalidRequest)
|
|
|
+ }
|
|
|
+
|
|
|
+ errorType := err.Type
|
|
|
+ if strings.TrimSpace(errorType) == "" {
|
|
|
+ errorType = "invalid_request_error"
|
|
|
+ }
|
|
|
+
|
|
|
+ message := strings.TrimSpace(err.Message)
|
|
|
+ if message == "" {
|
|
|
+ message = "request blocked by param override"
|
|
|
+ }
|
|
|
+
|
|
|
+ opts := make([]types.NewAPIErrorOptions, 0, 1)
|
|
|
+ if err.SkipRetry {
|
|
|
+ opts = append(opts, types.ErrOptionWithSkipRetry())
|
|
|
+ }
|
|
|
+
|
|
|
+ return types.WithOpenAIError(types.OpenAIError{
|
|
|
+ Message: message,
|
|
|
+ Type: errorType,
|
|
|
+ Code: errorCode,
|
|
|
+ }, statusCode, opts...)
|
|
|
+}
|
|
|
+
|
|
|
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
|
|
if len(paramOverride) == 0 {
|
|
|
return jsonData, nil
|
|
|
@@ -372,16 +445,104 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
|
|
result, err = replaceStringValue(result, opPath, op.From, op.To)
|
|
|
case "regex_replace":
|
|
|
result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
|
|
|
+ case "return_error":
|
|
|
+ returnErr, parseErr := parseParamOverrideReturnError(op.Value)
|
|
|
+ if parseErr != nil {
|
|
|
+ return "", parseErr
|
|
|
+ }
|
|
|
+ return "", returnErr
|
|
|
+ case "prune_objects":
|
|
|
+ result, err = pruneObjects(result, opPath, contextJSON, op.Value)
|
|
|
default:
|
|
|
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
|
|
}
|
|
|
if err != nil {
|
|
|
- return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
|
|
|
+ return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return result, nil
|
|
|
+}
|
|
|
+
|
|
|
+func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) {
|
|
|
+ result := &ParamOverrideReturnError{
|
|
|
+ StatusCode: http.StatusBadRequest,
|
|
|
+ Code: string(types.ErrorCodeInvalidRequest),
|
|
|
+ Type: "invalid_request_error",
|
|
|
+ SkipRetry: true,
|
|
|
+ }
|
|
|
+
|
|
|
+ switch raw := value.(type) {
|
|
|
+ case nil:
|
|
|
+ return nil, fmt.Errorf("return_error value is required")
|
|
|
+ case string:
|
|
|
+ result.Message = strings.TrimSpace(raw)
|
|
|
+ case map[string]interface{}:
|
|
|
+ if message, ok := raw["message"].(string); ok {
|
|
|
+ result.Message = strings.TrimSpace(message)
|
|
|
+ }
|
|
|
+ if result.Message == "" {
|
|
|
+ if message, ok := raw["msg"].(string); ok {
|
|
|
+ result.Message = strings.TrimSpace(message)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if code, exists := raw["code"]; exists {
|
|
|
+ codeStr := strings.TrimSpace(fmt.Sprintf("%v", code))
|
|
|
+ if codeStr != "" {
|
|
|
+ result.Code = codeStr
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if errType, ok := raw["type"].(string); ok {
|
|
|
+ errType = strings.TrimSpace(errType)
|
|
|
+ if errType != "" {
|
|
|
+ result.Type = errType
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if skipRetry, ok := raw["skip_retry"].(bool); ok {
|
|
|
+ result.SkipRetry = skipRetry
|
|
|
}
|
|
|
+
|
|
|
+ if statusCodeRaw, exists := raw["status_code"]; exists {
|
|
|
+ statusCode, ok := parseOverrideInt(statusCodeRaw)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("return_error status_code must be an integer")
|
|
|
+ }
|
|
|
+ result.StatusCode = statusCode
|
|
|
+ } else if statusRaw, exists := raw["status"]; exists {
|
|
|
+ statusCode, ok := parseOverrideInt(statusRaw)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("return_error status must be an integer")
|
|
|
+ }
|
|
|
+ result.StatusCode = statusCode
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ return nil, fmt.Errorf("return_error value must be string or object")
|
|
|
}
|
|
|
+
|
|
|
+ if result.Message == "" {
|
|
|
+ return nil, fmt.Errorf("return_error message is required")
|
|
|
+ }
|
|
|
+ if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired {
|
|
|
+ return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode)
|
|
|
+ }
|
|
|
+
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
+func parseOverrideInt(v interface{}) (int, bool) {
|
|
|
+ switch value := v.(type) {
|
|
|
+ case int:
|
|
|
+ return value, true
|
|
|
+ case float64:
|
|
|
+ if value != float64(int(value)) {
|
|
|
+ return 0, false
|
|
|
+ }
|
|
|
+ return int(value), true
|
|
|
+ default:
|
|
|
+ return 0, false
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
|
|
sourceValue := gjson.Get(jsonStr, fromPath)
|
|
|
if !sourceValue.Exists() {
|
|
|
@@ -537,6 +698,217 @@ func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string
|
|
|
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
|
|
|
}
|
|
|
|
|
|
+type pruneObjectsOptions struct {
|
|
|
+ conditions []ConditionOperation
|
|
|
+ logic string
|
|
|
+ recursive bool
|
|
|
+}
|
|
|
+
|
|
|
+func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
|
|
|
+ options, err := parsePruneObjectsOptions(value)
|
|
|
+ if err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+
|
|
|
+ if path == "" {
|
|
|
+ var root interface{}
|
|
|
+ if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
|
|
|
+ if err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ cleanedBytes, err := common.Marshal(cleaned)
|
|
|
+ if err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ return string(cleanedBytes), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ target := gjson.Get(jsonStr, path)
|
|
|
+ if !target.Exists() {
|
|
|
+ return jsonStr, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ var targetNode interface{}
|
|
|
+ if target.Type == gjson.JSON {
|
|
|
+ if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ targetNode = target.Value()
|
|
|
+ }
|
|
|
+
|
|
|
+ cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
|
|
|
+ if err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ cleanedBytes, err := common.Marshal(cleaned)
|
|
|
+ if err != nil {
|
|
|
+ return "", err
|
|
|
+ }
|
|
|
+ return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
|
|
|
+}
|
|
|
+
|
|
|
+func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
|
|
|
+ opts := pruneObjectsOptions{
|
|
|
+ logic: "AND",
|
|
|
+ recursive: true,
|
|
|
+ }
|
|
|
+
|
|
|
+ switch raw := value.(type) {
|
|
|
+ case nil:
|
|
|
+ return opts, fmt.Errorf("prune_objects value is required")
|
|
|
+ case string:
|
|
|
+ v := strings.TrimSpace(raw)
|
|
|
+ if v == "" {
|
|
|
+ return opts, fmt.Errorf("prune_objects value is required")
|
|
|
+ }
|
|
|
+ opts.conditions = []ConditionOperation{
|
|
|
+ {
|
|
|
+ Path: "type",
|
|
|
+ Mode: "full",
|
|
|
+ Value: v,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ case map[string]interface{}:
|
|
|
+ if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" {
|
|
|
+ opts.logic = logic
|
|
|
+ }
|
|
|
+ if recursive, ok := raw["recursive"].(bool); ok {
|
|
|
+ opts.recursive = recursive
|
|
|
+ }
|
|
|
+
|
|
|
+ if condRaw, exists := raw["conditions"]; exists {
|
|
|
+ conditions, err := parseConditionOperations(condRaw)
|
|
|
+ if err != nil {
|
|
|
+ return opts, err
|
|
|
+ }
|
|
|
+ opts.conditions = append(opts.conditions, conditions...)
|
|
|
+ }
|
|
|
+
|
|
|
+ if whereRaw, exists := raw["where"]; exists {
|
|
|
+ whereMap, ok := whereRaw.(map[string]interface{})
|
|
|
+ if !ok {
|
|
|
+ return opts, fmt.Errorf("prune_objects where must be object")
|
|
|
+ }
|
|
|
+ for key, val := range whereMap {
|
|
|
+ key = strings.TrimSpace(key)
|
|
|
+ if key == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ opts.conditions = append(opts.conditions, ConditionOperation{
|
|
|
+ Path: key,
|
|
|
+ Mode: "full",
|
|
|
+ Value: val,
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if matchType, exists := raw["type"]; exists {
|
|
|
+ opts.conditions = append(opts.conditions, ConditionOperation{
|
|
|
+ Path: "type",
|
|
|
+ Mode: "full",
|
|
|
+ Value: matchType,
|
|
|
+ })
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ return opts, fmt.Errorf("prune_objects value must be string or object")
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(opts.conditions) == 0 {
|
|
|
+ return opts, fmt.Errorf("prune_objects conditions are required")
|
|
|
+ }
|
|
|
+ return opts, nil
|
|
|
+}
|
|
|
+
|
|
|
+func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
|
|
|
+ items, ok := raw.([]interface{})
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("conditions must be an array")
|
|
|
+ }
|
|
|
+
|
|
|
+ result := make([]ConditionOperation, 0, len(items))
|
|
|
+ for _, item := range items {
|
|
|
+ itemMap, ok := item.(map[string]interface{})
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("condition must be object")
|
|
|
+ }
|
|
|
+ path, _ := itemMap["path"].(string)
|
|
|
+ mode, _ := itemMap["mode"].(string)
|
|
|
+ if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
|
|
|
+ return nil, fmt.Errorf("condition path/mode is required")
|
|
|
+ }
|
|
|
+ condition := ConditionOperation{
|
|
|
+ Path: path,
|
|
|
+ Mode: mode,
|
|
|
+ }
|
|
|
+ if value, exists := itemMap["value"]; exists {
|
|
|
+ condition.Value = value
|
|
|
+ }
|
|
|
+ if invert, ok := itemMap["invert"].(bool); ok {
|
|
|
+ condition.Invert = invert
|
|
|
+ }
|
|
|
+ if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
|
|
|
+ condition.PassMissingKey = passMissingKey
|
|
|
+ }
|
|
|
+ result = append(result, condition)
|
|
|
+ }
|
|
|
+ return result, nil
|
|
|
+}
|
|
|
+
|
|
|
+func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
|
|
|
+ switch value := node.(type) {
|
|
|
+ case []interface{}:
|
|
|
+ result := make([]interface{}, 0, len(value))
|
|
|
+ for _, item := range value {
|
|
|
+ next, drop, err := pruneObjectsNode(item, options, contextJSON, false)
|
|
|
+ if err != nil {
|
|
|
+ return nil, false, err
|
|
|
+ }
|
|
|
+ if drop {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ result = append(result, next)
|
|
|
+ }
|
|
|
+ return result, false, nil
|
|
|
+ case map[string]interface{}:
|
|
|
+ shouldDrop, err := shouldPruneObject(value, options, contextJSON)
|
|
|
+ if err != nil {
|
|
|
+ return nil, false, err
|
|
|
+ }
|
|
|
+ if shouldDrop && !isRoot {
|
|
|
+ return nil, true, nil
|
|
|
+ }
|
|
|
+ if !options.recursive {
|
|
|
+ return value, false, nil
|
|
|
+ }
|
|
|
+ for key, child := range value {
|
|
|
+ next, drop, err := pruneObjectsNode(child, options, contextJSON, false)
|
|
|
+ if err != nil {
|
|
|
+ return nil, false, err
|
|
|
+ }
|
|
|
+ if drop {
|
|
|
+ delete(value, key)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ value[key] = next
|
|
|
+ }
|
|
|
+ return value, false, nil
|
|
|
+ default:
|
|
|
+ return node, false, nil
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) {
|
|
|
+ nodeBytes, err := common.Marshal(node)
|
|
|
+ if err != nil {
|
|
|
+ return false, err
|
|
|
+ }
|
|
|
+ return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
|
|
|
+}
|
|
|
+
|
|
|
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
|
|
|
current := gjson.Get(jsonStr, path)
|
|
|
var currentMap, newMap map[string]interface{}
|
|
|
@@ -598,6 +970,32 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ ctx["retry_index"] = info.RetryIndex
|
|
|
+ ctx["is_retry"] = info.RetryIndex > 0
|
|
|
+ ctx["retry"] = map[string]interface{}{
|
|
|
+ "index": info.RetryIndex,
|
|
|
+ "is_retry": info.RetryIndex > 0,
|
|
|
+ }
|
|
|
+
|
|
|
+ if info.LastError != nil {
|
|
|
+ code := string(info.LastError.GetErrorCode())
|
|
|
+ errorType := string(info.LastError.GetErrorType())
|
|
|
+ lastError := map[string]interface{}{
|
|
|
+ "status_code": info.LastError.StatusCode,
|
|
|
+ "message": info.LastError.Error(),
|
|
|
+ "code": code,
|
|
|
+ "error_code": code,
|
|
|
+ "type": errorType,
|
|
|
+ "error_type": errorType,
|
|
|
+ "skip_retry": types.IsSkipRetryError(info.LastError),
|
|
|
+ }
|
|
|
+ ctx["last_error"] = lastError
|
|
|
+ ctx["last_error_status_code"] = info.LastError.StatusCode
|
|
|
+ ctx["last_error_message"] = info.LastError.Error()
|
|
|
+ ctx["last_error_code"] = code
|
|
|
+ ctx["last_error_type"] = errorType
|
|
|
+ }
|
|
|
+
|
|
|
ctx["is_channel_test"] = info.IsChannelTest
|
|
|
return ctx
|
|
|
}
|