relay-gemini.go 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181
  1. package gemini
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/constant"
  10. "one-api/dto"
  11. "one-api/logger"
  12. "one-api/relay/channel/openai"
  13. relaycommon "one-api/relay/common"
  14. "one-api/relay/helper"
  15. "one-api/service"
  16. "one-api/setting/model_setting"
  17. "one-api/types"
  18. "strconv"
  19. "strings"
  20. "unicode/utf8"
  21. "github.com/gin-gonic/gin"
  22. )
  23. var geminiSupportedMimeTypes = map[string]bool{
  24. "application/pdf": true,
  25. "audio/mpeg": true,
  26. "audio/mp3": true,
  27. "audio/wav": true,
  28. "image/png": true,
  29. "image/jpeg": true,
  30. "text/plain": true,
  31. "video/mov": true,
  32. "video/mpeg": true,
  33. "video/mp4": true,
  34. "video/mpg": true,
  35. "video/avi": true,
  36. "video/wmv": true,
  37. "video/mpegps": true,
  38. "video/flv": true,
  39. }
  40. // Gemini 允许的思考预算范围
  41. const (
  42. pro25MinBudget = 128
  43. pro25MaxBudget = 32768
  44. flash25MaxBudget = 24576
  45. flash25LiteMinBudget = 512
  46. flash25LiteMaxBudget = 24576
  47. )
  48. func isNew25ProModel(modelName string) bool {
  49. return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  50. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  51. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  52. }
  53. func is25FlashLiteModel(modelName string) bool {
  54. return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
  55. }
  56. // clampThinkingBudget 根据模型名称将预算限制在允许的范围内
  57. func clampThinkingBudget(modelName string, budget int) int {
  58. isNew25Pro := isNew25ProModel(modelName)
  59. is25FlashLite := is25FlashLiteModel(modelName)
  60. if is25FlashLite {
  61. if budget < flash25LiteMinBudget {
  62. return flash25LiteMinBudget
  63. }
  64. if budget > flash25LiteMaxBudget {
  65. return flash25LiteMaxBudget
  66. }
  67. } else if isNew25Pro {
  68. if budget < pro25MinBudget {
  69. return pro25MinBudget
  70. }
  71. if budget > pro25MaxBudget {
  72. return pro25MaxBudget
  73. }
  74. } else { // 其他模型
  75. if budget < 0 {
  76. return 0
  77. }
  78. if budget > flash25MaxBudget {
  79. return flash25MaxBudget
  80. }
  81. }
  82. return budget
  83. }
  84. // "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
  85. // "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
  86. // "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
  87. func clampThinkingBudgetByEffort(modelName string, effort string) int {
  88. isNew25Pro := isNew25ProModel(modelName)
  89. is25FlashLite := is25FlashLiteModel(modelName)
  90. maxBudget := 0
  91. if is25FlashLite {
  92. maxBudget = flash25LiteMaxBudget
  93. }
  94. if isNew25Pro {
  95. maxBudget = pro25MaxBudget
  96. } else {
  97. maxBudget = flash25MaxBudget
  98. }
  99. switch effort {
  100. case "high":
  101. maxBudget = maxBudget * 80 / 100
  102. case "medium":
  103. maxBudget = maxBudget * 50 / 100
  104. case "low":
  105. maxBudget = maxBudget * 20 / 100
  106. }
  107. return clampThinkingBudget(modelName, maxBudget)
  108. }
  109. func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
  110. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
  111. modelName := info.UpstreamModelName
  112. isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  113. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  114. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  115. if strings.Contains(modelName, "-thinking-") {
  116. parts := strings.SplitN(modelName, "-thinking-", 2)
  117. if len(parts) == 2 && parts[1] != "" {
  118. if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
  119. clampedBudget := clampThinkingBudget(modelName, budgetTokens)
  120. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  121. ThinkingBudget: common.GetPointer(clampedBudget),
  122. IncludeThoughts: true,
  123. }
  124. }
  125. }
  126. } else if strings.HasSuffix(modelName, "-thinking") {
  127. unsupportedModels := []string{
  128. "gemini-2.5-pro-preview-05-06",
  129. "gemini-2.5-pro-preview-03-25",
  130. }
  131. isUnsupported := false
  132. for _, unsupportedModel := range unsupportedModels {
  133. if strings.HasPrefix(modelName, unsupportedModel) {
  134. isUnsupported = true
  135. break
  136. }
  137. }
  138. if isUnsupported {
  139. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  140. IncludeThoughts: true,
  141. }
  142. } else {
  143. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  144. IncludeThoughts: true,
  145. }
  146. if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
  147. budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
  148. clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
  149. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
  150. } else {
  151. if len(oaiRequest) > 0 {
  152. // 如果有reasoningEffort参数,则根据其值设置思考预算
  153. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
  154. }
  155. }
  156. }
  157. } else if strings.HasSuffix(modelName, "-nothinking") {
  158. if !isNew25Pro {
  159. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  160. ThinkingBudget: common.GetPointer(0),
  161. }
  162. }
  163. }
  164. }
  165. }
  166. // Setting safety to the lowest possible values since Gemini is already powerless enough
  167. func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
  168. geminiRequest := dto.GeminiChatRequest{
  169. Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
  170. GenerationConfig: dto.GeminiChatGenerationConfig{
  171. Temperature: textRequest.Temperature,
  172. TopP: textRequest.TopP,
  173. MaxOutputTokens: textRequest.GetMaxTokens(),
  174. Seed: int64(textRequest.Seed),
  175. },
  176. }
  177. if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
  178. geminiRequest.GenerationConfig.ResponseModalities = []string{
  179. "TEXT",
  180. "IMAGE",
  181. }
  182. }
  183. adaptorWithExtraBody := false
  184. if len(textRequest.ExtraBody) > 0 {
  185. if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  186. var extraBody map[string]interface{}
  187. if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
  188. return nil, fmt.Errorf("invalid extra body: %w", err)
  189. }
  190. // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
  191. if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
  192. adaptorWithExtraBody = true
  193. if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
  194. if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
  195. budgetInt := int(budget)
  196. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  197. ThinkingBudget: common.GetPointer(budgetInt),
  198. IncludeThoughts: true,
  199. }
  200. } else {
  201. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  202. IncludeThoughts: true,
  203. }
  204. }
  205. }
  206. }
  207. }
  208. }
  209. if !adaptorWithExtraBody {
  210. ThinkingAdaptor(&geminiRequest, info, textRequest)
  211. }
  212. safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
  213. for _, category := range SafetySettingList {
  214. safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
  215. Category: category,
  216. Threshold: model_setting.GetGeminiSafetySetting(category),
  217. })
  218. }
  219. geminiRequest.SafetySettings = safetySettings
  220. // openaiContent.FuncToToolCalls()
  221. if textRequest.Tools != nil {
  222. functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
  223. googleSearch := false
  224. codeExecution := false
  225. for _, tool := range textRequest.Tools {
  226. if tool.Function.Name == "googleSearch" {
  227. googleSearch = true
  228. continue
  229. }
  230. if tool.Function.Name == "codeExecution" {
  231. codeExecution = true
  232. continue
  233. }
  234. if tool.Function.Parameters != nil {
  235. params, ok := tool.Function.Parameters.(map[string]interface{})
  236. if ok {
  237. if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
  238. if len(props) == 0 {
  239. tool.Function.Parameters = nil
  240. }
  241. }
  242. }
  243. }
  244. // Clean the parameters before appending
  245. cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
  246. tool.Function.Parameters = cleanedParams
  247. functions = append(functions, tool.Function)
  248. }
  249. geminiTools := geminiRequest.GetTools()
  250. if codeExecution {
  251. geminiTools = append(geminiTools, dto.GeminiChatTool{
  252. CodeExecution: make(map[string]string),
  253. })
  254. }
  255. if googleSearch {
  256. geminiTools = append(geminiTools, dto.GeminiChatTool{
  257. GoogleSearch: make(map[string]string),
  258. })
  259. }
  260. if len(functions) > 0 {
  261. geminiTools = append(geminiTools, dto.GeminiChatTool{
  262. FunctionDeclarations: functions,
  263. })
  264. }
  265. geminiRequest.SetTools(geminiTools)
  266. }
  267. if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
  268. geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
  269. if len(textRequest.ResponseFormat.JsonSchema) > 0 {
  270. // 先将json.RawMessage解析
  271. var jsonSchema dto.FormatJsonSchema
  272. if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
  273. cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
  274. geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
  275. }
  276. }
  277. }
  278. tool_call_ids := make(map[string]string)
  279. var system_content []string
  280. //shouldAddDummyModelMessage := false
  281. for _, message := range textRequest.Messages {
  282. if message.Role == "system" {
  283. system_content = append(system_content, message.StringContent())
  284. continue
  285. } else if message.Role == "tool" || message.Role == "function" {
  286. if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
  287. geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
  288. Role: "user",
  289. })
  290. }
  291. var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
  292. name := ""
  293. if message.Name != nil {
  294. name = *message.Name
  295. } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
  296. name = val
  297. }
  298. var contentMap map[string]interface{}
  299. contentStr := message.StringContent()
  300. // 1. 尝试解析为 JSON 对象
  301. if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
  302. // 2. 如果失败,尝试解析为 JSON 数组
  303. var contentSlice []interface{}
  304. if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
  305. // 如果是数组,包装成对象
  306. contentMap = map[string]interface{}{"result": contentSlice}
  307. } else {
  308. // 3. 如果再次失败,作为纯文本处理
  309. contentMap = map[string]interface{}{"content": contentStr}
  310. }
  311. }
  312. functionResp := &dto.GeminiFunctionResponse{
  313. Name: name,
  314. Response: contentMap,
  315. }
  316. *parts = append(*parts, dto.GeminiPart{
  317. FunctionResponse: functionResp,
  318. })
  319. continue
  320. }
  321. var parts []dto.GeminiPart
  322. content := dto.GeminiChatContent{
  323. Role: message.Role,
  324. }
  325. // isToolCall := false
  326. if message.ToolCalls != nil {
  327. // message.Role = "model"
  328. // isToolCall = true
  329. for _, call := range message.ParseToolCalls() {
  330. args := map[string]interface{}{}
  331. if call.Function.Arguments != "" {
  332. if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
  333. return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
  334. }
  335. }
  336. toolCall := dto.GeminiPart{
  337. FunctionCall: &dto.FunctionCall{
  338. FunctionName: call.Function.Name,
  339. Arguments: args,
  340. },
  341. }
  342. parts = append(parts, toolCall)
  343. tool_call_ids[call.ID] = call.Function.Name
  344. }
  345. }
  346. openaiContent := message.ParseContent()
  347. imageNum := 0
  348. for _, part := range openaiContent {
  349. if part.Type == dto.ContentTypeText {
  350. if part.Text == "" {
  351. continue
  352. }
  353. parts = append(parts, dto.GeminiPart{
  354. Text: part.Text,
  355. })
  356. } else if part.Type == dto.ContentTypeImageURL {
  357. imageNum += 1
  358. if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
  359. return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
  360. }
  361. // 判断是否是url
  362. if strings.HasPrefix(part.GetImageMedia().Url, "http") {
  363. // 是url,获取文件的类型和base64编码的数据
  364. fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
  365. if err != nil {
  366. return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
  367. }
  368. // 校验 MimeType 是否在 Gemini 支持的白名单中
  369. if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
  370. url := part.GetImageMedia().Url
  371. return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
  372. }
  373. parts = append(parts, dto.GeminiPart{
  374. InlineData: &dto.GeminiInlineData{
  375. MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
  376. Data: fileData.Base64Data,
  377. },
  378. })
  379. } else {
  380. format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
  381. if err != nil {
  382. return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
  383. }
  384. parts = append(parts, dto.GeminiPart{
  385. InlineData: &dto.GeminiInlineData{
  386. MimeType: format,
  387. Data: base64String,
  388. },
  389. })
  390. }
  391. } else if part.Type == dto.ContentTypeFile {
  392. if part.GetFile().FileId != "" {
  393. return nil, fmt.Errorf("only base64 file is supported in gemini")
  394. }
  395. format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
  396. if err != nil {
  397. return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
  398. }
  399. parts = append(parts, dto.GeminiPart{
  400. InlineData: &dto.GeminiInlineData{
  401. MimeType: format,
  402. Data: base64String,
  403. },
  404. })
  405. } else if part.Type == dto.ContentTypeInputAudio {
  406. if part.GetInputAudio().Data == "" {
  407. return nil, fmt.Errorf("only base64 audio is supported in gemini")
  408. }
  409. base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
  410. if err != nil {
  411. return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
  412. }
  413. parts = append(parts, dto.GeminiPart{
  414. InlineData: &dto.GeminiInlineData{
  415. MimeType: "audio/" + part.GetInputAudio().Format,
  416. Data: base64String,
  417. },
  418. })
  419. }
  420. }
  421. content.Parts = parts
  422. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  423. if content.Role == "assistant" {
  424. content.Role = "model"
  425. }
  426. if len(content.Parts) > 0 {
  427. geminiRequest.Contents = append(geminiRequest.Contents, content)
  428. }
  429. }
  430. if len(system_content) > 0 {
  431. geminiRequest.SystemInstructions = &dto.GeminiChatContent{
  432. Parts: []dto.GeminiPart{
  433. {
  434. Text: strings.Join(system_content, "\n"),
  435. },
  436. },
  437. }
  438. }
  439. return &geminiRequest, nil
  440. }
  441. // Helper function to get a list of supported MIME types for error messages
  442. func getSupportedMimeTypesList() []string {
  443. keys := make([]string, 0, len(geminiSupportedMimeTypes))
  444. for k := range geminiSupportedMimeTypes {
  445. keys = append(keys, k)
  446. }
  447. return keys
  448. }
  449. // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
  450. func cleanFunctionParameters(params interface{}) interface{} {
  451. if params == nil {
  452. return nil
  453. }
  454. switch v := params.(type) {
  455. case map[string]interface{}:
  456. // Create a copy to avoid modifying the original
  457. cleanedMap := make(map[string]interface{})
  458. for k, val := range v {
  459. cleanedMap[k] = val
  460. }
  461. // Remove unsupported root-level fields
  462. delete(cleanedMap, "default")
  463. delete(cleanedMap, "exclusiveMaximum")
  464. delete(cleanedMap, "exclusiveMinimum")
  465. delete(cleanedMap, "$schema")
  466. delete(cleanedMap, "additionalProperties")
  467. // Check and clean 'format' for string types
  468. if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
  469. if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
  470. if formatValue != "enum" && formatValue != "date-time" {
  471. delete(cleanedMap, "format")
  472. }
  473. }
  474. }
  475. // Clean properties
  476. if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
  477. cleanedProps := make(map[string]interface{})
  478. for propName, propValue := range props {
  479. cleanedProps[propName] = cleanFunctionParameters(propValue)
  480. }
  481. cleanedMap["properties"] = cleanedProps
  482. }
  483. // Recursively clean items in arrays
  484. if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
  485. cleanedMap["items"] = cleanFunctionParameters(items)
  486. }
  487. // Also handle items if it's an array of schemas
  488. if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
  489. cleanedItemsArray := make([]interface{}, len(itemsArray))
  490. for i, item := range itemsArray {
  491. cleanedItemsArray[i] = cleanFunctionParameters(item)
  492. }
  493. cleanedMap["items"] = cleanedItemsArray
  494. }
  495. // Recursively clean other schema composition keywords
  496. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  497. if nested, ok := cleanedMap[field].([]interface{}); ok {
  498. cleanedNested := make([]interface{}, len(nested))
  499. for i, item := range nested {
  500. cleanedNested[i] = cleanFunctionParameters(item)
  501. }
  502. cleanedMap[field] = cleanedNested
  503. }
  504. }
  505. // Recursively clean patternProperties
  506. if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
  507. cleanedPatternProps := make(map[string]interface{})
  508. for pattern, schema := range patternProps {
  509. cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
  510. }
  511. cleanedMap["patternProperties"] = cleanedPatternProps
  512. }
  513. // Recursively clean definitions
  514. if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
  515. cleanedDefinitions := make(map[string]interface{})
  516. for defName, defSchema := range definitions {
  517. cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
  518. }
  519. cleanedMap["definitions"] = cleanedDefinitions
  520. }
  521. // Recursively clean $defs (newer JSON Schema draft)
  522. if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
  523. cleanedDefs := make(map[string]interface{})
  524. for defName, defSchema := range defs {
  525. cleanedDefs[defName] = cleanFunctionParameters(defSchema)
  526. }
  527. cleanedMap["$defs"] = cleanedDefs
  528. }
  529. // Clean conditional keywords
  530. for _, field := range []string{"if", "then", "else", "not"} {
  531. if nested, ok := cleanedMap[field]; ok {
  532. cleanedMap[field] = cleanFunctionParameters(nested)
  533. }
  534. }
  535. return cleanedMap
  536. case []interface{}:
  537. // Handle arrays of schemas
  538. cleanedArray := make([]interface{}, len(v))
  539. for i, item := range v {
  540. cleanedArray[i] = cleanFunctionParameters(item)
  541. }
  542. return cleanedArray
  543. default:
  544. // Not a map or array, return as is (e.g., could be a primitive)
  545. return params
  546. }
  547. }
  548. func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
  549. if depth >= 5 {
  550. return schema
  551. }
  552. v, ok := schema.(map[string]interface{})
  553. if !ok || len(v) == 0 {
  554. return schema
  555. }
  556. // 删除所有的title字段
  557. delete(v, "title")
  558. delete(v, "$schema")
  559. // 如果type不为object和array,则直接返回
  560. if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
  561. return schema
  562. }
  563. switch v["type"] {
  564. case "object":
  565. delete(v, "additionalProperties")
  566. // 处理 properties
  567. if properties, ok := v["properties"].(map[string]interface{}); ok {
  568. for key, value := range properties {
  569. properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
  570. }
  571. }
  572. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  573. if nested, ok := v[field].([]interface{}); ok {
  574. for i, item := range nested {
  575. nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
  576. }
  577. }
  578. }
  579. case "array":
  580. if items, ok := v["items"].(map[string]interface{}); ok {
  581. v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
  582. }
  583. }
  584. return v
  585. }
  586. func unescapeString(s string) (string, error) {
  587. var result []rune
  588. escaped := false
  589. i := 0
  590. for i < len(s) {
  591. r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
  592. if r == utf8.RuneError {
  593. return "", fmt.Errorf("invalid UTF-8 encoding")
  594. }
  595. if escaped {
  596. // 如果是转义符后的字符,检查其类型
  597. switch r {
  598. case '"':
  599. result = append(result, '"')
  600. case '\\':
  601. result = append(result, '\\')
  602. case '/':
  603. result = append(result, '/')
  604. case 'b':
  605. result = append(result, '\b')
  606. case 'f':
  607. result = append(result, '\f')
  608. case 'n':
  609. result = append(result, '\n')
  610. case 'r':
  611. result = append(result, '\r')
  612. case 't':
  613. result = append(result, '\t')
  614. case '\'':
  615. result = append(result, '\'')
  616. default:
  617. // 如果遇到一个非法的转义字符,直接按原样输出
  618. result = append(result, '\\', r)
  619. }
  620. escaped = false
  621. } else {
  622. if r == '\\' {
  623. escaped = true // 记录反斜杠作为转义符
  624. } else {
  625. result = append(result, r)
  626. }
  627. }
  628. i += size // 移动到下一个字符
  629. }
  630. return string(result), nil
  631. }
  632. func unescapeMapOrSlice(data interface{}) interface{} {
  633. switch v := data.(type) {
  634. case map[string]interface{}:
  635. for k, val := range v {
  636. v[k] = unescapeMapOrSlice(val)
  637. }
  638. case []interface{}:
  639. for i, val := range v {
  640. v[i] = unescapeMapOrSlice(val)
  641. }
  642. case string:
  643. if unescaped, err := unescapeString(v); err != nil {
  644. return v
  645. } else {
  646. return unescaped
  647. }
  648. }
  649. return data
  650. }
  651. func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
  652. var argsBytes []byte
  653. var err error
  654. if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
  655. argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
  656. } else {
  657. argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
  658. }
  659. if err != nil {
  660. return nil
  661. }
  662. return &dto.ToolCallResponse{
  663. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  664. Type: "function",
  665. Function: dto.FunctionResponse{
  666. Arguments: string(argsBytes),
  667. Name: item.FunctionCall.FunctionName,
  668. },
  669. }
  670. }
  671. func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
  672. fullTextResponse := dto.OpenAITextResponse{
  673. Id: helper.GetResponseID(c),
  674. Object: "chat.completion",
  675. Created: common.GetTimestamp(),
  676. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  677. }
  678. isToolCall := false
  679. for _, candidate := range response.Candidates {
  680. choice := dto.OpenAITextResponseChoice{
  681. Index: int(candidate.Index),
  682. Message: dto.Message{
  683. Role: "assistant",
  684. Content: "",
  685. },
  686. FinishReason: constant.FinishReasonStop,
  687. }
  688. if len(candidate.Content.Parts) > 0 {
  689. var texts []string
  690. var toolCalls []dto.ToolCallResponse
  691. for _, part := range candidate.Content.Parts {
  692. if part.InlineData != nil {
  693. // 媒体内容
  694. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  695. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  696. texts = append(texts, imgText)
  697. } else {
  698. // 其他媒体类型,直接显示链接
  699. texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
  700. }
  701. } else if part.FunctionCall != nil {
  702. choice.FinishReason = constant.FinishReasonToolCalls
  703. if call := getResponseToolCall(&part); call != nil {
  704. toolCalls = append(toolCalls, *call)
  705. }
  706. } else if part.Thought {
  707. choice.Message.ReasoningContent = part.Text
  708. } else {
  709. if part.ExecutableCode != nil {
  710. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
  711. } else if part.CodeExecutionResult != nil {
  712. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
  713. } else {
  714. // 过滤掉空行
  715. if part.Text != "\n" {
  716. texts = append(texts, part.Text)
  717. }
  718. }
  719. }
  720. }
  721. if len(toolCalls) > 0 {
  722. choice.Message.SetToolCalls(toolCalls)
  723. isToolCall = true
  724. }
  725. choice.Message.SetStringContent(strings.Join(texts, "\n"))
  726. }
  727. if candidate.FinishReason != nil {
  728. switch *candidate.FinishReason {
  729. case "STOP":
  730. choice.FinishReason = constant.FinishReasonStop
  731. case "MAX_TOKENS":
  732. choice.FinishReason = constant.FinishReasonLength
  733. default:
  734. choice.FinishReason = constant.FinishReasonContentFilter
  735. }
  736. }
  737. if isToolCall {
  738. choice.FinishReason = constant.FinishReasonToolCalls
  739. }
  740. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  741. }
  742. return &fullTextResponse
  743. }
  744. func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
  745. choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
  746. isStop := false
  747. for _, candidate := range geminiResponse.Candidates {
  748. if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
  749. isStop = true
  750. candidate.FinishReason = nil
  751. }
  752. choice := dto.ChatCompletionsStreamResponseChoice{
  753. Index: int(candidate.Index),
  754. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  755. //Role: "assistant",
  756. },
  757. }
  758. var texts []string
  759. isTools := false
  760. isThought := false
  761. if candidate.FinishReason != nil {
  762. // p := GeminiConvertFinishReason(*candidate.FinishReason)
  763. switch *candidate.FinishReason {
  764. case "STOP":
  765. choice.FinishReason = &constant.FinishReasonStop
  766. case "MAX_TOKENS":
  767. choice.FinishReason = &constant.FinishReasonLength
  768. default:
  769. choice.FinishReason = &constant.FinishReasonContentFilter
  770. }
  771. }
  772. for _, part := range candidate.Content.Parts {
  773. if part.InlineData != nil {
  774. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  775. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  776. texts = append(texts, imgText)
  777. }
  778. } else if part.FunctionCall != nil {
  779. isTools = true
  780. if call := getResponseToolCall(&part); call != nil {
  781. call.SetIndex(len(choice.Delta.ToolCalls))
  782. choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
  783. }
  784. } else if part.Thought {
  785. isThought = true
  786. texts = append(texts, part.Text)
  787. } else {
  788. if part.ExecutableCode != nil {
  789. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
  790. } else if part.CodeExecutionResult != nil {
  791. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
  792. } else {
  793. if part.Text != "\n" {
  794. texts = append(texts, part.Text)
  795. }
  796. }
  797. }
  798. }
  799. if isThought {
  800. choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
  801. } else {
  802. choice.Delta.SetContentString(strings.Join(texts, "\n"))
  803. }
  804. if isTools {
  805. choice.FinishReason = &constant.FinishReasonToolCalls
  806. }
  807. choices = append(choices, choice)
  808. }
  809. var response dto.ChatCompletionsStreamResponse
  810. response.Object = "chat.completion.chunk"
  811. response.Choices = choices
  812. return &response, isStop
  813. }
  814. func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  815. streamData, err := common.Marshal(resp)
  816. if err != nil {
  817. return fmt.Errorf("failed to marshal stream response: %w", err)
  818. }
  819. err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  820. if err != nil {
  821. return fmt.Errorf("failed to handle stream format: %w", err)
  822. }
  823. return nil
  824. }
  825. func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  826. streamData, err := common.Marshal(resp)
  827. if err != nil {
  828. return fmt.Errorf("failed to marshal stream response: %w", err)
  829. }
  830. openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
  831. return nil
  832. }
  833. func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  834. // responseText := ""
  835. id := helper.GetResponseID(c)
  836. createAt := common.GetTimestamp()
  837. responseText := strings.Builder{}
  838. var usage = &dto.Usage{}
  839. var imageCount int
  840. finishReason := constant.FinishReasonStop
  841. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  842. var geminiResponse dto.GeminiChatResponse
  843. err := common.UnmarshalJsonStr(data, &geminiResponse)
  844. if err != nil {
  845. logger.LogError(c, "error unmarshalling stream response: "+err.Error())
  846. return false
  847. }
  848. for _, candidate := range geminiResponse.Candidates {
  849. for _, part := range candidate.Content.Parts {
  850. if part.InlineData != nil && part.InlineData.MimeType != "" {
  851. imageCount++
  852. }
  853. if part.Text != "" {
  854. responseText.WriteString(part.Text)
  855. }
  856. }
  857. }
  858. response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
  859. response.Id = id
  860. response.Created = createAt
  861. response.Model = info.UpstreamModelName
  862. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  863. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  864. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
  865. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  866. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  867. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  868. if detail.Modality == "AUDIO" {
  869. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  870. } else if detail.Modality == "TEXT" {
  871. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  872. }
  873. }
  874. }
  875. logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
  876. if info.SendResponseCount == 0 {
  877. // send first response
  878. emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
  879. if response.IsToolCall() {
  880. emptyResponse.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 1)
  881. emptyResponse.Choices[0].Delta.ToolCalls[0] = *response.GetFirstToolCall()
  882. emptyResponse.Choices[0].Delta.ToolCalls[0].Function.Arguments = ""
  883. finishReason = constant.FinishReasonToolCalls
  884. err = handleStream(c, info, emptyResponse)
  885. if err != nil {
  886. logger.LogError(c, err.Error())
  887. }
  888. response.ClearToolCalls()
  889. if response.IsFinished() {
  890. response.Choices[0].FinishReason = nil
  891. }
  892. } else {
  893. err = handleStream(c, info, emptyResponse)
  894. if err != nil {
  895. logger.LogError(c, err.Error())
  896. }
  897. }
  898. }
  899. err = handleStream(c, info, response)
  900. if err != nil {
  901. logger.LogError(c, err.Error())
  902. }
  903. if isStop {
  904. _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
  905. }
  906. return true
  907. })
  908. if info.SendResponseCount == 0 {
  909. // 空补全,报错不计费
  910. // empty response, throw an error
  911. return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
  912. }
  913. if imageCount != 0 {
  914. if usage.CompletionTokens == 0 {
  915. usage.CompletionTokens = imageCount * 258
  916. }
  917. }
  918. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  919. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  920. if usage.CompletionTokens == 0 {
  921. str := responseText.String()
  922. if len(str) > 0 {
  923. usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
  924. } else {
  925. // 空补全,不需要使用量
  926. usage = &dto.Usage{}
  927. }
  928. }
  929. response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  930. err := handleFinalStream(c, info, response)
  931. if err != nil {
  932. common.SysLog("send final response failed: " + err.Error())
  933. }
  934. //if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  935. // helper.Done(c)
  936. //}
  937. //resp.Body.Close()
  938. return usage, nil
  939. }
  940. func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  941. responseBody, err := io.ReadAll(resp.Body)
  942. if err != nil {
  943. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  944. }
  945. service.CloseResponseBodyGracefully(resp)
  946. if common.DebugEnabled {
  947. println(string(responseBody))
  948. }
  949. var geminiResponse dto.GeminiChatResponse
  950. err = common.Unmarshal(responseBody, &geminiResponse)
  951. if err != nil {
  952. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  953. }
  954. if len(geminiResponse.Candidates) == 0 {
  955. return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  956. }
  957. fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
  958. fullTextResponse.Model = info.UpstreamModelName
  959. usage := dto.Usage{
  960. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  961. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  962. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  963. }
  964. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  965. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  966. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  967. if detail.Modality == "AUDIO" {
  968. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  969. } else if detail.Modality == "TEXT" {
  970. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  971. }
  972. }
  973. fullTextResponse.Usage = usage
  974. switch info.RelayFormat {
  975. case types.RelayFormatOpenAI:
  976. responseBody, err = common.Marshal(fullTextResponse)
  977. if err != nil {
  978. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  979. }
  980. case types.RelayFormatClaude:
  981. claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
  982. claudeRespStr, err := common.Marshal(claudeResp)
  983. if err != nil {
  984. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  985. }
  986. responseBody = claudeRespStr
  987. case types.RelayFormatGemini:
  988. break
  989. }
  990. service.IOCopyBytesGracefully(c, resp, responseBody)
  991. return &usage, nil
  992. }
  993. func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  994. defer service.CloseResponseBodyGracefully(resp)
  995. responseBody, readErr := io.ReadAll(resp.Body)
  996. if readErr != nil {
  997. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  998. }
  999. var geminiResponse dto.GeminiBatchEmbeddingResponse
  1000. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1001. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1002. }
  1003. // convert to openai format response
  1004. openAIResponse := dto.OpenAIEmbeddingResponse{
  1005. Object: "list",
  1006. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
  1007. Model: info.UpstreamModelName,
  1008. }
  1009. for i, embedding := range geminiResponse.Embeddings {
  1010. openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
  1011. Object: "embedding",
  1012. Embedding: embedding.Values,
  1013. Index: i,
  1014. })
  1015. }
  1016. // calculate usage
  1017. // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
  1018. // Google has not yet clarified how embedding models will be billed
  1019. // refer to openai billing method to use input tokens billing
  1020. // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
  1021. usage := &dto.Usage{
  1022. PromptTokens: info.PromptTokens,
  1023. CompletionTokens: 0,
  1024. TotalTokens: info.PromptTokens,
  1025. }
  1026. openAIResponse.Usage = *usage
  1027. jsonResponse, jsonErr := common.Marshal(openAIResponse)
  1028. if jsonErr != nil {
  1029. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1030. }
  1031. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  1032. return usage, nil
  1033. }
  1034. func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1035. responseBody, readErr := io.ReadAll(resp.Body)
  1036. if readErr != nil {
  1037. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1038. }
  1039. _ = resp.Body.Close()
  1040. var geminiResponse dto.GeminiImageResponse
  1041. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1042. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1043. }
  1044. if len(geminiResponse.Predictions) == 0 {
  1045. return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1046. }
  1047. // convert to openai format response
  1048. openAIResponse := dto.ImageResponse{
  1049. Created: common.GetTimestamp(),
  1050. Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
  1051. }
  1052. for _, prediction := range geminiResponse.Predictions {
  1053. if prediction.RaiFilteredReason != "" {
  1054. continue // skip filtered image
  1055. }
  1056. openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
  1057. B64Json: prediction.BytesBase64Encoded,
  1058. })
  1059. }
  1060. jsonResponse, jsonErr := json.Marshal(openAIResponse)
  1061. if jsonErr != nil {
  1062. return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
  1063. }
  1064. c.Writer.Header().Set("Content-Type", "application/json")
  1065. c.Writer.WriteHeader(resp.StatusCode)
  1066. _, _ = c.Writer.Write(jsonResponse)
  1067. // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
  1068. // each image has fixed 258 tokens
  1069. const imageTokens = 258
  1070. generatedImages := len(openAIResponse.Data)
  1071. usage := &dto.Usage{
  1072. PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
  1073. CompletionTokens: 0, // image generation does not calculate completion tokens
  1074. TotalTokens: imageTokens * generatedImages,
  1075. }
  1076. return usage, nil
  1077. }