gemini.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. package dto
  2. import (
  3. "encoding/json"
  4. "github.com/gin-gonic/gin"
  5. "one-api/common"
  6. "one-api/logger"
  7. "one-api/types"
  8. "strings"
  9. )
  10. type GeminiChatRequest struct {
  11. Contents []GeminiChatContent `json:"contents"`
  12. SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
  13. GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
  14. Tools json.RawMessage `json:"tools,omitempty"`
  15. ToolConfig *ToolConfig `json:"toolConfig,omitempty"`
  16. SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
  17. CachedContent string `json:"cachedContent,omitempty"`
  18. }
  19. type ToolConfig struct {
  20. FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
  21. RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
  22. }
  23. type FunctionCallingConfig struct {
  24. Mode FunctionCallingConfigMode `json:"mode,omitempty"`
  25. AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
  26. }
  27. type FunctionCallingConfigMode string
  28. type RetrievalConfig struct {
  29. LatLng *LatLng `json:"latLng,omitempty"`
  30. LanguageCode string `json:"languageCode,omitempty"`
  31. }
  32. type LatLng struct {
  33. Latitude *float64 `json:"latitude,omitempty"`
  34. Longitude *float64 `json:"longitude,omitempty"`
  35. }
  36. func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
  37. var files []*types.FileMeta = make([]*types.FileMeta, 0)
  38. var maxTokens int
  39. if r.GenerationConfig.MaxOutputTokens > 0 {
  40. maxTokens = int(r.GenerationConfig.MaxOutputTokens)
  41. }
  42. var inputTexts []string
  43. for _, content := range r.Contents {
  44. for _, part := range content.Parts {
  45. if part.Text != "" {
  46. inputTexts = append(inputTexts, part.Text)
  47. }
  48. if part.InlineData != nil && part.InlineData.Data != "" {
  49. if strings.HasPrefix(part.InlineData.MimeType, "image/") {
  50. files = append(files, &types.FileMeta{
  51. FileType: types.FileTypeImage,
  52. OriginData: part.InlineData.Data,
  53. })
  54. } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
  55. files = append(files, &types.FileMeta{
  56. FileType: types.FileTypeAudio,
  57. OriginData: part.InlineData.Data,
  58. })
  59. } else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
  60. files = append(files, &types.FileMeta{
  61. FileType: types.FileTypeVideo,
  62. OriginData: part.InlineData.Data,
  63. })
  64. } else {
  65. files = append(files, &types.FileMeta{
  66. FileType: types.FileTypeFile,
  67. OriginData: part.InlineData.Data,
  68. })
  69. }
  70. }
  71. }
  72. }
  73. inputText := strings.Join(inputTexts, "\n")
  74. return &types.TokenCountMeta{
  75. CombineText: inputText,
  76. Files: files,
  77. MaxTokens: maxTokens,
  78. }
  79. }
  80. func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
  81. if c.Query("alt") == "sse" {
  82. return true
  83. }
  84. return false
  85. }
  86. func (r *GeminiChatRequest) SetModelName(modelName string) {
  87. // GeminiChatRequest does not have a model field, so this method does nothing.
  88. }
  89. func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
  90. var tools []GeminiChatTool
  91. if strings.HasSuffix(string(r.Tools), "[") {
  92. // is array
  93. if err := common.Unmarshal(r.Tools, &tools); err != nil {
  94. logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
  95. return nil
  96. }
  97. } else if strings.HasPrefix(string(r.Tools), "{") {
  98. // is object
  99. singleTool := GeminiChatTool{}
  100. if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
  101. logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
  102. return nil
  103. }
  104. tools = []GeminiChatTool{singleTool}
  105. }
  106. return tools
  107. }
  108. func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
  109. if len(tools) == 0 {
  110. r.Tools = json.RawMessage("[]")
  111. return
  112. }
  113. // Marshal the tools to JSON
  114. data, err := common.Marshal(tools)
  115. if err != nil {
  116. logger.LogError(nil, "error_marshalling_tools: "+err.Error())
  117. return
  118. }
  119. r.Tools = data
  120. }
  121. type GeminiThinkingConfig struct {
  122. IncludeThoughts bool `json:"includeThoughts,omitempty"`
  123. ThinkingBudget *int `json:"thinkingBudget,omitempty"`
  124. }
  125. func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
  126. c.ThinkingBudget = &budget
  127. }
  128. type GeminiInlineData struct {
  129. MimeType string `json:"mimeType"`
  130. Data string `json:"data"`
  131. }
  132. // UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
  133. func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
  134. type Alias GeminiInlineData // Use type alias to avoid recursion
  135. var aux struct {
  136. Alias
  137. MimeTypeSnake string `json:"mime_type"`
  138. }
  139. if err := common.Unmarshal(data, &aux); err != nil {
  140. return err
  141. }
  142. *g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
  143. // Prioritize snake_case if present
  144. if aux.MimeTypeSnake != "" {
  145. g.MimeType = aux.MimeTypeSnake
  146. } else if aux.MimeType != "" { // Fallback to camelCase from Alias
  147. g.MimeType = aux.MimeType
  148. }
  149. // g.Data would be populated by aux.Alias.Data
  150. return nil
  151. }
  152. type FunctionCall struct {
  153. FunctionName string `json:"name"`
  154. Arguments any `json:"args"`
  155. }
  156. type GeminiFunctionResponse struct {
  157. Name string `json:"name"`
  158. Response map[string]interface{} `json:"response"`
  159. }
  160. type GeminiPartExecutableCode struct {
  161. Language string `json:"language,omitempty"`
  162. Code string `json:"code,omitempty"`
  163. }
  164. type GeminiPartCodeExecutionResult struct {
  165. Outcome string `json:"outcome,omitempty"`
  166. Output string `json:"output,omitempty"`
  167. }
  168. type GeminiFileData struct {
  169. MimeType string `json:"mimeType,omitempty"`
  170. FileUri string `json:"fileUri,omitempty"`
  171. }
  172. type GeminiPart struct {
  173. Text string `json:"text,omitempty"`
  174. Thought bool `json:"thought,omitempty"`
  175. InlineData *GeminiInlineData `json:"inlineData,omitempty"`
  176. FunctionCall *FunctionCall `json:"functionCall,omitempty"`
  177. FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
  178. FileData *GeminiFileData `json:"fileData,omitempty"`
  179. ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
  180. CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
  181. }
  182. // UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
  183. func (p *GeminiPart) UnmarshalJSON(data []byte) error {
  184. // Alias to avoid recursion during unmarshalling
  185. type Alias GeminiPart
  186. var aux struct {
  187. Alias
  188. InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
  189. }
  190. if err := common.Unmarshal(data, &aux); err != nil {
  191. return err
  192. }
  193. // Assign fields from alias
  194. *p = GeminiPart(aux.Alias)
  195. // Prioritize snake_case for InlineData if present
  196. if aux.InlineDataSnake != nil {
  197. p.InlineData = aux.InlineDataSnake
  198. } else if aux.InlineData != nil { // Fallback to camelCase from Alias
  199. p.InlineData = aux.InlineData
  200. }
  201. // Other fields like Text, FunctionCall etc. are already populated via aux.Alias
  202. return nil
  203. }
  204. type GeminiChatContent struct {
  205. Role string `json:"role,omitempty"`
  206. Parts []GeminiPart `json:"parts"`
  207. }
  208. type GeminiChatSafetySettings struct {
  209. Category string `json:"category"`
  210. Threshold string `json:"threshold"`
  211. }
  212. type GeminiChatTool struct {
  213. GoogleSearch any `json:"googleSearch,omitempty"`
  214. GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
  215. CodeExecution any `json:"codeExecution,omitempty"`
  216. FunctionDeclarations any `json:"functionDeclarations,omitempty"`
  217. }
  218. type GeminiChatGenerationConfig struct {
  219. Temperature *float64 `json:"temperature,omitempty"`
  220. TopP float64 `json:"topP,omitempty"`
  221. TopK float64 `json:"topK,omitempty"`
  222. MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
  223. CandidateCount int `json:"candidateCount,omitempty"`
  224. StopSequences []string `json:"stopSequences,omitempty"`
  225. ResponseMimeType string `json:"responseMimeType,omitempty"`
  226. ResponseSchema any `json:"responseSchema,omitempty"`
  227. ResponseJsonSchema any `json:"responseJsonSchema,omitempty"`
  228. PresencePenalty *float32 `json:"presencePenalty,omitempty"`
  229. FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
  230. ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
  231. Logprobs *int32 `json:"logprobs,omitempty"`
  232. MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
  233. Seed int64 `json:"seed,omitempty"`
  234. ResponseModalities []string `json:"responseModalities,omitempty"`
  235. ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
  236. SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
  237. }
  238. type MediaResolution string
  239. type GeminiChatCandidate struct {
  240. Content GeminiChatContent `json:"content"`
  241. FinishReason *string `json:"finishReason"`
  242. Index int64 `json:"index"`
  243. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  244. }
  245. type GeminiChatSafetyRating struct {
  246. Category string `json:"category"`
  247. Probability string `json:"probability"`
  248. }
  249. type GeminiChatPromptFeedback struct {
  250. SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
  251. }
  252. type GeminiChatResponse struct {
  253. Candidates []GeminiChatCandidate `json:"candidates"`
  254. PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
  255. UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
  256. }
  257. type GeminiUsageMetadata struct {
  258. PromptTokenCount int `json:"promptTokenCount"`
  259. CandidatesTokenCount int `json:"candidatesTokenCount"`
  260. TotalTokenCount int `json:"totalTokenCount"`
  261. ThoughtsTokenCount int `json:"thoughtsTokenCount"`
  262. PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
  263. }
  264. type GeminiPromptTokensDetails struct {
  265. Modality string `json:"modality"`
  266. TokenCount int `json:"tokenCount"`
  267. }
  268. // Imagen related structs
  269. type GeminiImageRequest struct {
  270. Instances []GeminiImageInstance `json:"instances"`
  271. Parameters GeminiImageParameters `json:"parameters"`
  272. }
  273. type GeminiImageInstance struct {
  274. Prompt string `json:"prompt"`
  275. }
  276. type GeminiImageParameters struct {
  277. SampleCount int `json:"sampleCount,omitempty"`
  278. AspectRatio string `json:"aspectRatio,omitempty"`
  279. PersonGeneration string `json:"personGeneration,omitempty"`
  280. }
  281. type GeminiImageResponse struct {
  282. Predictions []GeminiImagePrediction `json:"predictions"`
  283. }
  284. type GeminiImagePrediction struct {
  285. MimeType string `json:"mimeType"`
  286. BytesBase64Encoded string `json:"bytesBase64Encoded"`
  287. RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
  288. SafetyAttributes any `json:"safetyAttributes,omitempty"`
  289. }
  290. // Embedding related structs
  291. type GeminiEmbeddingRequest struct {
  292. Model string `json:"model,omitempty"`
  293. Content GeminiChatContent `json:"content"`
  294. TaskType string `json:"taskType,omitempty"`
  295. Title string `json:"title,omitempty"`
  296. OutputDimensionality int `json:"outputDimensionality,omitempty"`
  297. }
  298. func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
  299. // Gemini embedding requests are not streamed
  300. return false
  301. }
  302. func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
  303. var inputTexts []string
  304. for _, part := range r.Content.Parts {
  305. if part.Text != "" {
  306. inputTexts = append(inputTexts, part.Text)
  307. }
  308. }
  309. inputText := strings.Join(inputTexts, "\n")
  310. return &types.TokenCountMeta{
  311. CombineText: inputText,
  312. }
  313. }
  314. func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
  315. if modelName != "" {
  316. r.Model = modelName
  317. }
  318. }
  319. type GeminiBatchEmbeddingRequest struct {
  320. Requests []*GeminiEmbeddingRequest `json:"requests"`
  321. }
  322. func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
  323. // Gemini batch embedding requests are not streamed
  324. return false
  325. }
  326. func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
  327. var inputTexts []string
  328. for _, request := range r.Requests {
  329. meta := request.GetTokenCountMeta()
  330. if meta != nil && meta.CombineText != "" {
  331. inputTexts = append(inputTexts, meta.CombineText)
  332. }
  333. }
  334. inputText := strings.Join(inputTexts, "\n")
  335. return &types.TokenCountMeta{
  336. CombineText: inputText,
  337. }
  338. }
  339. func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
  340. if modelName != "" {
  341. for _, req := range r.Requests {
  342. req.SetModelName(modelName)
  343. }
  344. }
  345. }
  346. type GeminiEmbeddingResponse struct {
  347. Embedding ContentEmbedding `json:"embedding"`
  348. }
  349. type GeminiBatchEmbeddingResponse struct {
  350. Embeddings []*ContentEmbedding `json:"embeddings"`
  351. }
  352. type ContentEmbedding struct {
  353. Values []float64 `json:"values"`
  354. }