relay-gemini.go 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441
  1. package gemini
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "unicode/utf8"
  13. "github.com/QuantumNous/new-api/common"
  14. "github.com/QuantumNous/new-api/constant"
  15. "github.com/QuantumNous/new-api/dto"
  16. "github.com/QuantumNous/new-api/logger"
  17. "github.com/QuantumNous/new-api/relay/channel/openai"
  18. relaycommon "github.com/QuantumNous/new-api/relay/common"
  19. "github.com/QuantumNous/new-api/relay/helper"
  20. "github.com/QuantumNous/new-api/service"
  21. "github.com/QuantumNous/new-api/setting/model_setting"
  22. "github.com/QuantumNous/new-api/setting/reasoning"
  23. "github.com/QuantumNous/new-api/types"
  24. "github.com/gin-gonic/gin"
  25. )
  26. // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
  27. var geminiSupportedMimeTypes = map[string]bool{
  28. "application/pdf": true,
  29. "audio/mpeg": true,
  30. "audio/mp3": true,
  31. "audio/wav": true,
  32. "image/png": true,
  33. "image/jpeg": true,
  34. "image/jpg": true, // support old image/jpeg
  35. "image/webp": true,
  36. "text/plain": true,
  37. "video/mov": true,
  38. "video/mpeg": true,
  39. "video/mp4": true,
  40. "video/mpg": true,
  41. "video/avi": true,
  42. "video/wmv": true,
  43. "video/mpegps": true,
  44. "video/flv": true,
  45. }
  46. const thoughtSignatureBypassValue = "context_engineering_is_the_way_to_go"
  47. // Gemini 允许的思考预算范围
  48. const (
  49. pro25MinBudget = 128
  50. pro25MaxBudget = 32768
  51. flash25MaxBudget = 24576
  52. flash25LiteMinBudget = 512
  53. flash25LiteMaxBudget = 24576
  54. )
  55. func isNew25ProModel(modelName string) bool {
  56. return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  57. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  58. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  59. }
  60. func is25FlashLiteModel(modelName string) bool {
  61. return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
  62. }
  63. // clampThinkingBudget 根据模型名称将预算限制在允许的范围内
  64. func clampThinkingBudget(modelName string, budget int) int {
  65. isNew25Pro := isNew25ProModel(modelName)
  66. is25FlashLite := is25FlashLiteModel(modelName)
  67. if is25FlashLite {
  68. if budget < flash25LiteMinBudget {
  69. return flash25LiteMinBudget
  70. }
  71. if budget > flash25LiteMaxBudget {
  72. return flash25LiteMaxBudget
  73. }
  74. } else if isNew25Pro {
  75. if budget < pro25MinBudget {
  76. return pro25MinBudget
  77. }
  78. if budget > pro25MaxBudget {
  79. return pro25MaxBudget
  80. }
  81. } else { // 其他模型
  82. if budget < 0 {
  83. return 0
  84. }
  85. if budget > flash25MaxBudget {
  86. return flash25MaxBudget
  87. }
  88. }
  89. return budget
  90. }
  91. // "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
  92. // "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
  93. // "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
  94. // "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
  95. func clampThinkingBudgetByEffort(modelName string, effort string) int {
  96. isNew25Pro := isNew25ProModel(modelName)
  97. is25FlashLite := is25FlashLiteModel(modelName)
  98. maxBudget := 0
  99. if is25FlashLite {
  100. maxBudget = flash25LiteMaxBudget
  101. }
  102. if isNew25Pro {
  103. maxBudget = pro25MaxBudget
  104. } else {
  105. maxBudget = flash25MaxBudget
  106. }
  107. switch effort {
  108. case "high":
  109. maxBudget = maxBudget * 80 / 100
  110. case "medium":
  111. maxBudget = maxBudget * 50 / 100
  112. case "low":
  113. maxBudget = maxBudget * 20 / 100
  114. case "minimal":
  115. maxBudget = maxBudget * 5 / 100
  116. }
  117. return clampThinkingBudget(modelName, maxBudget)
  118. }
  119. func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
  120. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
  121. modelName := info.UpstreamModelName
  122. isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  123. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  124. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  125. if strings.Contains(modelName, "-thinking-") {
  126. parts := strings.SplitN(modelName, "-thinking-", 2)
  127. if len(parts) == 2 && parts[1] != "" {
  128. if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
  129. clampedBudget := clampThinkingBudget(modelName, budgetTokens)
  130. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  131. ThinkingBudget: common.GetPointer(clampedBudget),
  132. IncludeThoughts: true,
  133. }
  134. }
  135. }
  136. } else if strings.HasSuffix(modelName, "-thinking") {
  137. unsupportedModels := []string{
  138. "gemini-2.5-pro-preview-05-06",
  139. "gemini-2.5-pro-preview-03-25",
  140. }
  141. isUnsupported := false
  142. for _, unsupportedModel := range unsupportedModels {
  143. if strings.HasPrefix(modelName, unsupportedModel) {
  144. isUnsupported = true
  145. break
  146. }
  147. }
  148. if isUnsupported {
  149. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  150. IncludeThoughts: true,
  151. }
  152. } else {
  153. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  154. IncludeThoughts: true,
  155. }
  156. if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
  157. budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
  158. clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
  159. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
  160. } else {
  161. if len(oaiRequest) > 0 {
  162. // 如果有reasoningEffort参数,则根据其值设置思考预算
  163. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
  164. }
  165. }
  166. }
  167. } else if strings.HasSuffix(modelName, "-nothinking") {
  168. if !isNew25Pro {
  169. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  170. ThinkingBudget: common.GetPointer(0),
  171. }
  172. }
  173. } else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
  174. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  175. IncludeThoughts: true,
  176. ThinkingLevel: level,
  177. }
  178. info.ReasoningEffort = level
  179. }
  180. }
  181. }
  182. // Setting safety to the lowest possible values since Gemini is already powerless enough
  183. func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
  184. geminiRequest := dto.GeminiChatRequest{
  185. Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
  186. GenerationConfig: dto.GeminiChatGenerationConfig{
  187. Temperature: textRequest.Temperature,
  188. TopP: textRequest.TopP,
  189. MaxOutputTokens: textRequest.GetMaxTokens(),
  190. Seed: int64(textRequest.Seed),
  191. },
  192. }
  193. attachThoughtSignature := (info.ChannelType == constant.ChannelTypeGemini ||
  194. info.ChannelType == constant.ChannelTypeVertexAi) &&
  195. model_setting.GetGeminiSettings().FunctionCallThoughtSignatureEnabled
  196. if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
  197. geminiRequest.GenerationConfig.ResponseModalities = []string{
  198. "TEXT",
  199. "IMAGE",
  200. }
  201. }
  202. adaptorWithExtraBody := false
  203. // patch extra_body
  204. if len(textRequest.ExtraBody) > 0 {
  205. if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  206. var extraBody map[string]interface{}
  207. if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
  208. return nil, fmt.Errorf("invalid extra body: %w", err)
  209. }
  210. // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
  211. if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
  212. adaptorWithExtraBody = true
  213. // check error param name like thinkingConfig, should be thinking_config
  214. if _, hasErrorParam := googleBody["thinkingConfig"]; hasErrorParam {
  215. return nil, errors.New("extra_body.google.thinkingConfig is not supported, use extra_body.google.thinking_config instead")
  216. }
  217. if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
  218. // check error param name like thinkingBudget, should be thinking_budget
  219. if _, hasErrorParam := thinkingConfig["thinkingBudget"]; hasErrorParam {
  220. return nil, errors.New("extra_body.google.thinking_config.thinkingBudget is not supported, use extra_body.google.thinking_config.thinking_budget instead")
  221. }
  222. if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
  223. budgetInt := int(budget)
  224. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  225. ThinkingBudget: common.GetPointer(budgetInt),
  226. IncludeThoughts: true,
  227. }
  228. } else {
  229. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  230. IncludeThoughts: true,
  231. }
  232. }
  233. }
  234. // check error param name like imageConfig, should be image_config
  235. if _, hasErrorParam := googleBody["imageConfig"]; hasErrorParam {
  236. return nil, errors.New("extra_body.google.imageConfig is not supported, use extra_body.google.image_config instead")
  237. }
  238. if imageConfig, ok := googleBody["image_config"].(map[string]interface{}); ok {
  239. // check error param name like aspectRatio, should be aspect_ratio
  240. if _, hasErrorParam := imageConfig["aspectRatio"]; hasErrorParam {
  241. return nil, errors.New("extra_body.google.image_config.aspectRatio is not supported, use extra_body.google.image_config.aspect_ratio instead")
  242. }
  243. // check error param name like imageSize, should be image_size
  244. if _, hasErrorParam := imageConfig["imageSize"]; hasErrorParam {
  245. return nil, errors.New("extra_body.google.image_config.imageSize is not supported, use extra_body.google.image_config.image_size instead")
  246. }
  247. // convert snake_case to camelCase for Gemini API
  248. geminiImageConfig := make(map[string]interface{})
  249. if aspectRatio, ok := imageConfig["aspect_ratio"]; ok {
  250. geminiImageConfig["aspectRatio"] = aspectRatio
  251. }
  252. if imageSize, ok := imageConfig["image_size"]; ok {
  253. geminiImageConfig["imageSize"] = imageSize
  254. }
  255. if len(geminiImageConfig) > 0 {
  256. imageConfigBytes, err := common.Marshal(geminiImageConfig)
  257. if err != nil {
  258. return nil, fmt.Errorf("failed to marshal image_config: %w", err)
  259. }
  260. geminiRequest.GenerationConfig.ImageConfig = imageConfigBytes
  261. }
  262. }
  263. }
  264. }
  265. }
  266. if !adaptorWithExtraBody {
  267. ThinkingAdaptor(&geminiRequest, info, textRequest)
  268. }
  269. safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
  270. for _, category := range SafetySettingList {
  271. safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
  272. Category: category,
  273. Threshold: model_setting.GetGeminiSafetySetting(category),
  274. })
  275. }
  276. geminiRequest.SafetySettings = safetySettings
  277. // openaiContent.FuncToToolCalls()
  278. if textRequest.Tools != nil {
  279. functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
  280. googleSearch := false
  281. codeExecution := false
  282. urlContext := false
  283. for _, tool := range textRequest.Tools {
  284. if tool.Function.Name == "googleSearch" {
  285. googleSearch = true
  286. continue
  287. }
  288. if tool.Function.Name == "codeExecution" {
  289. codeExecution = true
  290. continue
  291. }
  292. if tool.Function.Name == "urlContext" {
  293. urlContext = true
  294. continue
  295. }
  296. if tool.Function.Parameters != nil {
  297. params, ok := tool.Function.Parameters.(map[string]interface{})
  298. if ok {
  299. if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
  300. if len(props) == 0 {
  301. tool.Function.Parameters = nil
  302. }
  303. }
  304. }
  305. }
  306. // Clean the parameters before appending
  307. cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
  308. tool.Function.Parameters = cleanedParams
  309. functions = append(functions, tool.Function)
  310. }
  311. geminiTools := geminiRequest.GetTools()
  312. if codeExecution {
  313. geminiTools = append(geminiTools, dto.GeminiChatTool{
  314. CodeExecution: make(map[string]string),
  315. })
  316. }
  317. if googleSearch {
  318. geminiTools = append(geminiTools, dto.GeminiChatTool{
  319. GoogleSearch: make(map[string]string),
  320. })
  321. }
  322. if urlContext {
  323. geminiTools = append(geminiTools, dto.GeminiChatTool{
  324. URLContext: make(map[string]string),
  325. })
  326. }
  327. if len(functions) > 0 {
  328. geminiTools = append(geminiTools, dto.GeminiChatTool{
  329. FunctionDeclarations: functions,
  330. })
  331. }
  332. geminiRequest.SetTools(geminiTools)
  333. }
  334. if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
  335. geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
  336. if len(textRequest.ResponseFormat.JsonSchema) > 0 {
  337. // 先将json.RawMessage解析
  338. var jsonSchema dto.FormatJsonSchema
  339. if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
  340. cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
  341. geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
  342. }
  343. }
  344. }
  345. tool_call_ids := make(map[string]string)
  346. var system_content []string
  347. //shouldAddDummyModelMessage := false
  348. for _, message := range textRequest.Messages {
  349. if message.Role == "system" || message.Role == "developer" {
  350. system_content = append(system_content, message.StringContent())
  351. continue
  352. } else if message.Role == "tool" || message.Role == "function" {
  353. if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
  354. geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
  355. Role: "user",
  356. })
  357. }
  358. var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
  359. name := ""
  360. if message.Name != nil {
  361. name = *message.Name
  362. } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
  363. name = val
  364. }
  365. var contentMap map[string]interface{}
  366. contentStr := message.StringContent()
  367. // 1. 尝试解析为 JSON 对象
  368. if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
  369. // 2. 如果失败,尝试解析为 JSON 数组
  370. var contentSlice []interface{}
  371. if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
  372. // 如果是数组,包装成对象
  373. contentMap = map[string]interface{}{"result": contentSlice}
  374. } else {
  375. // 3. 如果再次失败,作为纯文本处理
  376. contentMap = map[string]interface{}{"content": contentStr}
  377. }
  378. }
  379. functionResp := &dto.GeminiFunctionResponse{
  380. Name: name,
  381. Response: contentMap,
  382. }
  383. *parts = append(*parts, dto.GeminiPart{
  384. FunctionResponse: functionResp,
  385. })
  386. continue
  387. }
  388. var parts []dto.GeminiPart
  389. content := dto.GeminiChatContent{
  390. Role: message.Role,
  391. }
  392. shouldAttachThoughtSignature := attachThoughtSignature && (message.Role == "assistant" || message.Role == "model")
  393. signatureAttached := false
  394. // isToolCall := false
  395. if message.ToolCalls != nil {
  396. // message.Role = "model"
  397. // isToolCall = true
  398. for _, call := range message.ParseToolCalls() {
  399. args := map[string]interface{}{}
  400. if call.Function.Arguments != "" {
  401. if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
  402. return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
  403. }
  404. }
  405. toolCall := dto.GeminiPart{
  406. FunctionCall: &dto.FunctionCall{
  407. FunctionName: call.Function.Name,
  408. Arguments: args,
  409. },
  410. }
  411. if shouldAttachThoughtSignature && !signatureAttached && hasFunctionCallContent(toolCall.FunctionCall) && len(toolCall.ThoughtSignature) == 0 {
  412. toolCall.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  413. signatureAttached = true
  414. }
  415. parts = append(parts, toolCall)
  416. tool_call_ids[call.ID] = call.Function.Name
  417. }
  418. }
  419. openaiContent := message.ParseContent()
  420. imageNum := 0
  421. for _, part := range openaiContent {
  422. if part.Type == dto.ContentTypeText {
  423. if part.Text == "" {
  424. continue
  425. }
  426. // check markdown image ![image](data:image/jpeg;base64,xxxxxxxxxxxx)
  427. // 使用字符串查找而非正则,避免大文本性能问题
  428. text := part.Text
  429. hasMarkdownImage := false
  430. for {
  431. // 快速检查是否包含 markdown 图片标记
  432. startIdx := strings.Index(text, "![")
  433. if startIdx == -1 {
  434. break
  435. }
  436. // 找到 ](
  437. bracketIdx := strings.Index(text[startIdx:], "](data:")
  438. if bracketIdx == -1 {
  439. break
  440. }
  441. bracketIdx += startIdx
  442. // 找到闭合的 )
  443. closeIdx := strings.Index(text[bracketIdx+2:], ")")
  444. if closeIdx == -1 {
  445. break
  446. }
  447. closeIdx += bracketIdx + 2
  448. hasMarkdownImage = true
  449. // 添加图片前的文本
  450. if startIdx > 0 {
  451. textBefore := text[:startIdx]
  452. if textBefore != "" {
  453. parts = append(parts, dto.GeminiPart{
  454. Text: textBefore,
  455. })
  456. }
  457. }
  458. // 提取 data URL (从 "](" 后面开始,到 ")" 之前)
  459. dataUrl := text[bracketIdx+2 : closeIdx]
  460. imageNum += 1
  461. if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
  462. return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
  463. }
  464. format, base64String, err := service.DecodeBase64FileData(dataUrl)
  465. if err != nil {
  466. return nil, fmt.Errorf("decode markdown base64 image data failed: %s", err.Error())
  467. }
  468. imgPart := dto.GeminiPart{
  469. InlineData: &dto.GeminiInlineData{
  470. MimeType: format,
  471. Data: base64String,
  472. },
  473. }
  474. if shouldAttachThoughtSignature {
  475. imgPart.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  476. }
  477. parts = append(parts, imgPart)
  478. // 继续处理剩余文本
  479. text = text[closeIdx+1:]
  480. }
  481. // 添加剩余文本或原始文本(如果没有找到 markdown 图片)
  482. if !hasMarkdownImage {
  483. parts = append(parts, dto.GeminiPart{
  484. Text: part.Text,
  485. })
  486. }
  487. } else if part.Type == dto.ContentTypeImageURL {
  488. imageNum += 1
  489. if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
  490. return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
  491. }
  492. // 判断是否是url
  493. if strings.HasPrefix(part.GetImageMedia().Url, "http") {
  494. // 是url,获取文件的类型和base64编码的数据
  495. fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
  496. if err != nil {
  497. return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
  498. }
  499. // 校验 MimeType 是否在 Gemini 支持的白名单中
  500. if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
  501. url := part.GetImageMedia().Url
  502. return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
  503. }
  504. parts = append(parts, dto.GeminiPart{
  505. InlineData: &dto.GeminiInlineData{
  506. MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
  507. Data: fileData.Base64Data,
  508. },
  509. })
  510. } else {
  511. format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
  512. if err != nil {
  513. return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
  514. }
  515. parts = append(parts, dto.GeminiPart{
  516. InlineData: &dto.GeminiInlineData{
  517. MimeType: format,
  518. Data: base64String,
  519. },
  520. })
  521. }
  522. } else if part.Type == dto.ContentTypeFile {
  523. if part.GetFile().FileId != "" {
  524. return nil, fmt.Errorf("only base64 file is supported in gemini")
  525. }
  526. format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
  527. if err != nil {
  528. return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
  529. }
  530. parts = append(parts, dto.GeminiPart{
  531. InlineData: &dto.GeminiInlineData{
  532. MimeType: format,
  533. Data: base64String,
  534. },
  535. })
  536. } else if part.Type == dto.ContentTypeInputAudio {
  537. if part.GetInputAudio().Data == "" {
  538. return nil, fmt.Errorf("only base64 audio is supported in gemini")
  539. }
  540. base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
  541. if err != nil {
  542. return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
  543. }
  544. parts = append(parts, dto.GeminiPart{
  545. InlineData: &dto.GeminiInlineData{
  546. MimeType: "audio/" + part.GetInputAudio().Format,
  547. Data: base64String,
  548. },
  549. })
  550. }
  551. }
  552. // 如果需要附加签名但还没有附加(没有 tool_calls 或 tool_calls 为空),
  553. // 则在第一个文本 part 上附加 thoughtSignature
  554. if shouldAttachThoughtSignature && !signatureAttached && len(parts) > 0 {
  555. for i := range parts {
  556. if parts[i].Text != "" {
  557. parts[i].ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  558. break
  559. }
  560. }
  561. }
  562. content.Parts = parts
  563. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  564. if content.Role == "assistant" {
  565. content.Role = "model"
  566. }
  567. if len(content.Parts) > 0 {
  568. geminiRequest.Contents = append(geminiRequest.Contents, content)
  569. }
  570. }
  571. if len(system_content) > 0 {
  572. geminiRequest.SystemInstructions = &dto.GeminiChatContent{
  573. Parts: []dto.GeminiPart{
  574. {
  575. Text: strings.Join(system_content, "\n"),
  576. },
  577. },
  578. }
  579. }
  580. return &geminiRequest, nil
  581. }
  582. func hasFunctionCallContent(call *dto.FunctionCall) bool {
  583. if call == nil {
  584. return false
  585. }
  586. if strings.TrimSpace(call.FunctionName) != "" {
  587. return true
  588. }
  589. switch v := call.Arguments.(type) {
  590. case nil:
  591. return false
  592. case string:
  593. return strings.TrimSpace(v) != ""
  594. case map[string]interface{}:
  595. return len(v) > 0
  596. case []interface{}:
  597. return len(v) > 0
  598. default:
  599. return true
  600. }
  601. }
  602. // Helper function to get a list of supported MIME types for error messages
  603. func getSupportedMimeTypesList() []string {
  604. keys := make([]string, 0, len(geminiSupportedMimeTypes))
  605. for k := range geminiSupportedMimeTypes {
  606. keys = append(keys, k)
  607. }
  608. return keys
  609. }
  610. // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
  611. func cleanFunctionParameters(params interface{}) interface{} {
  612. if params == nil {
  613. return nil
  614. }
  615. switch v := params.(type) {
  616. case map[string]interface{}:
  617. // Create a copy to avoid modifying the original
  618. cleanedMap := make(map[string]interface{})
  619. for k, val := range v {
  620. cleanedMap[k] = val
  621. }
  622. // Remove unsupported root-level fields
  623. delete(cleanedMap, "default")
  624. delete(cleanedMap, "exclusiveMaximum")
  625. delete(cleanedMap, "exclusiveMinimum")
  626. delete(cleanedMap, "$schema")
  627. delete(cleanedMap, "additionalProperties")
  628. delete(cleanedMap, "propertyNames")
  629. // Check and clean 'format' for string types
  630. if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
  631. if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
  632. if formatValue != "enum" && formatValue != "date-time" {
  633. delete(cleanedMap, "format")
  634. }
  635. }
  636. }
  637. // Clean properties
  638. if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
  639. cleanedProps := make(map[string]interface{})
  640. for propName, propValue := range props {
  641. cleanedProps[propName] = cleanFunctionParameters(propValue)
  642. }
  643. cleanedMap["properties"] = cleanedProps
  644. }
  645. // Recursively clean items in arrays
  646. if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
  647. cleanedMap["items"] = cleanFunctionParameters(items)
  648. }
  649. // Also handle items if it's an array of schemas
  650. if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
  651. cleanedItemsArray := make([]interface{}, len(itemsArray))
  652. for i, item := range itemsArray {
  653. cleanedItemsArray[i] = cleanFunctionParameters(item)
  654. }
  655. cleanedMap["items"] = cleanedItemsArray
  656. }
  657. // Recursively clean other schema composition keywords
  658. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  659. if nested, ok := cleanedMap[field].([]interface{}); ok {
  660. cleanedNested := make([]interface{}, len(nested))
  661. for i, item := range nested {
  662. cleanedNested[i] = cleanFunctionParameters(item)
  663. }
  664. cleanedMap[field] = cleanedNested
  665. }
  666. }
  667. // Recursively clean patternProperties
  668. if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
  669. cleanedPatternProps := make(map[string]interface{})
  670. for pattern, schema := range patternProps {
  671. cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
  672. }
  673. cleanedMap["patternProperties"] = cleanedPatternProps
  674. }
  675. // Recursively clean definitions
  676. if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
  677. cleanedDefinitions := make(map[string]interface{})
  678. for defName, defSchema := range definitions {
  679. cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
  680. }
  681. cleanedMap["definitions"] = cleanedDefinitions
  682. }
  683. // Recursively clean $defs (newer JSON Schema draft)
  684. if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
  685. cleanedDefs := make(map[string]interface{})
  686. for defName, defSchema := range defs {
  687. cleanedDefs[defName] = cleanFunctionParameters(defSchema)
  688. }
  689. cleanedMap["$defs"] = cleanedDefs
  690. }
  691. // Clean conditional keywords
  692. for _, field := range []string{"if", "then", "else", "not"} {
  693. if nested, ok := cleanedMap[field]; ok {
  694. cleanedMap[field] = cleanFunctionParameters(nested)
  695. }
  696. }
  697. return cleanedMap
  698. case []interface{}:
  699. // Handle arrays of schemas
  700. cleanedArray := make([]interface{}, len(v))
  701. for i, item := range v {
  702. cleanedArray[i] = cleanFunctionParameters(item)
  703. }
  704. return cleanedArray
  705. default:
  706. // Not a map or array, return as is (e.g., could be a primitive)
  707. return params
  708. }
  709. }
  710. func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
  711. if depth >= 5 {
  712. return schema
  713. }
  714. v, ok := schema.(map[string]interface{})
  715. if !ok || len(v) == 0 {
  716. return schema
  717. }
  718. // 删除所有的title字段
  719. delete(v, "title")
  720. delete(v, "$schema")
  721. // 如果type不为object和array,则直接返回
  722. if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
  723. return schema
  724. }
  725. switch v["type"] {
  726. case "object":
  727. delete(v, "additionalProperties")
  728. // 处理 properties
  729. if properties, ok := v["properties"].(map[string]interface{}); ok {
  730. for key, value := range properties {
  731. properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
  732. }
  733. }
  734. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  735. if nested, ok := v[field].([]interface{}); ok {
  736. for i, item := range nested {
  737. nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
  738. }
  739. }
  740. }
  741. case "array":
  742. if items, ok := v["items"].(map[string]interface{}); ok {
  743. v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
  744. }
  745. }
  746. return v
  747. }
  748. func unescapeString(s string) (string, error) {
  749. var result []rune
  750. escaped := false
  751. i := 0
  752. for i < len(s) {
  753. r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
  754. if r == utf8.RuneError {
  755. return "", fmt.Errorf("invalid UTF-8 encoding")
  756. }
  757. if escaped {
  758. // 如果是转义符后的字符,检查其类型
  759. switch r {
  760. case '"':
  761. result = append(result, '"')
  762. case '\\':
  763. result = append(result, '\\')
  764. case '/':
  765. result = append(result, '/')
  766. case 'b':
  767. result = append(result, '\b')
  768. case 'f':
  769. result = append(result, '\f')
  770. case 'n':
  771. result = append(result, '\n')
  772. case 'r':
  773. result = append(result, '\r')
  774. case 't':
  775. result = append(result, '\t')
  776. case '\'':
  777. result = append(result, '\'')
  778. default:
  779. // 如果遇到一个非法的转义字符,直接按原样输出
  780. result = append(result, '\\', r)
  781. }
  782. escaped = false
  783. } else {
  784. if r == '\\' {
  785. escaped = true // 记录反斜杠作为转义符
  786. } else {
  787. result = append(result, r)
  788. }
  789. }
  790. i += size // 移动到下一个字符
  791. }
  792. return string(result), nil
  793. }
  794. func unescapeMapOrSlice(data interface{}) interface{} {
  795. switch v := data.(type) {
  796. case map[string]interface{}:
  797. for k, val := range v {
  798. v[k] = unescapeMapOrSlice(val)
  799. }
  800. case []interface{}:
  801. for i, val := range v {
  802. v[i] = unescapeMapOrSlice(val)
  803. }
  804. case string:
  805. if unescaped, err := unescapeString(v); err != nil {
  806. return v
  807. } else {
  808. return unescaped
  809. }
  810. }
  811. return data
  812. }
  813. func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
  814. var argsBytes []byte
  815. var err error
  816. if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
  817. argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
  818. } else {
  819. argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
  820. }
  821. if err != nil {
  822. return nil
  823. }
  824. return &dto.ToolCallResponse{
  825. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  826. Type: "function",
  827. Function: dto.FunctionResponse{
  828. Arguments: string(argsBytes),
  829. Name: item.FunctionCall.FunctionName,
  830. },
  831. }
  832. }
  833. func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
  834. fullTextResponse := dto.OpenAITextResponse{
  835. Id: helper.GetResponseID(c),
  836. Object: "chat.completion",
  837. Created: common.GetTimestamp(),
  838. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  839. }
  840. isToolCall := false
  841. for _, candidate := range response.Candidates {
  842. choice := dto.OpenAITextResponseChoice{
  843. Index: int(candidate.Index),
  844. Message: dto.Message{
  845. Role: "assistant",
  846. Content: "",
  847. },
  848. FinishReason: constant.FinishReasonStop,
  849. }
  850. if len(candidate.Content.Parts) > 0 {
  851. var texts []string
  852. var toolCalls []dto.ToolCallResponse
  853. for _, part := range candidate.Content.Parts {
  854. if part.InlineData != nil {
  855. // 媒体内容
  856. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  857. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  858. texts = append(texts, imgText)
  859. } else {
  860. // 其他媒体类型,直接显示链接
  861. texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
  862. }
  863. } else if part.FunctionCall != nil {
  864. choice.FinishReason = constant.FinishReasonToolCalls
  865. if call := getResponseToolCall(&part); call != nil {
  866. toolCalls = append(toolCalls, *call)
  867. }
  868. } else if part.Thought {
  869. choice.Message.ReasoningContent = part.Text
  870. } else {
  871. if part.ExecutableCode != nil {
  872. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
  873. } else if part.CodeExecutionResult != nil {
  874. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
  875. } else {
  876. // 过滤掉空行
  877. if part.Text != "\n" {
  878. texts = append(texts, part.Text)
  879. }
  880. }
  881. }
  882. }
  883. if len(toolCalls) > 0 {
  884. choice.Message.SetToolCalls(toolCalls)
  885. isToolCall = true
  886. }
  887. choice.Message.SetStringContent(strings.Join(texts, "\n"))
  888. }
  889. if candidate.FinishReason != nil {
  890. switch *candidate.FinishReason {
  891. case "STOP":
  892. choice.FinishReason = constant.FinishReasonStop
  893. case "MAX_TOKENS":
  894. choice.FinishReason = constant.FinishReasonLength
  895. default:
  896. choice.FinishReason = constant.FinishReasonContentFilter
  897. }
  898. }
  899. if isToolCall {
  900. choice.FinishReason = constant.FinishReasonToolCalls
  901. }
  902. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  903. }
  904. return &fullTextResponse
  905. }
  906. func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
  907. choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
  908. isStop := false
  909. for _, candidate := range geminiResponse.Candidates {
  910. if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
  911. isStop = true
  912. candidate.FinishReason = nil
  913. }
  914. choice := dto.ChatCompletionsStreamResponseChoice{
  915. Index: int(candidate.Index),
  916. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  917. //Role: "assistant",
  918. },
  919. }
  920. var texts []string
  921. isTools := false
  922. isThought := false
  923. if candidate.FinishReason != nil {
  924. // p := GeminiConvertFinishReason(*candidate.FinishReason)
  925. switch *candidate.FinishReason {
  926. case "STOP":
  927. choice.FinishReason = &constant.FinishReasonStop
  928. case "MAX_TOKENS":
  929. choice.FinishReason = &constant.FinishReasonLength
  930. default:
  931. choice.FinishReason = &constant.FinishReasonContentFilter
  932. }
  933. }
  934. for _, part := range candidate.Content.Parts {
  935. if part.InlineData != nil {
  936. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  937. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  938. texts = append(texts, imgText)
  939. }
  940. } else if part.FunctionCall != nil {
  941. isTools = true
  942. if call := getResponseToolCall(&part); call != nil {
  943. call.SetIndex(len(choice.Delta.ToolCalls))
  944. choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
  945. }
  946. } else if part.Thought {
  947. isThought = true
  948. texts = append(texts, part.Text)
  949. } else {
  950. if part.ExecutableCode != nil {
  951. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
  952. } else if part.CodeExecutionResult != nil {
  953. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
  954. } else {
  955. if part.Text != "\n" {
  956. texts = append(texts, part.Text)
  957. }
  958. }
  959. }
  960. }
  961. if isThought {
  962. choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
  963. } else {
  964. choice.Delta.SetContentString(strings.Join(texts, "\n"))
  965. }
  966. if isTools {
  967. choice.FinishReason = &constant.FinishReasonToolCalls
  968. }
  969. choices = append(choices, choice)
  970. }
  971. var response dto.ChatCompletionsStreamResponse
  972. response.Object = "chat.completion.chunk"
  973. response.Choices = choices
  974. return &response, isStop
  975. }
  976. func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  977. streamData, err := common.Marshal(resp)
  978. if err != nil {
  979. return fmt.Errorf("failed to marshal stream response: %w", err)
  980. }
  981. err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  982. if err != nil {
  983. return fmt.Errorf("failed to handle stream format: %w", err)
  984. }
  985. return nil
  986. }
  987. func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  988. streamData, err := common.Marshal(resp)
  989. if err != nil {
  990. return fmt.Errorf("failed to marshal stream response: %w", err)
  991. }
  992. openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
  993. return nil
  994. }
  995. func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
  996. var usage = &dto.Usage{}
  997. var imageCount int
  998. responseText := strings.Builder{}
  999. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  1000. var geminiResponse dto.GeminiChatResponse
  1001. err := common.UnmarshalJsonStr(data, &geminiResponse)
  1002. if err != nil {
  1003. logger.LogError(c, "error unmarshalling stream response: "+err.Error())
  1004. return false
  1005. }
  1006. // 统计图片数量
  1007. for _, candidate := range geminiResponse.Candidates {
  1008. for _, part := range candidate.Content.Parts {
  1009. if part.InlineData != nil && part.InlineData.MimeType != "" {
  1010. imageCount++
  1011. }
  1012. if part.Text != "" {
  1013. responseText.WriteString(part.Text)
  1014. }
  1015. }
  1016. }
  1017. // 更新使用量统计
  1018. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  1019. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  1020. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
  1021. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  1022. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  1023. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  1024. if detail.Modality == "AUDIO" {
  1025. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  1026. } else if detail.Modality == "TEXT" {
  1027. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  1028. }
  1029. }
  1030. }
  1031. return callback(data, &geminiResponse)
  1032. })
  1033. if imageCount != 0 {
  1034. if usage.CompletionTokens == 0 {
  1035. usage.CompletionTokens = imageCount * 1400
  1036. }
  1037. }
  1038. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  1039. if usage.TotalTokens > 0 {
  1040. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  1041. }
  1042. if usage.CompletionTokens <= 0 {
  1043. str := responseText.String()
  1044. if len(str) > 0 {
  1045. usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
  1046. } else {
  1047. usage = &dto.Usage{}
  1048. }
  1049. }
  1050. return usage, nil
  1051. }
  1052. func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1053. id := helper.GetResponseID(c)
  1054. createAt := common.GetTimestamp()
  1055. finishReason := constant.FinishReasonStop
  1056. usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
  1057. response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
  1058. response.Id = id
  1059. response.Created = createAt
  1060. response.Model = info.UpstreamModelName
  1061. logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
  1062. if info.SendResponseCount == 0 {
  1063. // send first response
  1064. emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
  1065. if response.IsToolCall() {
  1066. if len(emptyResponse.Choices) > 0 && len(response.Choices) > 0 {
  1067. toolCalls := response.Choices[0].Delta.ToolCalls
  1068. copiedToolCalls := make([]dto.ToolCallResponse, len(toolCalls))
  1069. for idx := range toolCalls {
  1070. copiedToolCalls[idx] = toolCalls[idx]
  1071. copiedToolCalls[idx].Function.Arguments = ""
  1072. }
  1073. emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
  1074. }
  1075. finishReason = constant.FinishReasonToolCalls
  1076. err := handleStream(c, info, emptyResponse)
  1077. if err != nil {
  1078. logger.LogError(c, err.Error())
  1079. }
  1080. response.ClearToolCalls()
  1081. if response.IsFinished() {
  1082. response.Choices[0].FinishReason = nil
  1083. }
  1084. } else {
  1085. err := handleStream(c, info, emptyResponse)
  1086. if err != nil {
  1087. logger.LogError(c, err.Error())
  1088. }
  1089. }
  1090. }
  1091. err := handleStream(c, info, response)
  1092. if err != nil {
  1093. logger.LogError(c, err.Error())
  1094. }
  1095. if isStop {
  1096. _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
  1097. }
  1098. return true
  1099. })
  1100. if err != nil {
  1101. return usage, err
  1102. }
  1103. response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  1104. handleErr := handleFinalStream(c, info, response)
  1105. if handleErr != nil {
  1106. common.SysLog("send final response failed: " + handleErr.Error())
  1107. }
  1108. return usage, nil
  1109. }
  1110. func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1111. responseBody, err := io.ReadAll(resp.Body)
  1112. if err != nil {
  1113. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1114. }
  1115. service.CloseResponseBodyGracefully(resp)
  1116. if common.DebugEnabled {
  1117. println(string(responseBody))
  1118. }
  1119. var geminiResponse dto.GeminiChatResponse
  1120. err = common.Unmarshal(responseBody, &geminiResponse)
  1121. if err != nil {
  1122. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1123. }
  1124. if len(geminiResponse.Candidates) == 0 {
  1125. //return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1126. //if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
  1127. // return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
  1128. //} else {
  1129. // return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
  1130. //}
  1131. }
  1132. fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
  1133. fullTextResponse.Model = info.UpstreamModelName
  1134. usage := dto.Usage{
  1135. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  1136. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  1137. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  1138. }
  1139. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  1140. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  1141. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  1142. if detail.Modality == "AUDIO" {
  1143. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  1144. } else if detail.Modality == "TEXT" {
  1145. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  1146. }
  1147. }
  1148. fullTextResponse.Usage = usage
  1149. switch info.RelayFormat {
  1150. case types.RelayFormatOpenAI:
  1151. responseBody, err = common.Marshal(fullTextResponse)
  1152. if err != nil {
  1153. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1154. }
  1155. case types.RelayFormatClaude:
  1156. claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
  1157. claudeRespStr, err := common.Marshal(claudeResp)
  1158. if err != nil {
  1159. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1160. }
  1161. responseBody = claudeRespStr
  1162. case types.RelayFormatGemini:
  1163. break
  1164. }
  1165. service.IOCopyBytesGracefully(c, resp, responseBody)
  1166. return &usage, nil
  1167. }
  1168. func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1169. defer service.CloseResponseBodyGracefully(resp)
  1170. responseBody, readErr := io.ReadAll(resp.Body)
  1171. if readErr != nil {
  1172. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1173. }
  1174. var geminiResponse dto.GeminiBatchEmbeddingResponse
  1175. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1176. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1177. }
  1178. // convert to openai format response
  1179. openAIResponse := dto.OpenAIEmbeddingResponse{
  1180. Object: "list",
  1181. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
  1182. Model: info.UpstreamModelName,
  1183. }
  1184. for i, embedding := range geminiResponse.Embeddings {
  1185. openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
  1186. Object: "embedding",
  1187. Embedding: embedding.Values,
  1188. Index: i,
  1189. })
  1190. }
  1191. // calculate usage
  1192. // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
  1193. // Google has not yet clarified how embedding models will be billed
  1194. // refer to openai billing method to use input tokens billing
  1195. // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
  1196. usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
  1197. openAIResponse.Usage = *usage
  1198. jsonResponse, jsonErr := common.Marshal(openAIResponse)
  1199. if jsonErr != nil {
  1200. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1201. }
  1202. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  1203. return usage, nil
  1204. }
  1205. func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1206. responseBody, readErr := io.ReadAll(resp.Body)
  1207. if readErr != nil {
  1208. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1209. }
  1210. _ = resp.Body.Close()
  1211. var geminiResponse dto.GeminiImageResponse
  1212. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1213. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1214. }
  1215. if len(geminiResponse.Predictions) == 0 {
  1216. return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1217. }
  1218. // convert to openai format response
  1219. openAIResponse := dto.ImageResponse{
  1220. Created: common.GetTimestamp(),
  1221. Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
  1222. }
  1223. for _, prediction := range geminiResponse.Predictions {
  1224. if prediction.RaiFilteredReason != "" {
  1225. continue // skip filtered image
  1226. }
  1227. openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
  1228. B64Json: prediction.BytesBase64Encoded,
  1229. })
  1230. }
  1231. jsonResponse, jsonErr := json.Marshal(openAIResponse)
  1232. if jsonErr != nil {
  1233. return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
  1234. }
  1235. c.Writer.Header().Set("Content-Type", "application/json")
  1236. c.Writer.WriteHeader(resp.StatusCode)
  1237. _, _ = c.Writer.Write(jsonResponse)
  1238. // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
  1239. // each image has fixed 258 tokens
  1240. const imageTokens = 258
  1241. generatedImages := len(openAIResponse.Data)
  1242. usage := &dto.Usage{
  1243. PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
  1244. CompletionTokens: 0, // image generation does not calculate completion tokens
  1245. TotalTokens: imageTokens * generatedImages,
  1246. }
  1247. return usage, nil
  1248. }
  1249. type GeminiModelsResponse struct {
  1250. Models []dto.GeminiModel `json:"models"`
  1251. NextPageToken string `json:"nextPageToken"`
  1252. }
  1253. func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
  1254. client, err := service.GetHttpClientWithProxy(proxyURL)
  1255. if err != nil {
  1256. return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
  1257. }
  1258. allModels := make([]string, 0)
  1259. nextPageToken := ""
  1260. maxPages := 100 // Safety limit to prevent infinite loops
  1261. for page := 0; page < maxPages; page++ {
  1262. url := fmt.Sprintf("%s/v1beta/models", baseURL)
  1263. if nextPageToken != "" {
  1264. url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
  1265. }
  1266. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  1267. request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
  1268. if err != nil {
  1269. cancel()
  1270. return nil, fmt.Errorf("创建请求失败: %v", err)
  1271. }
  1272. request.Header.Set("x-goog-api-key", apiKey)
  1273. response, err := client.Do(request)
  1274. if err != nil {
  1275. cancel()
  1276. return nil, fmt.Errorf("请求失败: %v", err)
  1277. }
  1278. if response.StatusCode != http.StatusOK {
  1279. body, _ := io.ReadAll(response.Body)
  1280. response.Body.Close()
  1281. cancel()
  1282. return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
  1283. }
  1284. body, err := io.ReadAll(response.Body)
  1285. response.Body.Close()
  1286. cancel()
  1287. if err != nil {
  1288. return nil, fmt.Errorf("读取响应失败: %v", err)
  1289. }
  1290. var modelsResponse GeminiModelsResponse
  1291. if err = common.Unmarshal(body, &modelsResponse); err != nil {
  1292. return nil, fmt.Errorf("解析响应失败: %v", err)
  1293. }
  1294. for _, model := range modelsResponse.Models {
  1295. modelNameValue, ok := model.Name.(string)
  1296. if !ok {
  1297. continue
  1298. }
  1299. modelName := strings.TrimPrefix(modelNameValue, "models/")
  1300. allModels = append(allModels, modelName)
  1301. }
  1302. nextPageToken = modelsResponse.NextPageToken
  1303. if nextPageToken == "" {
  1304. break
  1305. }
  1306. }
  1307. return allModels, nil
  1308. }