|
|
@@ -1,6 +1,7 @@
|
|
|
package gemini
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
@@ -8,6 +9,7 @@ import (
|
|
|
"net/http"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
"unicode/utf8"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
@@ -653,101 +655,84 @@ func getSupportedMimeTypesList() []string {
|
|
|
return keys
|
|
|
}
|
|
|
|
|
|
+var geminiOpenAPISchemaAllowedFields = map[string]struct{}{
|
|
|
+ "anyOf": {},
|
|
|
+ "default": {},
|
|
|
+ "description": {},
|
|
|
+ "enum": {},
|
|
|
+ "example": {},
|
|
|
+ "format": {},
|
|
|
+ "items": {},
|
|
|
+ "maxItems": {},
|
|
|
+ "maxLength": {},
|
|
|
+ "maxProperties": {},
|
|
|
+ "maximum": {},
|
|
|
+ "minItems": {},
|
|
|
+ "minLength": {},
|
|
|
+ "minProperties": {},
|
|
|
+ "minimum": {},
|
|
|
+ "nullable": {},
|
|
|
+ "pattern": {},
|
|
|
+ "properties": {},
|
|
|
+ "propertyOrdering": {},
|
|
|
+ "required": {},
|
|
|
+ "title": {},
|
|
|
+ "type": {},
|
|
|
+}
|
|
|
+
|
|
|
+const geminiFunctionSchemaMaxDepth = 64
|
|
|
+
|
|
|
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
|
|
|
func cleanFunctionParameters(params interface{}) interface{} {
|
|
|
+ return cleanFunctionParametersWithDepth(params, 0)
|
|
|
+}
|
|
|
+
|
|
|
+func cleanFunctionParametersWithDepth(params interface{}, depth int) interface{} {
|
|
|
if params == nil {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+ if depth >= geminiFunctionSchemaMaxDepth {
|
|
|
+ return cleanFunctionParametersShallow(params)
|
|
|
+ }
|
|
|
+
|
|
|
switch v := params.(type) {
|
|
|
case map[string]interface{}:
|
|
|
- // Create a copy to avoid modifying the original
|
|
|
- cleanedMap := make(map[string]interface{})
|
|
|
+ // Keep only Gemini-supported OpenAPI schema subset fields (per official SDK Schema).
|
|
|
+ cleanedMap := make(map[string]interface{}, len(v))
|
|
|
for k, val := range v {
|
|
|
- cleanedMap[k] = val
|
|
|
- }
|
|
|
-
|
|
|
- // Remove unsupported root-level fields
|
|
|
- delete(cleanedMap, "default")
|
|
|
- delete(cleanedMap, "exclusiveMaximum")
|
|
|
- delete(cleanedMap, "exclusiveMinimum")
|
|
|
- delete(cleanedMap, "$schema")
|
|
|
- delete(cleanedMap, "additionalProperties")
|
|
|
-
|
|
|
- // Check and clean 'format' for string types
|
|
|
- if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
|
|
|
- if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
|
|
|
- if formatValue != "enum" && formatValue != "date-time" {
|
|
|
- delete(cleanedMap, "format")
|
|
|
- }
|
|
|
+ if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
|
|
|
+ cleanedMap[k] = val
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ normalizeGeminiSchemaTypeAndNullable(cleanedMap)
|
|
|
+
|
|
|
// Clean properties
|
|
|
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
|
|
|
cleanedProps := make(map[string]interface{})
|
|
|
for propName, propValue := range props {
|
|
|
- cleanedProps[propName] = cleanFunctionParameters(propValue)
|
|
|
+ cleanedProps[propName] = cleanFunctionParametersWithDepth(propValue, depth+1)
|
|
|
}
|
|
|
cleanedMap["properties"] = cleanedProps
|
|
|
}
|
|
|
|
|
|
// Recursively clean items in arrays
|
|
|
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
|
|
|
- cleanedMap["items"] = cleanFunctionParameters(items)
|
|
|
+ cleanedMap["items"] = cleanFunctionParametersWithDepth(items, depth+1)
|
|
|
}
|
|
|
- // Also handle items if it's an array of schemas
|
|
|
- if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
|
|
|
- cleanedItemsArray := make([]interface{}, len(itemsArray))
|
|
|
- for i, item := range itemsArray {
|
|
|
- cleanedItemsArray[i] = cleanFunctionParameters(item)
|
|
|
- }
|
|
|
- cleanedMap["items"] = cleanedItemsArray
|
|
|
- }
|
|
|
-
|
|
|
- // Recursively clean other schema composition keywords
|
|
|
- for _, field := range []string{"allOf", "anyOf", "oneOf"} {
|
|
|
- if nested, ok := cleanedMap[field].([]interface{}); ok {
|
|
|
- cleanedNested := make([]interface{}, len(nested))
|
|
|
- for i, item := range nested {
|
|
|
- cleanedNested[i] = cleanFunctionParameters(item)
|
|
|
- }
|
|
|
- cleanedMap[field] = cleanedNested
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Recursively clean patternProperties
|
|
|
- if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
|
|
|
- cleanedPatternProps := make(map[string]interface{})
|
|
|
- for pattern, schema := range patternProps {
|
|
|
- cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
|
|
|
- }
|
|
|
- cleanedMap["patternProperties"] = cleanedPatternProps
|
|
|
- }
|
|
|
-
|
|
|
- // Recursively clean definitions
|
|
|
- if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
|
|
|
- cleanedDefinitions := make(map[string]interface{})
|
|
|
- for defName, defSchema := range definitions {
|
|
|
- cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
|
|
|
- }
|
|
|
- cleanedMap["definitions"] = cleanedDefinitions
|
|
|
+ // OpenAPI tuple-style items is not supported by Gemini SDK Schema; keep first to avoid API rejection.
|
|
|
+ if itemsArray, ok := cleanedMap["items"].([]interface{}); ok && len(itemsArray) > 0 {
|
|
|
+ cleanedMap["items"] = cleanFunctionParametersWithDepth(itemsArray[0], depth+1)
|
|
|
}
|
|
|
|
|
|
- // Recursively clean $defs (newer JSON Schema draft)
|
|
|
- if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
|
|
|
- cleanedDefs := make(map[string]interface{})
|
|
|
- for defName, defSchema := range defs {
|
|
|
- cleanedDefs[defName] = cleanFunctionParameters(defSchema)
|
|
|
- }
|
|
|
- cleanedMap["$defs"] = cleanedDefs
|
|
|
- }
|
|
|
-
|
|
|
- // Clean conditional keywords
|
|
|
- for _, field := range []string{"if", "then", "else", "not"} {
|
|
|
- if nested, ok := cleanedMap[field]; ok {
|
|
|
- cleanedMap[field] = cleanFunctionParameters(nested)
|
|
|
+ // Recursively clean anyOf
|
|
|
+ if nested, ok := cleanedMap["anyOf"].([]interface{}); ok && nested != nil {
|
|
|
+ cleanedNested := make([]interface{}, len(nested))
|
|
|
+ for i, item := range nested {
|
|
|
+ cleanedNested[i] = cleanFunctionParametersWithDepth(item, depth+1)
|
|
|
}
|
|
|
+ cleanedMap["anyOf"] = cleanedNested
|
|
|
}
|
|
|
|
|
|
return cleanedMap
|
|
|
@@ -756,7 +741,7 @@ func cleanFunctionParameters(params interface{}) interface{} {
|
|
|
// Handle arrays of schemas
|
|
|
cleanedArray := make([]interface{}, len(v))
|
|
|
for i, item := range v {
|
|
|
- cleanedArray[i] = cleanFunctionParameters(item)
|
|
|
+ cleanedArray[i] = cleanFunctionParametersWithDepth(item, depth+1)
|
|
|
}
|
|
|
return cleanedArray
|
|
|
|
|
|
@@ -766,6 +751,91 @@ func cleanFunctionParameters(params interface{}) interface{} {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func cleanFunctionParametersShallow(params interface{}) interface{} {
|
|
|
+ switch v := params.(type) {
|
|
|
+ case map[string]interface{}:
|
|
|
+ cleanedMap := make(map[string]interface{}, len(v))
|
|
|
+ for k, val := range v {
|
|
|
+ if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
|
|
|
+ cleanedMap[k] = val
|
|
|
+ }
|
|
|
+ }
|
|
|
+ normalizeGeminiSchemaTypeAndNullable(cleanedMap)
|
|
|
+ // Stop recursion and avoid retaining huge nested structures.
|
|
|
+ delete(cleanedMap, "properties")
|
|
|
+ delete(cleanedMap, "items")
|
|
|
+ delete(cleanedMap, "anyOf")
|
|
|
+ return cleanedMap
|
|
|
+ case []interface{}:
|
|
|
+ // Prefer an empty list over deep recursion on attacker-controlled inputs.
|
|
|
+ return []interface{}{}
|
|
|
+ default:
|
|
|
+ return params
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func normalizeGeminiSchemaTypeAndNullable(schema map[string]interface{}) {
|
|
|
+ rawType, ok := schema["type"]
|
|
|
+ if !ok || rawType == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ normalize := func(t string) (string, bool) {
|
|
|
+ switch strings.ToLower(strings.TrimSpace(t)) {
|
|
|
+ case "object":
|
|
|
+ return "OBJECT", false
|
|
|
+ case "array":
|
|
|
+ return "ARRAY", false
|
|
|
+ case "string":
|
|
|
+ return "STRING", false
|
|
|
+ case "integer":
|
|
|
+ return "INTEGER", false
|
|
|
+ case "number":
|
|
|
+ return "NUMBER", false
|
|
|
+ case "boolean":
|
|
|
+ return "BOOLEAN", false
|
|
|
+ case "null":
|
|
|
+ return "", true
|
|
|
+ default:
|
|
|
+ return t, false
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ switch t := rawType.(type) {
|
|
|
+ case string:
|
|
|
+ normalized, isNull := normalize(t)
|
|
|
+ if isNull {
|
|
|
+ schema["nullable"] = true
|
|
|
+ delete(schema, "type")
|
|
|
+ return
|
|
|
+ }
|
|
|
+ schema["type"] = normalized
|
|
|
+ case []interface{}:
|
|
|
+ nullable := false
|
|
|
+ var chosen string
|
|
|
+ for _, item := range t {
|
|
|
+ if s, ok := item.(string); ok {
|
|
|
+ normalized, isNull := normalize(s)
|
|
|
+ if isNull {
|
|
|
+ nullable = true
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if chosen == "" {
|
|
|
+ chosen = normalized
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if nullable {
|
|
|
+ schema["nullable"] = true
|
|
|
+ }
|
|
|
+ if chosen != "" {
|
|
|
+ schema["type"] = chosen
|
|
|
+ } else {
|
|
|
+ delete(schema, "type")
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
|
|
|
if depth >= 5 {
|
|
|
return schema
|
|
|
@@ -1138,6 +1208,8 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
|
|
id := helper.GetResponseID(c)
|
|
|
createAt := common.GetTimestamp()
|
|
|
finishReason := constant.FinishReasonStop
|
|
|
+ toolCallIndexByChoice := make(map[int]map[string]int)
|
|
|
+ nextToolCallIndexByChoice := make(map[int]int)
|
|
|
|
|
|
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
|
|
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
|
|
|
@@ -1145,6 +1217,28 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
|
|
response.Id = id
|
|
|
response.Created = createAt
|
|
|
response.Model = info.UpstreamModelName
|
|
|
+ for choiceIdx := range response.Choices {
|
|
|
+ choiceKey := response.Choices[choiceIdx].Index
|
|
|
+ for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
|
|
|
+ tool := &response.Choices[choiceIdx].Delta.ToolCalls[toolIdx]
|
|
|
+ if tool.ID == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ m := toolCallIndexByChoice[choiceKey]
|
|
|
+ if m == nil {
|
|
|
+ m = make(map[string]int)
|
|
|
+ toolCallIndexByChoice[choiceKey] = m
|
|
|
+ }
|
|
|
+ if idx, ok := m[tool.ID]; ok {
|
|
|
+ tool.SetIndex(idx)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ idx := nextToolCallIndexByChoice[choiceKey]
|
|
|
+ nextToolCallIndexByChoice[choiceKey] = idx + 1
|
|
|
+ m[tool.ID] = idx
|
|
|
+ tool.SetIndex(idx)
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
|
|
|
if info.SendResponseCount == 0 {
|
|
|
@@ -1363,3 +1457,76 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
|
|
|
|
|
return usage, nil
|
|
|
}
|
|
|
+
|
|
|
+type GeminiModelsResponse struct {
|
|
|
+ Models []dto.GeminiModel `json:"models"`
|
|
|
+ NextPageToken string `json:"nextPageToken"`
|
|
|
+}
|
|
|
+
|
|
|
+func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
|
|
|
+ client, err := service.GetHttpClientWithProxy(proxyURL)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ allModels := make([]string, 0)
|
|
|
+ nextPageToken := ""
|
|
|
+ maxPages := 100 // Safety limit to prevent infinite loops
|
|
|
+
|
|
|
+ for page := 0; page < maxPages; page++ {
|
|
|
+ url := fmt.Sprintf("%s/v1beta/models", baseURL)
|
|
|
+ if nextPageToken != "" {
|
|
|
+ url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
|
+ request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
|
+ if err != nil {
|
|
|
+ cancel()
|
|
|
+ return nil, fmt.Errorf("创建请求失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ request.Header.Set("x-goog-api-key", apiKey)
|
|
|
+
|
|
|
+ response, err := client.Do(request)
|
|
|
+ if err != nil {
|
|
|
+ cancel()
|
|
|
+ return nil, fmt.Errorf("请求失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if response.StatusCode != http.StatusOK {
|
|
|
+ body, _ := io.ReadAll(response.Body)
|
|
|
+ response.Body.Close()
|
|
|
+ cancel()
|
|
|
+ return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
|
|
|
+ }
|
|
|
+
|
|
|
+ body, err := io.ReadAll(response.Body)
|
|
|
+ response.Body.Close()
|
|
|
+ cancel()
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("读取响应失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var modelsResponse GeminiModelsResponse
|
|
|
+ if err = common.Unmarshal(body, &modelsResponse); err != nil {
|
|
|
+ return nil, fmt.Errorf("解析响应失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, model := range modelsResponse.Models {
|
|
|
+ modelNameValue, ok := model.Name.(string)
|
|
|
+ if !ok {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ modelName := strings.TrimPrefix(modelNameValue, "models/")
|
|
|
+ allModels = append(allModels, modelName)
|
|
|
+ }
|
|
|
+
|
|
|
+ nextPageToken = modelsResponse.NextPageToken
|
|
|
+ if nextPageToken == "" {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return allModels, nil
|
|
|
+}
|