Explorar el Código

feat: enhance file handling and logging in the application

CaIon hace 11 meses
padre
commit
cca9c0479f
Se han modificado 5 ficheros con 82 adiciones y 9 borrados
  1. 5 5
      common/init.go
  2. 46 0
      dto/openai_request.go
  3. 1 1
      main.go
  4. 28 3
      relay/channel/gemini/relay-gemini.go
  5. 2 0
      service/token_counter.go

+ 5 - 5
common/init.go

@@ -73,25 +73,25 @@ func LoadEnv() {
 	DebugEnabled = os.Getenv("DEBUG") == "true"
 	MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
 	IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
-	
+
 	// Parse requestInterval and set RequestInterval
 	requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
 	RequestInterval = time.Duration(requestInterval) * time.Second
-	
+
 	// Initialize variables with GetEnvOrDefault
 	SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
 	BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
 	RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
-	
+
 	// Initialize string variables with GetEnvOrDefaultString
 	GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
 	CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
-	
+
 	// Initialize rate limit variables
 	GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
 	GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
 	GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
-	
+
 	GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
 	GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
 	GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))

+ 46 - 0
dto/openai_request.go

@@ -111,6 +111,7 @@ type MediaContent struct {
 	Text       string `json:"text,omitempty"`
 	ImageUrl   any    `json:"image_url,omitempty"`
 	InputAudio any    `json:"input_audio,omitempty"`
+	File       any    `json:"file,omitempty"`
 }
 
 func (m *MediaContent) GetImageMedia() *MessageImageUrl {
@@ -120,6 +121,20 @@ func (m *MediaContent) GetImageMedia() *MessageImageUrl {
 	return nil
 }
 
+func (m *MediaContent) GetInputAudio() *MessageInputAudio {
+	if m.InputAudio != nil {
+		return m.InputAudio.(*MessageInputAudio)
+	}
+	return nil
+}
+
+func (m *MediaContent) GetFile() *MessageFile {
+	if m.File != nil {
+		return m.File.(*MessageFile)
+	}
+	return nil
+}
+
 type MessageImageUrl struct {
 	Url      string `json:"url"`
 	Detail   string `json:"detail"`
@@ -135,10 +150,17 @@ type MessageInputAudio struct {
 	Format string `json:"format"`
 }
 
+type MessageFile struct {
+	FileName string `json:"filename,omitempty"`
+	FileData string `json:"file_data,omitempty"`
+	FileId   string `json:"file_id,omitempty"`
+}
+
 const (
 	ContentTypeText       = "text"
 	ContentTypeImageURL   = "image_url"
 	ContentTypeInputAudio = "input_audio"
+	ContentTypeFile       = "file"
 )
 
 func (m *Message) GetPrefix() bool {
@@ -292,6 +314,30 @@ func (m *Message) ParseContent() []MediaContent {
 						})
 					}
 				}
+			case ContentTypeFile:
+				if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
+					fileId, ok3 := fileData["file_id"].(string)
+					if ok3 {
+						contentList = append(contentList, MediaContent{
+							Type: ContentTypeFile,
+							File: &MessageFile{
+								FileId: fileId,
+							},
+						})
+					} else {
+						fileName, ok1 := fileData["filename"].(string)
+						fileDataStr, ok2 := fileData["file_data"].(string)
+						if ok1 && ok2 {
+							contentList = append(contentList, MediaContent{
+								Type: ContentTypeFile,
+								File: &MessageFile{
+									FileName: fileName,
+									FileData: fileDataStr,
+								},
+							})
+						}
+					}
+				}
 			}
 		}
 	}

+ 1 - 1
main.go

@@ -34,7 +34,7 @@ var indexPage []byte
 func main() {
 	err := godotenv.Load(".env")
 	if err != nil {
-		common.SysLog("Support for .env file is disabled")
+		common.SysLog("Support for .env file is disabled: " + err.Error())
 	}
 
 	common.LoadEnv()

+ 28 - 3
relay/channel/gemini/relay-gemini.go

@@ -208,6 +208,34 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 						},
 					})
 				}
+			} else if part.Type == dto.ContentTypeFile {
+				if part.GetFile().FileId != "" {
+					return nil, fmt.Errorf("only base64 file is supported in gemini")
+				}
+				format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
+				if err != nil {
+					return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
+				}
+				parts = append(parts, GeminiPart{
+					InlineData: &GeminiInlineData{
+						MimeType: format,
+						Data:     base64String,
+					},
+				})
+			} else if part.Type == dto.ContentTypeInputAudio {
+				if part.GetInputAudio().Data == "" {
+					return nil, fmt.Errorf("only base64 audio is supported in gemini")
+				}
+				format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
+				if err != nil {
+					return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
+				}
+				parts = append(parts, GeminiPart{
+					InlineData: &GeminiInlineData{
+						MimeType: format,
+						Data:     base64String,
+					},
+				})
 			}
 		}
 
@@ -233,7 +261,6 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 	return &geminiRequest, nil
 }
 
-
 // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
 func cleanFunctionParameters(params interface{}) interface{} {
 	if params == nil {
@@ -307,7 +334,6 @@ func cleanFunctionParameters(params interface{}) interface{} {
 		cleanedMap["items"] = cleanedItemsArray
 	}
 
-
 	// Recursively clean other schema composition keywords if necessary
 	for _, field := range []string{"allOf", "anyOf", "oneOf"} {
 		if nested, ok := cleanedMap[field].([]interface{}); ok {
@@ -322,7 +348,6 @@ func cleanFunctionParameters(params interface{}) interface{} {
 	return cleanedMap
 }
 
-
 func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
 	if depth >= 5 {
 		return schema

+ 2 - 0
service/token_counter.go

@@ -398,6 +398,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
 				} else if m.Type == dto.ContentTypeInputAudio {
 					// TODO: 音频token数量计算
 					tokenNum += 100
+				} else if m.Type == dto.ContentTypeFile {
+					tokenNum += 5000
 				} else {
 					tokenNum += getTokenNum(tokenEncoder, m.Text)
 				}