relay-gemini.go 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746
  1. package gemini
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "unicode/utf8"
  13. "github.com/QuantumNous/new-api/common"
  14. "github.com/QuantumNous/new-api/constant"
  15. "github.com/QuantumNous/new-api/dto"
  16. "github.com/QuantumNous/new-api/logger"
  17. "github.com/QuantumNous/new-api/relay/channel/openai"
  18. relaycommon "github.com/QuantumNous/new-api/relay/common"
  19. "github.com/QuantumNous/new-api/relay/helper"
  20. "github.com/QuantumNous/new-api/service"
  21. "github.com/QuantumNous/new-api/setting/model_setting"
  22. "github.com/QuantumNous/new-api/setting/reasoning"
  23. "github.com/QuantumNous/new-api/types"
  24. "github.com/gin-gonic/gin"
  25. "github.com/samber/lo"
  26. )
  27. // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
  28. var geminiSupportedMimeTypes = map[string]bool{
  29. "application/pdf": true,
  30. "audio/mpeg": true,
  31. "audio/mp3": true,
  32. "audio/wav": true,
  33. "image/png": true,
  34. "image/jpeg": true,
  35. "image/jpg": true, // support old image/jpeg
  36. "image/webp": true,
  37. "text/plain": true,
  38. "video/mov": true,
  39. "video/mpeg": true,
  40. "video/mp4": true,
  41. "video/mpg": true,
  42. "video/avi": true,
  43. "video/wmv": true,
  44. "video/mpegps": true,
  45. "video/flv": true,
  46. }
  47. const thoughtSignatureBypassValue = "context_engineering_is_the_way_to_go"
  48. // Gemini 允许的思考预算范围
  49. const (
  50. pro25MinBudget = 128
  51. pro25MaxBudget = 32768
  52. flash25MaxBudget = 24576
  53. flash25LiteMinBudget = 512
  54. flash25LiteMaxBudget = 24576
  55. )
  56. func isNew25ProModel(modelName string) bool {
  57. return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  58. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  59. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  60. }
  61. func is25FlashLiteModel(modelName string) bool {
  62. return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
  63. }
  64. // clampThinkingBudget 根据模型名称将预算限制在允许的范围内
  65. func clampThinkingBudget(modelName string, budget int) int {
  66. isNew25Pro := isNew25ProModel(modelName)
  67. is25FlashLite := is25FlashLiteModel(modelName)
  68. if is25FlashLite {
  69. if budget < flash25LiteMinBudget {
  70. return flash25LiteMinBudget
  71. }
  72. if budget > flash25LiteMaxBudget {
  73. return flash25LiteMaxBudget
  74. }
  75. } else if isNew25Pro {
  76. if budget < pro25MinBudget {
  77. return pro25MinBudget
  78. }
  79. if budget > pro25MaxBudget {
  80. return pro25MaxBudget
  81. }
  82. } else { // 其他模型
  83. if budget < 0 {
  84. return 0
  85. }
  86. if budget > flash25MaxBudget {
  87. return flash25MaxBudget
  88. }
  89. }
  90. return budget
  91. }
  92. // "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
  93. // "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
  94. // "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
  95. // "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
  96. func clampThinkingBudgetByEffort(modelName string, effort string) int {
  97. isNew25Pro := isNew25ProModel(modelName)
  98. is25FlashLite := is25FlashLiteModel(modelName)
  99. maxBudget := 0
  100. if is25FlashLite {
  101. maxBudget = flash25LiteMaxBudget
  102. }
  103. if isNew25Pro {
  104. maxBudget = pro25MaxBudget
  105. } else {
  106. maxBudget = flash25MaxBudget
  107. }
  108. switch effort {
  109. case "high":
  110. maxBudget = maxBudget * 80 / 100
  111. case "medium":
  112. maxBudget = maxBudget * 50 / 100
  113. case "low":
  114. maxBudget = maxBudget * 20 / 100
  115. case "minimal":
  116. maxBudget = maxBudget * 5 / 100
  117. }
  118. return clampThinkingBudget(modelName, maxBudget)
  119. }
  120. func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
  121. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
  122. modelName := info.UpstreamModelName
  123. isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  124. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  125. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  126. if strings.Contains(modelName, "-thinking-") {
  127. parts := strings.SplitN(modelName, "-thinking-", 2)
  128. if len(parts) == 2 && parts[1] != "" {
  129. if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
  130. clampedBudget := clampThinkingBudget(modelName, budgetTokens)
  131. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  132. ThinkingBudget: common.GetPointer(clampedBudget),
  133. IncludeThoughts: true,
  134. }
  135. }
  136. }
  137. } else if strings.HasSuffix(modelName, "-thinking") {
  138. unsupportedModels := []string{
  139. "gemini-2.5-pro-preview-05-06",
  140. "gemini-2.5-pro-preview-03-25",
  141. }
  142. isUnsupported := false
  143. for _, unsupportedModel := range unsupportedModels {
  144. if strings.HasPrefix(modelName, unsupportedModel) {
  145. isUnsupported = true
  146. break
  147. }
  148. }
  149. if isUnsupported {
  150. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  151. IncludeThoughts: true,
  152. }
  153. } else {
  154. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  155. IncludeThoughts: true,
  156. }
  157. if geminiRequest.GenerationConfig.MaxOutputTokens != nil && *geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
  158. budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(*geminiRequest.GenerationConfig.MaxOutputTokens)
  159. clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
  160. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
  161. } else {
  162. if len(oaiRequest) > 0 {
  163. // 如果有reasoningEffort参数,则根据其值设置思考预算
  164. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
  165. }
  166. }
  167. }
  168. } else if strings.HasSuffix(modelName, "-nothinking") {
  169. if !isNew25Pro {
  170. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  171. ThinkingBudget: common.GetPointer(0),
  172. }
  173. }
  174. } else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
  175. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  176. IncludeThoughts: true,
  177. ThinkingLevel: level,
  178. }
  179. info.ReasoningEffort = level
  180. }
  181. }
  182. }
  183. // Setting safety to the lowest possible values since Gemini is already powerless enough
  184. func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
  185. geminiRequest := dto.GeminiChatRequest{
  186. Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
  187. GenerationConfig: dto.GeminiChatGenerationConfig{
  188. Temperature: textRequest.Temperature,
  189. },
  190. }
  191. if textRequest.TopP != nil && *textRequest.TopP > 0 {
  192. geminiRequest.GenerationConfig.TopP = common.GetPointer(*textRequest.TopP)
  193. }
  194. if maxTokens := textRequest.GetMaxTokens(); maxTokens > 0 {
  195. geminiRequest.GenerationConfig.MaxOutputTokens = common.GetPointer(maxTokens)
  196. }
  197. if textRequest.Seed != nil && *textRequest.Seed != 0 {
  198. geminiSeed := int64(lo.FromPtr(textRequest.Seed))
  199. geminiRequest.GenerationConfig.Seed = common.GetPointer(geminiSeed)
  200. }
  201. attachThoughtSignature := (info.ChannelType == constant.ChannelTypeGemini ||
  202. info.ChannelType == constant.ChannelTypeVertexAi) &&
  203. model_setting.GetGeminiSettings().FunctionCallThoughtSignatureEnabled
  204. if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
  205. geminiRequest.GenerationConfig.ResponseModalities = []string{
  206. "TEXT",
  207. "IMAGE",
  208. }
  209. }
  210. if stopSequences := parseStopSequences(textRequest.Stop); len(stopSequences) > 0 {
  211. // Gemini supports up to 5 stop sequences
  212. if len(stopSequences) > 5 {
  213. stopSequences = stopSequences[:5]
  214. }
  215. geminiRequest.GenerationConfig.StopSequences = stopSequences
  216. }
  217. adaptorWithExtraBody := false
  218. // patch extra_body
  219. if len(textRequest.ExtraBody) > 0 {
  220. var extraBody map[string]interface{}
  221. if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
  222. return nil, fmt.Errorf("invalid extra body: %w", err)
  223. }
  224. // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
  225. if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
  226. if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  227. adaptorWithExtraBody = true
  228. // check error param name like thinkingConfig, should be thinking_config
  229. if _, hasErrorParam := googleBody["thinkingConfig"]; hasErrorParam {
  230. return nil, errors.New("extra_body.google.thinkingConfig is not supported, use extra_body.google.thinking_config instead")
  231. }
  232. if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
  233. // check error param name like thinkingBudget, should be thinking_budget
  234. if _, hasErrorParam := thinkingConfig["thinkingBudget"]; hasErrorParam {
  235. return nil, errors.New("extra_body.google.thinking_config.thinkingBudget is not supported, use extra_body.google.thinking_config.thinking_budget instead")
  236. }
  237. var hasThinkingConfig bool
  238. var tempThinkingConfig dto.GeminiThinkingConfig
  239. if thinkingBudget, exists := thinkingConfig["thinking_budget"]; exists {
  240. switch v := thinkingBudget.(type) {
  241. case float64:
  242. budgetInt := int(v)
  243. tempThinkingConfig.ThinkingBudget = common.GetPointer(budgetInt)
  244. if budgetInt > 0 {
  245. // 有正数预算
  246. tempThinkingConfig.IncludeThoughts = true
  247. } else {
  248. // 存在但为0或负数,禁用思考
  249. tempThinkingConfig.IncludeThoughts = false
  250. }
  251. hasThinkingConfig = true
  252. default:
  253. return nil, errors.New("extra_body.google.thinking_config.thinking_budget must be an integer")
  254. }
  255. }
  256. if includeThoughts, exists := thinkingConfig["include_thoughts"]; exists {
  257. if v, ok := includeThoughts.(bool); ok {
  258. tempThinkingConfig.IncludeThoughts = v
  259. hasThinkingConfig = true
  260. } else {
  261. return nil, errors.New("extra_body.google.thinking_config.include_thoughts must be a boolean")
  262. }
  263. }
  264. if thinkingLevel, exists := thinkingConfig["thinking_level"]; exists {
  265. if v, ok := thinkingLevel.(string); ok {
  266. tempThinkingConfig.ThinkingLevel = v
  267. hasThinkingConfig = true
  268. } else {
  269. return nil, errors.New("extra_body.google.thinking_config.thinking_level must be a string")
  270. }
  271. }
  272. if hasThinkingConfig {
  273. // 避免 panic: 仅在获得配置时分配,防止后续赋值时空指针
  274. if geminiRequest.GenerationConfig.ThinkingConfig == nil {
  275. geminiRequest.GenerationConfig.ThinkingConfig = &tempThinkingConfig
  276. } else {
  277. // 如果已分配,则合并内容
  278. if tempThinkingConfig.ThinkingBudget != nil {
  279. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = tempThinkingConfig.ThinkingBudget
  280. }
  281. geminiRequest.GenerationConfig.ThinkingConfig.IncludeThoughts = tempThinkingConfig.IncludeThoughts
  282. if tempThinkingConfig.ThinkingLevel != "" {
  283. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingLevel = tempThinkingConfig.ThinkingLevel
  284. }
  285. }
  286. }
  287. }
  288. }
  289. // check error param name like imageConfig, should be image_config
  290. if _, hasErrorParam := googleBody["imageConfig"]; hasErrorParam {
  291. return nil, errors.New("extra_body.google.imageConfig is not supported, use extra_body.google.image_config instead")
  292. }
  293. if imageConfig, ok := googleBody["image_config"].(map[string]interface{}); ok {
  294. // check error param name like aspectRatio, should be aspect_ratio
  295. if _, hasErrorParam := imageConfig["aspectRatio"]; hasErrorParam {
  296. return nil, errors.New("extra_body.google.image_config.aspectRatio is not supported, use extra_body.google.image_config.aspect_ratio instead")
  297. }
  298. // check error param name like imageSize, should be image_size
  299. if _, hasErrorParam := imageConfig["imageSize"]; hasErrorParam {
  300. return nil, errors.New("extra_body.google.image_config.imageSize is not supported, use extra_body.google.image_config.image_size instead")
  301. }
  302. // convert snake_case to camelCase for Gemini API
  303. geminiImageConfig := make(map[string]interface{})
  304. if aspectRatio, ok := imageConfig["aspect_ratio"]; ok {
  305. geminiImageConfig["aspectRatio"] = aspectRatio
  306. }
  307. if imageSize, ok := imageConfig["image_size"]; ok {
  308. geminiImageConfig["imageSize"] = imageSize
  309. }
  310. if len(geminiImageConfig) > 0 {
  311. imageConfigBytes, err := common.Marshal(geminiImageConfig)
  312. if err != nil {
  313. return nil, fmt.Errorf("failed to marshal image_config: %w", err)
  314. }
  315. geminiRequest.GenerationConfig.ImageConfig = imageConfigBytes
  316. }
  317. }
  318. }
  319. }
  320. if !adaptorWithExtraBody {
  321. ThinkingAdaptor(&geminiRequest, info, textRequest)
  322. }
  323. safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
  324. for _, category := range SafetySettingList {
  325. safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
  326. Category: category,
  327. Threshold: model_setting.GetGeminiSafetySetting(category),
  328. })
  329. }
  330. geminiRequest.SafetySettings = safetySettings
  331. // openaiContent.FuncToToolCalls()
  332. if textRequest.Tools != nil {
  333. functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
  334. googleSearch := false
  335. codeExecution := false
  336. urlContext := false
  337. for _, tool := range textRequest.Tools {
  338. if tool.Function.Name == "googleSearch" {
  339. googleSearch = true
  340. continue
  341. }
  342. if tool.Function.Name == "codeExecution" {
  343. codeExecution = true
  344. continue
  345. }
  346. if tool.Function.Name == "urlContext" {
  347. urlContext = true
  348. continue
  349. }
  350. if tool.Function.Parameters != nil {
  351. params, ok := tool.Function.Parameters.(map[string]interface{})
  352. if ok {
  353. if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
  354. if len(props) == 0 {
  355. tool.Function.Parameters = nil
  356. }
  357. }
  358. }
  359. }
  360. // Clean the parameters before appending
  361. cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
  362. tool.Function.Parameters = cleanedParams
  363. functions = append(functions, tool.Function)
  364. }
  365. geminiTools := geminiRequest.GetTools()
  366. if codeExecution {
  367. geminiTools = append(geminiTools, dto.GeminiChatTool{
  368. CodeExecution: make(map[string]string),
  369. })
  370. }
  371. if googleSearch {
  372. geminiTools = append(geminiTools, dto.GeminiChatTool{
  373. GoogleSearch: make(map[string]string),
  374. })
  375. }
  376. if urlContext {
  377. geminiTools = append(geminiTools, dto.GeminiChatTool{
  378. URLContext: make(map[string]string),
  379. })
  380. }
  381. if len(functions) > 0 {
  382. geminiTools = append(geminiTools, dto.GeminiChatTool{
  383. FunctionDeclarations: functions,
  384. })
  385. }
  386. geminiRequest.SetTools(geminiTools)
  387. // [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
  388. // Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
  389. // Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
  390. if textRequest.ToolChoice != nil {
  391. geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
  392. }
  393. }
  394. if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
  395. geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
  396. if len(textRequest.ResponseFormat.JsonSchema) > 0 {
  397. // 先将json.RawMessage解析
  398. var jsonSchema dto.FormatJsonSchema
  399. if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
  400. cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
  401. geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
  402. }
  403. }
  404. }
  405. tool_call_ids := make(map[string]string)
  406. var system_content []string
  407. //shouldAddDummyModelMessage := false
  408. for _, message := range textRequest.Messages {
  409. if message.Role == "system" || message.Role == "developer" {
  410. system_content = append(system_content, message.StringContent())
  411. continue
  412. } else if message.Role == "tool" || message.Role == "function" {
  413. if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
  414. geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
  415. Role: "user",
  416. })
  417. }
  418. var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
  419. name := ""
  420. if message.Name != nil {
  421. name = *message.Name
  422. } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
  423. name = val
  424. }
  425. var contentMap map[string]interface{}
  426. contentStr := message.StringContent()
  427. // 1. 尝试解析为 JSON 对象
  428. if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
  429. // 2. 如果失败,尝试解析为 JSON 数组
  430. var contentSlice []interface{}
  431. if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
  432. // 如果是数组,包装成对象
  433. contentMap = map[string]interface{}{"result": contentSlice}
  434. } else {
  435. // 3. 如果再次失败,作为纯文本处理
  436. contentMap = map[string]interface{}{"content": contentStr}
  437. }
  438. }
  439. functionResp := &dto.GeminiFunctionResponse{
  440. Name: name,
  441. Response: contentMap,
  442. }
  443. *parts = append(*parts, dto.GeminiPart{
  444. FunctionResponse: functionResp,
  445. })
  446. continue
  447. }
  448. var parts []dto.GeminiPart
  449. content := dto.GeminiChatContent{
  450. Role: message.Role,
  451. }
  452. shouldAttachThoughtSignature := attachThoughtSignature && (message.Role == "assistant" || message.Role == "model")
  453. signatureAttached := false
  454. // isToolCall := false
  455. if message.ToolCalls != nil {
  456. // message.Role = "model"
  457. // isToolCall = true
  458. for _, call := range message.ParseToolCalls() {
  459. args := map[string]interface{}{}
  460. if call.Function.Arguments != "" {
  461. if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
  462. return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
  463. }
  464. }
  465. toolCall := dto.GeminiPart{
  466. FunctionCall: &dto.FunctionCall{
  467. FunctionName: call.Function.Name,
  468. Arguments: args,
  469. },
  470. }
  471. if shouldAttachThoughtSignature && !signatureAttached && hasFunctionCallContent(toolCall.FunctionCall) && len(toolCall.ThoughtSignature) == 0 {
  472. toolCall.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  473. signatureAttached = true
  474. }
  475. parts = append(parts, toolCall)
  476. tool_call_ids[call.ID] = call.Function.Name
  477. }
  478. }
  479. openaiContent := message.ParseContent()
  480. for _, part := range openaiContent {
  481. if part.Type == dto.ContentTypeText {
  482. if part.Text == "" {
  483. continue
  484. }
  485. // check markdown image ![image](data:image/jpeg;base64,xxxxxxxxxxxx)
  486. // 使用字符串查找而非正则,避免大文本性能问题
  487. text := part.Text
  488. hasMarkdownImage := false
  489. for {
  490. // 快速检查是否包含 markdown 图片标记
  491. startIdx := strings.Index(text, "![")
  492. if startIdx == -1 {
  493. break
  494. }
  495. // 找到 ](
  496. bracketIdx := strings.Index(text[startIdx:], "](data:")
  497. if bracketIdx == -1 {
  498. break
  499. }
  500. bracketIdx += startIdx
  501. // 找到闭合的 )
  502. closeIdx := strings.Index(text[bracketIdx+2:], ")")
  503. if closeIdx == -1 {
  504. break
  505. }
  506. closeIdx += bracketIdx + 2
  507. hasMarkdownImage = true
  508. // 添加图片前的文本
  509. if startIdx > 0 {
  510. textBefore := text[:startIdx]
  511. if textBefore != "" {
  512. parts = append(parts, dto.GeminiPart{
  513. Text: textBefore,
  514. })
  515. }
  516. }
  517. // 提取 data URL (从 "](" 后面开始,到 ")" 之前)
  518. dataUrl := text[bracketIdx+2 : closeIdx]
  519. format, base64String, err := service.DecodeBase64FileData(dataUrl)
  520. if err != nil {
  521. return nil, fmt.Errorf("decode markdown base64 image data failed: %s", err.Error())
  522. }
  523. imgPart := dto.GeminiPart{
  524. InlineData: &dto.GeminiInlineData{
  525. MimeType: format,
  526. Data: base64String,
  527. },
  528. }
  529. if shouldAttachThoughtSignature {
  530. imgPart.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  531. }
  532. parts = append(parts, imgPart)
  533. // 继续处理剩余文本
  534. text = text[closeIdx+1:]
  535. }
  536. // 添加剩余文本或原始文本(如果没有找到 markdown 图片)
  537. if !hasMarkdownImage {
  538. parts = append(parts, dto.GeminiPart{
  539. Text: part.Text,
  540. })
  541. }
  542. } else if part.Type == dto.ContentTypeImageURL {
  543. // 使用统一的文件服务获取图片数据
  544. var source *types.FileSource
  545. imageUrl := part.GetImageMedia().Url
  546. if strings.HasPrefix(imageUrl, "http") {
  547. source = types.NewURLFileSource(imageUrl)
  548. } else {
  549. source = types.NewBase64FileSource(imageUrl, "")
  550. }
  551. base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
  552. if err != nil {
  553. return nil, fmt.Errorf("get file data from '%s' failed: %w", source.GetIdentifier(), err)
  554. }
  555. // 校验 MimeType 是否在 Gemini 支持的白名单中
  556. if _, ok := geminiSupportedMimeTypes[strings.ToLower(mimeType)]; !ok {
  557. return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
  558. }
  559. parts = append(parts, dto.GeminiPart{
  560. InlineData: &dto.GeminiInlineData{
  561. MimeType: mimeType,
  562. Data: base64Data,
  563. },
  564. })
  565. } else if part.Type == dto.ContentTypeFile {
  566. if part.GetFile().FileId != "" {
  567. return nil, fmt.Errorf("only base64 file is supported in gemini")
  568. }
  569. fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
  570. base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
  571. if err != nil {
  572. return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
  573. }
  574. parts = append(parts, dto.GeminiPart{
  575. InlineData: &dto.GeminiInlineData{
  576. MimeType: mimeType,
  577. Data: base64Data,
  578. },
  579. })
  580. } else if part.Type == dto.ContentTypeInputAudio {
  581. if part.GetInputAudio().Data == "" {
  582. return nil, fmt.Errorf("only base64 audio is supported in gemini")
  583. }
  584. audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
  585. base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
  586. if err != nil {
  587. return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
  588. }
  589. parts = append(parts, dto.GeminiPart{
  590. InlineData: &dto.GeminiInlineData{
  591. MimeType: mimeType,
  592. Data: base64Data,
  593. },
  594. })
  595. }
  596. }
  597. // 如果需要附加签名但还没有附加(没有 tool_calls 或 tool_calls 为空),
  598. // 则在第一个文本 part 上附加 thoughtSignature
  599. if shouldAttachThoughtSignature && !signatureAttached && len(parts) > 0 {
  600. for i := range parts {
  601. if parts[i].Text != "" {
  602. parts[i].ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  603. break
  604. }
  605. }
  606. }
  607. content.Parts = parts
  608. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  609. if content.Role == "assistant" {
  610. content.Role = "model"
  611. }
  612. if len(content.Parts) > 0 {
  613. geminiRequest.Contents = append(geminiRequest.Contents, content)
  614. }
  615. }
  616. if len(system_content) > 0 {
  617. geminiRequest.SystemInstructions = &dto.GeminiChatContent{
  618. Parts: []dto.GeminiPart{
  619. {
  620. Text: strings.Join(system_content, "\n"),
  621. },
  622. },
  623. }
  624. }
  625. return &geminiRequest, nil
  626. }
  627. // parseStopSequences 解析停止序列,支持字符串或字符串数组
  628. func parseStopSequences(stop any) []string {
  629. if stop == nil {
  630. return nil
  631. }
  632. switch v := stop.(type) {
  633. case string:
  634. if v != "" {
  635. return []string{v}
  636. }
  637. case []string:
  638. return v
  639. case []interface{}:
  640. sequences := make([]string, 0, len(v))
  641. for _, item := range v {
  642. if str, ok := item.(string); ok && str != "" {
  643. sequences = append(sequences, str)
  644. }
  645. }
  646. return sequences
  647. }
  648. return nil
  649. }
  650. func hasFunctionCallContent(call *dto.FunctionCall) bool {
  651. if call == nil {
  652. return false
  653. }
  654. if strings.TrimSpace(call.FunctionName) != "" {
  655. return true
  656. }
  657. switch v := call.Arguments.(type) {
  658. case nil:
  659. return false
  660. case string:
  661. return strings.TrimSpace(v) != ""
  662. case map[string]interface{}:
  663. return len(v) > 0
  664. case []interface{}:
  665. return len(v) > 0
  666. default:
  667. return true
  668. }
  669. }
  670. // Helper function to get a list of supported MIME types for error messages
  671. func getSupportedMimeTypesList() []string {
  672. keys := make([]string, 0, len(geminiSupportedMimeTypes))
  673. for k := range geminiSupportedMimeTypes {
  674. keys = append(keys, k)
  675. }
  676. return keys
  677. }
  678. var geminiOpenAPISchemaAllowedFields = map[string]struct{}{
  679. "anyOf": {},
  680. "default": {},
  681. "description": {},
  682. "enum": {},
  683. "example": {},
  684. "format": {},
  685. "items": {},
  686. "maxItems": {},
  687. "maxLength": {},
  688. "maxProperties": {},
  689. "maximum": {},
  690. "minItems": {},
  691. "minLength": {},
  692. "minProperties": {},
  693. "minimum": {},
  694. "nullable": {},
  695. "pattern": {},
  696. "properties": {},
  697. "propertyOrdering": {},
  698. "required": {},
  699. "title": {},
  700. "type": {},
  701. }
  702. const geminiFunctionSchemaMaxDepth = 64
  703. // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
  704. func cleanFunctionParameters(params interface{}) interface{} {
  705. return cleanFunctionParametersWithDepth(params, 0)
  706. }
  707. func cleanFunctionParametersWithDepth(params interface{}, depth int) interface{} {
  708. if params == nil {
  709. return nil
  710. }
  711. if depth >= geminiFunctionSchemaMaxDepth {
  712. return cleanFunctionParametersShallow(params)
  713. }
  714. switch v := params.(type) {
  715. case map[string]interface{}:
  716. // Keep only Gemini-supported OpenAPI schema subset fields (per official SDK Schema).
  717. cleanedMap := make(map[string]interface{}, len(v))
  718. for k, val := range v {
  719. if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
  720. cleanedMap[k] = val
  721. }
  722. }
  723. normalizeGeminiSchemaTypeAndNullable(cleanedMap)
  724. // Clean properties
  725. if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
  726. cleanedProps := make(map[string]interface{})
  727. for propName, propValue := range props {
  728. cleanedProps[propName] = cleanFunctionParametersWithDepth(propValue, depth+1)
  729. }
  730. cleanedMap["properties"] = cleanedProps
  731. }
  732. // Recursively clean items in arrays
  733. if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
  734. cleanedMap["items"] = cleanFunctionParametersWithDepth(items, depth+1)
  735. }
  736. // OpenAPI tuple-style items is not supported by Gemini SDK Schema; keep first to avoid API rejection.
  737. if itemsArray, ok := cleanedMap["items"].([]interface{}); ok && len(itemsArray) > 0 {
  738. cleanedMap["items"] = cleanFunctionParametersWithDepth(itemsArray[0], depth+1)
  739. }
  740. // Recursively clean anyOf
  741. if nested, ok := cleanedMap["anyOf"].([]interface{}); ok && nested != nil {
  742. cleanedNested := make([]interface{}, len(nested))
  743. for i, item := range nested {
  744. cleanedNested[i] = cleanFunctionParametersWithDepth(item, depth+1)
  745. }
  746. cleanedMap["anyOf"] = cleanedNested
  747. }
  748. return cleanedMap
  749. case []interface{}:
  750. // Handle arrays of schemas
  751. cleanedArray := make([]interface{}, len(v))
  752. for i, item := range v {
  753. cleanedArray[i] = cleanFunctionParametersWithDepth(item, depth+1)
  754. }
  755. return cleanedArray
  756. default:
  757. // Not a map or array, return as is (e.g., could be a primitive)
  758. return params
  759. }
  760. }
  761. func cleanFunctionParametersShallow(params interface{}) interface{} {
  762. switch v := params.(type) {
  763. case map[string]interface{}:
  764. cleanedMap := make(map[string]interface{}, len(v))
  765. for k, val := range v {
  766. if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
  767. cleanedMap[k] = val
  768. }
  769. }
  770. normalizeGeminiSchemaTypeAndNullable(cleanedMap)
  771. // Stop recursion and avoid retaining huge nested structures.
  772. delete(cleanedMap, "properties")
  773. delete(cleanedMap, "items")
  774. delete(cleanedMap, "anyOf")
  775. return cleanedMap
  776. case []interface{}:
  777. // Prefer an empty list over deep recursion on attacker-controlled inputs.
  778. return []interface{}{}
  779. default:
  780. return params
  781. }
  782. }
  783. func normalizeGeminiSchemaTypeAndNullable(schema map[string]interface{}) {
  784. rawType, ok := schema["type"]
  785. if !ok || rawType == nil {
  786. return
  787. }
  788. normalize := func(t string) (string, bool) {
  789. switch strings.ToLower(strings.TrimSpace(t)) {
  790. case "object":
  791. return "OBJECT", false
  792. case "array":
  793. return "ARRAY", false
  794. case "string":
  795. return "STRING", false
  796. case "integer":
  797. return "INTEGER", false
  798. case "number":
  799. return "NUMBER", false
  800. case "boolean":
  801. return "BOOLEAN", false
  802. case "null":
  803. return "", true
  804. default:
  805. return t, false
  806. }
  807. }
  808. switch t := rawType.(type) {
  809. case string:
  810. normalized, isNull := normalize(t)
  811. if isNull {
  812. schema["nullable"] = true
  813. delete(schema, "type")
  814. return
  815. }
  816. schema["type"] = normalized
  817. case []interface{}:
  818. nullable := false
  819. var chosen string
  820. for _, item := range t {
  821. if s, ok := item.(string); ok {
  822. normalized, isNull := normalize(s)
  823. if isNull {
  824. nullable = true
  825. continue
  826. }
  827. if chosen == "" {
  828. chosen = normalized
  829. }
  830. }
  831. }
  832. if nullable {
  833. schema["nullable"] = true
  834. }
  835. if chosen != "" {
  836. schema["type"] = chosen
  837. } else {
  838. delete(schema, "type")
  839. }
  840. }
  841. }
  842. func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
  843. if depth >= 5 {
  844. return schema
  845. }
  846. v, ok := schema.(map[string]interface{})
  847. if !ok || len(v) == 0 {
  848. return schema
  849. }
  850. // 删除所有的title字段
  851. delete(v, "title")
  852. delete(v, "$schema")
  853. // 如果type不为object和array,则直接返回
  854. if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
  855. return schema
  856. }
  857. switch v["type"] {
  858. case "object":
  859. delete(v, "additionalProperties")
  860. // 处理 properties
  861. if properties, ok := v["properties"].(map[string]interface{}); ok {
  862. for key, value := range properties {
  863. properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
  864. }
  865. }
  866. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  867. if nested, ok := v[field].([]interface{}); ok {
  868. for i, item := range nested {
  869. nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
  870. }
  871. }
  872. }
  873. case "array":
  874. if items, ok := v["items"].(map[string]interface{}); ok {
  875. v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
  876. }
  877. }
  878. return v
  879. }
  880. func unescapeString(s string) (string, error) {
  881. var result []rune
  882. escaped := false
  883. i := 0
  884. for i < len(s) {
  885. r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
  886. if r == utf8.RuneError {
  887. return "", fmt.Errorf("invalid UTF-8 encoding")
  888. }
  889. if escaped {
  890. // 如果是转义符后的字符,检查其类型
  891. switch r {
  892. case '"':
  893. result = append(result, '"')
  894. case '\\':
  895. result = append(result, '\\')
  896. case '/':
  897. result = append(result, '/')
  898. case 'b':
  899. result = append(result, '\b')
  900. case 'f':
  901. result = append(result, '\f')
  902. case 'n':
  903. result = append(result, '\n')
  904. case 'r':
  905. result = append(result, '\r')
  906. case 't':
  907. result = append(result, '\t')
  908. case '\'':
  909. result = append(result, '\'')
  910. default:
  911. // 如果遇到一个非法的转义字符,直接按原样输出
  912. result = append(result, '\\', r)
  913. }
  914. escaped = false
  915. } else {
  916. if r == '\\' {
  917. escaped = true // 记录反斜杠作为转义符
  918. } else {
  919. result = append(result, r)
  920. }
  921. }
  922. i += size // 移动到下一个字符
  923. }
  924. return string(result), nil
  925. }
  926. func unescapeMapOrSlice(data interface{}) interface{} {
  927. switch v := data.(type) {
  928. case map[string]interface{}:
  929. for k, val := range v {
  930. v[k] = unescapeMapOrSlice(val)
  931. }
  932. case []interface{}:
  933. for i, val := range v {
  934. v[i] = unescapeMapOrSlice(val)
  935. }
  936. case string:
  937. if unescaped, err := unescapeString(v); err != nil {
  938. return v
  939. } else {
  940. return unescaped
  941. }
  942. }
  943. return data
  944. }
  945. func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
  946. var argsBytes []byte
  947. var err error
  948. // 移除 unescapeMapOrSlice 调用,直接使用 json.Marshal
  949. // JSON 序列化/反序列化已经正确处理了转义字符
  950. argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
  951. if err != nil {
  952. return nil
  953. }
  954. return &dto.ToolCallResponse{
  955. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  956. Type: "function",
  957. Function: dto.FunctionResponse{
  958. Arguments: string(argsBytes),
  959. Name: item.FunctionCall.FunctionName,
  960. },
  961. }
  962. }
  963. func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
  964. promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
  965. if promptTokens <= 0 && fallbackPromptTokens > 0 {
  966. promptTokens = fallbackPromptTokens
  967. }
  968. usage := dto.Usage{
  969. PromptTokens: promptTokens,
  970. CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
  971. TotalTokens: metadata.TotalTokenCount,
  972. }
  973. usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
  974. usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
  975. for _, detail := range metadata.PromptTokensDetails {
  976. if detail.Modality == "AUDIO" {
  977. usage.PromptTokensDetails.AudioTokens += detail.TokenCount
  978. } else if detail.Modality == "TEXT" {
  979. usage.PromptTokensDetails.TextTokens += detail.TokenCount
  980. }
  981. }
  982. for _, detail := range metadata.ToolUsePromptTokensDetails {
  983. if detail.Modality == "AUDIO" {
  984. usage.PromptTokensDetails.AudioTokens += detail.TokenCount
  985. } else if detail.Modality == "TEXT" {
  986. usage.PromptTokensDetails.TextTokens += detail.TokenCount
  987. }
  988. }
  989. if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
  990. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  991. }
  992. if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
  993. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  994. }
  995. return usage
  996. }
  997. func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
  998. fullTextResponse := dto.OpenAITextResponse{
  999. Id: helper.GetResponseID(c),
  1000. Object: "chat.completion",
  1001. Created: common.GetTimestamp(),
  1002. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  1003. }
  1004. isToolCall := false
  1005. for _, candidate := range response.Candidates {
  1006. choice := dto.OpenAITextResponseChoice{
  1007. Index: int(candidate.Index),
  1008. Message: dto.Message{
  1009. Role: "assistant",
  1010. Content: "",
  1011. },
  1012. FinishReason: constant.FinishReasonStop,
  1013. }
  1014. if len(candidate.Content.Parts) > 0 {
  1015. var texts []string
  1016. var toolCalls []dto.ToolCallResponse
  1017. for _, part := range candidate.Content.Parts {
  1018. if part.InlineData != nil {
  1019. // 媒体内容
  1020. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  1021. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  1022. texts = append(texts, imgText)
  1023. } else {
  1024. // 其他媒体类型,直接显示链接
  1025. texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
  1026. }
  1027. } else if part.FunctionCall != nil {
  1028. choice.FinishReason = constant.FinishReasonToolCalls
  1029. if call := getResponseToolCall(&part); call != nil {
  1030. toolCalls = append(toolCalls, *call)
  1031. }
  1032. } else if part.Thought {
  1033. choice.Message.ReasoningContent = part.Text
  1034. } else {
  1035. if part.ExecutableCode != nil {
  1036. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
  1037. } else if part.CodeExecutionResult != nil {
  1038. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
  1039. } else {
  1040. // 过滤掉空行
  1041. if part.Text != "\n" {
  1042. texts = append(texts, part.Text)
  1043. }
  1044. }
  1045. }
  1046. }
  1047. if len(toolCalls) > 0 {
  1048. choice.Message.SetToolCalls(toolCalls)
  1049. isToolCall = true
  1050. }
  1051. choice.Message.SetStringContent(strings.Join(texts, "\n"))
  1052. }
  1053. if candidate.FinishReason != nil {
  1054. switch *candidate.FinishReason {
  1055. case "STOP":
  1056. choice.FinishReason = constant.FinishReasonStop
  1057. case "MAX_TOKENS":
  1058. choice.FinishReason = constant.FinishReasonLength
  1059. case "SAFETY":
  1060. // Safety filter triggered
  1061. choice.FinishReason = constant.FinishReasonContentFilter
  1062. case "RECITATION":
  1063. // Recitation (citation) detected
  1064. choice.FinishReason = constant.FinishReasonContentFilter
  1065. case "BLOCKLIST":
  1066. // Blocklist triggered
  1067. choice.FinishReason = constant.FinishReasonContentFilter
  1068. case "PROHIBITED_CONTENT":
  1069. // Prohibited content detected
  1070. choice.FinishReason = constant.FinishReasonContentFilter
  1071. case "SPII":
  1072. // Sensitive personally identifiable information
  1073. choice.FinishReason = constant.FinishReasonContentFilter
  1074. case "OTHER":
  1075. // Other reasons
  1076. choice.FinishReason = constant.FinishReasonContentFilter
  1077. default:
  1078. choice.FinishReason = constant.FinishReasonContentFilter
  1079. }
  1080. }
  1081. if isToolCall {
  1082. choice.FinishReason = constant.FinishReasonToolCalls
  1083. }
  1084. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  1085. }
  1086. return &fullTextResponse
  1087. }
  1088. func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
  1089. choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
  1090. isStop := false
  1091. for _, candidate := range geminiResponse.Candidates {
  1092. if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
  1093. isStop = true
  1094. candidate.FinishReason = nil
  1095. }
  1096. choice := dto.ChatCompletionsStreamResponseChoice{
  1097. Index: int(candidate.Index),
  1098. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  1099. //Role: "assistant",
  1100. },
  1101. }
  1102. var texts []string
  1103. isTools := false
  1104. isThought := false
  1105. if candidate.FinishReason != nil {
  1106. // Map Gemini FinishReason to OpenAI finish_reason
  1107. switch *candidate.FinishReason {
  1108. case "STOP":
  1109. // Normal completion
  1110. choice.FinishReason = &constant.FinishReasonStop
  1111. case "MAX_TOKENS":
  1112. // Reached maximum token limit
  1113. choice.FinishReason = &constant.FinishReasonLength
  1114. case "SAFETY":
  1115. // Safety filter triggered
  1116. choice.FinishReason = &constant.FinishReasonContentFilter
  1117. case "RECITATION":
  1118. // Recitation (citation) detected
  1119. choice.FinishReason = &constant.FinishReasonContentFilter
  1120. case "BLOCKLIST":
  1121. // Blocklist triggered
  1122. choice.FinishReason = &constant.FinishReasonContentFilter
  1123. case "PROHIBITED_CONTENT":
  1124. // Prohibited content detected
  1125. choice.FinishReason = &constant.FinishReasonContentFilter
  1126. case "SPII":
  1127. // Sensitive personally identifiable information
  1128. choice.FinishReason = &constant.FinishReasonContentFilter
  1129. case "OTHER":
  1130. // Other reasons
  1131. choice.FinishReason = &constant.FinishReasonContentFilter
  1132. default:
  1133. // Unknown reason, treat as content filter
  1134. choice.FinishReason = &constant.FinishReasonContentFilter
  1135. }
  1136. }
  1137. for _, part := range candidate.Content.Parts {
  1138. if part.InlineData != nil {
  1139. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  1140. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  1141. texts = append(texts, imgText)
  1142. }
  1143. } else if part.FunctionCall != nil {
  1144. isTools = true
  1145. if call := getResponseToolCall(&part); call != nil {
  1146. call.SetIndex(len(choice.Delta.ToolCalls))
  1147. choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
  1148. }
  1149. } else if part.Thought {
  1150. isThought = true
  1151. texts = append(texts, part.Text)
  1152. } else {
  1153. if part.ExecutableCode != nil {
  1154. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
  1155. } else if part.CodeExecutionResult != nil {
  1156. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
  1157. } else {
  1158. if part.Text != "\n" {
  1159. texts = append(texts, part.Text)
  1160. }
  1161. }
  1162. }
  1163. }
  1164. if isThought {
  1165. choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
  1166. } else {
  1167. choice.Delta.SetContentString(strings.Join(texts, "\n"))
  1168. }
  1169. if isTools {
  1170. choice.FinishReason = &constant.FinishReasonToolCalls
  1171. }
  1172. choices = append(choices, choice)
  1173. }
  1174. var response dto.ChatCompletionsStreamResponse
  1175. response.Object = "chat.completion.chunk"
  1176. response.Choices = choices
  1177. return &response, isStop
  1178. }
  1179. func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  1180. streamData, err := common.Marshal(resp)
  1181. if err != nil {
  1182. return fmt.Errorf("failed to marshal stream response: %w", err)
  1183. }
  1184. err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  1185. if err != nil {
  1186. return fmt.Errorf("failed to handle stream format: %w", err)
  1187. }
  1188. return nil
  1189. }
  1190. func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  1191. streamData, err := common.Marshal(resp)
  1192. if err != nil {
  1193. return fmt.Errorf("failed to marshal stream response: %w", err)
  1194. }
  1195. openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
  1196. return nil
  1197. }
  1198. func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
  1199. var usage = &dto.Usage{}
  1200. var imageCount int
  1201. responseText := strings.Builder{}
  1202. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  1203. var geminiResponse dto.GeminiChatResponse
  1204. err := common.UnmarshalJsonStr(data, &geminiResponse)
  1205. if err != nil {
  1206. logger.LogError(c, "error unmarshalling stream response: "+err.Error())
  1207. return false
  1208. }
  1209. if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
  1210. common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
  1211. }
  1212. // 统计图片数量
  1213. for _, candidate := range geminiResponse.Candidates {
  1214. for _, part := range candidate.Content.Parts {
  1215. if part.InlineData != nil && part.InlineData.MimeType != "" {
  1216. imageCount++
  1217. }
  1218. if part.Text != "" {
  1219. responseText.WriteString(part.Text)
  1220. }
  1221. }
  1222. }
  1223. // 更新使用量统计
  1224. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  1225. mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
  1226. *usage = mappedUsage
  1227. }
  1228. return callback(data, &geminiResponse)
  1229. })
  1230. if imageCount != 0 {
  1231. if usage.CompletionTokens == 0 {
  1232. usage.CompletionTokens = imageCount * 1400
  1233. }
  1234. }
  1235. if usage.CompletionTokens <= 0 {
  1236. if info.ReceivedResponseCount > 0 {
  1237. usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
  1238. } else {
  1239. usage = &dto.Usage{}
  1240. }
  1241. }
  1242. return usage, nil
  1243. }
  1244. func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1245. id := helper.GetResponseID(c)
  1246. createAt := common.GetTimestamp()
  1247. finishReason := constant.FinishReasonStop
  1248. toolCallIndexByChoice := make(map[int]map[string]int)
  1249. nextToolCallIndexByChoice := make(map[int]int)
  1250. usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
  1251. response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
  1252. response.Id = id
  1253. response.Created = createAt
  1254. response.Model = info.UpstreamModelName
  1255. for choiceIdx := range response.Choices {
  1256. choiceKey := response.Choices[choiceIdx].Index
  1257. for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
  1258. tool := &response.Choices[choiceIdx].Delta.ToolCalls[toolIdx]
  1259. if tool.ID == "" {
  1260. continue
  1261. }
  1262. m := toolCallIndexByChoice[choiceKey]
  1263. if m == nil {
  1264. m = make(map[string]int)
  1265. toolCallIndexByChoice[choiceKey] = m
  1266. }
  1267. if idx, ok := m[tool.ID]; ok {
  1268. tool.SetIndex(idx)
  1269. continue
  1270. }
  1271. idx := nextToolCallIndexByChoice[choiceKey]
  1272. nextToolCallIndexByChoice[choiceKey] = idx + 1
  1273. m[tool.ID] = idx
  1274. tool.SetIndex(idx)
  1275. }
  1276. }
  1277. logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
  1278. if info.SendResponseCount == 0 {
  1279. // send first response
  1280. emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
  1281. if response.IsToolCall() {
  1282. if len(emptyResponse.Choices) > 0 && len(response.Choices) > 0 {
  1283. toolCalls := response.Choices[0].Delta.ToolCalls
  1284. copiedToolCalls := make([]dto.ToolCallResponse, len(toolCalls))
  1285. for idx := range toolCalls {
  1286. copiedToolCalls[idx] = toolCalls[idx]
  1287. copiedToolCalls[idx].Function.Arguments = ""
  1288. }
  1289. emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
  1290. }
  1291. finishReason = constant.FinishReasonToolCalls
  1292. err := handleStream(c, info, emptyResponse)
  1293. if err != nil {
  1294. logger.LogError(c, err.Error())
  1295. }
  1296. response.ClearToolCalls()
  1297. if response.IsFinished() {
  1298. response.Choices[0].FinishReason = nil
  1299. }
  1300. } else {
  1301. err := handleStream(c, info, emptyResponse)
  1302. if err != nil {
  1303. logger.LogError(c, err.Error())
  1304. }
  1305. }
  1306. }
  1307. err := handleStream(c, info, response)
  1308. if err != nil {
  1309. logger.LogError(c, err.Error())
  1310. }
  1311. if isStop {
  1312. _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
  1313. }
  1314. return true
  1315. })
  1316. if err != nil {
  1317. return usage, err
  1318. }
  1319. response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  1320. handleErr := handleFinalStream(c, info, response)
  1321. if handleErr != nil {
  1322. common.SysLog("send final response failed: " + handleErr.Error())
  1323. }
  1324. return usage, nil
  1325. }
  1326. func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1327. responseBody, err := io.ReadAll(resp.Body)
  1328. if err != nil {
  1329. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1330. }
  1331. service.CloseResponseBodyGracefully(resp)
  1332. if common.DebugEnabled {
  1333. println(string(responseBody))
  1334. }
  1335. var geminiResponse dto.GeminiChatResponse
  1336. err = common.Unmarshal(responseBody, &geminiResponse)
  1337. if err != nil {
  1338. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1339. }
  1340. if len(geminiResponse.Candidates) == 0 {
  1341. usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
  1342. var newAPIError *types.NewAPIError
  1343. if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
  1344. common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
  1345. newAPIError = types.NewOpenAIError(
  1346. errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
  1347. types.ErrorCodePromptBlocked,
  1348. http.StatusBadRequest,
  1349. )
  1350. } else {
  1351. common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "gemini_empty_candidates")
  1352. newAPIError = types.NewOpenAIError(
  1353. errors.New("empty response from Gemini API"),
  1354. types.ErrorCodeEmptyResponse,
  1355. http.StatusInternalServerError,
  1356. )
  1357. }
  1358. service.ResetStatusCode(newAPIError, c.GetString("status_code_mapping"))
  1359. switch info.RelayFormat {
  1360. case types.RelayFormatClaude:
  1361. c.JSON(newAPIError.StatusCode, gin.H{
  1362. "type": "error",
  1363. "error": newAPIError.ToClaudeError(),
  1364. })
  1365. default:
  1366. c.JSON(newAPIError.StatusCode, gin.H{
  1367. "error": newAPIError.ToOpenAIError(),
  1368. })
  1369. }
  1370. return &usage, nil
  1371. }
  1372. fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
  1373. fullTextResponse.Model = info.UpstreamModelName
  1374. usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
  1375. fullTextResponse.Usage = usage
  1376. switch info.RelayFormat {
  1377. case types.RelayFormatOpenAI:
  1378. responseBody, err = common.Marshal(fullTextResponse)
  1379. if err != nil {
  1380. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1381. }
  1382. case types.RelayFormatClaude:
  1383. claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
  1384. claudeRespStr, err := common.Marshal(claudeResp)
  1385. if err != nil {
  1386. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1387. }
  1388. responseBody = claudeRespStr
  1389. case types.RelayFormatGemini:
  1390. break
  1391. }
  1392. service.IOCopyBytesGracefully(c, resp, responseBody)
  1393. return &usage, nil
  1394. }
  1395. func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1396. defer service.CloseResponseBodyGracefully(resp)
  1397. responseBody, readErr := io.ReadAll(resp.Body)
  1398. if readErr != nil {
  1399. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1400. }
  1401. var geminiResponse dto.GeminiBatchEmbeddingResponse
  1402. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1403. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1404. }
  1405. // convert to openai format response
  1406. openAIResponse := dto.OpenAIEmbeddingResponse{
  1407. Object: "list",
  1408. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
  1409. Model: info.UpstreamModelName,
  1410. }
  1411. for i, embedding := range geminiResponse.Embeddings {
  1412. openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
  1413. Object: "embedding",
  1414. Embedding: embedding.Values,
  1415. Index: i,
  1416. })
  1417. }
  1418. // calculate usage
  1419. // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
  1420. // Google has not yet clarified how embedding models will be billed
  1421. // refer to openai billing method to use input tokens billing
  1422. // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
  1423. usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
  1424. openAIResponse.Usage = *usage
  1425. jsonResponse, jsonErr := common.Marshal(openAIResponse)
  1426. if jsonErr != nil {
  1427. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1428. }
  1429. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  1430. return usage, nil
  1431. }
  1432. func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1433. responseBody, readErr := io.ReadAll(resp.Body)
  1434. if readErr != nil {
  1435. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1436. }
  1437. _ = resp.Body.Close()
  1438. var geminiResponse dto.GeminiImageResponse
  1439. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1440. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1441. }
  1442. if len(geminiResponse.Predictions) == 0 {
  1443. return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1444. }
  1445. // convert to openai format response
  1446. openAIResponse := dto.ImageResponse{
  1447. Created: common.GetTimestamp(),
  1448. Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
  1449. }
  1450. for _, prediction := range geminiResponse.Predictions {
  1451. if prediction.RaiFilteredReason != "" {
  1452. continue // skip filtered image
  1453. }
  1454. openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
  1455. B64Json: prediction.BytesBase64Encoded,
  1456. })
  1457. }
  1458. jsonResponse, jsonErr := json.Marshal(openAIResponse)
  1459. if jsonErr != nil {
  1460. return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
  1461. }
  1462. c.Writer.Header().Set("Content-Type", "application/json")
  1463. c.Writer.WriteHeader(resp.StatusCode)
  1464. _, _ = c.Writer.Write(jsonResponse)
  1465. // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
  1466. // each image has fixed 258 tokens
  1467. const imageTokens = 258
  1468. generatedImages := len(openAIResponse.Data)
  1469. usage := &dto.Usage{
  1470. PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
  1471. CompletionTokens: 0, // image generation does not calculate completion tokens
  1472. TotalTokens: imageTokens * generatedImages,
  1473. }
  1474. return usage, nil
  1475. }
  1476. type GeminiModelsResponse struct {
  1477. Models []dto.GeminiModel `json:"models"`
  1478. NextPageToken string `json:"nextPageToken"`
  1479. }
  1480. func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
  1481. client, err := service.GetHttpClientWithProxy(proxyURL)
  1482. if err != nil {
  1483. return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
  1484. }
  1485. allModels := make([]string, 0)
  1486. nextPageToken := ""
  1487. maxPages := 100 // Safety limit to prevent infinite loops
  1488. for page := 0; page < maxPages; page++ {
  1489. url := fmt.Sprintf("%s/v1beta/models", baseURL)
  1490. if nextPageToken != "" {
  1491. url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
  1492. }
  1493. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  1494. request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  1495. if err != nil {
  1496. cancel()
  1497. return nil, fmt.Errorf("创建请求失败: %v", err)
  1498. }
  1499. request.Header.Set("x-goog-api-key", apiKey)
  1500. response, err := client.Do(request)
  1501. if err != nil {
  1502. cancel()
  1503. return nil, fmt.Errorf("请求失败: %v", err)
  1504. }
  1505. if response.StatusCode != http.StatusOK {
  1506. body, _ := io.ReadAll(response.Body)
  1507. response.Body.Close()
  1508. cancel()
  1509. return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
  1510. }
  1511. body, err := io.ReadAll(response.Body)
  1512. response.Body.Close()
  1513. cancel()
  1514. if err != nil {
  1515. return nil, fmt.Errorf("读取响应失败: %v", err)
  1516. }
  1517. var modelsResponse GeminiModelsResponse
  1518. if err = common.Unmarshal(body, &modelsResponse); err != nil {
  1519. return nil, fmt.Errorf("解析响应失败: %v", err)
  1520. }
  1521. for _, model := range modelsResponse.Models {
  1522. modelNameValue, ok := model.Name.(string)
  1523. if !ok {
  1524. continue
  1525. }
  1526. modelName := strings.TrimPrefix(modelNameValue, "models/")
  1527. allModels = append(allModels, modelName)
  1528. }
  1529. nextPageToken = modelsResponse.NextPageToken
  1530. if nextPageToken == "" {
  1531. break
  1532. }
  1533. }
  1534. return allModels, nil
  1535. }
  1536. // convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
  1537. // OpenAI tool_choice values:
  1538. // - "auto": Let the model decide (default)
  1539. // - "none": Don't call any tools
  1540. // - "required": Must call at least one tool
  1541. // - {"type": "function", "function": {"name": "xxx"}}: Call specific function
  1542. //
  1543. // Gemini functionCallingConfig.mode values:
  1544. // - "AUTO": Model decides whether to call functions
  1545. // - "NONE": Model won't call functions
  1546. // - "ANY": Model must call at least one function
  1547. func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
  1548. if toolChoice == nil {
  1549. return nil
  1550. }
  1551. // Handle string values: "auto", "none", "required"
  1552. if toolChoiceStr, ok := toolChoice.(string); ok {
  1553. config := &dto.ToolConfig{
  1554. FunctionCallingConfig: &dto.FunctionCallingConfig{},
  1555. }
  1556. switch toolChoiceStr {
  1557. case "auto":
  1558. config.FunctionCallingConfig.Mode = "AUTO"
  1559. case "none":
  1560. config.FunctionCallingConfig.Mode = "NONE"
  1561. case "required":
  1562. config.FunctionCallingConfig.Mode = "ANY"
  1563. default:
  1564. // Unknown string value, default to AUTO
  1565. config.FunctionCallingConfig.Mode = "AUTO"
  1566. }
  1567. return config
  1568. }
  1569. // Handle object value: {"type": "function", "function": {"name": "xxx"}}
  1570. if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
  1571. if toolChoiceMap["type"] == "function" {
  1572. config := &dto.ToolConfig{
  1573. FunctionCallingConfig: &dto.FunctionCallingConfig{
  1574. Mode: "ANY",
  1575. },
  1576. }
  1577. // Extract function name if specified
  1578. if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
  1579. if name, ok := function["name"].(string); ok && name != "" {
  1580. config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
  1581. }
  1582. }
  1583. return config
  1584. }
  1585. // Unsupported map structure (type is not "function"), return nil
  1586. return nil
  1587. }
  1588. // Unsupported type, return nil
  1589. return nil
  1590. }