relay-gemini.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. package gemini
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net/http"
  9. "one-api/common"
  10. "one-api/constant"
  11. "one-api/dto"
  12. relaycommon "one-api/relay/common"
  13. "one-api/service"
  14. "strings"
  15. "time"
  16. "github.com/gin-gonic/gin"
  17. )
  18. // Setting safety to the lowest possible values since Gemini is already powerless enough
  19. func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
  20. geminiRequest := GeminiChatRequest{
  21. Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
  22. SafetySettings: []GeminiChatSafetySettings{
  23. {
  24. Category: "HARM_CATEGORY_HARASSMENT",
  25. Threshold: common.GeminiSafetySetting,
  26. },
  27. {
  28. Category: "HARM_CATEGORY_HATE_SPEECH",
  29. Threshold: common.GeminiSafetySetting,
  30. },
  31. {
  32. Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  33. Threshold: common.GeminiSafetySetting,
  34. },
  35. {
  36. Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  37. Threshold: common.GeminiSafetySetting,
  38. },
  39. },
  40. GenerationConfig: GeminiChatGenerationConfig{
  41. Temperature: textRequest.Temperature,
  42. TopP: textRequest.TopP,
  43. MaxOutputTokens: textRequest.MaxTokens,
  44. },
  45. }
  46. if textRequest.Functions != nil {
  47. geminiRequest.Tools = []GeminiChatTools{
  48. {
  49. FunctionDeclarations: textRequest.Functions,
  50. },
  51. }
  52. }
  53. shouldAddDummyModelMessage := false
  54. for _, message := range textRequest.Messages {
  55. content := GeminiChatContent{
  56. Role: message.Role,
  57. Parts: []GeminiPart{
  58. {
  59. Text: message.StringContent(),
  60. },
  61. },
  62. }
  63. openaiContent := message.ParseContent()
  64. var parts []GeminiPart
  65. imageNum := 0
  66. for _, part := range openaiContent {
  67. if part.Type == dto.ContentTypeText {
  68. parts = append(parts, GeminiPart{
  69. Text: part.Text,
  70. })
  71. } else if part.Type == dto.ContentTypeImageURL {
  72. imageNum += 1
  73. if imageNum > GeminiVisionMaxImageNum {
  74. continue
  75. }
  76. mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
  77. parts = append(parts, GeminiPart{
  78. InlineData: &GeminiInlineData{
  79. MimeType: mimeType,
  80. Data: data,
  81. },
  82. })
  83. }
  84. }
  85. content.Parts = parts
  86. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  87. if content.Role == "assistant" {
  88. content.Role = "model"
  89. }
  90. // Converting system prompt to prompt from user for the same reason
  91. if content.Role == "system" {
  92. content.Role = "user"
  93. shouldAddDummyModelMessage = true
  94. }
  95. geminiRequest.Contents = append(geminiRequest.Contents, content)
  96. // If a system message is the last message, we need to add a dummy model message to make gemini happy
  97. if shouldAddDummyModelMessage {
  98. geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
  99. Role: "model",
  100. Parts: []GeminiPart{
  101. {
  102. Text: "Okay",
  103. },
  104. },
  105. })
  106. shouldAddDummyModelMessage = false
  107. }
  108. }
  109. return &geminiRequest
  110. }
  111. func (g *GeminiChatResponse) GetResponseText() string {
  112. if g == nil {
  113. return ""
  114. }
  115. if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
  116. return g.Candidates[0].Content.Parts[0].Text
  117. }
  118. return ""
  119. }
  120. func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
  121. fullTextResponse := dto.OpenAITextResponse{
  122. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  123. Object: "chat.completion",
  124. Created: common.GetTimestamp(),
  125. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  126. }
  127. content, _ := json.Marshal("")
  128. for i, candidate := range response.Candidates {
  129. choice := dto.OpenAITextResponseChoice{
  130. Index: i,
  131. Message: dto.Message{
  132. Role: "assistant",
  133. Content: content,
  134. },
  135. FinishReason: relaycommon.StopFinishReason,
  136. }
  137. if len(candidate.Content.Parts) > 0 {
  138. content, _ = json.Marshal(candidate.Content.Parts[0].Text)
  139. choice.Message.Content = content
  140. }
  141. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  142. }
  143. return &fullTextResponse
  144. }
  145. func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
  146. var choice dto.ChatCompletionsStreamResponseChoice
  147. choice.Delta.SetContentString(geminiResponse.GetResponseText())
  148. choice.FinishReason = &relaycommon.StopFinishReason
  149. var response dto.ChatCompletionsStreamResponse
  150. response.Object = "chat.completion.chunk"
  151. response.Model = "gemini"
  152. response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
  153. return &response
  154. }
  155. func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  156. responseText := ""
  157. responseJson := ""
  158. id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  159. createAt := common.GetTimestamp()
  160. var usage = &dto.Usage{}
  161. dataChan := make(chan string, 5)
  162. stopChan := make(chan bool, 2)
  163. scanner := bufio.NewScanner(resp.Body)
  164. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  165. if atEOF && len(data) == 0 {
  166. return 0, nil, nil
  167. }
  168. if i := strings.Index(string(data), "\n"); i >= 0 {
  169. return i + 1, data[0:i], nil
  170. }
  171. if atEOF {
  172. return len(data), data, nil
  173. }
  174. return 0, nil, nil
  175. })
  176. go func() {
  177. for scanner.Scan() {
  178. data := scanner.Text()
  179. responseJson += data
  180. data = strings.TrimSpace(data)
  181. if !strings.HasPrefix(data, "\"text\": \"") {
  182. continue
  183. }
  184. data = strings.TrimPrefix(data, "\"text\": \"")
  185. data = strings.TrimSuffix(data, "\"")
  186. if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
  187. // send data timeout, stop the stream
  188. common.LogError(c, "send data timeout, stop the stream")
  189. break
  190. }
  191. }
  192. stopChan <- true
  193. }()
  194. isFirst := true
  195. service.SetEventStreamHeaders(c)
  196. c.Stream(func(w io.Writer) bool {
  197. select {
  198. case data := <-dataChan:
  199. if isFirst {
  200. isFirst = false
  201. info.FirstResponseTime = time.Now()
  202. }
  203. // this is used to prevent annoying \ related format bug
  204. data = fmt.Sprintf("{\"content\": \"%s\"}", data)
  205. type dummyStruct struct {
  206. Content string `json:"content"`
  207. }
  208. var dummy dummyStruct
  209. err := json.Unmarshal([]byte(data), &dummy)
  210. responseText += dummy.Content
  211. var choice dto.ChatCompletionsStreamResponseChoice
  212. choice.Delta.SetContentString(dummy.Content)
  213. response := dto.ChatCompletionsStreamResponse{
  214. Id: id,
  215. Object: "chat.completion.chunk",
  216. Created: createAt,
  217. Model: info.UpstreamModelName,
  218. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  219. }
  220. jsonResponse, err := json.Marshal(response)
  221. if err != nil {
  222. common.SysError("error marshalling stream response: " + err.Error())
  223. return true
  224. }
  225. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  226. return true
  227. case <-stopChan:
  228. return false
  229. }
  230. })
  231. var geminiChatResponses []GeminiChatResponse
  232. err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
  233. if err != nil {
  234. log.Printf("cannot get gemini usage: %s", err.Error())
  235. usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
  236. } else {
  237. for _, response := range geminiChatResponses {
  238. usage.PromptTokens = response.UsageMetadata.PromptTokenCount
  239. usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
  240. }
  241. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  242. }
  243. if info.ShouldIncludeUsage {
  244. response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  245. err := service.ObjectData(c, response)
  246. if err != nil {
  247. common.SysError("send final response failed: " + err.Error())
  248. }
  249. }
  250. service.Done(c)
  251. err = resp.Body.Close()
  252. if err != nil {
  253. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
  254. }
  255. return nil, usage
  256. }
  257. func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  258. responseBody, err := io.ReadAll(resp.Body)
  259. if err != nil {
  260. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  261. }
  262. err = resp.Body.Close()
  263. if err != nil {
  264. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  265. }
  266. var geminiResponse GeminiChatResponse
  267. err = json.Unmarshal(responseBody, &geminiResponse)
  268. if err != nil {
  269. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  270. }
  271. if len(geminiResponse.Candidates) == 0 {
  272. return &dto.OpenAIErrorWithStatusCode{
  273. Error: dto.OpenAIError{
  274. Message: "No candidates returned",
  275. Type: "server_error",
  276. Param: "",
  277. Code: 500,
  278. },
  279. StatusCode: resp.StatusCode,
  280. }, nil
  281. }
  282. fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
  283. usage := dto.Usage{
  284. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  285. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  286. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  287. }
  288. fullTextResponse.Usage = usage
  289. jsonResponse, err := json.Marshal(fullTextResponse)
  290. if err != nil {
  291. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  292. }
  293. c.Writer.Header().Set("Content-Type", "application/json")
  294. c.Writer.WriteHeader(resp.StatusCode)
  295. _, err = c.Writer.Write(jsonResponse)
  296. return nil, &usage
  297. }