relay-gemini.go 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214
  1. package gemini
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "unicode/utf8"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/logger"
  15. "github.com/QuantumNous/new-api/relay/channel/openai"
  16. relaycommon "github.com/QuantumNous/new-api/relay/common"
  17. "github.com/QuantumNous/new-api/relay/helper"
  18. "github.com/QuantumNous/new-api/service"
  19. "github.com/QuantumNous/new-api/setting/model_setting"
  20. "github.com/QuantumNous/new-api/types"
  21. "github.com/gin-gonic/gin"
  22. )
  23. // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
  24. var geminiSupportedMimeTypes = map[string]bool{
  25. "application/pdf": true,
  26. "audio/mpeg": true,
  27. "audio/mp3": true,
  28. "audio/wav": true,
  29. "image/png": true,
  30. "image/jpeg": true,
  31. "image/webp": true,
  32. "text/plain": true,
  33. "video/mov": true,
  34. "video/mpeg": true,
  35. "video/mp4": true,
  36. "video/mpg": true,
  37. "video/avi": true,
  38. "video/wmv": true,
  39. "video/mpegps": true,
  40. "video/flv": true,
  41. }
  42. // Gemini 允许的思考预算范围
  43. const (
  44. pro25MinBudget = 128
  45. pro25MaxBudget = 32768
  46. flash25MaxBudget = 24576
  47. flash25LiteMinBudget = 512
  48. flash25LiteMaxBudget = 24576
  49. )
  50. func isNew25ProModel(modelName string) bool {
  51. return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  52. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  53. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  54. }
  55. func is25FlashLiteModel(modelName string) bool {
  56. return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
  57. }
  58. // clampThinkingBudget 根据模型名称将预算限制在允许的范围内
  59. func clampThinkingBudget(modelName string, budget int) int {
  60. isNew25Pro := isNew25ProModel(modelName)
  61. is25FlashLite := is25FlashLiteModel(modelName)
  62. if is25FlashLite {
  63. if budget < flash25LiteMinBudget {
  64. return flash25LiteMinBudget
  65. }
  66. if budget > flash25LiteMaxBudget {
  67. return flash25LiteMaxBudget
  68. }
  69. } else if isNew25Pro {
  70. if budget < pro25MinBudget {
  71. return pro25MinBudget
  72. }
  73. if budget > pro25MaxBudget {
  74. return pro25MaxBudget
  75. }
  76. } else { // 其他模型
  77. if budget < 0 {
  78. return 0
  79. }
  80. if budget > flash25MaxBudget {
  81. return flash25MaxBudget
  82. }
  83. }
  84. return budget
  85. }
  86. // "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
  87. // "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
  88. // "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
  89. func clampThinkingBudgetByEffort(modelName string, effort string) int {
  90. isNew25Pro := isNew25ProModel(modelName)
  91. is25FlashLite := is25FlashLiteModel(modelName)
  92. maxBudget := 0
  93. if is25FlashLite {
  94. maxBudget = flash25LiteMaxBudget
  95. }
  96. if isNew25Pro {
  97. maxBudget = pro25MaxBudget
  98. } else {
  99. maxBudget = flash25MaxBudget
  100. }
  101. switch effort {
  102. case "high":
  103. maxBudget = maxBudget * 80 / 100
  104. case "medium":
  105. maxBudget = maxBudget * 50 / 100
  106. case "low":
  107. maxBudget = maxBudget * 20 / 100
  108. }
  109. return clampThinkingBudget(modelName, maxBudget)
  110. }
  111. func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
  112. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
  113. modelName := info.UpstreamModelName
  114. isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
  115. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
  116. !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
  117. if strings.Contains(modelName, "-thinking-") {
  118. parts := strings.SplitN(modelName, "-thinking-", 2)
  119. if len(parts) == 2 && parts[1] != "" {
  120. if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
  121. clampedBudget := clampThinkingBudget(modelName, budgetTokens)
  122. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  123. ThinkingBudget: common.GetPointer(clampedBudget),
  124. IncludeThoughts: true,
  125. }
  126. }
  127. }
  128. } else if strings.HasSuffix(modelName, "-thinking") {
  129. unsupportedModels := []string{
  130. "gemini-2.5-pro-preview-05-06",
  131. "gemini-2.5-pro-preview-03-25",
  132. }
  133. isUnsupported := false
  134. for _, unsupportedModel := range unsupportedModels {
  135. if strings.HasPrefix(modelName, unsupportedModel) {
  136. isUnsupported = true
  137. break
  138. }
  139. }
  140. if isUnsupported {
  141. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  142. IncludeThoughts: true,
  143. }
  144. } else {
  145. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  146. IncludeThoughts: true,
  147. }
  148. if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
  149. budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
  150. clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
  151. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
  152. } else {
  153. if len(oaiRequest) > 0 {
  154. // 如果有reasoningEffort参数,则根据其值设置思考预算
  155. geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
  156. }
  157. }
  158. }
  159. } else if strings.HasSuffix(modelName, "-nothinking") {
  160. if !isNew25Pro {
  161. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  162. ThinkingBudget: common.GetPointer(0),
  163. }
  164. }
  165. }
  166. }
  167. }
  168. // Setting safety to the lowest possible values since Gemini is already powerless enough
  169. func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
  170. geminiRequest := dto.GeminiChatRequest{
  171. Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
  172. GenerationConfig: dto.GeminiChatGenerationConfig{
  173. Temperature: textRequest.Temperature,
  174. TopP: textRequest.TopP,
  175. MaxOutputTokens: textRequest.GetMaxTokens(),
  176. Seed: int64(textRequest.Seed),
  177. },
  178. }
  179. if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
  180. geminiRequest.GenerationConfig.ResponseModalities = []string{
  181. "TEXT",
  182. "IMAGE",
  183. }
  184. }
  185. adaptorWithExtraBody := false
  186. if len(textRequest.ExtraBody) > 0 {
  187. if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  188. var extraBody map[string]interface{}
  189. if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
  190. return nil, fmt.Errorf("invalid extra body: %w", err)
  191. }
  192. // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
  193. if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
  194. adaptorWithExtraBody = true
  195. // check error param name like thinkingConfig, should be thinking_config
  196. if _, hasErrorParam := googleBody["thinkingConfig"]; hasErrorParam {
  197. return nil, errors.New("extra_body.google.thinkingConfig is not supported, use extra_body.google.thinking_config instead")
  198. }
  199. if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
  200. // check error param name like thinkingBudget, should be thinking_budget
  201. if _, hasErrorParam := thinkingConfig["thinkingBudget"]; hasErrorParam {
  202. return nil, errors.New("extra_body.google.thinking_config.thinkingBudget is not supported, use extra_body.google.thinking_config.thinking_budget instead")
  203. }
  204. if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
  205. budgetInt := int(budget)
  206. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  207. ThinkingBudget: common.GetPointer(budgetInt),
  208. IncludeThoughts: true,
  209. }
  210. } else {
  211. geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
  212. IncludeThoughts: true,
  213. }
  214. }
  215. }
  216. }
  217. }
  218. }
  219. if !adaptorWithExtraBody {
  220. ThinkingAdaptor(&geminiRequest, info, textRequest)
  221. }
  222. safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
  223. for _, category := range SafetySettingList {
  224. safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
  225. Category: category,
  226. Threshold: model_setting.GetGeminiSafetySetting(category),
  227. })
  228. }
  229. geminiRequest.SafetySettings = safetySettings
  230. // openaiContent.FuncToToolCalls()
  231. if textRequest.Tools != nil {
  232. functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
  233. googleSearch := false
  234. codeExecution := false
  235. urlContext := false
  236. for _, tool := range textRequest.Tools {
  237. if tool.Function.Name == "googleSearch" {
  238. googleSearch = true
  239. continue
  240. }
  241. if tool.Function.Name == "codeExecution" {
  242. codeExecution = true
  243. continue
  244. }
  245. if tool.Function.Name == "urlContext" {
  246. urlContext = true
  247. continue
  248. }
  249. if tool.Function.Parameters != nil {
  250. params, ok := tool.Function.Parameters.(map[string]interface{})
  251. if ok {
  252. if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
  253. if len(props) == 0 {
  254. tool.Function.Parameters = nil
  255. }
  256. }
  257. }
  258. }
  259. // Clean the parameters before appending
  260. cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
  261. tool.Function.Parameters = cleanedParams
  262. functions = append(functions, tool.Function)
  263. }
  264. geminiTools := geminiRequest.GetTools()
  265. if codeExecution {
  266. geminiTools = append(geminiTools, dto.GeminiChatTool{
  267. CodeExecution: make(map[string]string),
  268. })
  269. }
  270. if googleSearch {
  271. geminiTools = append(geminiTools, dto.GeminiChatTool{
  272. GoogleSearch: make(map[string]string),
  273. })
  274. }
  275. if urlContext {
  276. geminiTools = append(geminiTools, dto.GeminiChatTool{
  277. URLContext: make(map[string]string),
  278. })
  279. }
  280. if len(functions) > 0 {
  281. geminiTools = append(geminiTools, dto.GeminiChatTool{
  282. FunctionDeclarations: functions,
  283. })
  284. }
  285. geminiRequest.SetTools(geminiTools)
  286. }
  287. if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
  288. geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
  289. if len(textRequest.ResponseFormat.JsonSchema) > 0 {
  290. // 先将json.RawMessage解析
  291. var jsonSchema dto.FormatJsonSchema
  292. if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
  293. cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
  294. geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
  295. }
  296. }
  297. }
  298. tool_call_ids := make(map[string]string)
  299. var system_content []string
  300. //shouldAddDummyModelMessage := false
  301. for _, message := range textRequest.Messages {
  302. if message.Role == "system" {
  303. system_content = append(system_content, message.StringContent())
  304. continue
  305. } else if message.Role == "tool" || message.Role == "function" {
  306. if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
  307. geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
  308. Role: "user",
  309. })
  310. }
  311. var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
  312. name := ""
  313. if message.Name != nil {
  314. name = *message.Name
  315. } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
  316. name = val
  317. }
  318. var contentMap map[string]interface{}
  319. contentStr := message.StringContent()
  320. // 1. 尝试解析为 JSON 对象
  321. if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
  322. // 2. 如果失败,尝试解析为 JSON 数组
  323. var contentSlice []interface{}
  324. if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
  325. // 如果是数组,包装成对象
  326. contentMap = map[string]interface{}{"result": contentSlice}
  327. } else {
  328. // 3. 如果再次失败,作为纯文本处理
  329. contentMap = map[string]interface{}{"content": contentStr}
  330. }
  331. }
  332. functionResp := &dto.GeminiFunctionResponse{
  333. Name: name,
  334. Response: contentMap,
  335. }
  336. *parts = append(*parts, dto.GeminiPart{
  337. FunctionResponse: functionResp,
  338. })
  339. continue
  340. }
  341. var parts []dto.GeminiPart
  342. content := dto.GeminiChatContent{
  343. Role: message.Role,
  344. }
  345. // isToolCall := false
  346. if message.ToolCalls != nil {
  347. // message.Role = "model"
  348. // isToolCall = true
  349. for _, call := range message.ParseToolCalls() {
  350. args := map[string]interface{}{}
  351. if call.Function.Arguments != "" {
  352. if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
  353. return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
  354. }
  355. }
  356. toolCall := dto.GeminiPart{
  357. FunctionCall: &dto.FunctionCall{
  358. FunctionName: call.Function.Name,
  359. Arguments: args,
  360. },
  361. }
  362. parts = append(parts, toolCall)
  363. tool_call_ids[call.ID] = call.Function.Name
  364. }
  365. }
  366. openaiContent := message.ParseContent()
  367. imageNum := 0
  368. for _, part := range openaiContent {
  369. if part.Type == dto.ContentTypeText {
  370. if part.Text == "" {
  371. continue
  372. }
  373. parts = append(parts, dto.GeminiPart{
  374. Text: part.Text,
  375. })
  376. } else if part.Type == dto.ContentTypeImageURL {
  377. imageNum += 1
  378. if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
  379. return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
  380. }
  381. // 判断是否是url
  382. if strings.HasPrefix(part.GetImageMedia().Url, "http") {
  383. // 是url,获取文件的类型和base64编码的数据
  384. fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
  385. if err != nil {
  386. return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
  387. }
  388. // 校验 MimeType 是否在 Gemini 支持的白名单中
  389. if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
  390. url := part.GetImageMedia().Url
  391. return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
  392. }
  393. parts = append(parts, dto.GeminiPart{
  394. InlineData: &dto.GeminiInlineData{
  395. MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
  396. Data: fileData.Base64Data,
  397. },
  398. })
  399. } else {
  400. format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
  401. if err != nil {
  402. return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
  403. }
  404. parts = append(parts, dto.GeminiPart{
  405. InlineData: &dto.GeminiInlineData{
  406. MimeType: format,
  407. Data: base64String,
  408. },
  409. })
  410. }
  411. } else if part.Type == dto.ContentTypeFile {
  412. if part.GetFile().FileId != "" {
  413. return nil, fmt.Errorf("only base64 file is supported in gemini")
  414. }
  415. format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
  416. if err != nil {
  417. return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
  418. }
  419. parts = append(parts, dto.GeminiPart{
  420. InlineData: &dto.GeminiInlineData{
  421. MimeType: format,
  422. Data: base64String,
  423. },
  424. })
  425. } else if part.Type == dto.ContentTypeInputAudio {
  426. if part.GetInputAudio().Data == "" {
  427. return nil, fmt.Errorf("only base64 audio is supported in gemini")
  428. }
  429. base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
  430. if err != nil {
  431. return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
  432. }
  433. parts = append(parts, dto.GeminiPart{
  434. InlineData: &dto.GeminiInlineData{
  435. MimeType: "audio/" + part.GetInputAudio().Format,
  436. Data: base64String,
  437. },
  438. })
  439. }
  440. }
  441. content.Parts = parts
  442. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  443. if content.Role == "assistant" {
  444. content.Role = "model"
  445. }
  446. if len(content.Parts) > 0 {
  447. geminiRequest.Contents = append(geminiRequest.Contents, content)
  448. }
  449. }
  450. if len(system_content) > 0 {
  451. geminiRequest.SystemInstructions = &dto.GeminiChatContent{
  452. Parts: []dto.GeminiPart{
  453. {
  454. Text: strings.Join(system_content, "\n"),
  455. },
  456. },
  457. }
  458. }
  459. return &geminiRequest, nil
  460. }
  461. // Helper function to get a list of supported MIME types for error messages
  462. func getSupportedMimeTypesList() []string {
  463. keys := make([]string, 0, len(geminiSupportedMimeTypes))
  464. for k := range geminiSupportedMimeTypes {
  465. keys = append(keys, k)
  466. }
  467. return keys
  468. }
  469. // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
  470. func cleanFunctionParameters(params interface{}) interface{} {
  471. if params == nil {
  472. return nil
  473. }
  474. switch v := params.(type) {
  475. case map[string]interface{}:
  476. // Create a copy to avoid modifying the original
  477. cleanedMap := make(map[string]interface{})
  478. for k, val := range v {
  479. cleanedMap[k] = val
  480. }
  481. // Remove unsupported root-level fields
  482. delete(cleanedMap, "default")
  483. delete(cleanedMap, "exclusiveMaximum")
  484. delete(cleanedMap, "exclusiveMinimum")
  485. delete(cleanedMap, "$schema")
  486. delete(cleanedMap, "additionalProperties")
  487. // Check and clean 'format' for string types
  488. if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
  489. if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
  490. if formatValue != "enum" && formatValue != "date-time" {
  491. delete(cleanedMap, "format")
  492. }
  493. }
  494. }
  495. // Clean properties
  496. if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
  497. cleanedProps := make(map[string]interface{})
  498. for propName, propValue := range props {
  499. cleanedProps[propName] = cleanFunctionParameters(propValue)
  500. }
  501. cleanedMap["properties"] = cleanedProps
  502. }
  503. // Recursively clean items in arrays
  504. if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
  505. cleanedMap["items"] = cleanFunctionParameters(items)
  506. }
  507. // Also handle items if it's an array of schemas
  508. if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
  509. cleanedItemsArray := make([]interface{}, len(itemsArray))
  510. for i, item := range itemsArray {
  511. cleanedItemsArray[i] = cleanFunctionParameters(item)
  512. }
  513. cleanedMap["items"] = cleanedItemsArray
  514. }
  515. // Recursively clean other schema composition keywords
  516. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  517. if nested, ok := cleanedMap[field].([]interface{}); ok {
  518. cleanedNested := make([]interface{}, len(nested))
  519. for i, item := range nested {
  520. cleanedNested[i] = cleanFunctionParameters(item)
  521. }
  522. cleanedMap[field] = cleanedNested
  523. }
  524. }
  525. // Recursively clean patternProperties
  526. if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
  527. cleanedPatternProps := make(map[string]interface{})
  528. for pattern, schema := range patternProps {
  529. cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
  530. }
  531. cleanedMap["patternProperties"] = cleanedPatternProps
  532. }
  533. // Recursively clean definitions
  534. if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
  535. cleanedDefinitions := make(map[string]interface{})
  536. for defName, defSchema := range definitions {
  537. cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
  538. }
  539. cleanedMap["definitions"] = cleanedDefinitions
  540. }
  541. // Recursively clean $defs (newer JSON Schema draft)
  542. if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
  543. cleanedDefs := make(map[string]interface{})
  544. for defName, defSchema := range defs {
  545. cleanedDefs[defName] = cleanFunctionParameters(defSchema)
  546. }
  547. cleanedMap["$defs"] = cleanedDefs
  548. }
  549. // Clean conditional keywords
  550. for _, field := range []string{"if", "then", "else", "not"} {
  551. if nested, ok := cleanedMap[field]; ok {
  552. cleanedMap[field] = cleanFunctionParameters(nested)
  553. }
  554. }
  555. return cleanedMap
  556. case []interface{}:
  557. // Handle arrays of schemas
  558. cleanedArray := make([]interface{}, len(v))
  559. for i, item := range v {
  560. cleanedArray[i] = cleanFunctionParameters(item)
  561. }
  562. return cleanedArray
  563. default:
  564. // Not a map or array, return as is (e.g., could be a primitive)
  565. return params
  566. }
  567. }
  568. func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
  569. if depth >= 5 {
  570. return schema
  571. }
  572. v, ok := schema.(map[string]interface{})
  573. if !ok || len(v) == 0 {
  574. return schema
  575. }
  576. // 删除所有的title字段
  577. delete(v, "title")
  578. delete(v, "$schema")
  579. // 如果type不为object和array,则直接返回
  580. if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
  581. return schema
  582. }
  583. switch v["type"] {
  584. case "object":
  585. delete(v, "additionalProperties")
  586. // 处理 properties
  587. if properties, ok := v["properties"].(map[string]interface{}); ok {
  588. for key, value := range properties {
  589. properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
  590. }
  591. }
  592. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  593. if nested, ok := v[field].([]interface{}); ok {
  594. for i, item := range nested {
  595. nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
  596. }
  597. }
  598. }
  599. case "array":
  600. if items, ok := v["items"].(map[string]interface{}); ok {
  601. v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
  602. }
  603. }
  604. return v
  605. }
  606. func unescapeString(s string) (string, error) {
  607. var result []rune
  608. escaped := false
  609. i := 0
  610. for i < len(s) {
  611. r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
  612. if r == utf8.RuneError {
  613. return "", fmt.Errorf("invalid UTF-8 encoding")
  614. }
  615. if escaped {
  616. // 如果是转义符后的字符,检查其类型
  617. switch r {
  618. case '"':
  619. result = append(result, '"')
  620. case '\\':
  621. result = append(result, '\\')
  622. case '/':
  623. result = append(result, '/')
  624. case 'b':
  625. result = append(result, '\b')
  626. case 'f':
  627. result = append(result, '\f')
  628. case 'n':
  629. result = append(result, '\n')
  630. case 'r':
  631. result = append(result, '\r')
  632. case 't':
  633. result = append(result, '\t')
  634. case '\'':
  635. result = append(result, '\'')
  636. default:
  637. // 如果遇到一个非法的转义字符,直接按原样输出
  638. result = append(result, '\\', r)
  639. }
  640. escaped = false
  641. } else {
  642. if r == '\\' {
  643. escaped = true // 记录反斜杠作为转义符
  644. } else {
  645. result = append(result, r)
  646. }
  647. }
  648. i += size // 移动到下一个字符
  649. }
  650. return string(result), nil
  651. }
  652. func unescapeMapOrSlice(data interface{}) interface{} {
  653. switch v := data.(type) {
  654. case map[string]interface{}:
  655. for k, val := range v {
  656. v[k] = unescapeMapOrSlice(val)
  657. }
  658. case []interface{}:
  659. for i, val := range v {
  660. v[i] = unescapeMapOrSlice(val)
  661. }
  662. case string:
  663. if unescaped, err := unescapeString(v); err != nil {
  664. return v
  665. } else {
  666. return unescaped
  667. }
  668. }
  669. return data
  670. }
  671. func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
  672. var argsBytes []byte
  673. var err error
  674. if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
  675. argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
  676. } else {
  677. argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
  678. }
  679. if err != nil {
  680. return nil
  681. }
  682. return &dto.ToolCallResponse{
  683. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  684. Type: "function",
  685. Function: dto.FunctionResponse{
  686. Arguments: string(argsBytes),
  687. Name: item.FunctionCall.FunctionName,
  688. },
  689. }
  690. }
  691. func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
  692. fullTextResponse := dto.OpenAITextResponse{
  693. Id: helper.GetResponseID(c),
  694. Object: "chat.completion",
  695. Created: common.GetTimestamp(),
  696. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  697. }
  698. isToolCall := false
  699. for _, candidate := range response.Candidates {
  700. choice := dto.OpenAITextResponseChoice{
  701. Index: int(candidate.Index),
  702. Message: dto.Message{
  703. Role: "assistant",
  704. Content: "",
  705. },
  706. FinishReason: constant.FinishReasonStop,
  707. }
  708. if len(candidate.Content.Parts) > 0 {
  709. var texts []string
  710. var toolCalls []dto.ToolCallResponse
  711. for _, part := range candidate.Content.Parts {
  712. if part.InlineData != nil {
  713. // 媒体内容
  714. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  715. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  716. texts = append(texts, imgText)
  717. } else {
  718. // 其他媒体类型,直接显示链接
  719. texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
  720. }
  721. } else if part.FunctionCall != nil {
  722. choice.FinishReason = constant.FinishReasonToolCalls
  723. if call := getResponseToolCall(&part); call != nil {
  724. toolCalls = append(toolCalls, *call)
  725. }
  726. } else if part.Thought {
  727. choice.Message.ReasoningContent = part.Text
  728. } else {
  729. if part.ExecutableCode != nil {
  730. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
  731. } else if part.CodeExecutionResult != nil {
  732. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
  733. } else {
  734. // 过滤掉空行
  735. if part.Text != "\n" {
  736. texts = append(texts, part.Text)
  737. }
  738. }
  739. }
  740. }
  741. if len(toolCalls) > 0 {
  742. choice.Message.SetToolCalls(toolCalls)
  743. isToolCall = true
  744. }
  745. choice.Message.SetStringContent(strings.Join(texts, "\n"))
  746. }
  747. if candidate.FinishReason != nil {
  748. switch *candidate.FinishReason {
  749. case "STOP":
  750. choice.FinishReason = constant.FinishReasonStop
  751. case "MAX_TOKENS":
  752. choice.FinishReason = constant.FinishReasonLength
  753. default:
  754. choice.FinishReason = constant.FinishReasonContentFilter
  755. }
  756. }
  757. if isToolCall {
  758. choice.FinishReason = constant.FinishReasonToolCalls
  759. }
  760. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  761. }
  762. return &fullTextResponse
  763. }
  764. func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
  765. choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
  766. isStop := false
  767. for _, candidate := range geminiResponse.Candidates {
  768. if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
  769. isStop = true
  770. candidate.FinishReason = nil
  771. }
  772. choice := dto.ChatCompletionsStreamResponseChoice{
  773. Index: int(candidate.Index),
  774. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  775. //Role: "assistant",
  776. },
  777. }
  778. var texts []string
  779. isTools := false
  780. isThought := false
  781. if candidate.FinishReason != nil {
  782. // p := GeminiConvertFinishReason(*candidate.FinishReason)
  783. switch *candidate.FinishReason {
  784. case "STOP":
  785. choice.FinishReason = &constant.FinishReasonStop
  786. case "MAX_TOKENS":
  787. choice.FinishReason = &constant.FinishReasonLength
  788. default:
  789. choice.FinishReason = &constant.FinishReasonContentFilter
  790. }
  791. }
  792. for _, part := range candidate.Content.Parts {
  793. if part.InlineData != nil {
  794. if strings.HasPrefix(part.InlineData.MimeType, "image") {
  795. imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
  796. texts = append(texts, imgText)
  797. }
  798. } else if part.FunctionCall != nil {
  799. isTools = true
  800. if call := getResponseToolCall(&part); call != nil {
  801. call.SetIndex(len(choice.Delta.ToolCalls))
  802. choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
  803. }
  804. } else if part.Thought {
  805. isThought = true
  806. texts = append(texts, part.Text)
  807. } else {
  808. if part.ExecutableCode != nil {
  809. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
  810. } else if part.CodeExecutionResult != nil {
  811. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
  812. } else {
  813. if part.Text != "\n" {
  814. texts = append(texts, part.Text)
  815. }
  816. }
  817. }
  818. }
  819. if isThought {
  820. choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
  821. } else {
  822. choice.Delta.SetContentString(strings.Join(texts, "\n"))
  823. }
  824. if isTools {
  825. choice.FinishReason = &constant.FinishReasonToolCalls
  826. }
  827. choices = append(choices, choice)
  828. }
  829. var response dto.ChatCompletionsStreamResponse
  830. response.Object = "chat.completion.chunk"
  831. response.Choices = choices
  832. return &response, isStop
  833. }
  834. func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  835. streamData, err := common.Marshal(resp)
  836. if err != nil {
  837. return fmt.Errorf("failed to marshal stream response: %w", err)
  838. }
  839. err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  840. if err != nil {
  841. return fmt.Errorf("failed to handle stream format: %w", err)
  842. }
  843. return nil
  844. }
  845. func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
  846. streamData, err := common.Marshal(resp)
  847. if err != nil {
  848. return fmt.Errorf("failed to marshal stream response: %w", err)
  849. }
  850. openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
  851. return nil
  852. }
  853. func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  854. // responseText := ""
  855. id := helper.GetResponseID(c)
  856. createAt := common.GetTimestamp()
  857. responseText := strings.Builder{}
  858. var usage = &dto.Usage{}
  859. var imageCount int
  860. finishReason := constant.FinishReasonStop
  861. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  862. var geminiResponse dto.GeminiChatResponse
  863. err := common.UnmarshalJsonStr(data, &geminiResponse)
  864. if err != nil {
  865. logger.LogError(c, "error unmarshalling stream response: "+err.Error())
  866. return false
  867. }
  868. for _, candidate := range geminiResponse.Candidates {
  869. for _, part := range candidate.Content.Parts {
  870. if part.InlineData != nil && part.InlineData.MimeType != "" {
  871. imageCount++
  872. }
  873. if part.Text != "" {
  874. responseText.WriteString(part.Text)
  875. }
  876. }
  877. }
  878. response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
  879. response.Id = id
  880. response.Created = createAt
  881. response.Model = info.UpstreamModelName
  882. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  883. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  884. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
  885. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  886. usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
  887. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  888. if detail.Modality == "AUDIO" {
  889. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  890. } else if detail.Modality == "TEXT" {
  891. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  892. }
  893. }
  894. }
  895. logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
  896. if info.SendResponseCount == 0 {
  897. // send first response
  898. emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
  899. if response.IsToolCall() {
  900. if len(emptyResponse.Choices) > 0 && len(response.Choices) > 0 {
  901. toolCalls := response.Choices[0].Delta.ToolCalls
  902. copiedToolCalls := make([]dto.ToolCallResponse, len(toolCalls))
  903. for idx := range toolCalls {
  904. copiedToolCalls[idx] = toolCalls[idx]
  905. copiedToolCalls[idx].Function.Arguments = ""
  906. }
  907. emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
  908. }
  909. finishReason = constant.FinishReasonToolCalls
  910. err = handleStream(c, info, emptyResponse)
  911. if err != nil {
  912. logger.LogError(c, err.Error())
  913. }
  914. response.ClearToolCalls()
  915. if response.IsFinished() {
  916. response.Choices[0].FinishReason = nil
  917. }
  918. } else {
  919. err = handleStream(c, info, emptyResponse)
  920. if err != nil {
  921. logger.LogError(c, err.Error())
  922. }
  923. }
  924. }
  925. err = handleStream(c, info, response)
  926. if err != nil {
  927. logger.LogError(c, err.Error())
  928. }
  929. if isStop {
  930. _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
  931. }
  932. return true
  933. })
  934. if info.SendResponseCount == 0 {
  935. // 空补全,报错不计费
  936. // empty response, throw an error
  937. return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
  938. }
  939. if imageCount != 0 {
  940. if usage.CompletionTokens == 0 {
  941. usage.CompletionTokens = imageCount * 258
  942. }
  943. }
  944. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  945. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  946. if usage.CompletionTokens == 0 {
  947. str := responseText.String()
  948. if len(str) > 0 {
  949. usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
  950. } else {
  951. // 空补全,不需要使用量
  952. usage = &dto.Usage{}
  953. }
  954. }
  955. response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  956. err := handleFinalStream(c, info, response)
  957. if err != nil {
  958. common.SysLog("send final response failed: " + err.Error())
  959. }
  960. //if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  961. // helper.Done(c)
  962. //}
  963. //resp.Body.Close()
  964. return usage, nil
  965. }
  966. func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  967. responseBody, err := io.ReadAll(resp.Body)
  968. if err != nil {
  969. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  970. }
  971. service.CloseResponseBodyGracefully(resp)
  972. if common.DebugEnabled {
  973. println(string(responseBody))
  974. }
  975. var geminiResponse dto.GeminiChatResponse
  976. err = common.Unmarshal(responseBody, &geminiResponse)
  977. if err != nil {
  978. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  979. }
  980. if len(geminiResponse.Candidates) == 0 {
  981. //return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  982. //if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
  983. // return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
  984. //} else {
  985. // return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
  986. //}
  987. }
  988. fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
  989. fullTextResponse.Model = info.UpstreamModelName
  990. usage := dto.Usage{
  991. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  992. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  993. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  994. }
  995. usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
  996. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  997. for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
  998. if detail.Modality == "AUDIO" {
  999. usage.PromptTokensDetails.AudioTokens = detail.TokenCount
  1000. } else if detail.Modality == "TEXT" {
  1001. usage.PromptTokensDetails.TextTokens = detail.TokenCount
  1002. }
  1003. }
  1004. fullTextResponse.Usage = usage
  1005. switch info.RelayFormat {
  1006. case types.RelayFormatOpenAI:
  1007. responseBody, err = common.Marshal(fullTextResponse)
  1008. if err != nil {
  1009. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1010. }
  1011. case types.RelayFormatClaude:
  1012. claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
  1013. claudeRespStr, err := common.Marshal(claudeResp)
  1014. if err != nil {
  1015. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  1016. }
  1017. responseBody = claudeRespStr
  1018. case types.RelayFormatGemini:
  1019. break
  1020. }
  1021. service.IOCopyBytesGracefully(c, resp, responseBody)
  1022. return &usage, nil
  1023. }
  1024. func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1025. defer service.CloseResponseBodyGracefully(resp)
  1026. responseBody, readErr := io.ReadAll(resp.Body)
  1027. if readErr != nil {
  1028. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1029. }
  1030. var geminiResponse dto.GeminiBatchEmbeddingResponse
  1031. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1032. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1033. }
  1034. // convert to openai format response
  1035. openAIResponse := dto.OpenAIEmbeddingResponse{
  1036. Object: "list",
  1037. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
  1038. Model: info.UpstreamModelName,
  1039. }
  1040. for i, embedding := range geminiResponse.Embeddings {
  1041. openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
  1042. Object: "embedding",
  1043. Embedding: embedding.Values,
  1044. Index: i,
  1045. })
  1046. }
  1047. // calculate usage
  1048. // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
  1049. // Google has not yet clarified how embedding models will be billed
  1050. // refer to openai billing method to use input tokens billing
  1051. // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
  1052. usage := &dto.Usage{
  1053. PromptTokens: info.PromptTokens,
  1054. CompletionTokens: 0,
  1055. TotalTokens: info.PromptTokens,
  1056. }
  1057. openAIResponse.Usage = *usage
  1058. jsonResponse, jsonErr := common.Marshal(openAIResponse)
  1059. if jsonErr != nil {
  1060. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1061. }
  1062. service.IOCopyBytesGracefully(c, resp, jsonResponse)
  1063. return usage, nil
  1064. }
  1065. func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  1066. responseBody, readErr := io.ReadAll(resp.Body)
  1067. if readErr != nil {
  1068. return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1069. }
  1070. _ = resp.Body.Close()
  1071. var geminiResponse dto.GeminiImageResponse
  1072. if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  1073. return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1074. }
  1075. if len(geminiResponse.Predictions) == 0 {
  1076. return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  1077. }
  1078. // convert to openai format response
  1079. openAIResponse := dto.ImageResponse{
  1080. Created: common.GetTimestamp(),
  1081. Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
  1082. }
  1083. for _, prediction := range geminiResponse.Predictions {
  1084. if prediction.RaiFilteredReason != "" {
  1085. continue // skip filtered image
  1086. }
  1087. openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
  1088. B64Json: prediction.BytesBase64Encoded,
  1089. })
  1090. }
  1091. jsonResponse, jsonErr := json.Marshal(openAIResponse)
  1092. if jsonErr != nil {
  1093. return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
  1094. }
  1095. c.Writer.Header().Set("Content-Type", "application/json")
  1096. c.Writer.WriteHeader(resp.StatusCode)
  1097. _, _ = c.Writer.Write(jsonResponse)
  1098. // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
  1099. // each image has fixed 258 tokens
  1100. const imageTokens = 258
  1101. generatedImages := len(openAIResponse.Data)
  1102. usage := &dto.Usage{
  1103. PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
  1104. CompletionTokens: 0, // image generation does not calculate completion tokens
  1105. TotalTokens: imageTokens * generatedImages,
  1106. }
  1107. return usage, nil
  1108. }