relay-gemini.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. package gemini
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/service"
  12. "strings"
  13. "github.com/gin-gonic/gin"
  14. )
  15. // Setting safety to the lowest possible values since Gemini is already powerless enough
  16. func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
  17. geminiRequest := GeminiChatRequest{
  18. Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
  19. SafetySettings: []GeminiChatSafetySettings{
  20. {
  21. Category: "HARM_CATEGORY_HARASSMENT",
  22. Threshold: common.GeminiSafetySetting,
  23. },
  24. {
  25. Category: "HARM_CATEGORY_HATE_SPEECH",
  26. Threshold: common.GeminiSafetySetting,
  27. },
  28. {
  29. Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  30. Threshold: common.GeminiSafetySetting,
  31. },
  32. {
  33. Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  34. Threshold: common.GeminiSafetySetting,
  35. },
  36. },
  37. GenerationConfig: GeminiChatGenerationConfig{
  38. Temperature: textRequest.Temperature,
  39. TopP: textRequest.TopP,
  40. MaxOutputTokens: textRequest.MaxTokens,
  41. },
  42. }
  43. if textRequest.Functions != nil {
  44. geminiRequest.Tools = []GeminiChatTools{
  45. {
  46. FunctionDeclarations: textRequest.Functions,
  47. },
  48. }
  49. }
  50. shouldAddDummyModelMessage := false
  51. for _, message := range textRequest.Messages {
  52. content := GeminiChatContent{
  53. Role: message.Role,
  54. Parts: []GeminiPart{
  55. {
  56. Text: message.StringContent(),
  57. },
  58. },
  59. }
  60. openaiContent := message.ParseContent()
  61. var parts []GeminiPart
  62. imageNum := 0
  63. for _, part := range openaiContent {
  64. if part.Type == dto.ContentTypeText {
  65. parts = append(parts, GeminiPart{
  66. Text: part.Text,
  67. })
  68. } else if part.Type == dto.ContentTypeImageURL {
  69. imageNum += 1
  70. if imageNum > GeminiVisionMaxImageNum {
  71. continue
  72. }
  73. mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
  74. parts = append(parts, GeminiPart{
  75. InlineData: &GeminiInlineData{
  76. MimeType: mimeType,
  77. Data: data,
  78. },
  79. })
  80. }
  81. }
  82. content.Parts = parts
  83. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  84. if content.Role == "assistant" {
  85. content.Role = "model"
  86. }
  87. // Converting system prompt to prompt from user for the same reason
  88. if content.Role == "system" {
  89. content.Role = "user"
  90. shouldAddDummyModelMessage = true
  91. }
  92. geminiRequest.Contents = append(geminiRequest.Contents, content)
  93. // If a system message is the last message, we need to add a dummy model message to make gemini happy
  94. if shouldAddDummyModelMessage {
  95. geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
  96. Role: "model",
  97. Parts: []GeminiPart{
  98. {
  99. Text: "Okay",
  100. },
  101. },
  102. })
  103. shouldAddDummyModelMessage = false
  104. }
  105. }
  106. return &geminiRequest
  107. }
  108. func (g *GeminiChatResponse) GetResponseText() string {
  109. if g == nil {
  110. return ""
  111. }
  112. if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
  113. return g.Candidates[0].Content.Parts[0].Text
  114. }
  115. return ""
  116. }
  117. func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
  118. fullTextResponse := dto.OpenAITextResponse{
  119. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  120. Object: "chat.completion",
  121. Created: common.GetTimestamp(),
  122. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  123. }
  124. content, _ := json.Marshal("")
  125. for i, candidate := range response.Candidates {
  126. choice := dto.OpenAITextResponseChoice{
  127. Index: i,
  128. Message: dto.Message{
  129. Role: "assistant",
  130. Content: content,
  131. },
  132. FinishReason: relaycommon.StopFinishReason,
  133. }
  134. if len(candidate.Content.Parts) > 0 {
  135. content, _ = json.Marshal(candidate.Content.Parts[0].Text)
  136. choice.Message.Content = content
  137. }
  138. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  139. }
  140. return &fullTextResponse
  141. }
  142. func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
  143. var choice dto.ChatCompletionsStreamResponseChoice
  144. choice.Delta.SetContentString(geminiResponse.GetResponseText())
  145. choice.FinishReason = &relaycommon.StopFinishReason
  146. var response dto.ChatCompletionsStreamResponse
  147. response.Object = "chat.completion.chunk"
  148. response.Model = "gemini"
  149. response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
  150. return &response
  151. }
  152. func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
  153. responseText := ""
  154. dataChan := make(chan string)
  155. stopChan := make(chan bool)
  156. scanner := bufio.NewScanner(resp.Body)
  157. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  158. if atEOF && len(data) == 0 {
  159. return 0, nil, nil
  160. }
  161. if i := strings.Index(string(data), "\n"); i >= 0 {
  162. return i + 1, data[0:i], nil
  163. }
  164. if atEOF {
  165. return len(data), data, nil
  166. }
  167. return 0, nil, nil
  168. })
  169. go func() {
  170. for scanner.Scan() {
  171. data := scanner.Text()
  172. data = strings.TrimSpace(data)
  173. if !strings.HasPrefix(data, "\"text\": \"") {
  174. continue
  175. }
  176. data = strings.TrimPrefix(data, "\"text\": \"")
  177. data = strings.TrimSuffix(data, "\"")
  178. dataChan <- data
  179. }
  180. stopChan <- true
  181. }()
  182. service.SetEventStreamHeaders(c)
  183. c.Stream(func(w io.Writer) bool {
  184. select {
  185. case data := <-dataChan:
  186. // this is used to prevent annoying \ related format bug
  187. data = fmt.Sprintf("{\"content\": \"%s\"}", data)
  188. type dummyStruct struct {
  189. Content string `json:"content"`
  190. }
  191. var dummy dummyStruct
  192. err := json.Unmarshal([]byte(data), &dummy)
  193. responseText += dummy.Content
  194. var choice dto.ChatCompletionsStreamResponseChoice
  195. choice.Delta.SetContentString(dummy.Content)
  196. response := dto.ChatCompletionsStreamResponse{
  197. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  198. Object: "chat.completion.chunk",
  199. Created: common.GetTimestamp(),
  200. Model: "gemini-pro",
  201. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  202. }
  203. jsonResponse, err := json.Marshal(response)
  204. if err != nil {
  205. common.SysError("error marshalling stream response: " + err.Error())
  206. return true
  207. }
  208. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  209. return true
  210. case <-stopChan:
  211. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  212. return false
  213. }
  214. })
  215. err := resp.Body.Close()
  216. if err != nil {
  217. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
  218. }
  219. return nil, responseText
  220. }
  221. func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  222. responseBody, err := io.ReadAll(resp.Body)
  223. if err != nil {
  224. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  225. }
  226. err = resp.Body.Close()
  227. if err != nil {
  228. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  229. }
  230. var geminiResponse GeminiChatResponse
  231. err = json.Unmarshal(responseBody, &geminiResponse)
  232. if err != nil {
  233. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  234. }
  235. if len(geminiResponse.Candidates) == 0 {
  236. return &dto.OpenAIErrorWithStatusCode{
  237. Error: dto.OpenAIError{
  238. Message: "No candidates returned",
  239. Type: "server_error",
  240. Param: "",
  241. Code: 500,
  242. },
  243. StatusCode: resp.StatusCode,
  244. }, nil
  245. }
  246. fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
  247. completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
  248. usage := dto.Usage{
  249. PromptTokens: promptTokens,
  250. CompletionTokens: completionTokens,
  251. TotalTokens: promptTokens + completionTokens,
  252. }
  253. fullTextResponse.Usage = usage
  254. jsonResponse, err := json.Marshal(fullTextResponse)
  255. if err != nil {
  256. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  257. }
  258. c.Writer.Header().Set("Content-Type", "application/json")
  259. c.Writer.WriteHeader(resp.StatusCode)
  260. _, err = c.Writer.Write(jsonResponse)
  261. return nil, &usage
  262. }