Browse Source

Merge pull request #2742 from seefs001/fix/pr-2540

feat(gemini): 支持 tool_choice 参数转换,优化多个渠道错误处理
Calcium-Ion 1 month ago
parent
commit
3722c63c18

+ 4 - 0
constant/context_key.go

@@ -55,4 +55,8 @@ const (
 	ContextKeyLocalCountTokens ContextKey = "local_count_tokens"
 
 	ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
+
+	// ContextKeyAdminRejectReason stores an admin-only reject/block reason extracted from upstream responses.
+	// It is not returned to end users, but can be persisted into consume/error logs for debugging.
+	ContextKeyAdminRejectReason ContextKey = "admin_reject_reason"
 )

+ 1 - 0
model/log.go

@@ -59,6 +59,7 @@ func formatUserLogs(logs []*Log) {
 			// Remove admin-only debug fields.
 			delete(otherMap, "admin_info")
 			delete(otherMap, "request_conversion")
+			delete(otherMap, "reject_reason")
 		}
 		logs[i].Other = common.MapToJsonStr(otherMap)
 		logs[i].Id = logs[i].Id % 1024

+ 18 - 11
relay/channel/claude/relay-claude.go

@@ -8,11 +8,13 @@ import (
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/relay/channel/openrouter"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/relay/helper"
+	"github.com/QuantumNous/new-api/relay/reasonmap"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/QuantumNous/new-api/setting/model_setting"
 	"github.com/QuantumNous/new-api/types"
@@ -27,17 +29,15 @@ const (
 )
 
 func stopReasonClaude2OpenAI(reason string) string {
-	switch reason {
-	case "stop_sequence":
-		return "stop"
-	case "end_turn":
-		return "stop"
-	case "max_tokens":
-		return "length"
-	case "tool_use":
-		return "tool_calls"
-	default:
-		return reason
+	return reasonmap.ClaudeStopReasonToOpenAIFinishReason(reason)
+}
+
+func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
+	if c == nil {
+		return
+	}
+	if strings.EqualFold(stopReason, "refusal") {
+		common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "claude_stop_reason=refusal")
 	}
 }
 
@@ -644,6 +644,12 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 	if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
 		return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
 	}
+	if claudeResponse.StopReason != "" {
+		maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
+	}
+	if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
+		maybeMarkClaudeRefusal(c, *claudeResponse.Delta.StopReason)
+	}
 	if info.RelayFormat == types.RelayFormatClaude {
 		FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
 
@@ -735,6 +741,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 	if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
 		return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
 	}
+	maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
 	if requestMode == RequestModeCompletion {
 		claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
 	} else {

+ 6 - 0
relay/channel/gemini/relay-gemini-native.go

@@ -1,10 +1,12 @@
 package gemini
 
 import (
+	"fmt"
 	"io"
 	"net/http"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/logger"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -35,6 +37,10 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 
+	if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
+		common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
+	}
+
 	// 计算使用量(基于 UsageMetadata)
 	usage := dto.Usage{
 		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,

+ 156 - 7
relay/channel/gemini/relay-gemini.go

@@ -359,6 +359,13 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
 			})
 		}
 		geminiRequest.SetTools(geminiTools)
+
+		// [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
+		// Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
+		// Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
+		if textRequest.ToolChoice != nil {
+			geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
+		}
 	}
 
 	if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -1031,6 +1038,24 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
 				choice.FinishReason = constant.FinishReasonStop
 			case "MAX_TOKENS":
 				choice.FinishReason = constant.FinishReasonLength
+			case "SAFETY":
+				// Safety filter triggered
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "RECITATION":
+				// Recitation (citation) detected
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "BLOCKLIST":
+				// Blocklist triggered
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "PROHIBITED_CONTENT":
+				// Prohibited content detected
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "SPII":
+				// Sensitive personally identifiable information
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "OTHER":
+				// Other reasons
+				choice.FinishReason = constant.FinishReasonContentFilter
 			default:
 				choice.FinishReason = constant.FinishReasonContentFilter
 			}
@@ -1062,13 +1087,34 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
 		isTools := false
 		isThought := false
 		if candidate.FinishReason != nil {
-			// p := GeminiConvertFinishReason(*candidate.FinishReason)
+			// Map Gemini FinishReason to OpenAI finish_reason
 			switch *candidate.FinishReason {
 			case "STOP":
+				// Normal completion
 				choice.FinishReason = &constant.FinishReasonStop
 			case "MAX_TOKENS":
+				// Reached maximum token limit
 				choice.FinishReason = &constant.FinishReasonLength
+			case "SAFETY":
+				// Safety filter triggered
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "RECITATION":
+				// Recitation (citation) detected
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "BLOCKLIST":
+				// Blocklist triggered
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "PROHIBITED_CONTENT":
+				// Prohibited content detected
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "SPII":
+				// Sensitive personally identifiable information
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "OTHER":
+				// Other reasons
+				choice.FinishReason = &constant.FinishReasonContentFilter
 			default:
+				// Unknown reason, treat as content filter
 				choice.FinishReason = &constant.FinishReasonContentFilter
 			}
 		}
@@ -1151,6 +1197,10 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 			return false
 		}
 
+		if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
+			common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
+		}
+
 		// 统计图片数量
 		for _, candidate := range geminiResponse.Candidates {
 			for _, part := range candidate.Content.Parts {
@@ -1309,12 +1359,52 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	if len(geminiResponse.Candidates) == 0 {
-		//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-		//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
-		//	return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
-		//} else {
-		//	return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
-		//}
+		usage := dto.Usage{
+			PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
+		}
+		usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+		for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+			if detail.Modality == "AUDIO" {
+				usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+			} else if detail.Modality == "TEXT" {
+				usage.PromptTokensDetails.TextTokens = detail.TokenCount
+			}
+		}
+		if usage.PromptTokens <= 0 {
+			usage.PromptTokens = info.GetEstimatePromptTokens()
+		}
+
+		var newAPIError *types.NewAPIError
+		if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
+			common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
+			newAPIError = types.NewOpenAIError(
+				errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
+				types.ErrorCodePromptBlocked,
+				http.StatusBadRequest,
+			)
+		} else {
+			common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "gemini_empty_candidates")
+			newAPIError = types.NewOpenAIError(
+				errors.New("empty response from Gemini API"),
+				types.ErrorCodeEmptyResponse,
+				http.StatusInternalServerError,
+			)
+		}
+
+		service.ResetStatusCode(newAPIError, c.GetString("status_code_mapping"))
+
+		switch info.RelayFormat {
+		case types.RelayFormatClaude:
+			c.JSON(newAPIError.StatusCode, gin.H{
+				"type":  "error",
+				"error": newAPIError.ToClaudeError(),
+			})
+		default:
+			c.JSON(newAPIError.StatusCode, gin.H{
+				"error": newAPIError.ToOpenAIError(),
+			})
+		}
+		return &usage, nil
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
 	fullTextResponse.Model = info.UpstreamModelName
@@ -1530,3 +1620,62 @@ func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
 
 	return allModels, nil
 }
+
+// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
+// OpenAI tool_choice values:
+//   - "auto": Let the model decide (default)
+//   - "none": Don't call any tools
+//   - "required": Must call at least one tool
+//   - {"type": "function", "function": {"name": "xxx"}}: Call specific function
+//
+// Gemini functionCallingConfig.mode values:
+//   - "AUTO": Model decides whether to call functions
+//   - "NONE": Model won't call functions
+//   - "ANY": Model must call at least one function
+func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
+	if toolChoice == nil {
+		return nil
+	}
+
+	// Handle string values: "auto", "none", "required"
+	if toolChoiceStr, ok := toolChoice.(string); ok {
+		config := &dto.ToolConfig{
+			FunctionCallingConfig: &dto.FunctionCallingConfig{},
+		}
+		switch toolChoiceStr {
+		case "auto":
+			config.FunctionCallingConfig.Mode = "AUTO"
+		case "none":
+			config.FunctionCallingConfig.Mode = "NONE"
+		case "required":
+			config.FunctionCallingConfig.Mode = "ANY"
+		default:
+			// Unknown string value, default to AUTO
+			config.FunctionCallingConfig.Mode = "AUTO"
+		}
+		return config
+	}
+
+	// Handle object value: {"type": "function", "function": {"name": "xxx"}}
+	if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
+		if toolChoiceMap["type"] == "function" {
+			config := &dto.ToolConfig{
+				FunctionCallingConfig: &dto.FunctionCallingConfig{
+					Mode: "ANY",
+				},
+			}
+			// Extract function name if specified
+			if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
+				if name, ok := function["name"].(string); ok && name != "" {
+					config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
+				}
+			}
+			return config
+		}
+		// Unsupported map structure (type is not "function"), return nil
+		return nil
+	}
+
+	// Unsupported type, return nil
+	return nil
+}

+ 7 - 0
relay/channel/openai/relay-openai.go

@@ -229,6 +229,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 		return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
 	}
 
+	for _, choice := range simpleResponse.Choices {
+		if choice.FinishReason == constant.FinishReasonContentFilter {
+			common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "openai_finish_reason=content_filter")
+			break
+		}
+	}
+
 	forceFormat := false
 	if info.ChannelSetting.ForceFormat {
 		forceFormat = true

+ 6 - 0
relay/compatible_handler.go

@@ -237,6 +237,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 		}
 		extraContent = append(extraContent, "上游无计费信息")
 	}
+
+	adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
+
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens
 	cacheTokens := usage.PromptTokensDetails.CachedTokens
@@ -461,6 +464,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	}
 	logContent := strings.Join(extraContent, ", ")
 	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
+	if adminRejectReason != "" {
+		other["reject_reason"] = adminRejectReason
+	}
 	// For chat-based calls to the Claude model, tagging is required. Using Claude's rendering logs, the two approaches handle input rendering differently.
 	if isClaudeUsageSemantic {
 		other["claude"] = true

+ 41 - 0
relay/reasonmap/reasonmap.go

@@ -0,0 +1,41 @@
+package reasonmap
+
+import (
+	"strings"
+
+	"github.com/QuantumNous/new-api/constant"
+)
+
+func ClaudeStopReasonToOpenAIFinishReason(stopReason string) string {
+	switch strings.ToLower(stopReason) {
+	case "stop_sequence":
+		return "stop"
+	case "end_turn":
+		return "stop"
+	case "max_tokens":
+		return "length"
+	case "tool_use":
+		return "tool_calls"
+	case "refusal":
+		return constant.FinishReasonContentFilter
+	default:
+		return stopReason
+	}
+}
+
+func OpenAIFinishReasonToClaudeStopReason(finishReason string) string {
+	switch strings.ToLower(finishReason) {
+	case "stop":
+		return "end_turn"
+	case "stop_sequence":
+		return "stop_sequence"
+	case "length", "max_tokens":
+		return "max_tokens"
+	case constant.FinishReasonContentFilter:
+		return "refusal"
+	case "tool_calls":
+		return "tool_use"
+	default:
+		return finishReason
+	}
+}

+ 2 - 14
service/convert.go

@@ -10,6 +10,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel/openrouter"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/relay/reasonmap"
 )
 
 func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
@@ -540,20 +541,7 @@ func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relayco
 }
 
 func stopReasonOpenAI2Claude(reason string) string {
-	switch reason {
-	case "stop":
-		return "end_turn"
-	case "stop_sequence":
-		return "stop_sequence"
-	case "length":
-		fallthrough
-	case "max_tokens":
-		return "max_tokens"
-	case "tool_calls":
-		return "tool_use"
-	default:
-		return reason
-	}
+	return reasonmap.OpenAIFinishReasonToClaudeStopReason(reason)
 }
 
 func toJSONString(v interface{}) string {

+ 6 - 0
web/src/hooks/usage-logs/useUsageLogsData.jsx

@@ -397,6 +397,12 @@ export const useLogsData = () => {
             value: logs[i].content,
           });
         }
+        if (isAdminUser && other?.reject_reason) {
+          expandDataLocal.push({
+            key: t('拦截原因'),
+            value: other.reject_reason,
+          });
+        }
       }
       if (logs[i].type === 2) {
         let modelMapped =