relay-gemini.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "strings"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type GeminiChatRequest struct {
  13. Contents []GeminiChatContent `json:"contents"`
  14. SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
  15. GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
  16. Tools []GeminiChatTools `json:"tools,omitempty"`
  17. }
  18. type GeminiInlineData struct {
  19. MimeType string `json:"mimeType"`
  20. Data string `json:"data"`
  21. }
  22. type GeminiPart struct {
  23. Text string `json:"text,omitempty"`
  24. InlineData *GeminiInlineData `json:"inlineData,omitempty"`
  25. }
  26. type GeminiChatContent struct {
  27. Role string `json:"role,omitempty"`
  28. Parts []GeminiPart `json:"parts"`
  29. }
  30. type GeminiChatSafetySettings struct {
  31. Category string `json:"category"`
  32. Threshold string `json:"threshold"`
  33. }
  34. type GeminiChatTools struct {
  35. FunctionDeclarations any `json:"functionDeclarations,omitempty"`
  36. }
  37. type GeminiChatGenerationConfig struct {
  38. Temperature float64 `json:"temperature,omitempty"`
  39. TopP float64 `json:"topP,omitempty"`
  40. TopK float64 `json:"topK,omitempty"`
  41. MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
  42. CandidateCount int `json:"candidateCount,omitempty"`
  43. StopSequences []string `json:"stopSequences,omitempty"`
  44. }
  45. // Setting safety to the lowest possible values since Gemini is already powerless enough
  46. func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
  47. geminiRequest := GeminiChatRequest{
  48. Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
  49. //SafetySettings: []GeminiChatSafetySettings{
  50. // {
  51. // Category: "HARM_CATEGORY_HARASSMENT",
  52. // Threshold: "BLOCK_ONLY_HIGH",
  53. // },
  54. // {
  55. // Category: "HARM_CATEGORY_HATE_SPEECH",
  56. // Threshold: "BLOCK_ONLY_HIGH",
  57. // },
  58. // {
  59. // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  60. // Threshold: "BLOCK_ONLY_HIGH",
  61. // },
  62. // {
  63. // Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  64. // Threshold: "BLOCK_ONLY_HIGH",
  65. // },
  66. //},
  67. GenerationConfig: GeminiChatGenerationConfig{
  68. Temperature: textRequest.Temperature,
  69. TopP: textRequest.TopP,
  70. MaxOutputTokens: textRequest.MaxTokens,
  71. },
  72. }
  73. if textRequest.Functions != nil {
  74. geminiRequest.Tools = []GeminiChatTools{
  75. {
  76. FunctionDeclarations: textRequest.Functions,
  77. },
  78. }
  79. }
  80. shouldAddDummyModelMessage := false
  81. for _, message := range textRequest.Messages {
  82. content := GeminiChatContent{
  83. Role: message.Role,
  84. Parts: []GeminiPart{
  85. {
  86. Text: string(message.Content),
  87. },
  88. },
  89. }
  90. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  91. if content.Role == "assistant" {
  92. content.Role = "model"
  93. }
  94. // Converting system prompt to prompt from user for the same reason
  95. if content.Role == "system" {
  96. content.Role = "user"
  97. shouldAddDummyModelMessage = true
  98. }
  99. geminiRequest.Contents = append(geminiRequest.Contents, content)
  100. // If a system message is the last message, we need to add a dummy model message to make gemini happy
  101. if shouldAddDummyModelMessage {
  102. geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
  103. Role: "model",
  104. Parts: []GeminiPart{
  105. {
  106. Text: "Okay",
  107. },
  108. },
  109. })
  110. shouldAddDummyModelMessage = false
  111. }
  112. }
  113. return &geminiRequest
  114. }
  115. type GeminiChatResponse struct {
  116. Candidates []GeminiChatCandidate `json:"candidates"`
  117. PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
  118. }
  119. func (g *GeminiChatResponse) GetResponseText() string {
  120. if g == nil {
  121. return ""
  122. }
  123. if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
  124. return g.Candidates[0].Content.Parts[0].Text
  125. }
  126. return ""
  127. }
  128. type GeminiChatCandidate struct {
  129. Content GeminiChatContent `json:"content"`
  130. FinishReason string `json:"finishReason"`
  131. Index int64 `json:"index"`
  132. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  133. }
  134. type GeminiChatSafetyRating struct {
  135. Category string `json:"category"`
  136. Probability string `json:"probability"`
  137. }
  138. type GeminiChatPromptFeedback struct {
  139. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  140. }
  141. func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
  142. fullTextResponse := OpenAITextResponse{
  143. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  144. Object: "chat.completion",
  145. Created: common.GetTimestamp(),
  146. Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
  147. }
  148. content, _ := json.Marshal("")
  149. for i, candidate := range response.Candidates {
  150. choice := OpenAITextResponseChoice{
  151. Index: i,
  152. Message: Message{
  153. Role: "assistant",
  154. Content: content,
  155. },
  156. FinishReason: stopFinishReason,
  157. }
  158. content, _ = json.Marshal(candidate.Content.Parts[0].Text)
  159. if len(candidate.Content.Parts) > 0 {
  160. choice.Message.Content = content
  161. }
  162. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  163. }
  164. return &fullTextResponse
  165. }
  166. func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
  167. var choice ChatCompletionsStreamResponseChoice
  168. choice.Delta.Content = geminiResponse.GetResponseText()
  169. choice.FinishReason = &stopFinishReason
  170. var response ChatCompletionsStreamResponse
  171. response.Object = "chat.completion.chunk"
  172. response.Model = "gemini"
  173. response.Choices = []ChatCompletionsStreamResponseChoice{choice}
  174. return &response
  175. }
  176. func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
  177. responseText := ""
  178. dataChan := make(chan string)
  179. stopChan := make(chan bool)
  180. scanner := bufio.NewScanner(resp.Body)
  181. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  182. if atEOF && len(data) == 0 {
  183. return 0, nil, nil
  184. }
  185. if i := strings.Index(string(data), "\n"); i >= 0 {
  186. return i + 1, data[0:i], nil
  187. }
  188. if atEOF {
  189. return len(data), data, nil
  190. }
  191. return 0, nil, nil
  192. })
  193. go func() {
  194. for scanner.Scan() {
  195. data := scanner.Text()
  196. data = strings.TrimSpace(data)
  197. if !strings.HasPrefix(data, "\"text\": \"") {
  198. continue
  199. }
  200. data = strings.TrimPrefix(data, "\"text\": \"")
  201. data = strings.TrimSuffix(data, "\"")
  202. dataChan <- data
  203. }
  204. stopChan <- true
  205. }()
  206. setEventStreamHeaders(c)
  207. c.Stream(func(w io.Writer) bool {
  208. select {
  209. case data := <-dataChan:
  210. // this is used to prevent annoying \ related format bug
  211. data = fmt.Sprintf("{\"content\": \"%s\"}", data)
  212. type dummyStruct struct {
  213. Content string `json:"content"`
  214. }
  215. var dummy dummyStruct
  216. err := json.Unmarshal([]byte(data), &dummy)
  217. responseText += dummy.Content
  218. var choice ChatCompletionsStreamResponseChoice
  219. choice.Delta.Content = dummy.Content
  220. response := ChatCompletionsStreamResponse{
  221. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  222. Object: "chat.completion.chunk",
  223. Created: common.GetTimestamp(),
  224. Model: "gemini-pro",
  225. Choices: []ChatCompletionsStreamResponseChoice{choice},
  226. }
  227. jsonResponse, err := json.Marshal(response)
  228. if err != nil {
  229. common.SysError("error marshalling stream response: " + err.Error())
  230. return true
  231. }
  232. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  233. return true
  234. case <-stopChan:
  235. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  236. return false
  237. }
  238. })
  239. err := resp.Body.Close()
  240. if err != nil {
  241. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
  242. }
  243. return nil, responseText
  244. }
  245. func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
  246. responseBody, err := io.ReadAll(resp.Body)
  247. if err != nil {
  248. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  249. }
  250. err = resp.Body.Close()
  251. if err != nil {
  252. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  253. }
  254. var geminiResponse GeminiChatResponse
  255. err = json.Unmarshal(responseBody, &geminiResponse)
  256. if err != nil {
  257. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  258. }
  259. if len(geminiResponse.Candidates) == 0 {
  260. return &OpenAIErrorWithStatusCode{
  261. OpenAIError: OpenAIError{
  262. Message: "No candidates returned",
  263. Type: "server_error",
  264. Param: "",
  265. Code: 500,
  266. },
  267. StatusCode: resp.StatusCode,
  268. }, nil
  269. }
  270. fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
  271. completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
  272. usage := Usage{
  273. PromptTokens: promptTokens,
  274. CompletionTokens: completionTokens,
  275. TotalTokens: promptTokens + completionTokens,
  276. }
  277. fullTextResponse.Usage = usage
  278. jsonResponse, err := json.Marshal(fullTextResponse)
  279. if err != nil {
  280. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  281. }
  282. c.Writer.Header().Set("Content-Type", "application/json")
  283. c.Writer.WriteHeader(resp.StatusCode)
  284. _, err = c.Writer.Write(jsonResponse)
  285. return nil, &usage
  286. }