relay-gemini.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. package gemini
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/dto"
  11. relaycommon "one-api/relay/common"
  12. "one-api/service"
  13. "strings"
  14. )
  15. // Setting safety to the lowest possible values since Gemini is already powerless enough
  16. func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
  17. geminiRequest := GeminiChatRequest{
  18. Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
  19. SafetySettings: []GeminiChatSafetySettings{
  20. {
  21. Category: "HARM_CATEGORY_HARASSMENT",
  22. Threshold: common.GeminiSafetySetting,
  23. },
  24. {
  25. Category: "HARM_CATEGORY_HATE_SPEECH",
  26. Threshold: common.GeminiSafetySetting,
  27. },
  28. {
  29. Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  30. Threshold: common.GeminiSafetySetting,
  31. },
  32. {
  33. Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  34. Threshold: common.GeminiSafetySetting,
  35. },
  36. },
  37. GenerationConfig: GeminiChatGenerationConfig{
  38. Temperature: textRequest.Temperature,
  39. TopP: textRequest.TopP,
  40. MaxOutputTokens: textRequest.MaxTokens,
  41. },
  42. }
  43. if textRequest.Tools != nil {
  44. functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
  45. for _, tool := range textRequest.Tools {
  46. functions = append(functions, tool.Function)
  47. }
  48. geminiRequest.Tools = []GeminiChatTools{
  49. {
  50. FunctionDeclarations: functions,
  51. },
  52. }
  53. } else if textRequest.Functions != nil {
  54. geminiRequest.Tools = []GeminiChatTools{
  55. {
  56. FunctionDeclarations: textRequest.Functions,
  57. },
  58. }
  59. }
  60. shouldAddDummyModelMessage := false
  61. for _, message := range textRequest.Messages {
  62. content := GeminiChatContent{
  63. Role: message.Role,
  64. Parts: []GeminiPart{
  65. {
  66. Text: message.StringContent(),
  67. },
  68. },
  69. }
  70. openaiContent := message.ParseContent()
  71. var parts []GeminiPart
  72. imageNum := 0
  73. for _, part := range openaiContent {
  74. if part.Type == dto.ContentTypeText {
  75. parts = append(parts, GeminiPart{
  76. Text: part.Text,
  77. })
  78. } else if part.Type == dto.ContentTypeImageURL {
  79. imageNum += 1
  80. if imageNum > GeminiVisionMaxImageNum {
  81. continue
  82. }
  83. // 判断是否是url
  84. if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
  85. // 是url,获取图片的类型和base64编码的数据
  86. mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
  87. parts = append(parts, GeminiPart{
  88. InlineData: &GeminiInlineData{
  89. MimeType: mimeType,
  90. Data: data,
  91. },
  92. })
  93. } else {
  94. _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
  95. if err != nil {
  96. continue
  97. }
  98. parts = append(parts, GeminiPart{
  99. InlineData: &GeminiInlineData{
  100. MimeType: "image/" + format,
  101. Data: base64String,
  102. },
  103. })
  104. }
  105. }
  106. }
  107. content.Parts = parts
  108. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  109. if content.Role == "assistant" {
  110. content.Role = "model"
  111. }
  112. // Converting system prompt to prompt from user for the same reason
  113. if content.Role == "system" {
  114. content.Role = "user"
  115. shouldAddDummyModelMessage = true
  116. }
  117. geminiRequest.Contents = append(geminiRequest.Contents, content)
  118. // If a system message is the last message, we need to add a dummy model message to make gemini happy
  119. if shouldAddDummyModelMessage {
  120. geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
  121. Role: "model",
  122. Parts: []GeminiPart{
  123. {
  124. Text: "Okay",
  125. },
  126. },
  127. })
  128. shouldAddDummyModelMessage = false
  129. }
  130. }
  131. return &geminiRequest
  132. }
  133. func (g *GeminiChatResponse) GetResponseText() string {
  134. if g == nil {
  135. return ""
  136. }
  137. if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
  138. return g.Candidates[0].Content.Parts[0].Text
  139. }
  140. return ""
  141. }
  142. func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
  143. var toolCalls []dto.ToolCall
  144. item := candidate.Content.Parts[0]
  145. if item.FunctionCall == nil {
  146. return toolCalls
  147. }
  148. argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
  149. if err != nil {
  150. //common.SysError("getToolCalls failed: " + err.Error())
  151. return toolCalls
  152. }
  153. toolCall := dto.ToolCall{
  154. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  155. Type: "function",
  156. Function: dto.FunctionCall{
  157. Arguments: string(argsBytes),
  158. Name: item.FunctionCall.FunctionName,
  159. },
  160. }
  161. toolCalls = append(toolCalls, toolCall)
  162. return toolCalls
  163. }
  164. func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
  165. fullTextResponse := dto.OpenAITextResponse{
  166. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  167. Object: "chat.completion",
  168. Created: common.GetTimestamp(),
  169. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  170. }
  171. content, _ := json.Marshal("")
  172. for i, candidate := range response.Candidates {
  173. choice := dto.OpenAITextResponseChoice{
  174. Index: i,
  175. Message: dto.Message{
  176. Role: "assistant",
  177. Content: content,
  178. },
  179. FinishReason: relaycommon.StopFinishReason,
  180. }
  181. if len(candidate.Content.Parts) > 0 {
  182. if candidate.Content.Parts[0].FunctionCall != nil {
  183. choice.Message.ToolCalls = getToolCalls(&candidate)
  184. } else {
  185. choice.Message.SetStringContent(candidate.Content.Parts[0].Text)
  186. }
  187. }
  188. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  189. }
  190. return &fullTextResponse
  191. }
  192. func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
  193. var choice dto.ChatCompletionsStreamResponseChoice
  194. //choice.Delta.SetContentString(geminiResponse.GetResponseText())
  195. if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
  196. respFirst := geminiResponse.Candidates[0].Content.Parts[0]
  197. if respFirst.FunctionCall != nil {
  198. // function response
  199. choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
  200. } else {
  201. // text response
  202. choice.Delta.SetContentString(respFirst.Text)
  203. }
  204. }
  205. var response dto.ChatCompletionsStreamResponse
  206. response.Object = "chat.completion.chunk"
  207. response.Model = "gemini"
  208. response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
  209. return &response
  210. }
  211. func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  212. responseText := ""
  213. id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  214. createAt := common.GetTimestamp()
  215. var usage = &dto.Usage{}
  216. scanner := bufio.NewScanner(resp.Body)
  217. scanner.Split(bufio.ScanLines)
  218. service.SetEventStreamHeaders(c)
  219. for scanner.Scan() {
  220. data := scanner.Text()
  221. info.SetFirstResponseTime()
  222. data = strings.TrimSpace(data)
  223. if !strings.HasPrefix(data, "data: ") {
  224. continue
  225. }
  226. data = strings.TrimPrefix(data, "data: ")
  227. data = strings.TrimSuffix(data, "\"")
  228. var geminiResponse GeminiChatResponse
  229. err := json.Unmarshal([]byte(data), &geminiResponse)
  230. if err != nil {
  231. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  232. continue
  233. }
  234. response := streamResponseGeminiChat2OpenAI(&geminiResponse)
  235. if response == nil {
  236. continue
  237. }
  238. response.Id = id
  239. response.Created = createAt
  240. responseText += response.Choices[0].Delta.GetContentString()
  241. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  242. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  243. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
  244. }
  245. err = service.ObjectData(c, response)
  246. if err != nil {
  247. common.LogError(c, err.Error())
  248. }
  249. }
  250. response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, relaycommon.StopFinishReason)
  251. service.ObjectData(c, response)
  252. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  253. if info.ShouldIncludeUsage {
  254. response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  255. err := service.ObjectData(c, response)
  256. if err != nil {
  257. common.SysError("send final response failed: " + err.Error())
  258. }
  259. }
  260. service.Done(c)
  261. resp.Body.Close()
  262. return nil, usage
  263. }
  264. func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  265. responseBody, err := io.ReadAll(resp.Body)
  266. if err != nil {
  267. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  268. }
  269. err = resp.Body.Close()
  270. if err != nil {
  271. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  272. }
  273. var geminiResponse GeminiChatResponse
  274. err = json.Unmarshal(responseBody, &geminiResponse)
  275. if err != nil {
  276. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  277. }
  278. if len(geminiResponse.Candidates) == 0 {
  279. return &dto.OpenAIErrorWithStatusCode{
  280. Error: dto.OpenAIError{
  281. Message: "No candidates returned",
  282. Type: "server_error",
  283. Param: "",
  284. Code: 500,
  285. },
  286. StatusCode: resp.StatusCode,
  287. }, nil
  288. }
  289. fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
  290. usage := dto.Usage{
  291. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  292. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  293. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  294. }
  295. fullTextResponse.Usage = usage
  296. jsonResponse, err := json.Marshal(fullTextResponse)
  297. if err != nil {
  298. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  299. }
  300. c.Writer.Header().Set("Content-Type", "application/json")
  301. c.Writer.WriteHeader(resp.StatusCode)
  302. _, err = c.Writer.Write(jsonResponse)
  303. return nil, &usage
  304. }