Просмотр исходного кода

feat: support /images/edit

(cherry picked from commit 1c0a1238787d490f02dd9269b616580a16604180)
xyfacai 10 месяцев назад
Родитель
Сommit
f9f32a0158

+ 1 - 1
controller/relay.go

@@ -24,7 +24,7 @@ import (
 func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
 	var err *dto.OpenAIErrorWithStatusCode
 	switch relayMode {
-	case relayconstant.RelayModeImagesGenerations:
+	case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
 		err = relay.ImageHelper(c)
 	case relayconstant.RelayModeAudioSpeech:
 		fallthrough

+ 8 - 4
dto/openai_response.go

@@ -166,12 +166,16 @@ type CompletionsStreamResponse struct {
 }
 
 type Usage struct {
-	PromptTokens           int                `json:"prompt_tokens"`
-	CompletionTokens       int                `json:"completion_tokens"`
-	TotalTokens            int                `json:"total_tokens"`
-	PromptCacheHitTokens   int                `json:"prompt_cache_hit_tokens,omitempty"`
+	PromptTokens         int `json:"prompt_tokens"`
+	CompletionTokens     int `json:"completion_tokens"`
+	TotalTokens          int `json:"total_tokens"`
+	PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
+
 	PromptTokensDetails    InputTokenDetails  `json:"prompt_tokens_details"`
 	CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
+	InputTokens            int                `json:"input_tokens"`
+	OutputTokens           int                `json:"output_tokens"`
+	InputTokensDetails     *InputTokenDetails `json:"input_tokens_details"`
 }
 
 type InputTokenDetails struct {

+ 3 - 1
middleware/distributor.go

@@ -162,7 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 		c.Set("platform", string(constant.TaskPlatformSuno))
 		c.Set("relay_mode", relayMode)
-	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
+	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
 		err = common.UnmarshalBodyReusable(c, &modelRequest)
 	}
 	if err != nil {
@@ -184,6 +184,8 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 		modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
+		modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "gpt-image-1")
 	}
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 		relayMode := relayconstant.RelayModeAudioSpeech

+ 57 - 4
relay/channel/openai/adaptor.go

@@ -236,11 +236,64 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	return request, nil
+	switch info.RelayMode {
+	case constant.RelayModeImagesEdits:
+		body, err := common.GetRequestBody(c)
+		if err != nil {
+			return nil, errors.New("get request body fail")
+		}
+		return bytes.NewReader(body), nil
+
+		/*var requestBody bytes.Buffer
+		writer := multipart.NewWriter(&requestBody)
+
+		writer.WriteField("model", request.Model)
+		// 获取所有表单字段
+		formData := c.Request.PostForm
+		// 遍历表单字段并打印输出
+		for key, values := range formData {
+			if key == "model" {
+				continue
+			}
+			for _, value := range values {
+				writer.WriteField(key, value)
+			}
+		}
+
+		// 添加文件字段
+		imageFiles := c.Request.MultipartForm.File["image[]"]
+		for _, file := range imageFiles {
+			part, err := writer.CreateFormFile("image[]", file.Filename)
+			if err != nil {
+				return nil, errors.New("create form file failed")
+			}
+			// 打开文件
+			src, err := file.Open()
+			if err != nil {
+				return nil, errors.New("open file failed")
+			}
+			// 将文件数据写入 form part
+			_, err = io.Copy(part, src)
+			if err != nil {
+				return nil, errors.New("copy file failed")
+			}
+			src.Close()
+		}
+
+		// 关闭 multipart 编写器以设置分界线
+		writer.Close()
+		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+		return bytes.NewReader(requestBody.Bytes()), nil*/
+
+	default:
+		return request, nil
+	}
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
-	if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
+	if info.RelayMode == constant.RelayModeAudioTranscription ||
+		info.RelayMode == constant.RelayModeAudioTranslation ||
+		info.RelayMode == constant.RelayModeImagesEdits {
 		return channel.DoFormRequest(a, c, info, requestBody)
 	} else if info.RelayMode == constant.RelayModeRealtime {
 		return channel.DoWssRequest(a, c, info, requestBody)
@@ -259,8 +312,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		fallthrough
 	case constant.RelayModeAudioTranscription:
 		err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
-	case constant.RelayModeImagesGenerations:
-		err, usage = OpenaiTTSHandler(c, resp, info)
+	case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
+		err, usage = OpenaiHandlerWithUsage(c, resp, info)
 	case constant.RelayModeRerank:
 		err, usage = common_handler.RerankHandler(c, info, resp)
 	default:

+ 49 - 0
relay/channel/openai/relay-openai.go

@@ -595,3 +595,52 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
 	err := service.PreWssConsumeQuota(ctx, info, usage)
 	return err
 }
+
+func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	// Reset response body
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	// We shouldn't set the header before we parse the response body, because the parse part may fail.
+	// And then we will have to send an error response, but in this case, the header has already been set.
+	// So the httpClient will be confused by the response.
+	// For example, Postman will report error, and we cannot check the response at all.
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	// reset content length
+	c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	var usageResp dto.SimpleResponse
+	err = json.Unmarshal(responseBody, &usageResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
+	}
+	// format
+	if usageResp.InputTokens > 0 {
+		usageResp.PromptTokens += usageResp.InputTokens
+	}
+	if usageResp.OutputTokens > 0 {
+		usageResp.CompletionTokens += usageResp.OutputTokens
+	}
+	if usageResp.InputTokensDetails != nil {
+		usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
+		usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
+	}
+	return nil, &usageResp.Usage
+}

+ 3 - 0
relay/constant/relay_mode.go

@@ -12,6 +12,7 @@ const (
 	RelayModeEmbeddings
 	RelayModeModerations
 	RelayModeImagesGenerations
+	RelayModeImagesEdits
 	RelayModeEdits
 
 	RelayModeMidjourneyImagine
@@ -56,6 +57,8 @@ func Path2RelayMode(path string) int {
 		relayMode = RelayModeModerations
 	} else if strings.HasPrefix(path, "/v1/images/generations") {
 		relayMode = RelayModeImagesGenerations
+	} else if strings.HasPrefix(path, "/v1/images/edits") {
+		relayMode = RelayModeImagesEdits
 	} else if strings.HasPrefix(path, "/v1/edits") {
 		relayMode = RelayModeEdits
 	} else if strings.HasPrefix(path, "/v1/audio/speech") {

+ 6 - 2
relay/helper/price.go

@@ -15,14 +15,15 @@ type PriceData struct {
 	ModelRatio             float64
 	CompletionRatio        float64
 	CacheRatio             float64
+	CacheCreationRatio     float64
+	ImageRatio             float64
 	GroupRatio             float64
 	UsePrice               bool
-	CacheCreationRatio     float64
 	ShouldPreConsumedQuota int
 }
 
 func (p PriceData) ToSetting() string {
-	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota)
+	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
 }
 
 func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
@@ -32,6 +33,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 	var modelRatio float64
 	var completionRatio float64
 	var cacheRatio float64
+	var imageRatio float64
 	var cacheCreationRatio float64
 	if !usePrice {
 		preConsumedTokens := common.PreConsumedQuota
@@ -55,6 +57,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 		completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
 		cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
 		cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
+		imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
 		ratio := modelRatio * groupRatio
 		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
 	} else {
@@ -68,6 +71,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 		GroupRatio:             groupRatio,
 		UsePrice:               usePrice,
 		CacheRatio:             cacheRatio,
+		ImageRatio:             imageRatio,
 		CacheCreationRatio:     cacheCreationRatio,
 		ShouldPreConsumedQuota: preConsumedQuota,
 	}

+ 106 - 65
relay/relay-image.go

@@ -12,6 +12,7 @@ import (
 	"one-api/dto"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/setting"
@@ -20,13 +21,56 @@ import (
 
 func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
 	imageRequest := &dto.ImageRequest{}
-	err := common.UnmarshalBodyReusable(c, imageRequest)
-	if err != nil {
-		return nil, err
+
+	switch info.RelayMode {
+	case relayconstant.RelayModeImagesEdits:
+		_, err := c.MultipartForm()
+		if err != nil {
+			return nil, err
+		}
+		formData := c.Request.PostForm
+		imageRequest.Prompt = formData.Get("prompt")
+		imageRequest.Model = formData.Get("model")
+		imageRequest.N = common.String2Int(formData.Get("n"))
+		imageRequest.Quality = formData.Get("quality")
+		imageRequest.Size = formData.Get("size")
+
+		if imageRequest.Model == "gpt-image-1" {
+			if imageRequest.Quality == "" {
+				imageRequest.Quality = "standard"
+			}
+		}
+	default:
+		err := common.UnmarshalBodyReusable(c, imageRequest)
+		if err != nil {
+			return nil, err
+		}
+		// Not "256x256", "512x512", or "1024x1024"
+		if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
+			if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
+				return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
+			}
+		} else if imageRequest.Model == "dall-e-3" {
+			if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
+				return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
+			}
+			if imageRequest.Quality == "" {
+				imageRequest.Quality = "standard"
+			}
+			// N should between 1 and 10
+			//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
+			//	return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
+			//}
+		}
 	}
+
 	if imageRequest.Prompt == "" {
 		return nil, errors.New("prompt is required")
 	}
+
+	if imageRequest.Model == "" {
+		imageRequest.Model = "dall-e-2"
+	}
 	if strings.Contains(imageRequest.Size, "×") {
 		return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
 	}
@@ -36,30 +80,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 	if imageRequest.Size == "" {
 		imageRequest.Size = "1024x1024"
 	}
-	if imageRequest.Model == "" {
-		imageRequest.Model = "dall-e-2"
-	}
 
-	// Not "256x256", "512x512", or "1024x1024"
-	if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
-		if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
-			return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
-		}
-	} else if imageRequest.Model == "dall-e-3" {
-		if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
-			return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
-		}
-		if imageRequest.Quality == "" {
-			imageRequest.Quality = "standard"
-		}
-		//if imageRequest.N != 1 {
-		//	return nil, errors.New("n must be 1")
-		//}
-	}
-	// N should between 1 and 10
-	//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
-	//	return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
-	//}
 	if setting.ShouldCheckPromptSensitive() {
 		words, err := service.CheckSensitiveInput(imageRequest.Prompt)
 		if err != nil {
@@ -86,43 +107,59 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 
 	imageRequest.Model = relayInfo.UpstreamModelName
 
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
+	priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
 	}
+	var preConsumedQuota int
+	var quota int
+	var userQuota int
 	if !priceData.UsePrice {
 		// modelRatio 16 = modelPrice $0.04
 		// per 1 modelRatio = $0.04 / 16
-		priceData.ModelPrice = 0.0025 * priceData.ModelRatio
-	}
-
-	userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
-
-	sizeRatio := 1.0
-	// Size
-	if imageRequest.Size == "256x256" {
-		sizeRatio = 0.4
-	} else if imageRequest.Size == "512x512" {
-		sizeRatio = 0.45
-	} else if imageRequest.Size == "1024x1024" {
-		sizeRatio = 1
-	} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
-		sizeRatio = 2
-	}
-
-	qualityRatio := 1.0
-	if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
-		qualityRatio = 2.0
-		if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
-			qualityRatio = 1.5
+		// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
+		var openaiErr *dto.OpenAIErrorWithStatusCode
+		preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+		if openaiErr != nil {
+			return openaiErr
+		}
+		defer func() {
+			if openaiErr != nil {
+				returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+			}
+		}()
+
+	} else {
+		sizeRatio := 1.0
+		// Size
+		if imageRequest.Size == "256x256" {
+			sizeRatio = 0.4
+		} else if imageRequest.Size == "512x512" {
+			sizeRatio = 0.45
+		} else if imageRequest.Size == "1024x1024" {
+			sizeRatio = 1
+		} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
+			sizeRatio = 2
 		}
-	}
 
-	priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
-	quota := int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
+		qualityRatio := 1.0
+		if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
+			qualityRatio = 2.0
+			if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
+				qualityRatio = 1.5
+			}
+		}
 
-	if userQuota-quota < 0 {
-		return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
+		// reset model price
+		priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
+		quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
+		userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
+		if err != nil {
+			return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
+		}
+		if userQuota-quota < 0 {
+			return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
+		}
 	}
 
 	adaptor := GetAdaptor(relayInfo.ApiType)
@@ -137,12 +174,15 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
 	}
-
-	jsonData, err := json.Marshal(convertedRequest)
-	if err != nil {
-		return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
+	if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
+		requestBody = convertedRequest.(io.Reader)
+	} else {
+		jsonData, err := json.Marshal(convertedRequest)
+		if err != nil {
+			return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonData)
 	}
-	requestBody = bytes.NewBuffer(jsonData)
 
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 
@@ -162,24 +202,25 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		}
 	}
 
-	_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
+	usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
 	if openaiErr != nil {
 		// reset status code 重置状态码
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
 
-	usage := &dto.Usage{
-		PromptTokens: imageRequest.N,
-		TotalTokens:  imageRequest.N,
+	if usage.(*dto.Usage).TotalTokens == 0 {
+		usage.(*dto.Usage).TotalTokens = imageRequest.N
+	}
+	if usage.(*dto.Usage).PromptTokens == 0 {
+		usage.(*dto.Usage).PromptTokens = imageRequest.N
 	}
-
 	quality := "standard"
 	if imageRequest.Quality == "hd" {
 		quality = "hd"
 	}
 
 	logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
-	postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
+	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
 	return nil
 }

+ 11 - 0
relay/relay-text.go

@@ -331,12 +331,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens
 	cacheTokens := usage.PromptTokensDetails.CachedTokens
+	imageTokens := usage.PromptTokensDetails.ImageTokens
 	completionTokens := usage.CompletionTokens
 	modelName := relayInfo.OriginModelName
 
 	tokenName := ctx.GetString("token_name")
 	completionRatio := priceData.CompletionRatio
 	cacheRatio := priceData.CacheRatio
+	imageRatio := priceData.ImageRatio
 	modelRatio := priceData.ModelRatio
 	groupRatio := priceData.GroupRatio
 	modelPrice := priceData.ModelPrice
@@ -344,9 +346,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	// Convert values to decimal for precise calculation
 	dPromptTokens := decimal.NewFromInt(int64(promptTokens))
 	dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
+	dImageTokens := decimal.NewFromInt(int64(imageTokens))
 	dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
 	dCompletionRatio := decimal.NewFromFloat(completionRatio)
 	dCacheRatio := decimal.NewFromFloat(cacheRatio)
+	dImageRatio := decimal.NewFromFloat(imageRatio)
 	dModelRatio := decimal.NewFromFloat(modelRatio)
 	dGroupRatio := decimal.NewFromFloat(groupRatio)
 	dModelPrice := decimal.NewFromFloat(modelPrice)
@@ -358,7 +362,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	if !priceData.UsePrice {
 		nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
 		cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
+
 		promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
+		if imageTokens > 0 {
+			nonImageTokens := dPromptTokens.Sub(dImageTokens)
+			imageTokensWithRatio := dImageTokens.Mul(dImageRatio)
+			promptQuota = nonImageTokens.Add(imageTokensWithRatio)
+		}
+
 		completionQuota := dCompletionTokens.Mul(dCompletionRatio)
 
 		quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)

+ 1 - 1
router/relay-router.go

@@ -40,7 +40,7 @@ func SetRelayRouter(router *gin.Engine) {
 		httpRouter.POST("/chat/completions", controller.Relay)
 		httpRouter.POST("/edits", controller.Relay)
 		httpRouter.POST("/images/generations", controller.Relay)
-		httpRouter.POST("/images/edits", controller.RelayNotImplemented)
+		httpRouter.POST("/images/edits", controller.Relay)
 		httpRouter.POST("/images/variations", controller.RelayNotImplemented)
 		httpRouter.POST("/embeddings", controller.Relay)
 		httpRouter.POST("/engines/:model/embeddings", controller.Relay)

+ 59 - 23
setting/operation_setting/model-ratio.go

@@ -51,26 +51,27 @@ var defaultModelRatio = map[string]float64{
 	"gpt-4o-realtime-preview-2024-12-17":      2.5,
 	"gpt-4o-mini-realtime-preview":            0.3,
 	"gpt-4o-mini-realtime-preview-2024-12-17": 0.3,
-	"o1":                         7.5,
-	"o1-2024-12-17":              7.5,
-	"o1-preview":                 7.5,
-	"o1-preview-2024-09-12":      7.5,
-	"o1-mini":                    0.55,
-	"o1-mini-2024-09-12":         0.55,
-	"o3-mini":                    0.55,
-	"o3-mini-2025-01-31":         0.55,
-	"o3-mini-high":               0.55,
-	"o3-mini-2025-01-31-high":    0.55,
-	"o3-mini-low":                0.55,
-	"o3-mini-2025-01-31-low":     0.55,
-	"o3-mini-medium":             0.55,
-	"o3-mini-2025-01-31-medium":  0.55,
-	"gpt-4o-mini":                0.075,
-	"gpt-4o-mini-2024-07-18":     0.075,
-	"gpt-4-turbo":                5, // $0.01 / 1K tokens
-	"gpt-4-turbo-2024-04-09":     5, // $0.01 / 1K tokens
-	"gpt-4.5-preview":            37.5,
-	"gpt-4.5-preview-2025-02-27": 37.5,
+	"gpt-image-1":                             2.5,
+	"o1":                                      7.5,
+	"o1-2024-12-17":                           7.5,
+	"o1-preview":                              7.5,
+	"o1-preview-2024-09-12":                   7.5,
+	"o1-mini":                                 0.55,
+	"o1-mini-2024-09-12":                      0.55,
+	"o3-mini":                                 0.55,
+	"o3-mini-2025-01-31":                      0.55,
+	"o3-mini-high":                            0.55,
+	"o3-mini-2025-01-31-high":                 0.55,
+	"o3-mini-low":                             0.55,
+	"o3-mini-2025-01-31-low":                  0.55,
+	"o3-mini-medium":                          0.55,
+	"o3-mini-2025-01-31-medium":               0.55,
+	"gpt-4o-mini":                             0.075,
+	"gpt-4o-mini-2024-07-18":                  0.075,
+	"gpt-4-turbo":                             5, // $0.01 / 1K tokens
+	"gpt-4-turbo-2024-04-09":                  5, // $0.01 / 1K tokens
+	"gpt-4.5-preview":                         37.5,
+	"gpt-4.5-preview-2025-02-27":              37.5,
 	//"gpt-3.5-turbo-0301":           0.75, //deprecated
 	"gpt-3.5-turbo":          0.25,
 	"gpt-3.5-turbo-0613":     0.75,
@@ -255,6 +256,7 @@ var defaultCompletionRatio = map[string]float64{
 	"gpt-4-gizmo-*":  2,
 	"gpt-4o-gizmo-*": 3,
 	"gpt-4-all":      2,
+	"gpt-image-1":    8,
 }
 
 // InitModelSettings initializes all model related settings maps
@@ -275,9 +277,10 @@ func InitModelSettings() {
 	CompletionRatioMutex.Unlock()
 
 	// Initialize cacheRatioMap
-	cacheRatioMapMutex.Lock()
-	cacheRatioMap = defaultCacheRatio
-	cacheRatioMapMutex.Unlock()
+	imageRatioMapMutex.Lock()
+	imageRatioMap = defaultImageRatio
+	imageRatioMapMutex.Unlock()
+
 }
 
 func GetModelPriceMap() map[string]float64 {
@@ -548,3 +551,36 @@ func ModelRatio2JSONString() string {
 	}
 	return string(jsonBytes)
 }
+
+var defaultImageRatio = map[string]float64{
+	"gpt-image-1": 2,
+}
+var imageRatioMap map[string]float64
+var imageRatioMapMutex sync.RWMutex
+
+func ImageRatio2JSONString() string {
+	imageRatioMapMutex.RLock()
+	defer imageRatioMapMutex.RUnlock()
+	jsonBytes, err := json.Marshal(imageRatioMap)
+	if err != nil {
+		common.SysError("error marshalling cache ratio: " + err.Error())
+	}
+	return string(jsonBytes)
+}
+
+func UpdateImageRatioByJSONString(jsonStr string) error {
+	imageRatioMapMutex.Lock()
+	defer imageRatioMapMutex.Unlock()
+	imageRatioMap = make(map[string]float64)
+	return json.Unmarshal([]byte(jsonStr), &imageRatioMap)
+}
+
+func GetImageRatio(name string) (float64, bool) {
+	imageRatioMapMutex.RLock()
+	defer imageRatioMapMutex.RUnlock()
+	ratio, ok := imageRatioMap[name]
+	if !ok {
+		return 1, false // Default to 1 if not found
+	}
+	return ratio, true
+}