relay-gemini.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. package gemini
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/constant"
  10. "one-api/dto"
  11. relaycommon "one-api/relay/common"
  12. "one-api/service"
  13. "strings"
  14. "github.com/gin-gonic/gin"
  15. )
  16. // Setting safety to the lowest possible values since Gemini is already powerless enough
  17. func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
  18. geminiRequest := GeminiChatRequest{
  19. Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
  20. SafetySettings: []GeminiChatSafetySettings{
  21. {
  22. Category: "HARM_CATEGORY_HARASSMENT",
  23. Threshold: common.GeminiSafetySetting,
  24. },
  25. {
  26. Category: "HARM_CATEGORY_HATE_SPEECH",
  27. Threshold: common.GeminiSafetySetting,
  28. },
  29. {
  30. Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  31. Threshold: common.GeminiSafetySetting,
  32. },
  33. {
  34. Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  35. Threshold: common.GeminiSafetySetting,
  36. },
  37. {
  38. Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
  39. Threshold: common.GeminiSafetySetting,
  40. },
  41. },
  42. GenerationConfig: GeminiChatGenerationConfig{
  43. Temperature: textRequest.Temperature,
  44. TopP: textRequest.TopP,
  45. MaxOutputTokens: textRequest.MaxTokens,
  46. Seed: int64(textRequest.Seed),
  47. },
  48. }
  49. // openaiContent.FuncToToolCalls()
  50. if textRequest.Tools != nil {
  51. functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
  52. googleSearch := false
  53. codeExecution := false
  54. for _, tool := range textRequest.Tools {
  55. if tool.Function.Name == "googleSearch" {
  56. googleSearch = true
  57. continue
  58. }
  59. if tool.Function.Name == "codeExecution" {
  60. codeExecution = true
  61. continue
  62. }
  63. if tool.Function.Parameters != nil {
  64. params, ok := tool.Function.Parameters.(map[string]interface{})
  65. if ok {
  66. if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
  67. if len(props) == 0 {
  68. tool.Function.Parameters = nil
  69. }
  70. }
  71. }
  72. }
  73. functions = append(functions, tool.Function)
  74. }
  75. if codeExecution {
  76. geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
  77. CodeExecution: make(map[string]string),
  78. })
  79. }
  80. if googleSearch {
  81. geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
  82. GoogleSearch: make(map[string]string),
  83. })
  84. }
  85. if len(functions) > 0 {
  86. geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
  87. FunctionDeclarations: functions,
  88. })
  89. }
  90. // common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
  91. // json_data, _ := json.Marshal(geminiRequest.Tools)
  92. // common.SysLog("tools_json: " + string(json_data))
  93. } else if textRequest.Functions != nil {
  94. geminiRequest.Tools = []GeminiChatTool{
  95. {
  96. FunctionDeclarations: textRequest.Functions,
  97. },
  98. }
  99. }
  100. if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
  101. geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
  102. if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
  103. cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
  104. geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
  105. }
  106. }
  107. tool_call_ids := make(map[string]string)
  108. var system_content []string
  109. //shouldAddDummyModelMessage := false
  110. for _, message := range textRequest.Messages {
  111. if message.Role == "system" {
  112. system_content = append(system_content, message.StringContent())
  113. continue
  114. } else if message.Role == "tool" || message.Role == "function" {
  115. if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
  116. geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
  117. Role: "user",
  118. })
  119. }
  120. var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
  121. name := ""
  122. if message.Name != nil {
  123. name = *message.Name
  124. } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
  125. name = val
  126. }
  127. content := common.StrToMap(message.StringContent())
  128. functionResp := &FunctionResponse{
  129. Name: name,
  130. Response: GeminiFunctionResponseContent{
  131. Name: name,
  132. Content: content,
  133. },
  134. }
  135. if content == nil {
  136. functionResp.Response.Content = message.StringContent()
  137. }
  138. *parts = append(*parts, GeminiPart{
  139. FunctionResponse: functionResp,
  140. })
  141. continue
  142. }
  143. var parts []GeminiPart
  144. content := GeminiChatContent{
  145. Role: message.Role,
  146. }
  147. // isToolCall := false
  148. if message.ToolCalls != nil {
  149. // message.Role = "model"
  150. // isToolCall = true
  151. for _, call := range message.ParseToolCalls() {
  152. args := map[string]interface{}{}
  153. if call.Function.Arguments != "" {
  154. if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
  155. return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
  156. }
  157. }
  158. toolCall := GeminiPart{
  159. FunctionCall: &FunctionCall{
  160. FunctionName: call.Function.Name,
  161. Arguments: args,
  162. },
  163. }
  164. parts = append(parts, toolCall)
  165. tool_call_ids[call.ID] = call.Function.Name
  166. }
  167. }
  168. openaiContent := message.ParseContent()
  169. imageNum := 0
  170. for _, part := range openaiContent {
  171. if part.Type == dto.ContentTypeText {
  172. if part.Text == "" {
  173. continue
  174. }
  175. parts = append(parts, GeminiPart{
  176. Text: part.Text,
  177. })
  178. } else if part.Type == dto.ContentTypeImageURL {
  179. imageNum += 1
  180. if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
  181. return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
  182. }
  183. // 判断是否是url
  184. if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
  185. // 是url,获取图片的类型和base64编码的数据
  186. mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
  187. parts = append(parts, GeminiPart{
  188. InlineData: &GeminiInlineData{
  189. MimeType: mimeType,
  190. Data: data,
  191. },
  192. })
  193. } else {
  194. _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
  195. if err != nil {
  196. return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
  197. }
  198. parts = append(parts, GeminiPart{
  199. InlineData: &GeminiInlineData{
  200. MimeType: "image/" + format,
  201. Data: base64String,
  202. },
  203. })
  204. }
  205. }
  206. }
  207. content.Parts = parts
  208. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  209. if content.Role == "assistant" {
  210. content.Role = "model"
  211. }
  212. geminiRequest.Contents = append(geminiRequest.Contents, content)
  213. }
  214. if len(system_content) > 0 {
  215. geminiRequest.SystemInstructions = &GeminiChatContent{
  216. Parts: []GeminiPart{
  217. {
  218. Text: strings.Join(system_content, "\n"),
  219. },
  220. },
  221. }
  222. }
  223. return &geminiRequest, nil
  224. }
  225. func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
  226. if depth >= 5 {
  227. return schema
  228. }
  229. v, ok := schema.(map[string]interface{})
  230. if !ok || len(v) == 0 {
  231. return schema
  232. }
  233. // 删除所有的title字段
  234. delete(v, "title")
  235. // 如果type不为object和array,则直接返回
  236. if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
  237. return schema
  238. }
  239. switch v["type"] {
  240. case "object":
  241. delete(v, "additionalProperties")
  242. // 处理 properties
  243. if properties, ok := v["properties"].(map[string]interface{}); ok {
  244. for key, value := range properties {
  245. properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
  246. }
  247. }
  248. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  249. if nested, ok := v[field].([]interface{}); ok {
  250. for i, item := range nested {
  251. nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
  252. }
  253. }
  254. }
  255. case "array":
  256. if items, ok := v["items"].(map[string]interface{}); ok {
  257. v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
  258. }
  259. }
  260. return v
  261. }
  262. // func (g *GeminiChatResponse) GetResponseText() string {
  263. // if g == nil {
  264. // return ""
  265. // }
  266. // if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
  267. // return g.Candidates[0].Content.Parts[0].Text
  268. // }
  269. // return ""
  270. // }
  271. func getToolCall(item *GeminiPart) *dto.ToolCall {
  272. argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
  273. if err != nil {
  274. //common.SysError("getToolCall failed: " + err.Error())
  275. return nil
  276. }
  277. return &dto.ToolCall{
  278. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  279. Type: "function",
  280. Function: dto.FunctionCall{
  281. // 不好评价,得去转义一下反斜杠,Gemini 的特性好像是,Google 返回的时候本身就会转义“\”
  282. Arguments: strings.ReplaceAll(string(argsBytes), "\\\\", "\\"),
  283. Name: item.FunctionCall.FunctionName,
  284. },
  285. }
  286. }
  287. // func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
  288. // var toolCalls []dto.ToolCall
  289. // item := candidate.Content.Parts[index]
  290. // if item.FunctionCall == nil {
  291. // return toolCalls
  292. // }
  293. // argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
  294. // if err != nil {
  295. // //common.SysError("getToolCalls failed: " + err.Error())
  296. // return toolCalls
  297. // }
  298. // toolCall := dto.ToolCall{
  299. // ID: fmt.Sprintf("call_%s", common.GetUUID()),
  300. // Type: "function",
  301. // Function: dto.FunctionCall{
  302. // Arguments: string(argsBytes),
  303. // Name: item.FunctionCall.FunctionName,
  304. // },
  305. // }
  306. // toolCalls = append(toolCalls, toolCall)
  307. // return toolCalls
  308. // }
  309. func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
  310. fullTextResponse := dto.OpenAITextResponse{
  311. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  312. Object: "chat.completion",
  313. Created: common.GetTimestamp(),
  314. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  315. }
  316. content, _ := json.Marshal("")
  317. is_tool_call := false
  318. for _, candidate := range response.Candidates {
  319. choice := dto.OpenAITextResponseChoice{
  320. Index: int(candidate.Index),
  321. Message: dto.Message{
  322. Role: "assistant",
  323. Content: content,
  324. },
  325. FinishReason: constant.FinishReasonStop,
  326. }
  327. if len(candidate.Content.Parts) > 0 {
  328. var texts []string
  329. var tool_calls []dto.ToolCall
  330. for _, part := range candidate.Content.Parts {
  331. if part.FunctionCall != nil {
  332. choice.FinishReason = constant.FinishReasonToolCalls
  333. if call := getToolCall(&part); call != nil {
  334. tool_calls = append(tool_calls, *call)
  335. }
  336. } else {
  337. if part.ExecutableCode != nil {
  338. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
  339. } else if part.CodeExecutionResult != nil {
  340. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
  341. } else {
  342. // 过滤掉空行
  343. if part.Text != "\n" {
  344. texts = append(texts, part.Text)
  345. }
  346. }
  347. }
  348. }
  349. if len(tool_calls) > 0 {
  350. choice.Message.SetToolCalls(tool_calls)
  351. is_tool_call = true
  352. }
  353. choice.Message.SetStringContent(strings.Join(texts, "\n"))
  354. }
  355. if candidate.FinishReason != nil {
  356. switch *candidate.FinishReason {
  357. case "STOP":
  358. choice.FinishReason = constant.FinishReasonStop
  359. case "MAX_TOKENS":
  360. choice.FinishReason = constant.FinishReasonLength
  361. default:
  362. choice.FinishReason = constant.FinishReasonContentFilter
  363. }
  364. }
  365. if is_tool_call {
  366. choice.FinishReason = constant.FinishReasonToolCalls
  367. }
  368. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  369. }
  370. return &fullTextResponse
  371. }
  372. func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
  373. choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
  374. is_stop := false
  375. for _, candidate := range geminiResponse.Candidates {
  376. if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
  377. is_stop = true
  378. candidate.FinishReason = nil
  379. }
  380. choice := dto.ChatCompletionsStreamResponseChoice{
  381. Index: int(candidate.Index),
  382. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  383. Role: "assistant",
  384. },
  385. }
  386. var texts []string
  387. isTools := false
  388. if candidate.FinishReason != nil {
  389. // p := GeminiConvertFinishReason(*candidate.FinishReason)
  390. switch *candidate.FinishReason {
  391. case "STOP":
  392. choice.FinishReason = &constant.FinishReasonStop
  393. case "MAX_TOKENS":
  394. choice.FinishReason = &constant.FinishReasonLength
  395. default:
  396. choice.FinishReason = &constant.FinishReasonContentFilter
  397. }
  398. }
  399. for _, part := range candidate.Content.Parts {
  400. if part.FunctionCall != nil {
  401. isTools = true
  402. if call := getToolCall(&part); call != nil {
  403. call.SetIndex(len(choice.Delta.ToolCalls))
  404. choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
  405. }
  406. } else {
  407. if part.ExecutableCode != nil {
  408. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
  409. } else if part.CodeExecutionResult != nil {
  410. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
  411. } else {
  412. if part.Text != "\n" {
  413. texts = append(texts, part.Text)
  414. }
  415. }
  416. }
  417. }
  418. choice.Delta.SetContentString(strings.Join(texts, "\n"))
  419. if isTools {
  420. choice.FinishReason = &constant.FinishReasonToolCalls
  421. }
  422. choices = append(choices, choice)
  423. }
  424. var response dto.ChatCompletionsStreamResponse
  425. response.Object = "chat.completion.chunk"
  426. response.Model = "gemini"
  427. response.Choices = choices
  428. return &response, is_stop
  429. }
  430. func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  431. // responseText := ""
  432. id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  433. createAt := common.GetTimestamp()
  434. var usage = &dto.Usage{}
  435. scanner := bufio.NewScanner(resp.Body)
  436. scanner.Split(bufio.ScanLines)
  437. service.SetEventStreamHeaders(c)
  438. for scanner.Scan() {
  439. data := scanner.Text()
  440. info.SetFirstResponseTime()
  441. data = strings.TrimSpace(data)
  442. if !strings.HasPrefix(data, "data: ") {
  443. continue
  444. }
  445. data = strings.TrimPrefix(data, "data: ")
  446. data = strings.TrimSuffix(data, "\"")
  447. var geminiResponse GeminiChatResponse
  448. err := json.Unmarshal([]byte(data), &geminiResponse)
  449. if err != nil {
  450. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  451. continue
  452. }
  453. response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
  454. response.Id = id
  455. response.Created = createAt
  456. response.Model = info.UpstreamModelName
  457. // responseText += response.Choices[0].Delta.GetContentString()
  458. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  459. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  460. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
  461. }
  462. err = service.ObjectData(c, response)
  463. if err != nil {
  464. common.LogError(c, err.Error())
  465. }
  466. if is_stop {
  467. response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
  468. service.ObjectData(c, response)
  469. }
  470. }
  471. var response *dto.ChatCompletionsStreamResponse
  472. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  473. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  474. usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
  475. if info.ShouldIncludeUsage {
  476. response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  477. err := service.ObjectData(c, response)
  478. if err != nil {
  479. common.SysError("send final response failed: " + err.Error())
  480. }
  481. }
  482. service.Done(c)
  483. resp.Body.Close()
  484. return nil, usage
  485. }
  486. func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  487. responseBody, err := io.ReadAll(resp.Body)
  488. if err != nil {
  489. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  490. }
  491. err = resp.Body.Close()
  492. if err != nil {
  493. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  494. }
  495. var geminiResponse GeminiChatResponse
  496. err = json.Unmarshal(responseBody, &geminiResponse)
  497. if err != nil {
  498. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  499. }
  500. if len(geminiResponse.Candidates) == 0 {
  501. return &dto.OpenAIErrorWithStatusCode{
  502. Error: dto.OpenAIError{
  503. Message: "No candidates returned",
  504. Type: "server_error",
  505. Param: "",
  506. Code: 500,
  507. },
  508. StatusCode: resp.StatusCode,
  509. }, nil
  510. }
  511. fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
  512. fullTextResponse.Model = info.UpstreamModelName
  513. usage := dto.Usage{
  514. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  515. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  516. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  517. }
  518. fullTextResponse.Usage = usage
  519. jsonResponse, err := json.Marshal(fullTextResponse)
  520. if err != nil {
  521. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  522. }
  523. c.Writer.Header().Set("Content-Type", "application/json")
  524. c.Writer.WriteHeader(resp.StatusCode)
  525. _, err = c.Writer.Write(jsonResponse)
  526. return nil, &usage
  527. }