Jelajahi Sumber

feat: add minimax image generation relay support (#4103)

forsakenyang 2 bulan lalu
induk
melakukan
c734db34e8

+ 7 - 1
relay/channel/minimax/adaptor.go

@@ -78,7 +78,10 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	return request, nil
+	if info.RelayMode != constant.RelayModeImagesGenerations {
+		return nil, fmt.Errorf("unsupported image relay mode: %d", info.RelayMode)
+	}
+	return oaiImage2MiniMaxImageRequest(request), nil
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -121,6 +124,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.RelayMode == constant.RelayModeAudioSpeech {
 		return handleTTSResponse(c, resp, info)
 	}
+	if info.RelayMode == constant.RelayModeImagesGenerations {
+		return miniMaxImageHandler(c, resp, info)
+	}
 
 	switch info.RelayFormat {
 	case types.RelayFormatClaude:

+ 137 - 0
relay/channel/minimax/adaptor_test.go

@@ -0,0 +1,137 @@
+package minimax
+
+import (
+	"encoding/json"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/dto"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	relayconstant "github.com/QuantumNous/new-api/relay/constant"
+
+	"github.com/gin-gonic/gin"
+)
+
+func TestGetRequestURLForImageGeneration(t *testing.T) {
+	t.Parallel()
+
+	info := &relaycommon.RelayInfo{
+		RelayMode: relayconstant.RelayModeImagesGenerations,
+		ChannelMeta: &relaycommon.ChannelMeta{
+			ChannelBaseUrl: "https://api.minimax.chat",
+		},
+	}
+
+	got, err := GetRequestURL(info)
+	if err != nil {
+		t.Fatalf("GetRequestURL returned error: %v", err)
+	}
+
+	want := "https://api.minimax.chat/v1/image_generation"
+	if got != want {
+		t.Fatalf("GetRequestURL() = %q, want %q", got, want)
+	}
+}
+
+func TestConvertImageRequest(t *testing.T) {
+	t.Parallel()
+
+	adaptor := &Adaptor{}
+	info := &relaycommon.RelayInfo{
+		RelayMode:       relayconstant.RelayModeImagesGenerations,
+		OriginModelName: "image-01",
+	}
+	request := dto.ImageRequest{
+		Model:          "image-01",
+		Prompt:         "a red fox in snowfall",
+		Size:           "1536x1024",
+		ResponseFormat: "url",
+		N:              uintPtr(2),
+	}
+
+	got, err := adaptor.ConvertImageRequest(gin.CreateTestContextOnly(httptest.NewRecorder(), gin.New()), info, request)
+	if err != nil {
+		t.Fatalf("ConvertImageRequest returned error: %v", err)
+	}
+
+	body, err := json.Marshal(got)
+	if err != nil {
+		t.Fatalf("json.Marshal returned error: %v", err)
+	}
+
+	var payload map[string]any
+	if err := json.Unmarshal(body, &payload); err != nil {
+		t.Fatalf("json.Unmarshal returned error: %v", err)
+	}
+
+	if payload["model"] != "image-01" {
+		t.Fatalf("model = %#v, want %q", payload["model"], "image-01")
+	}
+	if payload["prompt"] != request.Prompt {
+		t.Fatalf("prompt = %#v, want %q", payload["prompt"], request.Prompt)
+	}
+	if payload["n"] != float64(2) {
+		t.Fatalf("n = %#v, want 2", payload["n"])
+	}
+	if payload["aspect_ratio"] != "3:2" {
+		t.Fatalf("aspect_ratio = %#v, want %q", payload["aspect_ratio"], "3:2")
+	}
+	if payload["response_format"] != "url" {
+		t.Fatalf("response_format = %#v, want %q", payload["response_format"], "url")
+	}
+}
+
+func TestDoResponseForImageGeneration(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(recorder)
+
+	info := &relaycommon.RelayInfo{
+		RelayMode: relayconstant.RelayModeImagesGenerations,
+		StartTime: time.Unix(1700000000, 0),
+	}
+	resp := &http.Response{
+		StatusCode: http.StatusOK,
+		Header:     make(http.Header),
+		Body:       httptest.NewRecorder().Result().Body,
+	}
+	resp.Body = ioNopCloser(`{"data":{"image_urls":["https://example.com/minimax.png"]}}`)
+
+	adaptor := &Adaptor{}
+	usage, err := adaptor.DoResponse(c, resp, info)
+	if err != nil {
+		t.Fatalf("DoResponse returned error: %v", err)
+	}
+	if usage == nil {
+		t.Fatalf("DoResponse returned nil usage")
+	}
+
+	body := recorder.Body.String()
+	if !strings.Contains(body, `"url":"https://example.com/minimax.png"`) {
+		t.Fatalf("response body = %s, want OpenAI image response with image URL", body)
+	}
+	if strings.Contains(body, `"image_urls"`) {
+		t.Fatalf("response body = %s, should not expose raw MiniMax image_urls payload", body)
+	}
+}
+
+type nopReadCloser struct {
+	*strings.Reader
+}
+
+func (n nopReadCloser) Close() error {
+	return nil
+}
+
+func ioNopCloser(body string) nopReadCloser {
+	return nopReadCloser{Reader: strings.NewReader(body)}
+}
+
+func uintPtr(v uint) *uint {
+	return &v
+}

+ 4 - 0
relay/channel/minimax/constants.go

@@ -8,6 +8,8 @@ var ModelList = []string{
 	"abab6-chat",
 	"abab5.5-chat",
 	"abab5.5s-chat",
+	"MiniMax-M2.7",
+	"MiniMax-M2.7-highspeed",
 	"speech-2.5-hd-preview",
 	"speech-2.5-turbo-preview",
 	"speech-02-hd",
@@ -19,6 +21,8 @@ var ModelList = []string{
 	"MiniMax-M2",
 	"MiniMax-M2.5",
 	"MiniMax-M2.5-highspeed",
+	"image-01",
+	"image-01-live",
 }
 
 var ChannelName = "minimax"

+ 213 - 0
relay/channel/minimax/image.go

@@ -0,0 +1,213 @@
+package minimax
+
+import (
+	"fmt"
+	"io"
+	"net/http"
+	"strconv"
+	"strings"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/dto"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/service"
+	"github.com/QuantumNous/new-api/types"
+
+	"github.com/gin-gonic/gin"
+)
+
+type MiniMaxImageRequest struct {
+	Model           string `json:"model"`
+	Prompt          string `json:"prompt"`
+	AspectRatio     string `json:"aspect_ratio,omitempty"`
+	ResponseFormat  string `json:"response_format,omitempty"`
+	N               int    `json:"n,omitempty"`
+	PromptOptimizer *bool  `json:"prompt_optimizer,omitempty"`
+	AigcWatermark   *bool  `json:"aigc_watermark,omitempty"`
+}
+
+type MiniMaxImageResponse struct {
+	ID   string `json:"id"`
+	Data struct {
+		ImageURLs   []string `json:"image_urls"`
+		ImageBase64 []string `json:"image_base64"`
+	} `json:"data"`
+	Metadata map[string]any `json:"metadata"`
+	BaseResp struct {
+		StatusCode int    `json:"status_code"`
+		StatusMsg  string `json:"status_msg"`
+	} `json:"base_resp"`
+}
+
+func oaiImage2MiniMaxImageRequest(request dto.ImageRequest) MiniMaxImageRequest {
+	responseFormat := normalizeMiniMaxResponseFormat(request.ResponseFormat)
+	minimaxRequest := MiniMaxImageRequest{
+		Model:          request.Model,
+		Prompt:         request.Prompt,
+		ResponseFormat: responseFormat,
+		N:              1,
+		AigcWatermark:  request.Watermark,
+	}
+
+	if request.Model == "" {
+		minimaxRequest.Model = "image-01"
+	}
+	if request.N != nil && *request.N > 0 {
+		minimaxRequest.N = int(*request.N)
+	}
+	if aspectRatio := aspectRatioFromImageRequest(request); aspectRatio != "" {
+		minimaxRequest.AspectRatio = aspectRatio
+	}
+	if raw, ok := request.Extra["prompt_optimizer"]; ok {
+		var promptOptimizer bool
+		if err := common.Unmarshal(raw, &promptOptimizer); err == nil {
+			minimaxRequest.PromptOptimizer = &promptOptimizer
+		}
+	}
+
+	return minimaxRequest
+}
+
+func aspectRatioFromImageRequest(request dto.ImageRequest) string {
+	if raw, ok := request.Extra["aspect_ratio"]; ok {
+		var aspectRatio string
+		if err := common.Unmarshal(raw, &aspectRatio); err == nil && aspectRatio != "" {
+			return aspectRatio
+		}
+	}
+
+	switch request.Size {
+	case "1024x1024":
+		return "1:1"
+	case "1792x1024":
+		return "16:9"
+	case "1024x1792":
+		return "9:16"
+	case "1536x1024", "1248x832":
+		return "3:2"
+	case "1024x1536", "832x1248":
+		return "2:3"
+	case "1152x864":
+		return "4:3"
+	case "864x1152":
+		return "3:4"
+	case "1344x576":
+		return "21:9"
+	}
+
+	width, height, ok := parseImageSize(request.Size)
+	if !ok {
+		return ""
+	}
+	ratio := reduceAspectRatio(width, height)
+	switch ratio {
+	case "1:1", "16:9", "4:3", "3:2", "2:3", "3:4", "9:16", "21:9":
+		return ratio
+	default:
+		return ""
+	}
+}
+
+func parseImageSize(size string) (int, int, bool) {
+	parts := strings.Split(size, "x")
+	if len(parts) != 2 {
+		return 0, 0, false
+	}
+	width, err := strconv.Atoi(parts[0])
+	if err != nil {
+		return 0, 0, false
+	}
+	height, err := strconv.Atoi(parts[1])
+	if err != nil {
+		return 0, 0, false
+	}
+	if width <= 0 || height <= 0 {
+		return 0, 0, false
+	}
+	return width, height, true
+}
+
+func reduceAspectRatio(width, height int) string {
+	divisor := gcd(width, height)
+	return fmt.Sprintf("%d:%d", width/divisor, height/divisor)
+}
+
+func gcd(a, b int) int {
+	for b != 0 {
+		a, b = b, a%b
+	}
+	if a == 0 {
+		return 1
+	}
+	return a
+}
+
+func normalizeMiniMaxResponseFormat(responseFormat string) string {
+	switch strings.ToLower(responseFormat) {
+	case "", "url":
+		return "url"
+	case "b64_json", "base64":
+		return "base64"
+	default:
+		return responseFormat
+	}
+}
+
+func responseMiniMax2OpenAIImage(response *MiniMaxImageResponse, info *relaycommon.RelayInfo) (*dto.ImageResponse, error) {
+	imageResponse := &dto.ImageResponse{
+		Created: info.StartTime.Unix(),
+	}
+
+	for _, imageURL := range response.Data.ImageURLs {
+		imageResponse.Data = append(imageResponse.Data, dto.ImageData{Url: imageURL})
+	}
+	for _, imageBase64 := range response.Data.ImageBase64 {
+		imageResponse.Data = append(imageResponse.Data, dto.ImageData{B64Json: imageBase64})
+	}
+	if len(response.Metadata) > 0 {
+		metadata, err := common.Marshal(response.Metadata)
+		if err != nil {
+			return nil, err
+		}
+		imageResponse.Metadata = metadata
+	}
+
+	return imageResponse, nil
+}
+
+func miniMaxImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
+	}
+	service.CloseResponseBodyGracefully(resp)
+
+	var minimaxResponse MiniMaxImageResponse
+	if err := common.Unmarshal(responseBody, &minimaxResponse); err != nil {
+		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+	}
+	if minimaxResponse.BaseResp.StatusCode != 0 {
+		return nil, types.WithOpenAIError(types.OpenAIError{
+			Message: minimaxResponse.BaseResp.StatusMsg,
+			Type:    "minimax_image_error",
+			Code:    fmt.Sprintf("%d", minimaxResponse.BaseResp.StatusCode),
+		}, resp.StatusCode)
+	}
+
+	openAIResponse, err := responseMiniMax2OpenAIImage(&minimaxResponse, info)
+	if err != nil {
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+	}
+	jsonResponse, err := common.Marshal(openAIResponse)
+	if err != nil {
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+	}
+
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	if _, err := c.Writer.Write(jsonResponse); err != nil {
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+	}
+
+	return &dto.Usage{}, nil
+}

+ 2 - 0
relay/channel/minimax/relay-minimax.go

@@ -21,6 +21,8 @@ func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		switch info.RelayMode {
 		case constant.RelayModeChatCompletions:
 			return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
+		case constant.RelayModeImagesGenerations:
+			return fmt.Sprintf("%s/v1/image_generation", baseUrl), nil
 		case constant.RelayModeAudioSpeech:
 			return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
 		default: