relay-gemini.go 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735
  1. package gemini
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "unicode/utf8"
  13. "github.com/QuantumNous/new-api/common"
  14. "github.com/QuantumNous/new-api/constant"
  15. "github.com/QuantumNous/new-api/dto"
  16. "github.com/QuantumNous/new-api/logger"
  17. "github.com/QuantumNous/new-api/relay/channel/openai"
  18. relaycommon "github.com/QuantumNous/new-api/relay/common"
  19. "github.com/QuantumNous/new-api/relay/helper"
  20. "github.com/QuantumNous/new-api/service"
  21. "github.com/QuantumNous/new-api/setting/model_setting"
  22. "github.com/QuantumNous/new-api/setting/reasoning"
  23. "github.com/QuantumNous/new-api/types"
  24. "github.com/gin-gonic/gin"
  25. )
  26. // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
  27. var geminiSupportedMimeTypes = map[string]bool{
  28. "application/pdf": true,
  29. "audio/mpeg": true,
  30. "audio/mp3": true,
  31. "audio/wav": true,
  32. "image/png": true,
  33. "image/jpeg": true,
  34. "image/jpg": true, // support old image/jpeg
  35. "image/webp": true,
  36. "text/plain": true,
  37. "video/mov": true,
  38. "video/mpeg": true,
  39. "video/mp4": true,
  40. "video/mpg": true,
  41. "video/avi": true,
  42. "video/wmv": true,
  43. "video/mpegps": true,
  44. "video/flv": true,
  45. }
  46. const thoughtSignatureBypassValue = "context_engineering_is_the_way_to_go"
  47. // Gemini 允许的思考预算范围
  48. const (
  49. pro25MinBudget = 128
  50. pro25MaxBudget = 32768
  51. flash25MaxBudget = 24576
  52. flash25LiteMinBudget = 512
  53. flash25LiteMaxBudget = 24576
  54. )
  55. func isNew25ProModel(modelName string) bool {
  56. return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  57. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  58. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  59. }
  60. func is25FlashLiteModel(modelName string) bool {
  61. return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
  62. }
  63. // clampThinkingBudget 根据模型名称将预算限制在允许的范围内
  64. func clampThinkingBudget(modelName string, budget int) int {
  65. isNew25Pro := isNew25ProModel(modelName)
  66. is25FlashLite := is25FlashLiteModel(modelName)
  67. if is25FlashLite {
  68. if budget < flash25LiteMinBudget {
  69. return flash25LiteMinBudget
  70. }
  71. if budget > flash25LiteMaxBudget {
  72. return flash25LiteMaxBudget
  73. }
  74. } else if isNew25Pro {
  75. if budget < pro25MinBudget {
  76. return pro25MinBudget
  77. }
  78. if budget > pro25MaxBudget {
  79. return pro25MaxBudget
  80. }
  81. } else { // 其他模型
  82. if budget < 0 {
  83. return 0
  84. }
  85. if budget > flash25MaxBudget {
  86. return flash25MaxBudget
  87. }
  88. }
  89. return budget
  90. }
  91. // "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
  92. // "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
  93. // "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
  94. // "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
  95. func clampThinkingBudgetByEffort(modelName string, effort string) int {
  96. isNew25Pro := isNew25ProModel(modelName)
  97. is25FlashLite := is25FlashLiteModel(modelName)
  98. maxBudget := 0
  99. if is25FlashLite {
  100. maxBudget = flash25LiteMaxBudget
  101. }
  102. if isNew25Pro {
  103. maxBudget = pro25MaxBudget
  104. } else {
  105. maxBudget = flash25MaxBudget
  106. }
  107. switch effort {
  108. case "high":
  109. maxBudget = maxBudget * 80 / 100
  110. case "medium":
  111. maxBudget = maxBudget * 50 / 100
  112. case "low":
  113. maxBudget = maxBudget * 20 / 100
  114. case "minimal":
  115. maxBudget = maxBudget * 5 / 100
  116. }
  117. return clampThinkingBudget(modelName, maxBudget)
  118. }
  119. func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
  120. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
  121. modelName := info.UpstreamModelName
  122. isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  123. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  124. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  125. if strings.Contains(modelName, "-thinking-") {
  126. parts := strings.SplitN(modelName, "-thinking-", 2)
  127. if len(parts) == 2 && parts[1] != "" {
  128. if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
  129. clampedBudget := clampThinkingBudget(modelName, budgetTokens)
  130. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  131. ThinkingBudget: common.GetPointer(clampedBudget),
  132. IncludeThoughts: true,
  133. }
  134. }
  135. }
  136. } else if strings.HasSuffix(modelName, "-thinking") {
  137. unsupportedModels := []string{
  138. "gemini-2.5-pro-preview-05-06",
  139. "gemini-2.5-pro-preview-03-25",
  140. }
  141. isUnsupported := false
  142. for _, unsupportedModel := range unsupportedModels {
  143. if strings.HasPrefix(modelName, unsupportedModel) {
  144. isUnsupported = true
  145. break
  146. }
  147. }
  148. if isUnsupported {
  149. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  150. IncludeThoughts: true,
  151. }
  152. } else {
  153. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  154. IncludeThoughts: true,
  155. }
  156. if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
  157. budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
  158. clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
  159. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
  160. } else {
  161. if len(oaiRequest) > 0 {
  162. // 如果有reasoningEffort参数,则根据其值设置思考预算
  163. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
  164. }
  165. }
  166. }
  167. } else if strings.HasSuffix(modelName, "-nothinking") {
  168. if !isNew25Pro {
  169. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  170. ThinkingBudget: common.GetPointer(0),
  171. }
  172. }
  173. } else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
  174. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  175. IncludeThoughts: true,
  176. ThinkingLevel: level,
  177. }
  178. info.ReasoningEffort = level
  179. }
  180. }
  181. }
  182. // Setting safety to the lowest possible values since Gemini is already powerless enough
  183. func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
  184. geminiRequest := dto.GeminiChatRequest{
  185. Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
  186. GenerationConfig: dto.GeminiChatGenerationConfig{
  187. Temperature: textRequest.Temperature,
  188. TopP: textRequest.TopP,
  189. MaxOutputTokens: textRequest.GetMaxTokens(),
  190. Seed: int64(textRequest.Seed),
  191. },
  192. }
  193. attachThoughtSignature := (info.ChannelType == constant.ChannelTypeGemini ||
  194. info.ChannelType == constant.ChannelTypeVertexAi) &&
  195. model_setting.GetGeminiSettings().FunctionCallThoughtSignatureEnabled
  196. if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
  197. geminiRequest.GenerationConfig.ResponseModalities = []string{
  198. "TEXT",
  199. "IMAGE",
  200. }
  201. }
  202. if stopSequences := parseStopSequences(textRequest.Stop); len(stopSequences) > 0 {
  203. // Gemini supports up to 5 stop sequences
  204. if len(stopSequences) > 5 {
  205. stopSequences = stopSequences[:5]
  206. }
  207. geminiRequest.GenerationConfig.StopSequences = stopSequences
  208. }
  209. adaptorWithExtraBody := false
  210. // patch extra_body
  211. if len(textRequest.ExtraBody) > 0 {
  212. var extraBody map[string]interface{}
  213. if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
  214. return nil, fmt.Errorf("invalid extra body: %w", err)
  215. }
  216. // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
  217. if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
  218. if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  219. adaptorWithExtraBody = true
  220. // check error param name like thinkingConfig, should be thinking_config
  221. if _, hasErrorParam := googleBody["thinkingConfig"]; hasErrorParam {
  222. return nil, errors.New("extra_body.google.thinkingConfig is not supported, use extra_body.google.thinking_config instead")
  223. }
  224. if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
  225. // check error param name like thinkingBudget, should be thinking_budget
  226. if _, hasErrorParam := thinkingConfig["thinkingBudget"]; hasErrorParam {
  227. return nil, errors.New("extra_body.google.thinking_config.thinkingBudget is not supported, use extra_body.google.thinking_config.thinking_budget instead")
  228. }
  229. var hasThinkingConfig bool
  230. var tempThinkingConfig dto.GeminiThinkingConfig
  231. if thinkingBudget, exists := thinkingConfig["thinking_budget"]; exists {
  232. switch v := thinkingBudget.(type) {
  233. case float64:
  234. budgetInt := int(v)
  235. tempThinkingConfig.ThinkingBudget = common.GetPointer(budgetInt)
  236. if budgetInt > 0 {
  237. // 有正数预算
  238. tempThinkingConfig.IncludeThoughts = true
  239. } else {
  240. // 存在但为0或负数,禁用思考
  241. tempThinkingConfig.IncludeThoughts = false
  242. }
  243. hasThinkingConfig = true
  244. default:
  245. return nil, errors.New("extra_body.google.thinking_config.thinking_budget must be an integer")
  246. }
  247. }
  248. if includeThoughts, exists := thinkingConfig["include_thoughts"]; exists {
  249. if v, ok := includeThoughts.(bool); ok {
  250. tempThinkingConfig.IncludeThoughts = v
  251. hasThinkingConfig = true
  252. } else {
  253. return nil, errors.New("extra_body.google.thinking_config.include_thoughts must be a boolean")
  254. }
  255. }
  256. if thinkingLevel, exists := thinkingConfig["thinking_level"]; exists {
  257. if v, ok := thinkingLevel.(string); ok {
  258. tempThinkingConfig.ThinkingLevel = v
  259. hasThinkingConfig = true
  260. } else {
  261. return nil, errors.New("extra_body.google.thinking_config.thinking_level must be a string")
  262. }
  263. }
  264. if hasThinkingConfig {
  265. // 避免 panic: 仅在获得配置时分配,防止后续赋值时空指针
  266. if geminiRequest.GenerationConfig.ThinkingConfig == nil {
  267. geminiRequest.GenerationConfig.ThinkingConfig = &tempThinkingConfig
  268. } else {
  269. // 如果已分配,则合并内容
  270. if tempThinkingConfig.ThinkingBudget != nil {
  271. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = tempThinkingConfig.ThinkingBudget
  272. }
  273. geminiRequest.GenerationConfig.ThinkingConfig.IncludeThoughts = tempThinkingConfig.IncludeThoughts
  274. if tempThinkingConfig.ThinkingLevel != "" {
  275. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingLevel = tempThinkingConfig.ThinkingLevel
  276. }
  277. }
  278. }
  279. }
  280. }
  281. // check error param name like imageConfig, should be image_config
  282. if _, hasErrorParam := googleBody["imageConfig"]; hasErrorParam {
  283. return nil, errors.New("extra_body.google.imageConfig is not supported, use extra_body.google.image_config instead")
  284. }
  285. if imageConfig, ok := googleBody["image_config"].(map[string]interface{}); ok {
  286. // check error param name like aspectRatio, should be aspect_ratio
  287. if _, hasErrorParam := imageConfig["aspectRatio"]; hasErrorParam {
  288. return nil, errors.New("extra_body.google.image_config.aspectRatio is not supported, use extra_body.google.image_config.aspect_ratio instead")
  289. }
  290. // check error param name like imageSize, should be image_size
  291. if _, hasErrorParam := imageConfig["imageSize"]; hasErrorParam {
  292. return nil, errors.New("extra_body.google.image_config.imageSize is not supported, use extra_body.google.image_config.image_size instead")
  293. }
  294. // convert snake_case to camelCase for Gemini API
  295. geminiImageConfig := make(map[string]interface{})
  296. if aspectRatio, ok := imageConfig["aspect_ratio"]; ok {
  297. geminiImageConfig["aspectRatio"] = aspectRatio
  298. }
  299. if imageSize, ok := imageConfig["image_size"]; ok {
  300. geminiImageConfig["imageSize"] = imageSize
  301. }
  302. if len(geminiImageConfig) > 0 {
  303. imageConfigBytes, err := common.Marshal(geminiImageConfig)
  304. if err != nil {
  305. return nil, fmt.Errorf("failed to marshal image_config: %w", err)
  306. }
  307. geminiRequest.GenerationConfig.ImageConfig = imageConfigBytes
  308. }
  309. }
  310. }
  311. }
  312. if !adaptorWithExtraBody {
  313. ThinkingAdaptor(&geminiRequest, info, textRequest)
  314. }
  315. safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
  316. for _, category := range SafetySettingList {
  317. safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
  318. Category: category,
  319. Threshold: model_setting.GetGeminiSafetySetting(category),
  320. })
  321. }
  322. geminiRequest.SafetySettings = safetySettings
  323. // openaiContent.FuncToToolCalls()
  324. if textRequest.Tools != nil {
  325. functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
  326. googleSearch := false
  327. codeExecution := false
  328. urlContext := false
  329. for _, tool := range textRequest.Tools {
  330. if tool.Function.Name == "googleSearch" {
  331. googleSearch = true
  332. continue
  333. }
  334. if tool.Function.Name == "codeExecution" {
  335. codeExecution = true
  336. continue
  337. }
  338. if tool.Function.Name == "urlContext" {
  339. urlContext = true
  340. continue
  341. }
  342. if tool.Function.Parameters != nil {
  343. params, ok := tool.Function.Parameters.(map[string]interface{})
  344. if ok {
  345. if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
  346. if len(props) == 0 {
  347. tool.Function.Parameters = nil
  348. }
  349. }
  350. }
  351. }
  352. // Clean the parameters before appending
  353. cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
  354. tool.Function.Parameters = cleanedParams
  355. functions = append(functions, tool.Function)
  356. }
  357. geminiTools := geminiRequest.GetTools()
  358. if codeExecution {
  359. geminiTools = append(geminiTools, dto.GeminiChatTool{
  360. CodeExecution: make(map[string]string),
  361. })
  362. }
  363. if googleSearch {
  364. geminiTools = append(geminiTools, dto.GeminiChatTool{
  365. GoogleSearch: make(map[string]string),
  366. })
  367. }
  368. if urlContext {
  369. geminiTools = append(geminiTools, dto.GeminiChatTool{
  370. URLContext: make(map[string]string),
  371. })
  372. }
  373. if len(functions) > 0 {
  374. geminiTools = append(geminiTools, dto.GeminiChatTool{
  375. FunctionDeclarations: functions,
  376. })
  377. }
  378. geminiRequest.SetTools(geminiTools)
  379. // [NEW] Convert OpenAI tool_choice to Gemini toolConfig.functionCallingConfig
  380. // Mapping: "auto" -> "AUTO", "none" -> "NONE", "required" -> "ANY"
  381. // Object format: {"type": "function", "function": {"name": "xxx"}} -> "ANY" + allowedFunctionNames
  382. if textRequest.ToolChoice != nil {
  383. geminiRequest.ToolConfig = convertToolChoiceToGeminiConfig(textRequest.ToolChoice)
  384. }
  385. }
  386. if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
  387. geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
  388. if len(textRequest.ResponseFormat.JsonSchema) > 0 {
  389. // 先将json.RawMessage解析
  390. var jsonSchema dto.FormatJsonSchema
  391. if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
  392. cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
  393. geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
  394. }
  395. }
  396. }
  397. tool_call_ids := make(map[string]string)
  398. var system_content []string
  399. //shouldAddDummyModelMessage := false
  400. for _, message := range textRequest.Messages {
  401. if message.Role == "system" || message.Role == "developer" {
  402. system_content = append(system_content, message.StringContent())
  403. continue
  404. } else if message.Role == "tool" || message.Role == "function" {
  405. if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
  406. geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
  407. Role: "user",
  408. })
  409. }
  410. var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
  411. name := ""
  412. if message.Name != nil {
  413. name = *message.Name
  414. } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
  415. name = val
  416. }
  417. var contentMap map[string]interface{}
  418. contentStr := message.StringContent()
  419. // 1. 尝试解析为 JSON 对象
  420. if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
  421. // 2. 如果失败,尝试解析为 JSON 数组
  422. var contentSlice []interface{}
  423. if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
  424. // 如果是数组,包装成对象
  425. contentMap = map[string]interface{}{"result": contentSlice}
  426. } else {
  427. // 3. 如果再次失败,作为纯文本处理
  428. contentMap = map[string]interface{}{"content": contentStr}
  429. }
  430. }
  431. functionResp := &dto.GeminiFunctionResponse{
  432. Name: name,
  433. Response: contentMap,
  434. }
  435. *parts = append(*parts, dto.GeminiPart{
  436. FunctionResponse: functionResp,
  437. })
  438. continue
  439. }
  440. var parts []dto.GeminiPart
  441. content := dto.GeminiChatContent{
  442. Role: message.Role,
  443. }
  444. shouldAttachThoughtSignature := attachThoughtSignature && (message.Role == "assistant" || message.Role == "model")
  445. signatureAttached := false
  446. // isToolCall := false
  447. if message.ToolCalls != nil {
  448. // message.Role = "model"
  449. // isToolCall = true
  450. for _, call := range message.ParseToolCalls() {
  451. args := map[string]interface{}{}
  452. if call.Function.Arguments != "" {
  453. if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
  454. return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
  455. }
  456. }
  457. toolCall := dto.GeminiPart{
  458. FunctionCall: &dto.FunctionCall{
  459. FunctionName: call.Function.Name,
  460. Arguments: args,
  461. },
  462. }
  463. if shouldAttachThoughtSignature && !signatureAttached && hasFunctionCallContent(toolCall.FunctionCall) && len(toolCall.ThoughtSignature) == 0 {
  464. toolCall.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  465. signatureAttached = true
  466. }
  467. parts = append(parts, toolCall)
  468. tool_call_ids[call.ID] = call.Function.Name
  469. }
  470. }
  471. openaiContent := message.ParseContent()
  472. for _, part := range openaiContent {
  473. if part.Type == dto.ContentTypeText {
  474. if part.Text == "" {
  475. continue
  476. }
  477. // check markdown image ![image](data:image/jpeg;base64,xxxxxxxxxxxx)
  478. // 使用字符串查找而非正则,避免大文本性能问题
  479. text := part.Text
  480. hasMarkdownImage := false
  481. for {
  482. // 快速检查是否包含 markdown 图片标记
  483. startIdx := strings.Index(text, "![")
  484. if startIdx == -1 {
  485. break
  486. }
  487. // 找到 ](
  488. bracketIdx := strings.Index(text[startIdx:], "](data:")
  489. if bracketIdx == -1 {
  490. break
  491. }
  492. bracketIdx += startIdx
  493. // 找到闭合的 )
  494. closeIdx := strings.Index(text[bracketIdx+2:], ")")
  495. if closeIdx == -1 {
  496. break
  497. }
  498. closeIdx += bracketIdx + 2
  499. hasMarkdownImage = true
  500. // 添加图片前的文本
  501. if startIdx > 0 {
  502. textBefore := text[:startIdx]
  503. if textBefore != "" {
  504. parts = append(parts, dto.GeminiPart{
  505. Text: textBefore,
  506. })
  507. }
  508. }
  509. // 提取 data URL (从 "](" 后面开始,到 ")" 之前)
  510. dataUrl := text[bracketIdx+2 : closeIdx]
  511. format, base64String, err := service.DecodeBase64FileData(dataUrl)
  512. if err != nil {
  513. return nil, fmt.Errorf("decode markdown base64 image data failed: %s", err.Error())
  514. }
  515. imgPart := dto.GeminiPart{
  516. InlineData: &dto.GeminiInlineData{
  517. MimeType: format,
  518. Data: base64String,
  519. },
  520. }
  521. if shouldAttachThoughtSignature {
  522. imgPart.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  523. }
  524. parts = append(parts, imgPart)
  525. // 继续处理剩余文本
  526. text = text[closeIdx+1:]
  527. }
  528. // 添加剩余文本或原始文本(如果没有找到 markdown 图片)
  529. if !hasMarkdownImage {
  530. parts = append(parts, dto.GeminiPart{
  531. Text: part.Text,
  532. })
  533. }
  534. } else if part.Type == dto.ContentTypeImageURL {
  535. // 使用统一的文件服务获取图片数据
  536. var source *types.FileSource
  537. imageUrl := part.GetImageMedia().Url
  538. if strings.HasPrefix(imageUrl, "http") {
  539. source = types.NewURLFileSource(imageUrl)
  540. } else {
  541. source = types.NewBase64FileSource(imageUrl, "")
  542. }
  543. base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
  544. if err != nil {
  545. return nil, fmt.Errorf("get file data from '%s' failed: %w", source.GetIdentifier(), err)
  546. }
  547. // 校验 MimeType 是否在 Gemini 支持的白名单中
  548. if _, ok := geminiSupportedMimeTypes[strings.ToLower(mimeType)]; !ok {
  549. return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
  550. }
  551. parts = append(parts, dto.GeminiPart{
  552. InlineData: &dto.GeminiInlineData{
  553. MimeType: mimeType,
  554. Data: base64Data,
  555. },
  556. })
  557. } else if part.Type == dto.ContentTypeFile {
  558. if part.GetFile().FileId != "" {
  559. return nil, fmt.Errorf("only base64 file is supported in gemini")
  560. }
  561. fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
  562. base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
  563. if err != nil {
  564. return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
  565. }
  566. parts = append(parts, dto.GeminiPart{
  567. InlineData: &dto.GeminiInlineData{
  568. MimeType: mimeType,
  569. Data: base64Data,
  570. },
  571. })
  572. } else if part.Type == dto.ContentTypeInputAudio {
  573. if part.GetInputAudio().Data == "" {
  574. return nil, fmt.Errorf("only base64 audio is supported in gemini")
  575. }
  576. audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
  577. base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
  578. if err != nil {
  579. return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
  580. }
  581. parts = append(parts, dto.GeminiPart{
  582. InlineData: &dto.GeminiInlineData{
  583. MimeType: mimeType,
  584. Data: base64Data,
  585. },
  586. })
  587. }
  588. }
  589. // 如果需要附加签名但还没有附加(没有 tool_calls 或 tool_calls 为空),
  590. // 则在第一个文本 part 上附加 thoughtSignature
  591. if shouldAttachThoughtSignature && !signatureAttached && len(parts) > 0 {
  592. for i := range parts {
  593. if parts[i].Text != "" {
  594. parts[i].ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  595. break
  596. }
  597. }
  598. }
  599. content.Parts = parts
  600. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  601. if content.Role == "assistant" {
  602. content.Role = "model"
  603. }
  604. if len(content.Parts) > 0 {
  605. geminiRequest.Contents = append(geminiRequest.Contents, content)
  606. }
  607. }
  608. if len(system_content) > 0 {
  609. geminiRequest.SystemInstructions = &dto.GeminiChatContent{
  610. Parts: []dto.GeminiPart{
  611. {
  612. Text: strings.Join(system_content, "\n"),
  613. },
  614. },
  615. }
  616. }
  617. return &geminiRequest, nil
  618. }
  619. // parseStopSequences 解析停止序列,支持字符串或字符串数组
  620. func parseStopSequences(stop any) []string {
  621. if stop == nil {
  622. return nil
  623. }
  624. switch v := stop.(type) {
  625. case string:
  626. if v != "" {
  627. return []string{v}
  628. }
  629. case []string:
  630. return v
  631. case []interface{}:
  632. sequences := make([]string, 0, len(v))
  633. for _, item := range v {
  634. if str, ok := item.(string); ok && str != "" {
  635. sequences = append(sequences, str)
  636. }
  637. }
  638. return sequences
  639. }
  640. return nil
  641. }
  642. func hasFunctionCallContent(call *dto.FunctionCall) bool {
  643. if call == nil {
  644. return false
  645. }
  646. if strings.TrimSpace(call.FunctionName) != "" {
  647. return true
  648. }
  649. switch v := call.Arguments.(type) {
  650. case nil:
  651. return false
  652. case string:
  653. return strings.TrimSpace(v) != ""
  654. case map[string]interface{}:
  655. return len(v) > 0
  656. case []interface{}:
  657. return len(v) > 0
  658. default:
  659. return true
  660. }
  661. }
  662. // Helper function to get a list of supported MIME types for error messages
  663. func getSupportedMimeTypesList() []string {
  664. keys := make([]string, 0, len(geminiSupportedMimeTypes))
  665. for k := range geminiSupportedMimeTypes {
  666. keys = append(keys, k)
  667. }
  668. return keys
  669. }
  670. var geminiOpenAPISchemaAllowedFields = map[string]struct{}{
  671. "anyOf": {},
  672. "default": {},
  673. "description": {},
  674. "enum": {},
  675. "example": {},
  676. "format": {},
  677. "items": {},
  678. "maxItems": {},
  679. "maxLength": {},
  680. "maxProperties": {},
  681. "maximum": {},
  682. "minItems": {},
  683. "minLength": {},
  684. "minProperties": {},
  685. "minimum": {},
  686. "nullable": {},
  687. "pattern": {},
  688. "properties": {},
  689. "propertyOrdering": {},
  690. "required": {},
  691. "title": {},
  692. "type": {},
  693. }
  694. const geminiFunctionSchemaMaxDepth = 64
  695. // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
  696. func cleanFunctionParameters(params interface{}) interface{} {
  697. return cleanFunctionParametersWithDepth(params, 0)
  698. }
  699. func cleanFunctionParametersWithDepth(params interface{}, depth int) interface{} {
  700. if params == nil {
  701. return nil
  702. }
  703. if depth >= geminiFunctionSchemaMaxDepth {
  704. return cleanFunctionParametersShallow(params)
  705. }
  706. switch v := params.(type) {
  707. case map[string]interface{}:
  708. // Keep only Gemini-supported OpenAPI schema subset fields (per official SDK Schema).
  709. cleanedMap := make(map[string]interface{}, len(v))
  710. for k, val := range v {
  711. if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
  712. cleanedMap[k] = val
  713. }
  714. }
  715. normalizeGeminiSchemaTypeAndNullable(cleanedMap)
  716. // Clean properties
  717. if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
  718. cleanedProps := make(map[string]interface{})
  719. for propName, propValue := range props {
  720. cleanedProps[propName] = cleanFunctionParametersWithDepth(propValue, depth+1)
  721. }
  722. cleanedMap["properties"] = cleanedProps
  723. }
  724. // Recursively clean items in arrays
  725. if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
  726. cleanedMap["items"] = cleanFunctionParametersWithDepth(items, depth+1)
  727. }
  728. // OpenAPI tuple-style items is not supported by Gemini SDK Schema; keep first to avoid API rejection.
  729. if itemsArray, ok := cleanedMap["items"].([]interface{}); ok && len(itemsArray) > 0 {
  730. cleanedMap["items"] = cleanFunctionParametersWithDepth(itemsArray[0], depth+1)
  731. }
  732. // Recursively clean anyOf
  733. if nested, ok := cleanedMap["anyOf"].([]interface{}); ok && nested != nil {
  734. cleanedNested := make([]interface{}, len(nested))
  735. for i, item := range nested {
  736. cleanedNested[i] = cleanFunctionParametersWithDepth(item, depth+1)
  737. }
  738. cleanedMap["anyOf"] = cleanedNested
  739. }
  740. return cleanedMap
  741. case []interface{}:
  742. // Handle arrays of schemas
  743. cleanedArray := make([]interface{}, len(v))
  744. for i, item := range v {
  745. cleanedArray[i] = cleanFunctionParametersWithDepth(item, depth+1)
  746. }
  747. return cleanedArray
  748. default:
  749. // Not a map or array, return as is (e.g., could be a primitive)
  750. return params
  751. }
  752. }
  753. func cleanFunctionParametersShallow(params interface{}) interface{} {
  754. switch v := params.(type) {
  755. case map[string]interface{}:
  756. cleanedMap := make(map[string]interface{}, len(v))
  757. for k, val := range v {
  758. if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
  759. cleanedMap[k] = val
  760. }
  761. }
  762. normalizeGeminiSchemaTypeAndNullable(cleanedMap)
  763. // Stop recursion and avoid retaining huge nested structures.
  764. delete(cleanedMap, "properties")
  765. delete(cleanedMap, "items")
  766. delete(cleanedMap, "anyOf")
  767. return cleanedMap
  768. case []interface{}:
  769. // Prefer an empty list over deep recursion on attacker-controlled inputs.
  770. return []interface{}{}
  771. default:
  772. return params
  773. }
  774. }
  775. func normalizeGeminiSchemaTypeAndNullable(schema map[string]interface{}) {
  776. rawType, ok := schema["type"]
  777. if !ok || rawType == nil {
  778. return
  779. }
  780. normalize := func(t string) (string, bool) {
  781. switch strings.ToLower(strings.TrimSpace(t)) {
  782. case "object":
  783. return "OBJECT", false
  784. case "array":
  785. return "ARRAY", false
  786. case "string":
  787. return "STRING", false
  788. case "integer":
  789. return "INTEGER", false
  790. case "number":
  791. return "NUMBER", false
  792. case "boolean":
  793. return "BOOLEAN", false
  794. case "null":
  795. return "", true
  796. default:
  797. return t, false
  798. }
  799. }
  800. switch t := rawType.(type) {
  801. case string:
  802. normalized, isNull := normalize(t)
  803. if isNull {
  804. schema["nullable"] = true
  805. delete(schema, "type")
  806. return
  807. }
  808. schema["type"] = normalized
  809. case []interface{}:
  810. nullable := false
  811. var chosen string
  812. for _, item := range t {
  813. if s, ok := item.(string); ok {
  814. normalized, isNull := normalize(s)
  815. if isNull {
  816. nullable = true
  817. continue
  818. }
  819. if chosen == "" {
  820. chosen = normalized
  821. }
  822. }
  823. }
  824. if nullable {
  825. schema["nullable"] = true
  826. }
  827. if chosen != "" {
  828. schema["type"] = chosen
  829. } else {
  830. delete(schema, "type")
  831. }
  832. }
  833. }
  834. func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
  835. if depth >= 5 {
  836. return schema
  837. }
  838. v, ok := schema.(map[string]interface{})
  839. if !ok || len(v) == 0 {
  840. return schema
  841. }
  842. // 删除所有的title字段
  843. delete(v, "title")
  844. delete(v, "$schema")
  845. // 如果type不为object和array,则直接返回
  846. if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
  847. return schema
  848. }
  849. switch v["type"] {
  850. case "object":
  851. delete(v, "additionalProperties")
  852. // 处理 properties
  853. if properties, ok := v["properties"].(map[string]interface{}); ok {
  854. for key, value := range properties {
  855. properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
  856. }
  857. }
  858. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  859. if nested, ok := v[field].([]interface{}); ok {
  860. for i, item := range nested {
  861. nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
  862. }
  863. }
  864. }
  865. case "array":
  866. if items, ok := v["items"].(map[string]interface{}); ok {
  867. v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
  868. }
  869. }
  870. return v
  871. }
  872. func unescapeString(s string) (string, error) {
  873. var result []rune
  874. escaped := false
  875. i := 0
  876. for i < len(s) {
  877. r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
  878. if r == utf8.RuneError {
  879. return "", fmt.Errorf("invalid UTF-8 encoding")
  880. }
  881. if escaped {
  882. // 如果是转义符后的字符,检查其类型
  883. switch r {
  884. case '"':
  885. result = append(result, '"')
  886. case '\\':
  887. result = append(result, '\\')
  888. case '/':
  889. result = append(result, '/')
  890. case 'b':
  891. result = append(result, '\b')
  892. case 'f':
  893. result = append(result, '\f')
  894. case 'n':
  895. result = append(result, '\n')
  896. case 'r':
  897. result = append(result, '\r')
  898. case 't':
  899. result = append(result, '\t')
  900. case '\'':
  901. result = append(result, '\'')
  902. default:
  903. // 如果遇到一个非法的转义字符,直接按原样输出
  904. result = append(result, '\\', r)
  905. }
  906. escaped = false
  907. } else {
  908. if r == '\\' {
  909. escaped = true // 记录反斜杠作为转义符
  910. } else {
  911. result = append(result, r)
  912. }
  913. }
  914. i += size // 移动到下一个字符
  915. }
  916. return string(result), nil
  917. }
  918. func unescapeMapOrSlice(data interface{}) interface{} {
  919. switch v := data.(type) {
  920. case map[string]interface{}:
  921. for k, val := range v {
  922. v[k] = unescapeMapOrSlice(val)
  923. }
  924. case []interface{}:
  925. for i, val := range v {
  926. v[i] = unescapeMapOrSlice(val)
  927. }
  928. case string:
  929. if unescaped, err := unescapeString(v); err != nil {
  930. return v
  931. } else {
  932. return unescaped
  933. }
  934. }
  935. return data
  936. }
  937. func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
  938. var argsBytes []byte
  939. var err error
  940. // 移除 unescapeMapOrSlice 调用,直接使用 json.Marshal
  941. // JSON 序列化/反序列化已经正确处理了转义字符
  942. argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
  943. if err != nil {
  944. return nil
  945. }
  946. return &dto.ToolCallResponse{
  947. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  948. Type: "function",
  949. Function: dto.FunctionResponse{
  950. Arguments: string(argsBytes),
  951. Name: item.FunctionCall.FunctionName,
  952. },
  953. }
  954. }
  955. func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
  956. promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
  957. if promptTokens <= 0 && fallbackPromptTokens > 0 {
  958. promptTokens = fallbackPromptTokens
  959. }
  960. usage := dto.Usage{
  961. PromptTokens: promptTokens,
  962. CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
  963. TotalTokens: metadata.TotalTokenCount,
  964. }
  965. usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
  966. usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
  967. for _, detail := range metadata.PromptTokensDetails {
  968. if detail.Modality == "AUDIO" {
  969. usage.PromptTokensDetails.AudioTokens += detail.TokenCount
  970. } else if detail.Modality == "TEXT" {
  971. usage.PromptTokensDetails.TextTokens += detail.TokenCount
  972. }
  973. }
  974. for _, detail := range metadata.ToolUsePromptTokensDetails {
  975. if detail.Modality == "AUDIO" {
  976. usage.PromptTokensDetails.AudioTokens += detail.TokenCount
  977. } else if detail.Modality == "TEXT" {
  978. usage.PromptTokensDetails.TextTokens += detail.TokenCount
  979. }
  980. }
  981. if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
  982. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  983. }
  984. if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
  985. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  986. }
  987. return usage
  988. }
  989. func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
  990. fullTextResponse := dto.OpenAITextResponse{
  991. Id: helper.GetResponseID(c),
  992. Object: "chat.completion",
  993. Created: common.GetTimestamp(),
  994. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  995. }
  996. isToolCall := false
  997. for _, candidate := range response.Candidates {
  998. choice := dto.OpenAITextResponseChoice{
  999. Index: int(candidate.Index),
  1000. Message: dto.Message{
  1001. Role: "assistant",
  1002. Content: "",
  1003. },
  1004. FinishReason: constant.FinishReasonStop,
  1005. }
  1006. if len(candidate.Content.Parts) > 0 {
  1007. var texts []string
  1008. var toolCalls []dto.ToolCallResponse
  1009. for _, part := range candidate.Content.Parts {
  1010. if part.InlineData != nil {
  1011. // 媒体内容
  1012. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  1013. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  1014. texts = append(texts, imgText)
  1015. } else {
  1016. // 其他媒体类型,直接显示链接
  1017. texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
  1018. }
  1019. } else if part.FunctionCall != nil {
  1020. choice.FinishReason = constant.FinishReasonToolCalls
  1021. if call := getResponseToolCall(&part); call != nil {
  1022. toolCalls = append(toolCalls, *call)
  1023. }
  1024. } else if part.Thought {
  1025. choice.Message.ReasoningContent = part.Text
  1026. } else {
  1027. if part.ExecutableCode != nil {
  1028. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
  1029. } else if part.CodeExecutionResult != nil {
  1030. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
  1031. } else {
  1032. // 过滤掉空行
  1033. if part.Text != "\n" {
  1034. texts = append(texts, part.Text)
  1035. }
  1036. }
  1037. }
  1038. }
  1039. if len(toolCalls) > 0 {
  1040. choice.Message.SetToolCalls(toolCalls)
  1041. isToolCall = true
  1042. }
  1043. choice.Message.SetStringContent(strings.Join(texts, "\n"))
  1044. }
  1045. if candidate.FinishReason != nil {
  1046. switch *candidate.FinishReason {
  1047. case "STOP":
  1048. choice.FinishReason = constant.FinishReasonStop
  1049. case "MAX_TOKENS":
  1050. choice.FinishReason = constant.FinishReasonLength
  1051. case "SAFETY":
  1052. // Safety filter triggered
  1053. choice.FinishReason = constant.FinishReasonContentFilter
  1054. case "RECITATION":
  1055. // Recitation (citation) detected
  1056. choice.FinishReason = constant.FinishReasonContentFilter
  1057. case "BLOCKLIST":
  1058. // Blocklist triggered
  1059. choice.FinishReason = constant.FinishReasonContentFilter
  1060. case "PROHIBITED_CONTENT":
  1061. // Prohibited content detected
  1062. choice.FinishReason = constant.FinishReasonContentFilter
  1063. case "SPII":
  1064. // Sensitive personally identifiable information
  1065. choice.FinishReason = constant.FinishReasonContentFilter
  1066. case "OTHER":
  1067. // Other reasons
  1068. choice.FinishReason = constant.FinishReasonContentFilter
  1069. default:
  1070. choice.FinishReason = constant.FinishReasonContentFilter
  1071. }
  1072. }
  1073. if isToolCall {
  1074. choice.FinishReason = constant.FinishReasonToolCalls
  1075. }
  1076. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  1077. }
  1078. return &fullTextResponse
  1079. }
  1080. func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
  1081. choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
  1082. isStop := false
  1083. for _, candidate := range geminiResponse.Candidates {
  1084. if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
  1085. isStop = true
  1086. candidate.FinishReason = nil
  1087. }
  1088. choice := dto.ChatCompletionsStreamResponseChoice{
  1089. Index: int(candidate.Index),
  1090. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  1091. //Role: "assistant",
  1092. },
  1093. }
  1094. var texts []string
  1095. isTools := false
  1096. isThought := false
  1097. if candidate.FinishReason != nil {
  1098. // Map Gemini FinishReason to OpenAI finish_reason
  1099. switch *candidate.FinishReason {
  1100. case "STOP":
  1101. // Normal completion
  1102. choice.FinishReason = &constant.FinishReasonStop
  1103. case "MAX_TOKENS":
  1104. // Reached maximum token limit
  1105. choice.FinishReason = &constant.FinishReasonLength
  1106. case "SAFETY":
  1107. // Safety filter triggered
  1108. choice.FinishReason = &constant.FinishReasonContentFilter
  1109. case "RECITATION":
  1110. // Recitation (citation) detected
  1111. choice.FinishReason = &constant.FinishReasonContentFilter
  1112. case "BLOCKLIST":
  1113. // Blocklist triggered
  1114. choice.FinishReason = &constant.FinishReasonContentFilter
  1115. case "PROHIBITED_CONTENT":
  1116. // Prohibited content detected
  1117. choice.FinishReason = &constant.FinishReasonContentFilter
  1118. case "SPII":
  1119. // Sensitive personally identifiable information
  1120. choice.FinishReason = &constant.FinishReasonContentFilter
  1121. case "OTHER":
  1122. // Other reasons
  1123. choice.FinishReason = &constant.FinishReasonContentFilter
  1124. default:
  1125. // Unknown reason, treat as content filter
  1126. choice.FinishReason = &constant.FinishReasonContentFilter
  1127. }
  1128. }
  1129. for _, part := range candidate.Content.Parts {
  1130. if part.InlineData != nil {
  1131. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  1132. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  1133. texts = append(texts, imgText)
  1134. }
  1135. } else if part.FunctionCall != nil {
  1136. isTools = true
  1137. if call := getResponseToolCall(&part); call != nil {
  1138. call.SetIndex(len(choice.Delta.ToolCalls))
  1139. choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
  1140. }
  1141. } else if part.Thought {
  1142. isThought = true
  1143. texts = append(texts, part.Text)
  1144. } else {
  1145. if part.ExecutableCode != nil {
  1146. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
  1147. } else if part.CodeExecutionResult != nil {
  1148. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
  1149. } else {
  1150. if part.Text != "\n" {
  1151. texts = append(texts, part.Text)
  1152. }
  1153. }
  1154. }
  1155. }
  1156. if isThought {
  1157. choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
  1158. } else {
  1159. choice.Delta.SetContentString(strings.Join(texts, "\n"))
  1160. }
  1161. if isTools {
  1162. choice.FinishReason = &constant.FinishReasonToolCalls
  1163. }
  1164. choices = append(choices, choice)
  1165. }
  1166. var response dto.ChatCompletionsStreamResponse
  1167. response.Object = "chat.completion.chunk"
  1168. response.Choices = choices
  1169. return &response, isStop
  1170. }
  1171. func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  1172. streamData, err := common.Marshal(resp)
  1173. if err != nil {
  1174. return fmt.Errorf("failed to marshal stream response: %w", err)
  1175. }
  1176. err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  1177. if err != nil {
  1178. return fmt.Errorf("failed to handle stream format: %w", err)
  1179. }
  1180. return nil
  1181. }
  1182. func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  1183. streamData, err := common.Marshal(resp)
  1184. if err != nil {
  1185. return fmt.Errorf("failed to marshal stream response: %w", err)
  1186. }
  1187. openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
  1188. return nil
  1189. }
  1190. func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
  1191. var usage = &dto.Usage{}
  1192. var imageCount int
  1193. responseText := strings.Builder{}
  1194. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  1195. var geminiResponse dto.GeminiChatResponse
  1196. err := common.UnmarshalJsonStr(data, &geminiResponse)
  1197. if err != nil {
  1198. logger.LogError(c, "error unmarshalling stream response: "+err.Error())
  1199. return false
  1200. }
  1201. if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
  1202. common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
  1203. }
  1204. // 统计图片数量
  1205. for _, candidate := range geminiResponse.Candidates {
  1206. for _, part := range candidate.Content.Parts {
  1207. if part.InlineData != nil && part.InlineData.MimeType != "" {
  1208. imageCount++
  1209. }
  1210. if part.Text != "" {
  1211. responseText.WriteString(part.Text)
  1212. }
  1213. }
  1214. }
  1215. // 更新使用量统计
  1216. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  1217. mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
  1218. *usage = mappedUsage
  1219. }
  1220. return callback(data, &geminiResponse)
  1221. })
  1222. if imageCount != 0 {
  1223. if usage.CompletionTokens == 0 {
  1224. usage.CompletionTokens = imageCount * 1400
  1225. }
  1226. }
  1227. if usage.CompletionTokens <= 0 {
  1228. if info.ReceivedResponseCount > 0 {
  1229. usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
  1230. } else {
  1231. usage = &dto.Usage{}
  1232. }
  1233. }
  1234. return usage, nil
  1235. }
  1236. func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1237. id := helper.GetResponseID(c)
  1238. createAt := common.GetTimestamp()
  1239. finishReason := constant.FinishReasonStop
  1240. toolCallIndexByChoice := make(map[int]map[string]int)
  1241. nextToolCallIndexByChoice := make(map[int]int)
  1242. usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
  1243. response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
  1244. response.Id = id
  1245. response.Created = createAt
  1246. response.Model = info.UpstreamModelName
  1247. for choiceIdx := range response.Choices {
  1248. choiceKey := response.Choices[choiceIdx].Index
  1249. for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
  1250. tool := &response.Choices[choiceIdx].Delta.ToolCalls[toolIdx]
  1251. if tool.ID == "" {
  1252. continue
  1253. }
  1254. m := toolCallIndexByChoice[choiceKey]
  1255. if m == nil {
  1256. m = make(map[string]int)
  1257. toolCallIndexByChoice[choiceKey] = m
  1258. }
  1259. if idx, ok := m[tool.ID]; ok {
  1260. tool.SetIndex(idx)
  1261. continue
  1262. }
  1263. idx := nextToolCallIndexByChoice[choiceKey]
  1264. nextToolCallIndexByChoice[choiceKey] = idx + 1
  1265. m[tool.ID] = idx
  1266. tool.SetIndex(idx)
  1267. }
  1268. }
  1269. logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
  1270. if info.SendResponseCount == 0 {
  1271. // send first response
  1272. emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
  1273. if response.IsToolCall() {
  1274. if len(emptyResponse.Choices) > 0 && len(response.Choices) > 0 {
  1275. toolCalls := response.Choices[0].Delta.ToolCalls
  1276. copiedToolCalls := make([]dto.ToolCallResponse, len(toolCalls))
  1277. for idx := range toolCalls {
  1278. copiedToolCalls[idx] = toolCalls[idx]
  1279. copiedToolCalls[idx].Function.Arguments = ""
  1280. }
  1281. emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
  1282. }
  1283. finishReason = constant.FinishReasonToolCalls
  1284. err := handleStream(c, info, emptyResponse)
  1285. if err != nil {
  1286. logger.LogError(c, err.Error())
  1287. }
  1288. response.ClearToolCalls()
  1289. if response.IsFinished() {
  1290. response.Choices[0].FinishReason = nil
  1291. }
  1292. } else {
  1293. err := handleStream(c, info, emptyResponse)
  1294. if err != nil {
  1295. logger.LogError(c, err.Error())
  1296. }
  1297. }
  1298. }
  1299. err := handleStream(c, info, response)
  1300. if err != nil {
  1301. logger.LogError(c, err.Error())
  1302. }
  1303. if isStop {
  1304. _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
  1305. }
  1306. return true
  1307. })
  1308. if err != nil {
  1309. return usage, err
  1310. }
  1311. response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  1312. handleErr := handleFinalStream(c, info, response)
  1313. if handleErr != nil {
  1314. common.SysLog("send final response failed: " + handleErr.Error())
  1315. }
  1316. return usage, nil
  1317. }
  1318. func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1319. responseBody, err := io.ReadAll(resp.Body)
  1320. if err != nil {
  1321. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1322. }
  1323. service.CloseResponseBodyGracefully(resp)
  1324. if common.DebugEnabled {
  1325. println(string(responseBody))
  1326. }
  1327. var geminiResponse dto.GeminiChatResponse
  1328. err = common.Unmarshal(responseBody, &geminiResponse)
  1329. if err != nil {
  1330. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1331. }
  1332. if len(geminiResponse.Candidates) == 0 {
  1333. usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
  1334. var newAPIError *types.NewAPIError
  1335. if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
  1336. common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
  1337. newAPIError = types.NewOpenAIError(
  1338. errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason),
  1339. types.ErrorCodePromptBlocked,
  1340. http.StatusBadRequest,
  1341. )
  1342. } else {
  1343. common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "gemini_empty_candidates")
  1344. newAPIError = types.NewOpenAIError(
  1345. errors.New("empty response from Gemini API"),
  1346. types.ErrorCodeEmptyResponse,
  1347. http.StatusInternalServerError,
  1348. )
  1349. }
  1350. service.ResetStatusCode(newAPIError, c.GetString("status_code_mapping"))
  1351. switch info.RelayFormat {
  1352. case types.RelayFormatClaude:
  1353. c.JSON(newAPIError.StatusCode, gin.H{
  1354. "type": "error",
  1355. "error": newAPIError.ToClaudeError(),
  1356. })
  1357. default:
  1358. c.JSON(newAPIError.StatusCode, gin.H{
  1359. "error": newAPIError.ToOpenAIError(),
  1360. })
  1361. }
  1362. return &usage, nil
  1363. }
  1364. fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
  1365. fullTextResponse.Model = info.UpstreamModelName
  1366. usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
  1367. fullTextResponse.Usage = usage
  1368. switch info.RelayFormat {
  1369. case types.RelayFormatOpenAI:
  1370. responseBody, err = common.Marshal(fullTextResponse)
  1371. if err != nil {
  1372. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1373. }
  1374. case types.RelayFormatClaude:
  1375. claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
  1376. claudeRespStr, err := common.Marshal(claudeResp)
  1377. if err != nil {
  1378. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1379. }
  1380. responseBody = claudeRespStr
  1381. case types.RelayFormatGemini:
  1382. break
  1383. }
  1384. service.IOCopyBytesGracefully(c, resp, responseBody)
  1385. return &usage, nil
  1386. }
  1387. func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1388. defer service.CloseResponseBodyGracefully(resp)
  1389. responseBody, readErr := io.ReadAll(resp.Body)
  1390. if readErr != nil {
  1391. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1392. }
  1393. var geminiResponse dto.GeminiBatchEmbeddingResponse
  1394. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1395. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1396. }
  1397. // convert to openai format response
  1398. openAIResponse := dto.OpenAIEmbeddingResponse{
  1399. Object: "list",
  1400. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
  1401. Model: info.UpstreamModelName,
  1402. }
  1403. for i, embedding := range geminiResponse.Embeddings {
  1404. openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
  1405. Object: "embedding",
  1406. Embedding: embedding.Values,
  1407. Index: i,
  1408. })
  1409. }
  1410. // calculate usage
  1411. // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
  1412. // Google has not yet clarified how embedding models will be billed
  1413. // refer to openai billing method to use input tokens billing
  1414. // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
  1415. usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
  1416. openAIResponse.Usage = *usage
  1417. jsonResponse, jsonErr := common.Marshal(openAIResponse)
  1418. if jsonErr != nil {
  1419. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1420. }
  1421. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  1422. return usage, nil
  1423. }
  1424. func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1425. responseBody, readErr := io.ReadAll(resp.Body)
  1426. if readErr != nil {
  1427. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1428. }
  1429. _ = resp.Body.Close()
  1430. var geminiResponse dto.GeminiImageResponse
  1431. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1432. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1433. }
  1434. if len(geminiResponse.Predictions) == 0 {
  1435. return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1436. }
  1437. // convert to openai format response
  1438. openAIResponse := dto.ImageResponse{
  1439. Created: common.GetTimestamp(),
  1440. Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
  1441. }
  1442. for _, prediction := range geminiResponse.Predictions {
  1443. if prediction.RaiFilteredReason != "" {
  1444. continue // skip filtered image
  1445. }
  1446. openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
  1447. B64Json: prediction.BytesBase64Encoded,
  1448. })
  1449. }
  1450. jsonResponse, jsonErr := json.Marshal(openAIResponse)
  1451. if jsonErr != nil {
  1452. return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
  1453. }
  1454. c.Writer.Header().Set("Content-Type", "application/json")
  1455. c.Writer.WriteHeader(resp.StatusCode)
  1456. _, _ = c.Writer.Write(jsonResponse)
  1457. // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
  1458. // each image has fixed 258 tokens
  1459. const imageTokens = 258
  1460. generatedImages := len(openAIResponse.Data)
  1461. usage := &dto.Usage{
  1462. PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
  1463. CompletionTokens: 0, // image generation does not calculate completion tokens
  1464. TotalTokens: imageTokens * generatedImages,
  1465. }
  1466. return usage, nil
  1467. }
  1468. type GeminiModelsResponse struct {
  1469. Models []dto.GeminiModel `json:"models"`
  1470. NextPageToken string `json:"nextPageToken"`
  1471. }
  1472. func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
  1473. client, err := service.GetHttpClientWithProxy(proxyURL)
  1474. if err != nil {
  1475. return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
  1476. }
  1477. allModels := make([]string, 0)
  1478. nextPageToken := ""
  1479. maxPages := 100 // Safety limit to prevent infinite loops
  1480. for page := 0; page < maxPages; page++ {
  1481. url := fmt.Sprintf("%s/v1beta/models", baseURL)
  1482. if nextPageToken != "" {
  1483. url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
  1484. }
  1485. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  1486. request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  1487. if err != nil {
  1488. cancel()
  1489. return nil, fmt.Errorf("创建请求失败: %v", err)
  1490. }
  1491. request.Header.Set("x-goog-api-key", apiKey)
  1492. response, err := client.Do(request)
  1493. if err != nil {
  1494. cancel()
  1495. return nil, fmt.Errorf("请求失败: %v", err)
  1496. }
  1497. if response.StatusCode != http.StatusOK {
  1498. body, _ := io.ReadAll(response.Body)
  1499. response.Body.Close()
  1500. cancel()
  1501. return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
  1502. }
  1503. body, err := io.ReadAll(response.Body)
  1504. response.Body.Close()
  1505. cancel()
  1506. if err != nil {
  1507. return nil, fmt.Errorf("读取响应失败: %v", err)
  1508. }
  1509. var modelsResponse GeminiModelsResponse
  1510. if err = common.Unmarshal(body, &modelsResponse); err != nil {
  1511. return nil, fmt.Errorf("解析响应失败: %v", err)
  1512. }
  1513. for _, model := range modelsResponse.Models {
  1514. modelNameValue, ok := model.Name.(string)
  1515. if !ok {
  1516. continue
  1517. }
  1518. modelName := strings.TrimPrefix(modelNameValue, "models/")
  1519. allModels = append(allModels, modelName)
  1520. }
  1521. nextPageToken = modelsResponse.NextPageToken
  1522. if nextPageToken == "" {
  1523. break
  1524. }
  1525. }
  1526. return allModels, nil
  1527. }
  1528. // convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
  1529. // OpenAI tool_choice values:
  1530. // - "auto": Let the model decide (default)
  1531. // - "none": Don't call any tools
  1532. // - "required": Must call at least one tool
  1533. // - {"type": "function", "function": {"name": "xxx"}}: Call specific function
  1534. //
  1535. // Gemini functionCallingConfig.mode values:
  1536. // - "AUTO": Model decides whether to call functions
  1537. // - "NONE": Model won't call functions
  1538. // - "ANY": Model must call at least one function
  1539. func convertToolChoiceToGeminiConfig(toolChoice any) *dto.ToolConfig {
  1540. if toolChoice == nil {
  1541. return nil
  1542. }
  1543. // Handle string values: "auto", "none", "required"
  1544. if toolChoiceStr, ok := toolChoice.(string); ok {
  1545. config := &dto.ToolConfig{
  1546. FunctionCallingConfig: &dto.FunctionCallingConfig{},
  1547. }
  1548. switch toolChoiceStr {
  1549. case "auto":
  1550. config.FunctionCallingConfig.Mode = "AUTO"
  1551. case "none":
  1552. config.FunctionCallingConfig.Mode = "NONE"
  1553. case "required":
  1554. config.FunctionCallingConfig.Mode = "ANY"
  1555. default:
  1556. // Unknown string value, default to AUTO
  1557. config.FunctionCallingConfig.Mode = "AUTO"
  1558. }
  1559. return config
  1560. }
  1561. // Handle object value: {"type": "function", "function": {"name": "xxx"}}
  1562. if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
  1563. if toolChoiceMap["type"] == "function" {
  1564. config := &dto.ToolConfig{
  1565. FunctionCallingConfig: &dto.FunctionCallingConfig{
  1566. Mode: "ANY",
  1567. },
  1568. }
  1569. // Extract function name if specified
  1570. if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
  1571. if name, ok := function["name"].(string); ok && name != "" {
  1572. config.FunctionCallingConfig.AllowedFunctionNames = []string{name}
  1573. }
  1574. }
  1575. return config
  1576. }
  1577. // Unsupported map structure (type is not "function"), return nil
  1578. return nil
  1579. }
  1580. // Unsupported type, return nil
  1581. return nil
  1582. }