Просмотр исходного кода

feat(gemini): 支持 tool_choice 参数转换,优化错误处理

Your Name 2 месяцев назад
Родитель
Сommit
b6a25d9f0f
1 измененных файлов с 121 добавлено и 7 удалено
  1. 121 7
      relay/channel/gemini/relay-gemini.go

+ 121 - 7
relay/channel/gemini/relay-gemini.go

@@ -356,6 +356,13 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
 			})
 		}
 		geminiRequest.SetTools(geminiTools)
+
+		// [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
+		// Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
+		// Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
+		if textRequest.ToolChoice != nil {
+			geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
+		}
 	}
 
 	if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -960,6 +967,24 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse)
 				choice.FinishReason = constant.FinishReasonStop
 			case "MAX_TOKENS":
 				choice.FinishReason = constant.FinishReasonLength
+			case "SAFETY":
+				// Safety filter triggered
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "RECITATION":
+				// Recitation (citation) detected
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "BLOCKLIST":
+				// Blocklist triggered
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "PROHIBITED_CONTENT":
+				// Prohibited content detected
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "SPII":
+				// Sensitive personally identifiable information
+				choice.FinishReason = constant.FinishReasonContentFilter
+			case "OTHER":
+				// Other reasons
+				choice.FinishReason = constant.FinishReasonContentFilter
 			default:
 				choice.FinishReason = constant.FinishReasonContentFilter
 			}
@@ -991,13 +1016,34 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
 		isTools := false
 		isThought := false
 		if candidate.FinishReason != nil {
-			// p := GeminiConvertFinishReason(*candidate.FinishReason)
+			// Map Gemini FinishReason to OpenAI finish_reason
 			switch *candidate.FinishReason {
 			case "STOP":
+				// Normal completion
 				choice.FinishReason = &constant.FinishReasonStop
 			case "MAX_TOKENS":
+				// Reached maximum token limit
 				choice.FinishReason = &constant.FinishReasonLength
+			case "SAFETY":
+				// Safety filter triggered
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "RECITATION":
+				// Recitation (citation) detected
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "BLOCKLIST":
+				// Blocklist triggered
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "PROHIBITED_CONTENT":
+				// Prohibited content detected
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "SPII":
+				// Sensitive personally identifiable information
+				choice.FinishReason = &constant.FinishReasonContentFilter
+			case "OTHER":
+				// Other reasons
+				choice.FinishReason = &constant.FinishReasonContentFilter
 			default:
+				// Unknown reason, treat as content filter
 				choice.FinishReason = &constant.FinishReasonContentFilter
 			}
 		}
@@ -1214,12 +1260,20 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	if len(geminiResponse.Candidates) == 0 {
-		//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-		//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
-		//	return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
-		//} else {
-		//	return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
-		//}
+		// [FIX] Return meaningful error when Candidates is empty
+		if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
+			return nil, types.NewOpenAIError(
+				errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
+				types.ErrorCodePromptBlocked,
+				http.StatusBadRequest,
+			)
+		} else {
+			return nil, types.NewOpenAIError(
+				errors.New("empty response from Gemini API"),
+				types.ErrorCodeEmptyResponse,
+				http.StatusInternalServerError,
+			)
+		}
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
 	fullTextResponse.Model = info.UpstreamModelName
@@ -1362,3 +1416,63 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
 
 	return usage, nil
 }
+
+// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
+// OpenAI tool_choice values:
+//   - "auto": Let the model decide (default)
+//   - "none": Don't call any tools
+//   - "required": Must call at least one tool
+//   - {"type": "function", "function": {"name": "xxx"}}: Call specific function
+//
+// Gemini functionCallingConfig.mode values:
+//   - "AUTO": Model decides whether to call functions
+//   - "NONE": Model won't call functions
+//   - "ANY": Model must call at least one function
+func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
+	if toolChoice == nil {
+		return nil
+	}
+
+	// Handle string values: "auto", "none", "required"
+	if toolChoiceStr, ok := toolChoice.(string); ok {
+		config := &dto.ToolConfig{
+			FunctionCallingConfig: &dto.FunctionCallingConfig{},
+		}
+		switch toolChoiceStr {
+		case "auto":
+			config.FunctionCallingConfig.Mode = "AUTO"
+		case "none":
+			config.FunctionCallingConfig.Mode = "NONE"
+		case "required":
+			config.FunctionCallingConfig.Mode = "ANY"
+		default:
+			// Unknown string value, default to AUTO
+			config.FunctionCallingConfig.Mode = "AUTO"
+		}
+		return config
+	}
+
+	// Handle object value: {"type": "function", "function": {"name": "xxx"}}
+	if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
+		if toolChoiceMap["type"] == "function" {
+			config := &dto.ToolConfig{
+				FunctionCallingConfig: &dto.FunctionCallingConfig{
+					Mode: "ANY",
+				},
+			}
+			// Extract function name if specified
+			if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
+				if name, ok := function["name"].(string); ok && name != "" {
+					config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
+				}
+			}
+			return config
+		}
+		// Unsupported map structure (type is not "function"), return nil
+		return nil
+	}
+
+	// Unsupported type, return nil
+	return nil
+}
+