Explorar o código

fix(adaptor): optimize multipart form handling and resource management

CaIon hai 6 meses
pai
achega
e77effaf8b
Modificáronse 1 ficheiros con 28 adicións e 23 borrados
  1. 28 23
      relay/channel/openai/adaptor.go

+ 28 - 23
relay/channel/openai/adaptor.go

@@ -359,40 +359,42 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 		writer := multipart.NewWriter(&requestBody)
 
 		writer.WriteField("model", request.Model)
-		// 获取所有表单字段
-		formData := c.Request.PostForm
-		// 遍历表单字段并打印输出
-		for key, values := range formData {
-			if key == "model" {
-				continue
-			}
-			for _, value := range values {
-				writer.WriteField(key, value)
+		// 使用已解析的 multipart 表单,避免重复解析
+		mf := c.Request.MultipartForm
+		if mf == nil {
+			if _, err := c.MultipartForm(); err != nil {
+				return nil, errors.New("failed to parse multipart form")
 			}
+			mf = c.Request.MultipartForm
 		}
 
-		// Parse the multipart form to handle both single image and multiple images
-		if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
-			return nil, errors.New("failed to parse multipart form")
+		// 写入所有非文件字段
+		if mf != nil {
+			for key, values := range mf.Value {
+				if key == "model" {
+					continue
+				}
+				for _, value := range values {
+					writer.WriteField(key, value)
+				}
+			}
 		}
 
-		if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+		if mf != nil && mf.File != nil {
 			// Check if "image" field exists in any form, including array notation
 			var imageFiles []*multipart.FileHeader
 			var exists bool
 
 			// First check for standard "image" field
-			if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+			if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
 				// If not found, check for "image[]" field
-				if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+				if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
 					// If still not found, iterate through all fields to find any that start with "image["
 					foundArrayImages := false
-					for fieldName, files := range c.Request.MultipartForm.File {
+					for fieldName, files := range mf.File {
 						if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
 							foundArrayImages = true
-							for _, file := range files {
-								imageFiles = append(imageFiles, file)
-							}
+							imageFiles = append(imageFiles, files...)
 						}
 					}
 
@@ -409,7 +411,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 				if err != nil {
 					return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
 				}
-				defer file.Close()
 
 				// If multiple images, use image[] as the field name
 				fieldName := "image"
@@ -433,15 +434,18 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 				if _, err := io.Copy(part, file); err != nil {
 					return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
 				}
+
+				// 复制完立即关闭,避免在循环内使用 defer 占用资源
+				_ = file.Close()
 			}
 
 			// Handle mask file if present
-			if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+			if maskFiles, exists := mf.File["mask"]; exists && len(maskFiles) > 0 {
 				maskFile, err := maskFiles[0].Open()
 				if err != nil {
 					return nil, errors.New("failed to open mask file")
 				}
-				defer maskFile.Close()
+				// 复制完立即关闭,避免在循环内使用 defer 占用资源
 
 				// Determine MIME type for mask file
 				mimeType := detectImageMimeType(maskFiles[0].Filename)
@@ -459,6 +463,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 				if _, err := io.Copy(maskPart, maskFile); err != nil {
 					return nil, errors.New("copy mask file failed")
 				}
+				_ = maskFile.Close()
 			}
 		} else {
 			return nil, errors.New("no multipart form data found")
@@ -467,7 +472,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 		// 关闭 multipart 编写器以设置分界线
 		writer.Close()
 		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
-		return bytes.NewReader(requestBody.Bytes()), nil
+		return &requestBody, nil
 
 	default:
 		return request, nil