Browse Source

Merge remote-tracking branch 'origin/feat/o1' into feat/o1

HynoR 1 year ago
parent
commit
eac3463401

+ 1 - 0
README.en.md

@@ -81,6 +81,7 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
 - `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true`
 - `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true`
 - `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
 - `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
 - `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
 - `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
+- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
 
 
 ## Deployment
 ## Deployment
 > [!TIP]
 > [!TIP]

+ 1 - 0
README.md

@@ -87,6 +87,7 @@
 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
 - `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
 - `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
 - `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。
 - `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。
+- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
 ## 部署
 ## 部署
 > [!TIP]
 > [!TIP]
 > 最新版Docker镜像:`calciumion/new-api:latest`  
 > 最新版Docker镜像:`calciumion/new-api:latest`  

+ 2 - 0
constant/env.go

@@ -23,6 +23,8 @@ var GeminiModelMap = map[string]string{
 	"gemini-1.0-pro": "v1",
 	"gemini-1.0-pro": "v1",
 }
 }
 
 
+var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
+
 func InitEnv() {
 func InitEnv() {
 	modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
 	modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
 	if modelVersionMapStr == "" {
 	if modelVersionMapStr == "" {

+ 5 - 1
relay/channel/gemini/adaptor.go

@@ -57,7 +57,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	if request == nil {
 	if request == nil {
 		return nil, errors.New("request is nil")
 		return nil, errors.New("request is nil")
 	}
 	}
-	return CovertGemini2OpenAI(*request), nil
+	ai, err := CovertGemini2OpenAI(*request)
+	if err != nil {
+		return nil, err
+	}
+	return ai, nil
 }
 }
 
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

+ 0 - 4
relay/channel/gemini/constant.go

@@ -1,9 +1,5 @@
 package gemini
 package gemini
 
 
-const (
-	GeminiVisionMaxImageNum = 16
-)
-
 var ModelList = []string{
 var ModelList = []string{
 	// stable version
 	// stable version
 	"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
 	"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",

+ 14 - 9
relay/channel/gemini/relay-gemini.go

@@ -17,7 +17,7 @@ import (
 )
 )
 
 
 // Setting safety to the lowest possible values since Gemini is already powerless enough
 // Setting safety to the lowest possible values since Gemini is already powerless enough
-func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
+func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
 	geminiRequest := GeminiChatRequest{
 	geminiRequest := GeminiChatRequest{
 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
 		SafetySettings: []GeminiChatSafetySettings{
 		SafetySettings: []GeminiChatSafetySettings{
@@ -108,9 +108,10 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
 				})
 				})
 			} else if part.Type == dto.ContentTypeImageURL {
 			} else if part.Type == dto.ContentTypeImageURL {
 				imageNum += 1
 				imageNum += 1
-				//if imageNum > GeminiVisionMaxImageNum {
-				//	continue
-				//}
+
+				if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
+					return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
+				}
 				// 判断是否是url
 				// 判断是否是url
 				if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
 				if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
 					// 是url,获取图片的类型和base64编码的数据
 					// 是url,获取图片的类型和base64编码的数据
@@ -124,7 +125,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
 				} else {
 				} else {
 					_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
 					_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
 					if err != nil {
 					if err != nil {
-						continue
+						return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
 					}
 					}
 					parts = append(parts, GeminiPart{
 					parts = append(parts, GeminiPart{
 						InlineData: &GeminiInlineData{
 						InlineData: &GeminiInlineData{
@@ -161,7 +162,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
 		//	shouldAddDummyModelMessage = false
 		//	shouldAddDummyModelMessage = false
 		//}
 		//}
 	}
 	}
-	return &geminiRequest
+	return &geminiRequest, nil
 }
 }
 
 
 func (g *GeminiChatResponse) GetResponseText() string {
 func (g *GeminiChatResponse) GetResponseText() string {
@@ -236,13 +237,17 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
 	var choice dto.ChatCompletionsStreamResponseChoice
 	var choice dto.ChatCompletionsStreamResponseChoice
 	//choice.Delta.SetContentString(geminiResponse.GetResponseText())
 	//choice.Delta.SetContentString(geminiResponse.GetResponseText())
 	if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
 	if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
-		respFirst := geminiResponse.Candidates[0].Content.Parts[0]
-		if respFirst.FunctionCall != nil {
+		respFirstParts := geminiResponse.Candidates[0].Content.Parts
+		if respFirstParts[0].FunctionCall != nil {
 			// function response
 			// function response
 			choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
 			choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
 		} else {
 		} else {
 			// text response
 			// text response
-			choice.Delta.SetContentString(respFirst.Text)
+			var texts []string
+			for _, part := range respFirstParts {
+				texts = append(texts, part.Text)
+			}
+			choice.Delta.SetContentString(strings.Join(texts, "\n"))
 		}
 		}
 	}
 	}
 	var response dto.ChatCompletionsStreamResponse
 	var response dto.ChatCompletionsStreamResponse

+ 4 - 1
relay/channel/vertex/adaptor.go

@@ -135,7 +135,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 		c.Set("request_model", request.Model)
 		c.Set("request_model", request.Model)
 		return vertexClaudeReq, nil
 		return vertexClaudeReq, nil
 	} else if a.RequestMode == RequestModeGemini {
 	} else if a.RequestMode == RequestModeGemini {
-		geminiRequest := gemini.CovertGemini2OpenAI(*request)
+		geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
+		if err != nil {
+			return nil, err
+		}
 		c.Set("request_model", request.Model)
 		c.Set("request_model", request.Model)
 		return geminiRequest, nil
 		return geminiRequest, nil
 	} else if a.RequestMode == RequestModeLlama {
 	} else if a.RequestMode == RequestModeLlama {

+ 1 - 6
web/src/pages/Setting/Operation/ModelSettingsVisualEditor.js

@@ -75,7 +75,7 @@ export default function ModelSettingsVisualEditor(props) {
           output.ModelPrice[model.name] = parseFloat(model.price)
           output.ModelPrice[model.name] = parseFloat(model.price)
         } else {
         } else {
           if (model.ratio !== '') output.ModelRatio[model.name] = parseFloat(model.ratio);
           if (model.ratio !== '') output.ModelRatio[model.name] = parseFloat(model.ratio);
-          if (model.completionRatio != '') output.CompletionRatio[model.name] = parseFloat(model.completionRatio);
+          if (model.completionRatio !== '') output.CompletionRatio[model.name] = parseFloat(model.completionRatio);
         }
         }
       });
       });
 
 
@@ -203,11 +203,6 @@ export default function ModelSettingsVisualEditor(props) {
       showError('模型名称已存在');
       showError('模型名称已存在');
       return;
       return;
     }
     }
-    // 不允许同时添加固定价格和倍率
-    if (values.price !== '' && (values.ratio !== '' || values.completionRatio !== '')) {
-      showError('固定价格和倍率不能同时存在');
-      return;
-    }
     setModels(prev => [{
     setModels(prev => [{
       name: values.name,
       name: values.name,
       price: values.price || '',
       price: values.price || '',