Просмотр исходного кода

feat: support for Gemini structured output.

MartialBE 1 год назад
Родитель
Сommit
43a7b59b68
3 измененных файлов с 96 добавлено и 36 удалено
  1. 38 30
      dto/openai_request.go
  2. 8 6
      relay/channel/gemini/dto.go
  3. 50 0
      relay/channel/gemini/relay-gemini.go

+ 38 - 30
dto/openai_request.go

@@ -3,39 +3,47 @@ package dto
 import "encoding/json"
 
 type ResponseFormat struct {
-	Type string `json:"type,omitempty"`
+	Type       string            `json:"type,omitempty"`
+	JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
+}
+
+type FormatJsonSchema struct {
+	Description string `json:"description,omitempty"`
+	Name        string `json:"name"`
+	Schema      any    `json:"schema,omitempty"`
+	Strict      any    `json:"strict,omitempty"`
 }
 
 type GeneralOpenAIRequest struct {
-	Model               string         `json:"model,omitempty"`
-	Messages            []Message      `json:"messages,omitempty"`
-	Prompt              any            `json:"prompt,omitempty"`
-	Stream              bool           `json:"stream,omitempty"`
-	StreamOptions       *StreamOptions `json:"stream_options,omitempty"`
-	MaxTokens           uint           `json:"max_tokens,omitempty"`
-	MaxCompletionTokens uint           `json:"max_completion_tokens,omitempty"`
-	Temperature         float64        `json:"temperature,omitempty"`
-	TopP                float64        `json:"top_p,omitempty"`
-	TopK                int            `json:"top_k,omitempty"`
-	Stop                any            `json:"stop,omitempty"`
-	N                   int            `json:"n,omitempty"`
-	Input               any            `json:"input,omitempty"`
-	Instruction         string         `json:"instruction,omitempty"`
-	Size                string         `json:"size,omitempty"`
-	Functions           any            `json:"functions,omitempty"`
-	FrequencyPenalty    float64        `json:"frequency_penalty,omitempty"`
-	PresencePenalty     float64        `json:"presence_penalty,omitempty"`
-	ResponseFormat      any            `json:"response_format,omitempty"`
-	EncodingFormat      any            `json:"encoding_format,omitempty"`
-	Seed                float64        `json:"seed,omitempty"`
-	Tools               []ToolCall     `json:"tools,omitempty"`
-	ToolChoice          any            `json:"tool_choice,omitempty"`
-	User                string         `json:"user,omitempty"`
-	LogProbs            bool           `json:"logprobs,omitempty"`
-	TopLogProbs         int            `json:"top_logprobs,omitempty"`
-	Dimensions          int            `json:"dimensions,omitempty"`
-	Modalities          any            `json:"modalities,omitempty"`
-	Audio               any            `json:"audio,omitempty"`
+	Model               string          `json:"model,omitempty"`
+	Messages            []Message       `json:"messages,omitempty"`
+	Prompt              any             `json:"prompt,omitempty"`
+	Stream              bool            `json:"stream,omitempty"`
+	StreamOptions       *StreamOptions  `json:"stream_options,omitempty"`
+	MaxTokens           uint            `json:"max_tokens,omitempty"`
+	MaxCompletionTokens uint            `json:"max_completion_tokens,omitempty"`
+	Temperature         float64         `json:"temperature,omitempty"`
+	TopP                float64         `json:"top_p,omitempty"`
+	TopK                int             `json:"top_k,omitempty"`
+	Stop                any             `json:"stop,omitempty"`
+	N                   int             `json:"n,omitempty"`
+	Input               any             `json:"input,omitempty"`
+	Instruction         string          `json:"instruction,omitempty"`
+	Size                string          `json:"size,omitempty"`
+	Functions           any             `json:"functions,omitempty"`
+	FrequencyPenalty    float64         `json:"frequency_penalty,omitempty"`
+	PresencePenalty     float64         `json:"presence_penalty,omitempty"`
+	ResponseFormat      *ResponseFormat `json:"response_format,omitempty"`
+	EncodingFormat      any             `json:"encoding_format,omitempty"`
+	Seed                float64         `json:"seed,omitempty"`
+	Tools               []ToolCall      `json:"tools,omitempty"`
+	ToolChoice          any             `json:"tool_choice,omitempty"`
+	User                string          `json:"user,omitempty"`
+	LogProbs            bool            `json:"logprobs,omitempty"`
+	TopLogProbs         int             `json:"top_logprobs,omitempty"`
+	Dimensions          int             `json:"dimensions,omitempty"`
+	Modalities          any             `json:"modalities,omitempty"`
+	Audio               any             `json:"audio,omitempty"`
 }
 
 type OpenAITools struct {

+ 8 - 6
relay/channel/gemini/dto.go

@@ -40,12 +40,14 @@ type GeminiChatTools struct {
 }
 
 type GeminiChatGenerationConfig struct {
-	Temperature     float64  `json:"temperature,omitempty"`
-	TopP            float64  `json:"topP,omitempty"`
-	TopK            float64  `json:"topK,omitempty"`
-	MaxOutputTokens uint     `json:"maxOutputTokens,omitempty"`
-	CandidateCount  int      `json:"candidateCount,omitempty"`
-	StopSequences   []string `json:"stopSequences,omitempty"`
+	Temperature      float64  `json:"temperature,omitempty"`
+	TopP             float64  `json:"topP,omitempty"`
+	TopK             float64  `json:"topK,omitempty"`
+	MaxOutputTokens  uint     `json:"maxOutputTokens,omitempty"`
+	CandidateCount   int      `json:"candidateCount,omitempty"`
+	StopSequences    []string `json:"stopSequences,omitempty"`
+	ResponseMimeType string   `json:"responseMimeType,omitempty"`
+	ResponseSchema   any      `json:"responseSchema,omitempty"`
 }
 
 type GeminiChatCandidate struct {

+ 50 - 0
relay/channel/gemini/relay-gemini.go

@@ -77,6 +77,16 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 			},
 		}
 	}
+
+	if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
+		geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
+
+		if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
+			cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
+			geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
+		}
+	}
+
 	//shouldAddDummyModelMessage := false
 	for _, message := range textRequest.Messages {
 
@@ -165,6 +175,46 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 	return &geminiRequest, nil
 }
 
+func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
+	if depth >= 5 {
+		return schema
+	}
+
+	v, ok := schema.(map[string]interface{})
+	if !ok || len(v) == 0 {
+		return schema
+	}
+
+	// 如果type不为object和array,则直接返回
+	if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
+		return schema
+	}
+
+	switch v["type"] {
+	case "object":
+		delete(v, "additionalProperties")
+		// 处理 properties
+		if properties, ok := v["properties"].(map[string]interface{}); ok {
+			for key, value := range properties {
+				properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
+			}
+		}
+		for _, field := range []string{"allOf", "anyOf", "oneOf"} {
+			if nested, ok := v[field].([]interface{}); ok {
+				for i, item := range nested {
+					nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
+				}
+			}
+		}
+	case "array":
+		if items, ok := v["items"].(map[string]interface{}); ok {
+			v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
+		}
+	}
+
+	return v
+}
+
 func (g *GeminiChatResponse) GetResponseText() string {
 	if g == nil {
 		return ""