فهرست منبع

feat: add retry-aware param override with return_error and prune_objects

Seefs 1 هفته پیش
والد
کامیت
ff76e75f4c

+ 8 - 1
controller/channel-test.go

@@ -366,7 +366,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
 			newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
 		}
 	}
-	jsonData, err := json.Marshal(convertedRequest)
+	jsonData, err := common.Marshal(convertedRequest)
 	if err != nil {
 		return testResult{
 			context:     c,
@@ -387,6 +387,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string,
 	if len(info.ParamOverride) > 0 {
 		jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 		if err != nil {
+			if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
+				return testResult{
+					context:     c,
+					localErr:    fixedErr,
+					newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr),
+				}
+			}
 			return testResult{
 				context:     c,
 				localErr:    err,

+ 5 - 0
controller/relay.go

@@ -182,8 +182,11 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		ModelName:  relayInfo.OriginModelName,
 		Retry:      common.GetPointer(0),
 	}
+	relayInfo.RetryIndex = 0
+	relayInfo.LastError = nil
 
 	for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
+		relayInfo.RetryIndex = retryParam.GetRetry()
 		channel, channelErr := getChannel(c, relayInfo, retryParam)
 		if channelErr != nil {
 			logger.LogError(c, channelErr.Error())
@@ -216,10 +219,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 		}
 
 		if newAPIError == nil {
+			relayInfo.LastError = nil
 			return
 		}
 
 		newAPIError = service.NormalizeViolationFeeError(newAPIError)
+		relayInfo.LastError = newAPIError
 
 		processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 

+ 1 - 1
relay/chat_completions_via_responses.go

@@ -84,7 +84,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
 	if len(info.ParamOverride) > 0 {
 		chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
 		if err != nil {
-			return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+			return nil, newAPIErrorFromParamOverride(err)
 		}
 	}
 

+ 1 - 1
relay/claude_handler.go

@@ -155,7 +155,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		if len(info.ParamOverride) > 0 {
 			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+				return newAPIErrorFromParamOverride(err)
 			}
 		}
 

+ 400 - 2
relay/common/override.go

@@ -1,12 +1,15 @@
 package common
 
 import (
+	"errors"
 	"fmt"
+	"net/http"
 	"regexp"
 	"strconv"
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/types"
 	"github.com/tidwall/gjson"
 	"github.com/tidwall/sjson"
 )
@@ -23,7 +26,7 @@ type ConditionOperation struct {
 
 type ParamOperation struct {
 	Path       string               `json:"path"`
-	Mode       string               `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace
+	Mode       string               `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace, return_error, prune_objects
 	Value      interface{}          `json:"value"`
 	KeepOrigin bool                 `json:"keep_origin"`
 	From       string               `json:"from,omitempty"`
@@ -32,6 +35,76 @@ type ParamOperation struct {
 	Logic      string               `json:"logic,omitempty"`      // AND, OR (默认OR)
 }
 
+type ParamOverrideReturnError struct {
+	Message    string
+	StatusCode int
+	Code       string
+	Type       string
+	SkipRetry  bool
+}
+
+func (e *ParamOverrideReturnError) Error() string {
+	if e == nil {
+		return "param override return error"
+	}
+	if e.Message == "" {
+		return "param override return error"
+	}
+	return e.Message
+}
+
+func AsParamOverrideReturnError(err error) (*ParamOverrideReturnError, bool) {
+	if err == nil {
+		return nil, false
+	}
+	var target *ParamOverrideReturnError
+	if errors.As(err, &target) {
+		return target, true
+	}
+	return nil, false
+}
+
+func NewAPIErrorFromParamOverride(err *ParamOverrideReturnError) *types.NewAPIError {
+	if err == nil {
+		return types.NewError(
+			errors.New("param override return error is nil"),
+			types.ErrorCodeChannelParamOverrideInvalid,
+			types.ErrOptionWithSkipRetry(),
+		)
+	}
+
+	statusCode := err.StatusCode
+	if statusCode < http.StatusContinue || statusCode > http.StatusNetworkAuthenticationRequired {
+		statusCode = http.StatusBadRequest
+	}
+
+	errorCode := err.Code
+	if strings.TrimSpace(errorCode) == "" {
+		errorCode = string(types.ErrorCodeInvalidRequest)
+	}
+
+	errorType := err.Type
+	if strings.TrimSpace(errorType) == "" {
+		errorType = "invalid_request_error"
+	}
+
+	message := strings.TrimSpace(err.Message)
+	if message == "" {
+		message = "request blocked by param override"
+	}
+
+	opts := make([]types.NewAPIErrorOptions, 0, 1)
+	if err.SkipRetry {
+		opts = append(opts, types.ErrOptionWithSkipRetry())
+	}
+
+	return types.WithOpenAIError(types.OpenAIError{
+		Message: message,
+		Type:    errorType,
+		Code:    errorCode,
+	}, statusCode, opts...)
+}
+
 func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
 	if len(paramOverride) == 0 {
 		return jsonData, nil
@@ -372,16 +445,104 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
 			result, err = replaceStringValue(result, opPath, op.From, op.To)
 		case "regex_replace":
 			result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
+		case "return_error":
+			returnErr, parseErr := parseParamOverrideReturnError(op.Value)
+			if parseErr != nil {
+				return "", parseErr
+			}
+			return "", returnErr
+		case "prune_objects":
+			result, err = pruneObjects(result, opPath, contextJSON, op.Value)
 		default:
 			return "", fmt.Errorf("unknown operation: %s", op.Mode)
 		}
 		if err != nil {
-			return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
+			return "", fmt.Errorf("operation %s failed: %w", op.Mode, err)
+		}
+	}
+	return result, nil
+}
+
+func parseParamOverrideReturnError(value interface{}) (*ParamOverrideReturnError, error) {
+	result := &ParamOverrideReturnError{
+		StatusCode: http.StatusBadRequest,
+		Code:       string(types.ErrorCodeInvalidRequest),
+		Type:       "invalid_request_error",
+		SkipRetry:  true,
+	}
+
+	switch raw := value.(type) {
+	case nil:
+		return nil, fmt.Errorf("return_error value is required")
+	case string:
+		result.Message = strings.TrimSpace(raw)
+	case map[string]interface{}:
+		if message, ok := raw["message"].(string); ok {
+			result.Message = strings.TrimSpace(message)
+		}
+		if result.Message == "" {
+			if message, ok := raw["msg"].(string); ok {
+				result.Message = strings.TrimSpace(message)
+			}
+		}
+
+		if code, exists := raw["code"]; exists {
+			codeStr := strings.TrimSpace(fmt.Sprintf("%v", code))
+			if codeStr != "" {
+				result.Code = codeStr
+			}
+		}
+		if errType, ok := raw["type"].(string); ok {
+			errType = strings.TrimSpace(errType)
+			if errType != "" {
+				result.Type = errType
+			}
+		}
+		if skipRetry, ok := raw["skip_retry"].(bool); ok {
+			result.SkipRetry = skipRetry
 		}
+
+		if statusCodeRaw, exists := raw["status_code"]; exists {
+			statusCode, ok := parseOverrideInt(statusCodeRaw)
+			if !ok {
+				return nil, fmt.Errorf("return_error status_code must be an integer")
+			}
+			result.StatusCode = statusCode
+		} else if statusRaw, exists := raw["status"]; exists {
+			statusCode, ok := parseOverrideInt(statusRaw)
+			if !ok {
+				return nil, fmt.Errorf("return_error status must be an integer")
+			}
+			result.StatusCode = statusCode
+		}
+	default:
+		return nil, fmt.Errorf("return_error value must be string or object")
 	}
+
+	if result.Message == "" {
+		return nil, fmt.Errorf("return_error message is required")
+	}
+	if result.StatusCode < http.StatusContinue || result.StatusCode > http.StatusNetworkAuthenticationRequired {
+		return nil, fmt.Errorf("return_error status code out of range: %d", result.StatusCode)
+	}
+
 	return result, nil
 }
 
+func parseOverrideInt(v interface{}) (int, bool) {
+	switch value := v.(type) {
+	case int:
+		return value, true
+	case float64:
+		if value != float64(int(value)) {
+			return 0, false
+		}
+		return int(value), true
+	default:
+		return 0, false
+	}
+}
+
 func moveValue(jsonStr, fromPath, toPath string) (string, error) {
 	sourceValue := gjson.Get(jsonStr, fromPath)
 	if !sourceValue.Exists() {
@@ -537,6 +698,217 @@ func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string
 	return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
 }
 
+type pruneObjectsOptions struct {
+	conditions []ConditionOperation
+	logic      string
+	recursive  bool
+}
+
+func pruneObjects(jsonStr, path, contextJSON string, value interface{}) (string, error) {
+	options, err := parsePruneObjectsOptions(value)
+	if err != nil {
+		return "", err
+	}
+
+	if path == "" {
+		var root interface{}
+		if err := common.Unmarshal([]byte(jsonStr), &root); err != nil {
+			return "", err
+		}
+		cleaned, _, err := pruneObjectsNode(root, options, contextJSON, true)
+		if err != nil {
+			return "", err
+		}
+		cleanedBytes, err := common.Marshal(cleaned)
+		if err != nil {
+			return "", err
+		}
+		return string(cleanedBytes), nil
+	}
+
+	target := gjson.Get(jsonStr, path)
+	if !target.Exists() {
+		return jsonStr, nil
+	}
+
+	var targetNode interface{}
+	if target.Type == gjson.JSON {
+		if err := common.Unmarshal([]byte(target.Raw), &targetNode); err != nil {
+			return "", err
+		}
+	} else {
+		targetNode = target.Value()
+	}
+
+	cleaned, _, err := pruneObjectsNode(targetNode, options, contextJSON, true)
+	if err != nil {
+		return "", err
+	}
+	cleanedBytes, err := common.Marshal(cleaned)
+	if err != nil {
+		return "", err
+	}
+	return sjson.SetRaw(jsonStr, path, string(cleanedBytes))
+}
+
+func parsePruneObjectsOptions(value interface{}) (pruneObjectsOptions, error) {
+	opts := pruneObjectsOptions{
+		logic:     "AND",
+		recursive: true,
+	}
+
+	switch raw := value.(type) {
+	case nil:
+		return opts, fmt.Errorf("prune_objects value is required")
+	case string:
+		v := strings.TrimSpace(raw)
+		if v == "" {
+			return opts, fmt.Errorf("prune_objects value is required")
+		}
+		opts.conditions = []ConditionOperation{
+			{
+				Path:  "type",
+				Mode:  "full",
+				Value: v,
+			},
+		}
+	case map[string]interface{}:
+		if logic, ok := raw["logic"].(string); ok && strings.TrimSpace(logic) != "" {
+			opts.logic = logic
+		}
+		if recursive, ok := raw["recursive"].(bool); ok {
+			opts.recursive = recursive
+		}
+
+		if condRaw, exists := raw["conditions"]; exists {
+			conditions, err := parseConditionOperations(condRaw)
+			if err != nil {
+				return opts, err
+			}
+			opts.conditions = append(opts.conditions, conditions...)
+		}
+
+		if whereRaw, exists := raw["where"]; exists {
+			whereMap, ok := whereRaw.(map[string]interface{})
+			if !ok {
+				return opts, fmt.Errorf("prune_objects where must be object")
+			}
+			for key, val := range whereMap {
+				key = strings.TrimSpace(key)
+				if key == "" {
+					continue
+				}
+				opts.conditions = append(opts.conditions, ConditionOperation{
+					Path:  key,
+					Mode:  "full",
+					Value: val,
+				})
+			}
+		}
+
+		if matchType, exists := raw["type"]; exists {
+			opts.conditions = append(opts.conditions, ConditionOperation{
+				Path:  "type",
+				Mode:  "full",
+				Value: matchType,
+			})
+		}
+	default:
+		return opts, fmt.Errorf("prune_objects value must be string or object")
+	}
+
+	if len(opts.conditions) == 0 {
+		return opts, fmt.Errorf("prune_objects conditions are required")
+	}
+	return opts, nil
+}
+
+func parseConditionOperations(raw interface{}) ([]ConditionOperation, error) {
+	items, ok := raw.([]interface{})
+	if !ok {
+		return nil, fmt.Errorf("conditions must be an array")
+	}
+
+	result := make([]ConditionOperation, 0, len(items))
+	for _, item := range items {
+		itemMap, ok := item.(map[string]interface{})
+		if !ok {
+			return nil, fmt.Errorf("condition must be object")
+		}
+		path, _ := itemMap["path"].(string)
+		mode, _ := itemMap["mode"].(string)
+		if strings.TrimSpace(path) == "" || strings.TrimSpace(mode) == "" {
+			return nil, fmt.Errorf("condition path/mode is required")
+		}
+		condition := ConditionOperation{
+			Path: path,
+			Mode: mode,
+		}
+		if value, exists := itemMap["value"]; exists {
+			condition.Value = value
+		}
+		if invert, ok := itemMap["invert"].(bool); ok {
+			condition.Invert = invert
+		}
+		if passMissingKey, ok := itemMap["pass_missing_key"].(bool); ok {
+			condition.PassMissingKey = passMissingKey
+		}
+		result = append(result, condition)
+	}
+	return result, nil
+}
+
+func pruneObjectsNode(node interface{}, options pruneObjectsOptions, contextJSON string, isRoot bool) (interface{}, bool, error) {
+	switch value := node.(type) {
+	case []interface{}:
+		result := make([]interface{}, 0, len(value))
+		for _, item := range value {
+			next, drop, err := pruneObjectsNode(item, options, contextJSON, false)
+			if err != nil {
+				return nil, false, err
+			}
+			if drop {
+				continue
+			}
+			result = append(result, next)
+		}
+		return result, false, nil
+	case map[string]interface{}:
+		shouldDrop, err := shouldPruneObject(value, options, contextJSON)
+		if err != nil {
+			return nil, false, err
+		}
+		if shouldDrop && !isRoot {
+			return nil, true, nil
+		}
+		if !options.recursive {
+			return value, false, nil
+		}
+		for key, child := range value {
+			next, drop, err := pruneObjectsNode(child, options, contextJSON, false)
+			if err != nil {
+				return nil, false, err
+			}
+			if drop {
+				delete(value, key)
+				continue
+			}
+			value[key] = next
+		}
+		return value, false, nil
+	default:
+		return node, false, nil
+	}
+}
+
+func shouldPruneObject(node map[string]interface{}, options pruneObjectsOptions, contextJSON string) (bool, error) {
+	nodeBytes, err := common.Marshal(node)
+	if err != nil {
+		return false, err
+	}
+	return checkConditions(string(nodeBytes), contextJSON, options.conditions, options.logic)
+}
+
 func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
 	current := gjson.Get(jsonStr, path)
 	var currentMap, newMap map[string]interface{}
@@ -598,6 +970,32 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
 		}
 	}
 
+	ctx["retry_index"] = info.RetryIndex
+	ctx["is_retry"] = info.RetryIndex > 0
+	ctx["retry"] = map[string]interface{}{
+		"index":    info.RetryIndex,
+		"is_retry": info.RetryIndex > 0,
+	}
+
+	if info.LastError != nil {
+		code := string(info.LastError.GetErrorCode())
+		errorType := string(info.LastError.GetErrorType())
+		lastError := map[string]interface{}{
+			"status_code": info.LastError.StatusCode,
+			"message":     info.LastError.Error(),
+			"code":        code,
+			"error_code":  code,
+			"type":        errorType,
+			"error_type":  errorType,
+			"skip_retry":  types.IsSkipRetryError(info.LastError),
+		}
+		ctx["last_error"] = lastError
+		ctx["last_error_status_code"] = info.LastError.StatusCode
+		ctx["last_error_message"] = info.LastError.Error()
+		ctx["last_error_code"] = code
+		ctx["last_error_type"] = errorType
+	}
+
 	ctx["is_channel_test"] = info.IsChannelTest
 	return ctx
 }

+ 184 - 0
relay/common/override_test.go

@@ -4,6 +4,8 @@ import (
 	"encoding/json"
 	"reflect"
 	"testing"
+
+	"github.com/QuantumNous/new-api/types"
 )
 
 func TestApplyParamOverrideTrimPrefix(t *testing.T) {
@@ -772,6 +774,188 @@ func TestApplyParamOverrideToUpper(t *testing.T) {
 	assertJSONEqual(t, `{"model":"GPT-4"}`, string(out))
 }
 
+func TestApplyParamOverrideReturnError(t *testing.T) {
+	input := []byte(`{"model":"gemini-2.5-pro"}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode": "return_error",
+				"value": map[string]interface{}{
+					"message":     "forced bad request by param override",
+					"status_code": 422,
+					"code":        "forced_bad_request",
+					"type":        "invalid_request_error",
+					"skip_retry":  true,
+				},
+				"conditions": []interface{}{
+					map[string]interface{}{
+						"path":  "retry.is_retry",
+						"mode":  "full",
+						"value": true,
+					},
+				},
+			},
+		},
+	}
+	ctx := map[string]interface{}{
+		"retry": map[string]interface{}{
+			"index":    1,
+			"is_retry": true,
+		},
+	}
+
+	_, err := ApplyParamOverride(input, override, ctx)
+	if err == nil {
+		t.Fatalf("expected error, got nil")
+	}
+	returnErr, ok := AsParamOverrideReturnError(err)
+	if !ok {
+		t.Fatalf("expected ParamOverrideReturnError, got %T: %v", err, err)
+	}
+	if returnErr.StatusCode != 422 {
+		t.Fatalf("expected status 422, got %d", returnErr.StatusCode)
+	}
+	if returnErr.Code != "forced_bad_request" {
+		t.Fatalf("expected code forced_bad_request, got %s", returnErr.Code)
+	}
+	if !returnErr.SkipRetry {
+		t.Fatalf("expected skip_retry true")
+	}
+}
+
+func TestApplyParamOverridePruneObjectsByTypeString(t *testing.T) {
+	input := []byte(`{
+		"messages":[
+			{"role":"assistant","content":[
+				{"type":"output_text","text":"a"},
+				{"type":"redacted_thinking","text":"secret"},
+				{"type":"tool_call","name":"tool_a"}
+			]},
+			{"role":"assistant","content":[
+				{"type":"output_text","text":"b"},
+				{"type":"wrapper","parts":[
+					{"type":"redacted_thinking","text":"secret2"},
+					{"type":"output_text","text":"c"}
+				]}
+			]}
+		]
+	}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode":  "prune_objects",
+				"value": "redacted_thinking",
+			},
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, nil)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{
+		"messages":[
+			{"role":"assistant","content":[
+				{"type":"output_text","text":"a"},
+				{"type":"tool_call","name":"tool_a"}
+			]},
+			{"role":"assistant","content":[
+				{"type":"output_text","text":"b"},
+				{"type":"wrapper","parts":[
+					{"type":"output_text","text":"c"}
+				]}
+			]}
+		]
+	}`, string(out))
+}
+
+func TestApplyParamOverridePruneObjectsWhereAndPath(t *testing.T) {
+	input := []byte(`{
+		"a":{"items":[{"type":"redacted_thinking","id":1},{"type":"output_text","id":2}]},
+		"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
+	}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"path": "a",
+				"mode": "prune_objects",
+				"value": map[string]interface{}{
+					"where": map[string]interface{}{
+						"type": "redacted_thinking",
+					},
+				},
+			},
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, nil)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{
+		"a":{"items":[{"type":"output_text","id":2}]},
+		"b":{"items":[{"type":"redacted_thinking","id":3},{"type":"output_text","id":4}]}
+	}`, string(out))
+}
+
+func TestApplyParamOverrideNormalizeThinkingSignatureUnsupported(t *testing.T) {
+	input := []byte(`{"items":[{"type":"redacted_thinking"}]}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"mode": "normalize_thinking_signature",
+			},
+		},
+	}
+
+	_, err := ApplyParamOverride(input, override, nil)
+	if err == nil {
+		t.Fatalf("expected error, got nil")
+	}
+}
+
+func TestApplyParamOverrideConditionFromRetryAndLastErrorContext(t *testing.T) {
+	info := &RelayInfo{
+		RetryIndex: 1,
+		LastError: types.WithOpenAIError(types.OpenAIError{
+			Message: "invalid thinking signature",
+			Type:    "invalid_request_error",
+			Code:    "bad_thought_signature",
+		}, 400),
+	}
+	ctx := BuildParamOverrideContext(info)
+
+	input := []byte(`{"temperature":0.7}`)
+	override := map[string]interface{}{
+		"operations": []interface{}{
+			map[string]interface{}{
+				"path":  "temperature",
+				"mode":  "set",
+				"value": 0.1,
+				"logic": "AND",
+				"conditions": []interface{}{
+					map[string]interface{}{
+						"path":  "is_retry",
+						"mode":  "full",
+						"value": true,
+					},
+					map[string]interface{}{
+						"path":  "last_error.code",
+						"mode":  "contains",
+						"value": "thought_signature",
+					},
+				},
+			},
+		},
+	}
+
+	out, err := ApplyParamOverride(input, override, ctx)
+	if err != nil {
+		t.Fatalf("ApplyParamOverride returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"temperature":0.1}`, string(out))
+}
+
 func assertJSONEqual(t *testing.T, want, got string) {
 	t.Helper()
 

+ 2 - 0
relay/common/relay_info.go

@@ -140,6 +140,8 @@ type RelayInfo struct {
 	SubscriptionAmountUsedAfterPreConsume int64
 	IsClaudeBetaQuery                     bool // /v1/messages?beta=true
 	IsChannelTest                         bool // channel test request
+	RetryIndex                            int
+	LastError                             *types.NewAPIError
 
 	PriceData types.PriceData
 

+ 1 - 1
relay/compatible_handler.go

@@ -174,7 +174,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 		if len(info.ParamOverride) > 0 {
 			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+				return newAPIErrorFromParamOverride(err)
 			}
 		}
 

+ 2 - 3
relay/embedding_handler.go

@@ -2,7 +2,6 @@ package relay
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"net/http"
 
@@ -46,7 +45,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
 	relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
-	jsonData, err := json.Marshal(convertedRequest)
+	jsonData, err := common.Marshal(convertedRequest)
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
@@ -54,7 +53,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 	if len(info.ParamOverride) > 0 {
 		jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+			return newAPIErrorFromParamOverride(err)
 		}
 	}
 

+ 3 - 8
relay/gemini_handler.go

@@ -159,7 +159,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		if len(info.ParamOverride) > 0 {
 			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+				return newAPIErrorFromParamOverride(err)
 			}
 		}
 
@@ -257,14 +257,9 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
 
 	// apply param override
 	if len(info.ParamOverride) > 0 {
-		reqMap := make(map[string]interface{})
-		_ = common.Unmarshal(jsonData, &reqMap)
-		for key, value := range info.ParamOverride {
-			reqMap[key] = value
-		}
-		jsonData, err = common.Marshal(reqMap)
+		jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+			return newAPIErrorFromParamOverride(err)
 		}
 	}
 	logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))

+ 1 - 1
relay/image_handler.go

@@ -72,7 +72,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 			if len(info.ParamOverride) > 0 {
 				jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 				if err != nil {
-					return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+					return newAPIErrorFromParamOverride(err)
 				}
 			}
 

+ 13 - 0
relay/param_override_error.go

@@ -0,0 +1,13 @@
+package relay
+
+import (
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/types"
+)
+
+func newAPIErrorFromParamOverride(err error) *types.NewAPIError {
+	if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok {
+		return relaycommon.NewAPIErrorFromParamOverride(fixedErr)
+	}
+	return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+}

+ 1 - 1
relay/rerank_handler.go

@@ -63,7 +63,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		if len(info.ParamOverride) > 0 {
 			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+				return newAPIErrorFromParamOverride(err)
 			}
 		}
 

+ 1 - 1
relay/responses_handler.go

@@ -98,7 +98,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		if len(info.ParamOverride) > 0 {
 			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
-				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+				return newAPIErrorFromParamOverride(err)
 			}
 		}