| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631 |
- package gemini
- import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting/model_setting"
- "strings"
- "unicode/utf8"
- "github.com/gin-gonic/gin"
- )
- // Setting safety to the lowest possible values since Gemini is already powerless enough
- func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
- geminiRequest := GeminiChatRequest{
- Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
- //SafetySettings: []GeminiChatSafetySettings{},
- GenerationConfig: GeminiChatGenerationConfig{
- Temperature: textRequest.Temperature,
- TopP: textRequest.TopP,
- MaxOutputTokens: textRequest.MaxTokens,
- Seed: int64(textRequest.Seed),
- },
- }
- safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
- for _, category := range SafetySettingList {
- safetySettings = append(safetySettings, GeminiChatSafetySettings{
- Category: category,
- Threshold: model_setting.GetGeminiSafetySetting(category),
- })
- }
- geminiRequest.SafetySettings = safetySettings
- // openaiContent.FuncToToolCalls()
- if textRequest.Tools != nil {
- functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
- googleSearch := false
- codeExecution := false
- for _, tool := range textRequest.Tools {
- if tool.Function.Name == "googleSearch" {
- googleSearch = true
- continue
- }
- if tool.Function.Name == "codeExecution" {
- codeExecution = true
- continue
- }
- if tool.Function.Parameters != nil {
- params, ok := tool.Function.Parameters.(map[string]interface{})
- if ok {
- if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
- if len(props) == 0 {
- tool.Function.Parameters = nil
- }
- }
- }
- }
- functions = append(functions, tool.Function)
- }
- if codeExecution {
- geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
- CodeExecution: make(map[string]string),
- })
- }
- if googleSearch {
- geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
- GoogleSearch: make(map[string]string),
- })
- }
- if len(functions) > 0 {
- geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
- FunctionDeclarations: functions,
- })
- }
- // common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
- // json_data, _ := json.Marshal(geminiRequest.Tools)
- // common.SysLog("tools_json: " + string(json_data))
- } else if textRequest.Functions != nil {
- geminiRequest.Tools = []GeminiChatTool{
- {
- FunctionDeclarations: textRequest.Functions,
- },
- }
- }
- if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
- geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
- if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
- cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
- geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
- }
- }
- tool_call_ids := make(map[string]string)
- var system_content []string
- //shouldAddDummyModelMessage := false
- for _, message := range textRequest.Messages {
- if message.Role == "system" {
- system_content = append(system_content, message.StringContent())
- continue
- } else if message.Role == "tool" || message.Role == "function" {
- if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
- geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
- Role: "user",
- })
- }
- var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
- name := ""
- if message.Name != nil {
- name = *message.Name
- } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
- name = val
- }
- content := common.StrToMap(message.StringContent())
- functionResp := &FunctionResponse{
- Name: name,
- Response: GeminiFunctionResponseContent{
- Name: name,
- Content: content,
- },
- }
- if content == nil {
- functionResp.Response.Content = message.StringContent()
- }
- *parts = append(*parts, GeminiPart{
- FunctionResponse: functionResp,
- })
- continue
- }
- var parts []GeminiPart
- content := GeminiChatContent{
- Role: message.Role,
- }
- // isToolCall := false
- if message.ToolCalls != nil {
- // message.Role = "model"
- // isToolCall = true
- for _, call := range message.ParseToolCalls() {
- args := map[string]interface{}{}
- if call.Function.Arguments != "" {
- if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
- return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
- }
- }
- toolCall := GeminiPart{
- FunctionCall: &FunctionCall{
- FunctionName: call.Function.Name,
- Arguments: args,
- },
- }
- parts = append(parts, toolCall)
- tool_call_ids[call.ID] = call.Function.Name
- }
- }
- openaiContent := message.ParseContent()
- imageNum := 0
- for _, part := range openaiContent {
- if part.Type == dto.ContentTypeText {
- if part.Text == "" {
- continue
- }
- parts = append(parts, GeminiPart{
- Text: part.Text,
- })
- } else if part.Type == dto.ContentTypeImageURL {
- imageNum += 1
- if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
- return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
- }
- // 判断是否是url
- if strings.HasPrefix(part.GetImageMedia().Url, "http") {
- // 是url,获取图片的类型和base64编码的数据
- fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
- if err != nil {
- return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
- }
- parts = append(parts, GeminiPart{
- InlineData: &GeminiInlineData{
- MimeType: fileData.MimeType,
- Data: fileData.Base64Data,
- },
- })
- } else {
- format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
- if err != nil {
- return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
- }
- parts = append(parts, GeminiPart{
- InlineData: &GeminiInlineData{
- MimeType: format,
- Data: base64String,
- },
- })
- }
- }
- }
- content.Parts = parts
- // there's no assistant role in gemini and API shall vomit if Role is not user or model
- if content.Role == "assistant" {
- content.Role = "model"
- }
- geminiRequest.Contents = append(geminiRequest.Contents, content)
- }
- if len(system_content) > 0 {
- geminiRequest.SystemInstructions = &GeminiChatContent{
- Parts: []GeminiPart{
- {
- Text: strings.Join(system_content, "\n"),
- },
- },
- }
- }
- return &geminiRequest, nil
- }
- func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
- if depth >= 5 {
- return schema
- }
- v, ok := schema.(map[string]interface{})
- if !ok || len(v) == 0 {
- return schema
- }
- // 删除所有的title字段
- delete(v, "title")
- // 如果type不为object和array,则直接返回
- if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
- return schema
- }
- switch v["type"] {
- case "object":
- delete(v, "additionalProperties")
- // 处理 properties
- if properties, ok := v["properties"].(map[string]interface{}); ok {
- for key, value := range properties {
- properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
- }
- }
- for _, field := range []string{"allOf", "anyOf", "oneOf"} {
- if nested, ok := v[field].([]interface{}); ok {
- for i, item := range nested {
- nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
- }
- }
- }
- case "array":
- if items, ok := v["items"].(map[string]interface{}); ok {
- v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
- }
- }
- return v
- }
- func unescapeString(s string) (string, error) {
- var result []rune
- escaped := false
- i := 0
- for i < len(s) {
- r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
- if r == utf8.RuneError {
- return "", fmt.Errorf("invalid UTF-8 encoding")
- }
- if escaped {
- // 如果是转义符后的字符,检查其类型
- switch r {
- case '"':
- result = append(result, '"')
- case '\\':
- result = append(result, '\\')
- case '/':
- result = append(result, '/')
- case 'b':
- result = append(result, '\b')
- case 'f':
- result = append(result, '\f')
- case 'n':
- result = append(result, '\n')
- case 'r':
- result = append(result, '\r')
- case 't':
- result = append(result, '\t')
- case '\'':
- result = append(result, '\'')
- default:
- // 如果遇到一个非法的转义字符,直接按原样输出
- result = append(result, '\\', r)
- }
- escaped = false
- } else {
- if r == '\\' {
- escaped = true // 记录反斜杠作为转义符
- } else {
- result = append(result, r)
- }
- }
- i += size // 移动到下一个字符
- }
- return string(result), nil
- }
- func unescapeMapOrSlice(data interface{}) interface{} {
- switch v := data.(type) {
- case map[string]interface{}:
- for k, val := range v {
- v[k] = unescapeMapOrSlice(val)
- }
- case []interface{}:
- for i, val := range v {
- v[i] = unescapeMapOrSlice(val)
- }
- case string:
- if unescaped, err := unescapeString(v); err != nil {
- return v
- } else {
- return unescaped
- }
- }
- return data
- }
- func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
- var argsBytes []byte
- var err error
- if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
- argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
- } else {
- argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
- }
- if err != nil {
- return nil
- }
- return &dto.ToolCallResponse{
- ID: fmt.Sprintf("call_%s", common.GetUUID()),
- Type: "function",
- Function: dto.FunctionResponse{
- Arguments: string(argsBytes),
- Name: item.FunctionCall.FunctionName,
- },
- }
- }
- func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
- fullTextResponse := dto.OpenAITextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
- Object: "chat.completion",
- Created: common.GetTimestamp(),
- Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
- }
- content, _ := json.Marshal("")
- isToolCall := false
- for _, candidate := range response.Candidates {
- choice := dto.OpenAITextResponseChoice{
- Index: int(candidate.Index),
- Message: dto.Message{
- Role: "assistant",
- Content: content,
- },
- FinishReason: constant.FinishReasonStop,
- }
- if len(candidate.Content.Parts) > 0 {
- var texts []string
- var toolCalls []dto.ToolCallResponse
- for _, part := range candidate.Content.Parts {
- if part.FunctionCall != nil {
- choice.FinishReason = constant.FinishReasonToolCalls
- if call := getResponseToolCall(&part); call != nil {
- toolCalls = append(toolCalls, *call)
- }
- } else {
- if part.ExecutableCode != nil {
- texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
- } else if part.CodeExecutionResult != nil {
- texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
- } else {
- // 过滤掉空行
- if part.Text != "\n" {
- texts = append(texts, part.Text)
- }
- }
- }
- }
- if len(toolCalls) > 0 {
- choice.Message.SetToolCalls(toolCalls)
- isToolCall = true
- }
- choice.Message.SetStringContent(strings.Join(texts, "\n"))
- }
- if candidate.FinishReason != nil {
- switch *candidate.FinishReason {
- case "STOP":
- choice.FinishReason = constant.FinishReasonStop
- case "MAX_TOKENS":
- choice.FinishReason = constant.FinishReasonLength
- default:
- choice.FinishReason = constant.FinishReasonContentFilter
- }
- }
- if isToolCall {
- choice.FinishReason = constant.FinishReasonToolCalls
- }
- fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
- }
- return &fullTextResponse
- }
- func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
- choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
- isStop := false
- for _, candidate := range geminiResponse.Candidates {
- if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
- isStop = true
- candidate.FinishReason = nil
- }
- choice := dto.ChatCompletionsStreamResponseChoice{
- Index: int(candidate.Index),
- Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
- Role: "assistant",
- },
- }
- var texts []string
- isTools := false
- if candidate.FinishReason != nil {
- // p := GeminiConvertFinishReason(*candidate.FinishReason)
- switch *candidate.FinishReason {
- case "STOP":
- choice.FinishReason = &constant.FinishReasonStop
- case "MAX_TOKENS":
- choice.FinishReason = &constant.FinishReasonLength
- default:
- choice.FinishReason = &constant.FinishReasonContentFilter
- }
- }
- for _, part := range candidate.Content.Parts {
- if part.FunctionCall != nil {
- isTools = true
- if call := getResponseToolCall(&part); call != nil {
- call.SetIndex(len(choice.Delta.ToolCalls))
- choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
- }
- } else {
- if part.ExecutableCode != nil {
- texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
- } else if part.CodeExecutionResult != nil {
- texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
- } else {
- if part.Text != "\n" {
- texts = append(texts, part.Text)
- }
- }
- }
- }
- choice.Delta.SetContentString(strings.Join(texts, "\n"))
- if isTools {
- choice.FinishReason = &constant.FinishReasonToolCalls
- }
- choices = append(choices, choice)
- }
- var response dto.ChatCompletionsStreamResponse
- response.Object = "chat.completion.chunk"
- response.Choices = choices
- return &response, isStop
- }
- func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- // responseText := ""
- id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
- createAt := common.GetTimestamp()
- var usage = &dto.Usage{}
- helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var geminiResponse GeminiChatResponse
- err := json.Unmarshal([]byte(data), &geminiResponse)
- if err != nil {
- common.LogError(c, "error unmarshalling stream response: "+err.Error())
- return false
- }
- response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
- response.Id = id
- response.Created = createAt
- response.Model = info.UpstreamModelName
- // responseText += response.Choices[0].Delta.GetContentString()
- if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
- usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
- usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
- }
- err = helper.ObjectData(c, response)
- if err != nil {
- common.LogError(c, err.Error())
- }
- if isStop {
- response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
- helper.ObjectData(c, response)
- }
- return true
- })
- var response *dto.ChatCompletionsStreamResponse
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- usage.PromptTokensDetails.TextTokens = usage.PromptTokens
- usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
- if info.ShouldIncludeUsage {
- response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
- err := helper.ObjectData(c, response)
- if err != nil {
- common.SysError("send final response failed: " + err.Error())
- }
- }
- helper.Done(c)
- //resp.Body.Close()
- return nil, usage
- }
- func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- var geminiResponse GeminiChatResponse
- err = json.Unmarshal(responseBody, &geminiResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
- }
- if len(geminiResponse.Candidates) == 0 {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: "No candidates returned",
- Type: "server_error",
- Param: "",
- Code: 500,
- },
- StatusCode: resp.StatusCode,
- }, nil
- }
- fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
- fullTextResponse.Model = info.UpstreamModelName
- usage := dto.Usage{
- PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
- CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
- TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
- }
- fullTextResponse.Usage = usage
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &usage
- }
- func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
- responseBody, readErr := io.ReadAll(resp.Body)
- if readErr != nil {
- return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
- }
- _ = resp.Body.Close()
- var geminiResponse GeminiEmbeddingResponse
- if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
- return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
- }
- // convert to openai format response
- openAIResponse := dto.OpenAIEmbeddingResponse{
- Object: "list",
- Data: []dto.OpenAIEmbeddingResponseItem{
- {
- Object: "embedding",
- Embedding: geminiResponse.Embedding.Values,
- Index: 0,
- },
- },
- Model: info.UpstreamModelName,
- }
- // calculate usage
- // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
- // Google has not yet clarified how embedding models will be billed
- // refer to openai billing method to use input tokens billing
- // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
- usage = &dto.Usage{
- PromptTokens: info.PromptTokens,
- CompletionTokens: 0,
- TotalTokens: info.PromptTokens,
- }
- openAIResponse.Usage = *usage.(*dto.Usage)
- jsonResponse, jsonErr := json.Marshal(openAIResponse)
- if jsonErr != nil {
- return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, _ = c.Writer.Write(jsonResponse)
- return usage, nil
- }
|