relay-gemini.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. const (
  13. GeminiVisionMaxImageNum = 16
  14. )
  15. type GeminiChatRequest struct {
  16. Contents []GeminiChatContent `json:"contents"`
  17. SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
  18. GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
  19. Tools []GeminiChatTools `json:"tools,omitempty"`
  20. }
  21. type GeminiInlineData struct {
  22. MimeType string `json:"mimeType"`
  23. Data string `json:"data"`
  24. }
  25. type GeminiPart struct {
  26. Text string `json:"text,omitempty"`
  27. InlineData *GeminiInlineData `json:"inlineData,omitempty"`
  28. }
  29. type GeminiChatContent struct {
  30. Role string `json:"role,omitempty"`
  31. Parts []GeminiPart `json:"parts"`
  32. }
  33. type GeminiChatSafetySettings struct {
  34. Category string `json:"category"`
  35. Threshold string `json:"threshold"`
  36. }
  37. type GeminiChatTools struct {
  38. FunctionDeclarations any `json:"functionDeclarations,omitempty"`
  39. }
  40. type GeminiChatGenerationConfig struct {
  41. Temperature float64 `json:"temperature,omitempty"`
  42. TopP float64 `json:"topP,omitempty"`
  43. TopK float64 `json:"topK,omitempty"`
  44. MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
  45. CandidateCount int `json:"candidateCount,omitempty"`
  46. StopSequences []string `json:"stopSequences,omitempty"`
  47. }
  48. // Setting safety to the lowest possible values since Gemini is already powerless enough
  49. func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
  50. geminiRequest := GeminiChatRequest{
  51. Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
  52. SafetySettings: []GeminiChatSafetySettings{
  53. {
  54. Category: "HARM_CATEGORY_HARASSMENT",
  55. Threshold: common.GeminiSafetySetting,
  56. },
  57. {
  58. Category: "HARM_CATEGORY_HATE_SPEECH",
  59. Threshold: common.GeminiSafetySetting,
  60. },
  61. {
  62. Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  63. Threshold: common.GeminiSafetySetting,
  64. },
  65. {
  66. Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  67. Threshold: common.GeminiSafetySetting,
  68. },
  69. },
  70. GenerationConfig: GeminiChatGenerationConfig{
  71. Temperature: textRequest.Temperature,
  72. TopP: textRequest.TopP,
  73. MaxOutputTokens: textRequest.MaxTokens,
  74. },
  75. }
  76. if textRequest.Functions != nil {
  77. geminiRequest.Tools = []GeminiChatTools{
  78. {
  79. FunctionDeclarations: textRequest.Functions,
  80. },
  81. }
  82. }
  83. shouldAddDummyModelMessage := false
  84. for _, message := range textRequest.Messages {
  85. content := GeminiChatContent{
  86. Role: message.Role,
  87. Parts: []GeminiPart{
  88. {
  89. Text: message.StringContent(),
  90. },
  91. },
  92. }
  93. openaiContent := message.ParseContent()
  94. var parts []GeminiPart
  95. imageNum := 0
  96. for _, part := range openaiContent {
  97. if part.Type == ContentTypeText {
  98. parts = append(parts, GeminiPart{
  99. Text: part.Text,
  100. })
  101. } else if part.Type == ContentTypeImageURL {
  102. imageNum += 1
  103. if imageNum > GeminiVisionMaxImageNum {
  104. continue
  105. }
  106. mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
  107. parts = append(parts, GeminiPart{
  108. InlineData: &GeminiInlineData{
  109. MimeType: mimeType,
  110. Data: data,
  111. },
  112. })
  113. }
  114. }
  115. content.Parts = parts
  116. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  117. if content.Role == "assistant" {
  118. content.Role = "model"
  119. }
  120. // Converting system prompt to prompt from user for the same reason
  121. if content.Role == "system" {
  122. content.Role = "user"
  123. shouldAddDummyModelMessage = true
  124. }
  125. geminiRequest.Contents = append(geminiRequest.Contents, content)
  126. // If a system message is the last message, we need to add a dummy model message to make gemini happy
  127. if shouldAddDummyModelMessage {
  128. geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
  129. Role: "model",
  130. Parts: []GeminiPart{
  131. {
  132. Text: "Okay",
  133. },
  134. },
  135. })
  136. shouldAddDummyModelMessage = false
  137. }
  138. }
  139. return &geminiRequest
  140. }
  141. type GeminiChatResponse struct {
  142. Candidates []GeminiChatCandidate `json:"candidates"`
  143. PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
  144. }
  145. func (g *GeminiChatResponse) GetResponseText() string {
  146. if g == nil {
  147. return ""
  148. }
  149. if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
  150. return g.Candidates[0].Content.Parts[0].Text
  151. }
  152. return ""
  153. }
  154. type GeminiChatCandidate struct {
  155. Content GeminiChatContent `json:"content"`
  156. FinishReason string `json:"finishReason"`
  157. Index int64 `json:"index"`
  158. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  159. }
  160. type GeminiChatSafetyRating struct {
  161. Category string `json:"category"`
  162. Probability string `json:"probability"`
  163. }
  164. type GeminiChatPromptFeedback struct {
  165. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  166. }
  167. func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
  168. fullTextResponse := OpenAITextResponse{
  169. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  170. Object: "chat.completion",
  171. Created: common.GetTimestamp(),
  172. Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
  173. }
  174. content, _ := json.Marshal("")
  175. for i, candidate := range response.Candidates {
  176. choice := OpenAITextResponseChoice{
  177. Index: i,
  178. Message: Message{
  179. Role: "assistant",
  180. Content: content,
  181. },
  182. FinishReason: stopFinishReason,
  183. }
  184. content, _ = json.Marshal(candidate.Content.Parts[0].Text)
  185. if len(candidate.Content.Parts) > 0 {
  186. choice.Message.Content = content
  187. }
  188. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  189. }
  190. return &fullTextResponse
  191. }
  192. func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
  193. var choice ChatCompletionsStreamResponseChoice
  194. choice.Delta.Content = geminiResponse.GetResponseText()
  195. choice.FinishReason = &stopFinishReason
  196. var response ChatCompletionsStreamResponse
  197. response.Object = "chat.completion.chunk"
  198. response.Model = "gemini"
  199. response.Choices = []ChatCompletionsStreamResponseChoice{choice}
  200. return &response
  201. }
  202. func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
  203. responseText := ""
  204. dataChan := make(chan string)
  205. stopChan := make(chan bool)
  206. scanner := bufio.NewScanner(resp.Body)
  207. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  208. if atEOF && len(data) == 0 {
  209. return 0, nil, nil
  210. }
  211. if i := strings.Index(string(data), "\n"); i >= 0 {
  212. return i + 1, data[0:i], nil
  213. }
  214. if atEOF {
  215. return len(data), data, nil
  216. }
  217. return 0, nil, nil
  218. })
  219. go func() {
  220. for scanner.Scan() {
  221. data := scanner.Text()
  222. data = strings.TrimSpace(data)
  223. if !strings.HasPrefix(data, "\"text\": \"") {
  224. continue
  225. }
  226. data = strings.TrimPrefix(data, "\"text\": \"")
  227. data = strings.TrimSuffix(data, "\"")
  228. dataChan <- data
  229. }
  230. stopChan <- true
  231. }()
  232. setEventStreamHeaders(c)
  233. c.Stream(func(w io.Writer) bool {
  234. select {
  235. case data := <-dataChan:
  236. // this is used to prevent annoying \ related format bug
  237. data = fmt.Sprintf("{\"content\": \"%s\"}", data)
  238. type dummyStruct struct {
  239. Content string `json:"content"`
  240. }
  241. var dummy dummyStruct
  242. err := json.Unmarshal([]byte(data), &dummy)
  243. responseText += dummy.Content
  244. var choice ChatCompletionsStreamResponseChoice
  245. choice.Delta.Content = dummy.Content
  246. response := ChatCompletionsStreamResponse{
  247. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  248. Object: "chat.completion.chunk",
  249. Created: common.GetTimestamp(),
  250. Model: "gemini-pro",
  251. Choices: []ChatCompletionsStreamResponseChoice{choice},
  252. }
  253. jsonResponse, err := json.Marshal(response)
  254. if err != nil {
  255. common.SysError("error marshalling stream response: " + err.Error())
  256. return true
  257. }
  258. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  259. return true
  260. case <-stopChan:
  261. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  262. return false
  263. }
  264. })
  265. err := resp.Body.Close()
  266. if err != nil {
  267. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
  268. }
  269. return nil, responseText
  270. }
  271. func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
  272. responseBody, err := io.ReadAll(resp.Body)
  273. if err != nil {
  274. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  275. }
  276. err = resp.Body.Close()
  277. if err != nil {
  278. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  279. }
  280. var geminiResponse GeminiChatResponse
  281. err = json.Unmarshal(responseBody, &geminiResponse)
  282. if err != nil {
  283. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  284. }
  285. if len(geminiResponse.Candidates) == 0 {
  286. return &OpenAIErrorWithStatusCode{
  287. OpenAIError: OpenAIError{
  288. Message: "No candidates returned",
  289. Type: "server_error",
  290. Param: "",
  291. Code: 500,
  292. },
  293. StatusCode: resp.StatusCode,
  294. }, nil
  295. }
  296. fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
  297. completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
  298. usage := Usage{
  299. PromptTokens: promptTokens,
  300. CompletionTokens: completionTokens,
  301. TotalTokens: promptTokens + completionTokens,
  302. }
  303. fullTextResponse.Usage = usage
  304. jsonResponse, err := json.Marshal(fullTextResponse)
  305. if err != nil {
  306. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  307. }
  308. c.Writer.Header().Set("Content-Type", "application/json")
  309. c.Writer.WriteHeader(resp.StatusCode)
  310. _, err = c.Writer.Write(jsonResponse)
  311. return nil, &usage
  312. }