relay-gemini.go 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365
  1. package gemini
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "unicode/utf8"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/logger"
  15. "github.com/QuantumNous/new-api/relay/channel/openai"
  16. relaycommon "github.com/QuantumNous/new-api/relay/common"
  17. "github.com/QuantumNous/new-api/relay/helper"
  18. "github.com/QuantumNous/new-api/service"
  19. "github.com/QuantumNous/new-api/setting/model_setting"
  20. "github.com/QuantumNous/new-api/setting/reasoning"
  21. "github.com/QuantumNous/new-api/types"
  22. "github.com/gin-gonic/gin"
  23. )
  24. // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
  25. var geminiSupportedMimeTypes = map[string]bool{
  26. "application/pdf": true,
  27. "audio/mpeg": true,
  28. "audio/mp3": true,
  29. "audio/wav": true,
  30. "image/png": true,
  31. "image/jpeg": true,
  32. "image/jpg": true, // support old image/jpeg
  33. "image/webp": true,
  34. "text/plain": true,
  35. "video/mov": true,
  36. "video/mpeg": true,
  37. "video/mp4": true,
  38. "video/mpg": true,
  39. "video/avi": true,
  40. "video/wmv": true,
  41. "video/mpegps": true,
  42. "video/flv": true,
  43. }
  44. const thoughtSignatureBypassValue = "context_engineering_is_the_way_to_go"
  45. // Gemini 允许的思考预算范围
  46. const (
  47. pro25MinBudget = 128
  48. pro25MaxBudget = 32768
  49. flash25MaxBudget = 24576
  50. flash25LiteMinBudget = 512
  51. flash25LiteMaxBudget = 24576
  52. )
  53. func isNew25ProModel(modelName string) bool {
  54. return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  55. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  56. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  57. }
  58. func is25FlashLiteModel(modelName string) bool {
  59. return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
  60. }
  61. // clampThinkingBudget 根据模型名称将预算限制在允许的范围内
  62. func clampThinkingBudget(modelName string, budget int) int {
  63. isNew25Pro := isNew25ProModel(modelName)
  64. is25FlashLite := is25FlashLiteModel(modelName)
  65. if is25FlashLite {
  66. if budget < flash25LiteMinBudget {
  67. return flash25LiteMinBudget
  68. }
  69. if budget > flash25LiteMaxBudget {
  70. return flash25LiteMaxBudget
  71. }
  72. } else if isNew25Pro {
  73. if budget < pro25MinBudget {
  74. return pro25MinBudget
  75. }
  76. if budget > pro25MaxBudget {
  77. return pro25MaxBudget
  78. }
  79. } else { // 其他模型
  80. if budget < 0 {
  81. return 0
  82. }
  83. if budget > flash25MaxBudget {
  84. return flash25MaxBudget
  85. }
  86. }
  87. return budget
  88. }
  89. // "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
  90. // "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
  91. // "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
  92. // "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
  93. func clampThinkingBudgetByEffort(modelName string, effort string) int {
  94. isNew25Pro := isNew25ProModel(modelName)
  95. is25FlashLite := is25FlashLiteModel(modelName)
  96. maxBudget := 0
  97. if is25FlashLite {
  98. maxBudget = flash25LiteMaxBudget
  99. }
  100. if isNew25Pro {
  101. maxBudget = pro25MaxBudget
  102. } else {
  103. maxBudget = flash25MaxBudget
  104. }
  105. switch effort {
  106. case "high":
  107. maxBudget = maxBudget * 80 / 100
  108. case "medium":
  109. maxBudget = maxBudget * 50 / 100
  110. case "low":
  111. maxBudget = maxBudget * 20 / 100
  112. case "minimal":
  113. maxBudget = maxBudget * 5 / 100
  114. }
  115. return clampThinkingBudget(modelName, maxBudget)
  116. }
  117. func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
  118. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
  119. modelName := info.UpstreamModelName
  120. isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  121. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  122. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  123. if strings.Contains(modelName, "-thinking-") {
  124. parts := strings.SplitN(modelName, "-thinking-", 2)
  125. if len(parts) == 2 && parts[1] != "" {
  126. if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
  127. clampedBudget := clampThinkingBudget(modelName, budgetTokens)
  128. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  129. ThinkingBudget: common.GetPointer(clampedBudget),
  130. IncludeThoughts: true,
  131. }
  132. }
  133. }
  134. } else if strings.HasSuffix(modelName, "-thinking") {
  135. unsupportedModels := []string{
  136. "gemini-2.5-pro-preview-05-06",
  137. "gemini-2.5-pro-preview-03-25",
  138. }
  139. isUnsupported := false
  140. for _, unsupportedModel := range unsupportedModels {
  141. if strings.HasPrefix(modelName, unsupportedModel) {
  142. isUnsupported = true
  143. break
  144. }
  145. }
  146. if isUnsupported {
  147. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  148. IncludeThoughts: true,
  149. }
  150. } else {
  151. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  152. IncludeThoughts: true,
  153. }
  154. if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
  155. budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
  156. clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
  157. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
  158. } else {
  159. if len(oaiRequest) > 0 {
  160. // 如果有reasoningEffort参数,则根据其值设置思考预算
  161. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
  162. }
  163. }
  164. }
  165. } else if strings.HasSuffix(modelName, "-nothinking") {
  166. if !isNew25Pro {
  167. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  168. ThinkingBudget: common.GetPointer(0),
  169. }
  170. }
  171. } else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
  172. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  173. IncludeThoughts: true,
  174. ThinkingLevel: level,
  175. }
  176. info.ReasoningEffort = level
  177. }
  178. }
  179. }
  180. // Setting safety to the lowest possible values since Gemini is already powerless enough
  181. func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
  182. geminiRequest := dto.GeminiChatRequest{
  183. Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
  184. GenerationConfig: dto.GeminiChatGenerationConfig{
  185. Temperature: textRequest.Temperature,
  186. TopP: textRequest.TopP,
  187. MaxOutputTokens: textRequest.GetMaxTokens(),
  188. Seed: int64(textRequest.Seed),
  189. },
  190. }
  191. attachThoughtSignature := (info.ChannelType == constant.ChannelTypeGemini ||
  192. info.ChannelType == constant.ChannelTypeVertexAi) &&
  193. model_setting.GetGeminiSettings().FunctionCallThoughtSignatureEnabled
  194. if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
  195. geminiRequest.GenerationConfig.ResponseModalities = []string{
  196. "TEXT",
  197. "IMAGE",
  198. }
  199. }
  200. adaptorWithExtraBody := false
  201. // patch extra_body
  202. if len(textRequest.ExtraBody) > 0 {
  203. if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  204. var extraBody map[string]interface{}
  205. if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
  206. return nil, fmt.Errorf("invalid extra body: %w", err)
  207. }
  208. // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
  209. if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
  210. adaptorWithExtraBody = true
  211. // check error param name like thinkingConfig, should be thinking_config
  212. if _, hasErrorParam := googleBody["thinkingConfig"]; hasErrorParam {
  213. return nil, errors.New("extra_body.google.thinkingConfig is not supported, use extra_body.google.thinking_config instead")
  214. }
  215. if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
  216. // check error param name like thinkingBudget, should be thinking_budget
  217. if _, hasErrorParam := thinkingConfig["thinkingBudget"]; hasErrorParam {
  218. return nil, errors.New("extra_body.google.thinking_config.thinkingBudget is not supported, use extra_body.google.thinking_config.thinking_budget instead")
  219. }
  220. if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
  221. budgetInt := int(budget)
  222. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  223. ThinkingBudget: common.GetPointer(budgetInt),
  224. IncludeThoughts: true,
  225. }
  226. } else {
  227. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  228. IncludeThoughts: true,
  229. }
  230. }
  231. }
  232. // check error param name like imageConfig, should be image_config
  233. if _, hasErrorParam := googleBody["imageConfig"]; hasErrorParam {
  234. return nil, errors.New("extra_body.google.imageConfig is not supported, use extra_body.google.image_config instead")
  235. }
  236. if imageConfig, ok := googleBody["image_config"].(map[string]interface{}); ok {
  237. // check error param name like aspectRatio, should be aspect_ratio
  238. if _, hasErrorParam := imageConfig["aspectRatio"]; hasErrorParam {
  239. return nil, errors.New("extra_body.google.image_config.aspectRatio is not supported, use extra_body.google.image_config.aspect_ratio instead")
  240. }
  241. // check error param name like imageSize, should be image_size
  242. if _, hasErrorParam := imageConfig["imageSize"]; hasErrorParam {
  243. return nil, errors.New("extra_body.google.image_config.imageSize is not supported, use extra_body.google.image_config.image_size instead")
  244. }
  245. // convert snake_case to camelCase for Gemini API
  246. geminiImageConfig := make(map[string]interface{})
  247. if aspectRatio, ok := imageConfig["aspect_ratio"]; ok {
  248. geminiImageConfig["aspectRatio"] = aspectRatio
  249. }
  250. if imageSize, ok := imageConfig["image_size"]; ok {
  251. geminiImageConfig["imageSize"] = imageSize
  252. }
  253. if len(geminiImageConfig) > 0 {
  254. imageConfigBytes, err := common.Marshal(geminiImageConfig)
  255. if err != nil {
  256. return nil, fmt.Errorf("failed to marshal image_config: %w", err)
  257. }
  258. geminiRequest.GenerationConfig.ImageConfig = imageConfigBytes
  259. }
  260. }
  261. }
  262. }
  263. }
  264. if !adaptorWithExtraBody {
  265. ThinkingAdaptor(&geminiRequest, info, textRequest)
  266. }
  267. safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
  268. for _, category := range SafetySettingList {
  269. safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
  270. Category: category,
  271. Threshold: model_setting.GetGeminiSafetySetting(category),
  272. })
  273. }
  274. geminiRequest.SafetySettings = safetySettings
  275. // openaiContent.FuncToToolCalls()
  276. if textRequest.Tools != nil {
  277. functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
  278. googleSearch := false
  279. codeExecution := false
  280. urlContext := false
  281. for _, tool := range textRequest.Tools {
  282. if tool.Function.Name == "googleSearch" {
  283. googleSearch = true
  284. continue
  285. }
  286. if tool.Function.Name == "codeExecution" {
  287. codeExecution = true
  288. continue
  289. }
  290. if tool.Function.Name == "urlContext" {
  291. urlContext = true
  292. continue
  293. }
  294. if tool.Function.Parameters != nil {
  295. params, ok := tool.Function.Parameters.(map[string]interface{})
  296. if ok {
  297. if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
  298. if len(props) == 0 {
  299. tool.Function.Parameters = nil
  300. }
  301. }
  302. }
  303. }
  304. // Clean the parameters before appending
  305. cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
  306. tool.Function.Parameters = cleanedParams
  307. functions = append(functions, tool.Function)
  308. }
  309. geminiTools := geminiRequest.GetTools()
  310. if codeExecution {
  311. geminiTools = append(geminiTools, dto.GeminiChatTool{
  312. CodeExecution: make(map[string]string),
  313. })
  314. }
  315. if googleSearch {
  316. geminiTools = append(geminiTools, dto.GeminiChatTool{
  317. GoogleSearch: make(map[string]string),
  318. })
  319. }
  320. if urlContext {
  321. geminiTools = append(geminiTools, dto.GeminiChatTool{
  322. URLContext: make(map[string]string),
  323. })
  324. }
  325. if len(functions) > 0 {
  326. geminiTools = append(geminiTools, dto.GeminiChatTool{
  327. FunctionDeclarations: functions,
  328. })
  329. }
  330. geminiRequest.SetTools(geminiTools)
  331. }
  332. if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
  333. geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
  334. if len(textRequest.ResponseFormat.JsonSchema) > 0 {
  335. // 先将json.RawMessage解析
  336. var jsonSchema dto.FormatJsonSchema
  337. if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
  338. cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
  339. geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
  340. }
  341. }
  342. }
  343. tool_call_ids := make(map[string]string)
  344. var system_content []string
  345. //shouldAddDummyModelMessage := false
  346. for _, message := range textRequest.Messages {
  347. if message.Role == "system" || message.Role == "developer" {
  348. system_content = append(system_content, message.StringContent())
  349. continue
  350. } else if message.Role == "tool" || message.Role == "function" {
  351. if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
  352. geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
  353. Role: "user",
  354. })
  355. }
  356. var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
  357. name := ""
  358. if message.Name != nil {
  359. name = *message.Name
  360. } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
  361. name = val
  362. }
  363. var contentMap map[string]interface{}
  364. contentStr := message.StringContent()
  365. // 1. 尝试解析为 JSON 对象
  366. if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
  367. // 2. 如果失败,尝试解析为 JSON 数组
  368. var contentSlice []interface{}
  369. if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
  370. // 如果是数组,包装成对象
  371. contentMap = map[string]interface{}{"result": contentSlice}
  372. } else {
  373. // 3. 如果再次失败,作为纯文本处理
  374. contentMap = map[string]interface{}{"content": contentStr}
  375. }
  376. }
  377. functionResp := &dto.GeminiFunctionResponse{
  378. Name: name,
  379. Response: contentMap,
  380. }
  381. *parts = append(*parts, dto.GeminiPart{
  382. FunctionResponse: functionResp,
  383. })
  384. continue
  385. }
  386. var parts []dto.GeminiPart
  387. content := dto.GeminiChatContent{
  388. Role: message.Role,
  389. }
  390. shouldAttachThoughtSignature := attachThoughtSignature && (message.Role == "assistant" || message.Role == "model")
  391. signatureAttached := false
  392. // isToolCall := false
  393. if message.ToolCalls != nil {
  394. // message.Role = "model"
  395. // isToolCall = true
  396. for _, call := range message.ParseToolCalls() {
  397. args := map[string]interface{}{}
  398. if call.Function.Arguments != "" {
  399. if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
  400. return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
  401. }
  402. }
  403. toolCall := dto.GeminiPart{
  404. FunctionCall: &dto.FunctionCall{
  405. FunctionName: call.Function.Name,
  406. Arguments: args,
  407. },
  408. }
  409. if shouldAttachThoughtSignature && !signatureAttached && hasFunctionCallContent(toolCall.FunctionCall) && len(toolCall.ThoughtSignature) == 0 {
  410. toolCall.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  411. signatureAttached = true
  412. }
  413. parts = append(parts, toolCall)
  414. tool_call_ids[call.ID] = call.Function.Name
  415. }
  416. }
  417. openaiContent := message.ParseContent()
  418. imageNum := 0
  419. for _, part := range openaiContent {
  420. if part.Type == dto.ContentTypeText {
  421. if part.Text == "" {
  422. continue
  423. }
  424. // check markdown image ![image](data:image/jpeg;base64,xxxxxxxxxxxx)
  425. // 使用字符串查找而非正则,避免大文本性能问题
  426. text := part.Text
  427. hasMarkdownImage := false
  428. for {
  429. // 快速检查是否包含 markdown 图片标记
  430. startIdx := strings.Index(text, "![")
  431. if startIdx == -1 {
  432. break
  433. }
  434. // 找到 ](
  435. bracketIdx := strings.Index(text[startIdx:], "](data:")
  436. if bracketIdx == -1 {
  437. break
  438. }
  439. bracketIdx += startIdx
  440. // 找到闭合的 )
  441. closeIdx := strings.Index(text[bracketIdx+2:], ")")
  442. if closeIdx == -1 {
  443. break
  444. }
  445. closeIdx += bracketIdx + 2
  446. hasMarkdownImage = true
  447. // 添加图片前的文本
  448. if startIdx > 0 {
  449. textBefore := text[:startIdx]
  450. if textBefore != "" {
  451. parts = append(parts, dto.GeminiPart{
  452. Text: textBefore,
  453. })
  454. }
  455. }
  456. // 提取 data URL (从 "](" 后面开始,到 ")" 之前)
  457. dataUrl := text[bracketIdx+2 : closeIdx]
  458. imageNum += 1
  459. if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
  460. return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
  461. }
  462. format, base64String, err := service.DecodeBase64FileData(dataUrl)
  463. if err != nil {
  464. return nil, fmt.Errorf("decode markdown base64 image data failed: %s", err.Error())
  465. }
  466. imgPart := dto.GeminiPart{
  467. InlineData: &dto.GeminiInlineData{
  468. MimeType: format,
  469. Data: base64String,
  470. },
  471. }
  472. if shouldAttachThoughtSignature {
  473. imgPart.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  474. }
  475. parts = append(parts, imgPart)
  476. // 继续处理剩余文本
  477. text = text[closeIdx+1:]
  478. }
  479. // 添加剩余文本或原始文本(如果没有找到 markdown 图片)
  480. if !hasMarkdownImage {
  481. parts = append(parts, dto.GeminiPart{
  482. Text: part.Text,
  483. })
  484. }
  485. } else if part.Type == dto.ContentTypeImageURL {
  486. imageNum += 1
  487. if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
  488. return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
  489. }
  490. // 判断是否是url
  491. if strings.HasPrefix(part.GetImageMedia().Url, "http") {
  492. // 是url,获取文件的类型和base64编码的数据
  493. fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
  494. if err != nil {
  495. return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
  496. }
  497. // 校验 MimeType 是否在 Gemini 支持的白名单中
  498. if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
  499. url := part.GetImageMedia().Url
  500. return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
  501. }
  502. parts = append(parts, dto.GeminiPart{
  503. InlineData: &dto.GeminiInlineData{
  504. MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
  505. Data: fileData.Base64Data,
  506. },
  507. })
  508. } else {
  509. format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
  510. if err != nil {
  511. return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
  512. }
  513. parts = append(parts, dto.GeminiPart{
  514. InlineData: &dto.GeminiInlineData{
  515. MimeType: format,
  516. Data: base64String,
  517. },
  518. })
  519. }
  520. } else if part.Type == dto.ContentTypeFile {
  521. if part.GetFile().FileId != "" {
  522. return nil, fmt.Errorf("only base64 file is supported in gemini")
  523. }
  524. format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
  525. if err != nil {
  526. return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
  527. }
  528. parts = append(parts, dto.GeminiPart{
  529. InlineData: &dto.GeminiInlineData{
  530. MimeType: format,
  531. Data: base64String,
  532. },
  533. })
  534. } else if part.Type == dto.ContentTypeInputAudio {
  535. if part.GetInputAudio().Data == "" {
  536. return nil, fmt.Errorf("only base64 audio is supported in gemini")
  537. }
  538. base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
  539. if err != nil {
  540. return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
  541. }
  542. parts = append(parts, dto.GeminiPart{
  543. InlineData: &dto.GeminiInlineData{
  544. MimeType: "audio/" + part.GetInputAudio().Format,
  545. Data: base64String,
  546. },
  547. })
  548. }
  549. }
  550. // 如果需要附加签名但还没有附加(没有 tool_calls 或 tool_calls 为空),
  551. // 则在第一个文本 part 上附加 thoughtSignature
  552. if shouldAttachThoughtSignature && !signatureAttached && len(parts) > 0 {
  553. for i := range parts {
  554. if parts[i].Text != "" {
  555. parts[i].ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
  556. break
  557. }
  558. }
  559. }
  560. content.Parts = parts
  561. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  562. if content.Role == "assistant" {
  563. content.Role = "model"
  564. }
  565. if len(content.Parts) > 0 {
  566. geminiRequest.Contents = append(geminiRequest.Contents, content)
  567. }
  568. }
  569. if len(system_content) > 0 {
  570. geminiRequest.SystemInstructions = &dto.GeminiChatContent{
  571. Parts: []dto.GeminiPart{
  572. {
  573. Text: strings.Join(system_content, "\n"),
  574. },
  575. },
  576. }
  577. }
  578. return &geminiRequest, nil
  579. }
  580. func hasFunctionCallContent(call *dto.FunctionCall) bool {
  581. if call == nil {
  582. return false
  583. }
  584. if strings.TrimSpace(call.FunctionName) != "" {
  585. return true
  586. }
  587. switch v := call.Arguments.(type) {
  588. case nil:
  589. return false
  590. case string:
  591. return strings.TrimSpace(v) != ""
  592. case map[string]interface{}:
  593. return len(v) > 0
  594. case []interface{}:
  595. return len(v) > 0
  596. default:
  597. return true
  598. }
  599. }
  600. // Helper function to get a list of supported MIME types for error messages
  601. func getSupportedMimeTypesList() []string {
  602. keys := make([]string, 0, len(geminiSupportedMimeTypes))
  603. for k := range geminiSupportedMimeTypes {
  604. keys = append(keys, k)
  605. }
  606. return keys
  607. }
  608. // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
  609. func cleanFunctionParameters(params interface{}) interface{} {
  610. if params == nil {
  611. return nil
  612. }
  613. switch v := params.(type) {
  614. case map[string]interface{}:
  615. // Create a copy to avoid modifying the original
  616. cleanedMap := make(map[string]interface{})
  617. for k, val := range v {
  618. cleanedMap[k] = val
  619. }
  620. // Remove unsupported root-level fields
  621. delete(cleanedMap, "default")
  622. delete(cleanedMap, "exclusiveMaximum")
  623. delete(cleanedMap, "exclusiveMinimum")
  624. delete(cleanedMap, "$schema")
  625. delete(cleanedMap, "additionalProperties")
  626. // Check and clean 'format' for string types
  627. if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
  628. if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
  629. if formatValue != "enum" && formatValue != "date-time" {
  630. delete(cleanedMap, "format")
  631. }
  632. }
  633. }
  634. // Clean properties
  635. if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
  636. cleanedProps := make(map[string]interface{})
  637. for propName, propValue := range props {
  638. cleanedProps[propName] = cleanFunctionParameters(propValue)
  639. }
  640. cleanedMap["properties"] = cleanedProps
  641. }
  642. // Recursively clean items in arrays
  643. if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
  644. cleanedMap["items"] = cleanFunctionParameters(items)
  645. }
  646. // Also handle items if it's an array of schemas
  647. if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
  648. cleanedItemsArray := make([]interface{}, len(itemsArray))
  649. for i, item := range itemsArray {
  650. cleanedItemsArray[i] = cleanFunctionParameters(item)
  651. }
  652. cleanedMap["items"] = cleanedItemsArray
  653. }
  654. // Recursively clean other schema composition keywords
  655. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  656. if nested, ok := cleanedMap[field].([]interface{}); ok {
  657. cleanedNested := make([]interface{}, len(nested))
  658. for i, item := range nested {
  659. cleanedNested[i] = cleanFunctionParameters(item)
  660. }
  661. cleanedMap[field] = cleanedNested
  662. }
  663. }
  664. // Recursively clean patternProperties
  665. if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
  666. cleanedPatternProps := make(map[string]interface{})
  667. for pattern, schema := range patternProps {
  668. cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
  669. }
  670. cleanedMap["patternProperties"] = cleanedPatternProps
  671. }
  672. // Recursively clean definitions
  673. if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
  674. cleanedDefinitions := make(map[string]interface{})
  675. for defName, defSchema := range definitions {
  676. cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
  677. }
  678. cleanedMap["definitions"] = cleanedDefinitions
  679. }
  680. // Recursively clean $defs (newer JSON Schema draft)
  681. if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
  682. cleanedDefs := make(map[string]interface{})
  683. for defName, defSchema := range defs {
  684. cleanedDefs[defName] = cleanFunctionParameters(defSchema)
  685. }
  686. cleanedMap["$defs"] = cleanedDefs
  687. }
  688. // Clean conditional keywords
  689. for _, field := range []string{"if", "then", "else", "not"} {
  690. if nested, ok := cleanedMap[field]; ok {
  691. cleanedMap[field] = cleanFunctionParameters(nested)
  692. }
  693. }
  694. return cleanedMap
  695. case []interface{}:
  696. // Handle arrays of schemas
  697. cleanedArray := make([]interface{}, len(v))
  698. for i, item := range v {
  699. cleanedArray[i] = cleanFunctionParameters(item)
  700. }
  701. return cleanedArray
  702. default:
  703. // Not a map or array, return as is (e.g., could be a primitive)
  704. return params
  705. }
  706. }
  707. func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
  708. if depth >= 5 {
  709. return schema
  710. }
  711. v, ok := schema.(map[string]interface{})
  712. if !ok || len(v) == 0 {
  713. return schema
  714. }
  715. // 删除所有的title字段
  716. delete(v, "title")
  717. delete(v, "$schema")
  718. // 如果type不为object和array,则直接返回
  719. if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
  720. return schema
  721. }
  722. switch v["type"] {
  723. case "object":
  724. delete(v, "additionalProperties")
  725. // 处理 properties
  726. if properties, ok := v["properties"].(map[string]interface{}); ok {
  727. for key, value := range properties {
  728. properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
  729. }
  730. }
  731. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  732. if nested, ok := v[field].([]interface{}); ok {
  733. for i, item := range nested {
  734. nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
  735. }
  736. }
  737. }
  738. case "array":
  739. if items, ok := v["items"].(map[string]interface{}); ok {
  740. v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
  741. }
  742. }
  743. return v
  744. }
  745. func unescapeString(s string) (string, error) {
  746. var result []rune
  747. escaped := false
  748. i := 0
  749. for i < len(s) {
  750. r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
  751. if r == utf8.RuneError {
  752. return "", fmt.Errorf("invalid UTF-8 encoding")
  753. }
  754. if escaped {
  755. // 如果是转义符后的字符,检查其类型
  756. switch r {
  757. case '"':
  758. result = append(result, '"')
  759. case '\\':
  760. result = append(result, '\\')
  761. case '/':
  762. result = append(result, '/')
  763. case 'b':
  764. result = append(result, '\b')
  765. case 'f':
  766. result = append(result, '\f')
  767. case 'n':
  768. result = append(result, '\n')
  769. case 'r':
  770. result = append(result, '\r')
  771. case 't':
  772. result = append(result, '\t')
  773. case '\'':
  774. result = append(result, '\'')
  775. default:
  776. // 如果遇到一个非法的转义字符,直接按原样输出
  777. result = append(result, '\\', r)
  778. }
  779. escaped = false
  780. } else {
  781. if r == '\\' {
  782. escaped = true // 记录反斜杠作为转义符
  783. } else {
  784. result = append(result, r)
  785. }
  786. }
  787. i += size // 移动到下一个字符
  788. }
  789. return string(result), nil
  790. }
  791. func unescapeMapOrSlice(data interface{}) interface{} {
  792. switch v := data.(type) {
  793. case map[string]interface{}:
  794. for k, val := range v {
  795. v[k] = unescapeMapOrSlice(val)
  796. }
  797. case []interface{}:
  798. for i, val := range v {
  799. v[i] = unescapeMapOrSlice(val)
  800. }
  801. case string:
  802. if unescaped, err := unescapeString(v); err != nil {
  803. return v
  804. } else {
  805. return unescaped
  806. }
  807. }
  808. return data
  809. }
  810. func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
  811. var argsBytes []byte
  812. var err error
  813. if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
  814. argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
  815. } else {
  816. argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
  817. }
  818. if err != nil {
  819. return nil
  820. }
  821. return &dto.ToolCallResponse{
  822. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  823. Type: "function",
  824. Function: dto.FunctionResponse{
  825. Arguments: string(argsBytes),
  826. Name: item.FunctionCall.FunctionName,
  827. },
  828. }
  829. }
  830. func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
  831. fullTextResponse := dto.OpenAITextResponse{
  832. Id: helper.GetResponseID(c),
  833. Object: "chat.completion",
  834. Created: common.GetTimestamp(),
  835. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  836. }
  837. isToolCall := false
  838. for _, candidate := range response.Candidates {
  839. choice := dto.OpenAITextResponseChoice{
  840. Index: int(candidate.Index),
  841. Message: dto.Message{
  842. Role: "assistant",
  843. Content: "",
  844. },
  845. FinishReason: constant.FinishReasonStop,
  846. }
  847. if len(candidate.Content.Parts) > 0 {
  848. var texts []string
  849. var toolCalls []dto.ToolCallResponse
  850. for _, part := range candidate.Content.Parts {
  851. if part.InlineData != nil {
  852. // 媒体内容
  853. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  854. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  855. texts = append(texts, imgText)
  856. } else {
  857. // 其他媒体类型,直接显示链接
  858. texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
  859. }
  860. } else if part.FunctionCall != nil {
  861. choice.FinishReason = constant.FinishReasonToolCalls
  862. if call := getResponseToolCall(&part); call != nil {
  863. toolCalls = append(toolCalls, *call)
  864. }
  865. } else if part.Thought {
  866. choice.Message.ReasoningContent = part.Text
  867. } else {
  868. if part.ExecutableCode != nil {
  869. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
  870. } else if part.CodeExecutionResult != nil {
  871. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
  872. } else {
  873. // 过滤掉空行
  874. if part.Text != "\n" {
  875. texts = append(texts, part.Text)
  876. }
  877. }
  878. }
  879. }
  880. if len(toolCalls) > 0 {
  881. choice.Message.SetToolCalls(toolCalls)
  882. isToolCall = true
  883. }
  884. choice.Message.SetStringContent(strings.Join(texts, "\n"))
  885. }
  886. if candidate.FinishReason != nil {
  887. switch *candidate.FinishReason {
  888. case "STOP":
  889. choice.FinishReason = constant.FinishReasonStop
  890. case "MAX_TOKENS":
  891. choice.FinishReason = constant.FinishReasonLength
  892. default:
  893. choice.FinishReason = constant.FinishReasonContentFilter
  894. }
  895. }
  896. if isToolCall {
  897. choice.FinishReason = constant.FinishReasonToolCalls
  898. }
  899. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  900. }
  901. return &fullTextResponse
  902. }
  903. func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
  904. choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
  905. isStop := false
  906. for _, candidate := range geminiResponse.Candidates {
  907. if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
  908. isStop = true
  909. candidate.FinishReason = nil
  910. }
  911. choice := dto.ChatCompletionsStreamResponseChoice{
  912. Index: int(candidate.Index),
  913. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  914. //Role: "assistant",
  915. },
  916. }
  917. var texts []string
  918. isTools := false
  919. isThought := false
  920. if candidate.FinishReason != nil {
  921. // p := GeminiConvertFinishReason(*candidate.FinishReason)
  922. switch *candidate.FinishReason {
  923. case "STOP":
  924. choice.FinishReason = &constant.FinishReasonStop
  925. case "MAX_TOKENS":
  926. choice.FinishReason = &constant.FinishReasonLength
  927. default:
  928. choice.FinishReason = &constant.FinishReasonContentFilter
  929. }
  930. }
  931. for _, part := range candidate.Content.Parts {
  932. if part.InlineData != nil {
  933. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  934. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  935. texts = append(texts, imgText)
  936. }
  937. } else if part.FunctionCall != nil {
  938. isTools = true
  939. if call := getResponseToolCall(&part); call != nil {
  940. call.SetIndex(len(choice.Delta.ToolCalls))
  941. choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
  942. }
  943. } else if part.Thought {
  944. isThought = true
  945. texts = append(texts, part.Text)
  946. } else {
  947. if part.ExecutableCode != nil {
  948. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
  949. } else if part.CodeExecutionResult != nil {
  950. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
  951. } else {
  952. if part.Text != "\n" {
  953. texts = append(texts, part.Text)
  954. }
  955. }
  956. }
  957. }
  958. if isThought {
  959. choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
  960. } else {
  961. choice.Delta.SetContentString(strings.Join(texts, "\n"))
  962. }
  963. if isTools {
  964. choice.FinishReason = &constant.FinishReasonToolCalls
  965. }
  966. choices = append(choices, choice)
  967. }
  968. var response dto.ChatCompletionsStreamResponse
  969. response.Object = "chat.completion.chunk"
  970. response.Choices = choices
  971. return &response, isStop
  972. }
  973. func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  974. streamData, err := common.Marshal(resp)
  975. if err != nil {
  976. return fmt.Errorf("failed to marshal stream response: %w", err)
  977. }
  978. err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  979. if err != nil {
  980. return fmt.Errorf("failed to handle stream format: %w", err)
  981. }
  982. return nil
  983. }
  984. func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  985. streamData, err := common.Marshal(resp)
  986. if err != nil {
  987. return fmt.Errorf("failed to marshal stream response: %w", err)
  988. }
  989. openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
  990. return nil
  991. }
  992. func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
  993. var usage = &dto.Usage{}
  994. var imageCount int
  995. responseText := strings.Builder{}
  996. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  997. var geminiResponse dto.GeminiChatResponse
  998. err := common.UnmarshalJsonStr(data, &geminiResponse)
  999. if err != nil {
  1000. logger.LogError(c, "error unmarshalling stream response: "+err.Error())
  1001. return false
  1002. }
  1003. // 统计图片数量
  1004. for _, candidate := range geminiResponse.Candidates {
  1005. for _, part := range candidate.Content.Parts {
  1006. if part.InlineData != nil && part.InlineData.MimeType != "" {
  1007. imageCount++
  1008. }
  1009. if part.Text != "" {
  1010. responseText.WriteString(part.Text)
  1011. }
  1012. }
  1013. }
  1014. // 更新使用量统计
  1015. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  1016. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  1017. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
  1018. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  1019. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  1020. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  1021. if detail.Modality == "AUDIO" {
  1022. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  1023. } else if detail.Modality == "TEXT" {
  1024. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  1025. }
  1026. }
  1027. }
  1028. return callback(data, &geminiResponse)
  1029. })
  1030. if imageCount != 0 {
  1031. if usage.CompletionTokens == 0 {
  1032. usage.CompletionTokens = imageCount * 1400
  1033. }
  1034. }
  1035. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  1036. if usage.TotalTokens > 0 {
  1037. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  1038. }
  1039. if usage.CompletionTokens <= 0 {
  1040. str := responseText.String()
  1041. if len(str) > 0 {
  1042. usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
  1043. } else {
  1044. usage = &dto.Usage{}
  1045. }
  1046. }
  1047. return usage, nil
  1048. }
  1049. func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1050. id := helper.GetResponseID(c)
  1051. createAt := common.GetTimestamp()
  1052. finishReason := constant.FinishReasonStop
  1053. usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
  1054. response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
  1055. response.Id = id
  1056. response.Created = createAt
  1057. response.Model = info.UpstreamModelName
  1058. logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
  1059. if info.SendResponseCount == 0 {
  1060. // send first response
  1061. emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
  1062. if response.IsToolCall() {
  1063. if len(emptyResponse.Choices) > 0 && len(response.Choices) > 0 {
  1064. toolCalls := response.Choices[0].Delta.ToolCalls
  1065. copiedToolCalls := make([]dto.ToolCallResponse, len(toolCalls))
  1066. for idx := range toolCalls {
  1067. copiedToolCalls[idx] = toolCalls[idx]
  1068. copiedToolCalls[idx].Function.Arguments = ""
  1069. }
  1070. emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
  1071. }
  1072. finishReason = constant.FinishReasonToolCalls
  1073. err := handleStream(c, info, emptyResponse)
  1074. if err != nil {
  1075. logger.LogError(c, err.Error())
  1076. }
  1077. response.ClearToolCalls()
  1078. if response.IsFinished() {
  1079. response.Choices[0].FinishReason = nil
  1080. }
  1081. } else {
  1082. err := handleStream(c, info, emptyResponse)
  1083. if err != nil {
  1084. logger.LogError(c, err.Error())
  1085. }
  1086. }
  1087. }
  1088. err := handleStream(c, info, response)
  1089. if err != nil {
  1090. logger.LogError(c, err.Error())
  1091. }
  1092. if isStop {
  1093. _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
  1094. }
  1095. return true
  1096. })
  1097. if err != nil {
  1098. return usage, err
  1099. }
  1100. response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  1101. handleErr := handleFinalStream(c, info, response)
  1102. if handleErr != nil {
  1103. common.SysLog("send final response failed: " + handleErr.Error())
  1104. }
  1105. return usage, nil
  1106. }
  1107. func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1108. responseBody, err := io.ReadAll(resp.Body)
  1109. if err != nil {
  1110. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1111. }
  1112. service.CloseResponseBodyGracefully(resp)
  1113. if common.DebugEnabled {
  1114. println(string(responseBody))
  1115. }
  1116. var geminiResponse dto.GeminiChatResponse
  1117. err = common.Unmarshal(responseBody, &geminiResponse)
  1118. if err != nil {
  1119. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1120. }
  1121. if len(geminiResponse.Candidates) == 0 {
  1122. //return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1123. //if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
  1124. // return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
  1125. //} else {
  1126. // return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
  1127. //}
  1128. }
  1129. fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
  1130. fullTextResponse.Model = info.UpstreamModelName
  1131. usage := dto.Usage{
  1132. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  1133. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  1134. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  1135. }
  1136. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  1137. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  1138. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  1139. if detail.Modality == "AUDIO" {
  1140. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  1141. } else if detail.Modality == "TEXT" {
  1142. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  1143. }
  1144. }
  1145. fullTextResponse.Usage = usage
  1146. switch info.RelayFormat {
  1147. case types.RelayFormatOpenAI:
  1148. responseBody, err = common.Marshal(fullTextResponse)
  1149. if err != nil {
  1150. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1151. }
  1152. case types.RelayFormatClaude:
  1153. claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
  1154. claudeRespStr, err := common.Marshal(claudeResp)
  1155. if err != nil {
  1156. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1157. }
  1158. responseBody = claudeRespStr
  1159. case types.RelayFormatGemini:
  1160. break
  1161. }
  1162. service.IOCopyBytesGracefully(c, resp, responseBody)
  1163. return &usage, nil
  1164. }
  1165. func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1166. defer service.CloseResponseBodyGracefully(resp)
  1167. responseBody, readErr := io.ReadAll(resp.Body)
  1168. if readErr != nil {
  1169. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1170. }
  1171. var geminiResponse dto.GeminiBatchEmbeddingResponse
  1172. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1173. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1174. }
  1175. // convert to openai format response
  1176. openAIResponse := dto.OpenAIEmbeddingResponse{
  1177. Object: "list",
  1178. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
  1179. Model: info.UpstreamModelName,
  1180. }
  1181. for i, embedding := range geminiResponse.Embeddings {
  1182. openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
  1183. Object: "embedding",
  1184. Embedding: embedding.Values,
  1185. Index: i,
  1186. })
  1187. }
  1188. // calculate usage
  1189. // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
  1190. // Google has not yet clarified how embedding models will be billed
  1191. // refer to openai billing method to use input tokens billing
  1192. // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
  1193. usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
  1194. openAIResponse.Usage = *usage
  1195. jsonResponse, jsonErr := common.Marshal(openAIResponse)
  1196. if jsonErr != nil {
  1197. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1198. }
  1199. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  1200. return usage, nil
  1201. }
  1202. func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1203. responseBody, readErr := io.ReadAll(resp.Body)
  1204. if readErr != nil {
  1205. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1206. }
  1207. _ = resp.Body.Close()
  1208. var geminiResponse dto.GeminiImageResponse
  1209. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1210. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1211. }
  1212. if len(geminiResponse.Predictions) == 0 {
  1213. return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1214. }
  1215. // convert to openai format response
  1216. openAIResponse := dto.ImageResponse{
  1217. Created: common.GetTimestamp(),
  1218. Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
  1219. }
  1220. for _, prediction := range geminiResponse.Predictions {
  1221. if prediction.RaiFilteredReason != "" {
  1222. continue // skip filtered image
  1223. }
  1224. openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
  1225. B64Json: prediction.BytesBase64Encoded,
  1226. })
  1227. }
  1228. jsonResponse, jsonErr := json.Marshal(openAIResponse)
  1229. if jsonErr != nil {
  1230. return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
  1231. }
  1232. c.Writer.Header().Set("Content-Type", "application/json")
  1233. c.Writer.WriteHeader(resp.StatusCode)
  1234. _, _ = c.Writer.Write(jsonResponse)
  1235. // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
  1236. // each image has fixed 258 tokens
  1237. const imageTokens = 258
  1238. generatedImages := len(openAIResponse.Data)
  1239. usage := &dto.Usage{
  1240. PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
  1241. CompletionTokens: 0, // image generation does not calculate completion tokens
  1242. TotalTokens: imageTokens * generatedImages,
  1243. }
  1244. return usage, nil
  1245. }