Przeglądaj źródła

Merge branch 'main' into ssrf

# Conflicts:
#	service/cf_worker.go
CaIon 5 miesięcy temu
rodzic
commit
2ea7634549

+ 39 - 0
controller/channel-test.go

@@ -90,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		requestPath = "/v1/embeddings" // 修改请求路径
 	}
 
+	// VolcEngine 图像生成模型
+	if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
+		requestPath = "/v1/images/generations"
+	}
+
 	c.Request = &http.Request{
 		Method: "POST",
 		URL:    &url.URL{Path: requestPath}, // 使用动态路径
@@ -109,6 +114,21 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		}
 	}
 
+	// 重新检查模型类型并更新请求路径
+	if strings.Contains(strings.ToLower(testModel), "embedding") ||
+		strings.HasPrefix(testModel, "m3e") ||
+		strings.Contains(testModel, "bge-") ||
+		strings.Contains(testModel, "embed") ||
+		channel.Type == constant.ChannelTypeMokaAI {
+		requestPath = "/v1/embeddings"
+		c.Request.URL.Path = requestPath
+	}
+
+	if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
+		requestPath = "/v1/images/generations"
+		c.Request.URL.Path = requestPath
+	}
+
 	cache, err := model.GetUserCache(1)
 	if err != nil {
 		return testResult{
@@ -140,6 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	if c.Request.URL.Path == "/v1/embeddings" {
 		relayFormat = types.RelayFormatEmbedding
 	}
+	if c.Request.URL.Path == "/v1/images/generations" {
+		relayFormat = types.RelayFormatOpenAIImage
+	}
 
 	info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
 
@@ -201,6 +224,22 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		}
 		// 调用专门用于 Embedding 的转换函数
 		convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
+	} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
+		// 创建一个 ImageRequest
+		prompt := "cat"
+		if request.Prompt != nil {
+			if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
+				prompt = promptStr
+			}
+		}
+		imageRequest := dto.ImageRequest{
+			Prompt: prompt,
+			Model:  request.Model,
+			N:      uint(request.N),
+			Size:   request.Size,
+		}
+		// 调用专门用于图像生成的转换函数
+		convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
 	} else {
 		// 对其他所有请求类型(如 Chat),保持原有逻辑
 		convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)

+ 10 - 10
controller/setup.go

@@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) {
 func PostSetup(c *gin.Context) {
 	// Check if setup is already completed
 	if constant.Setup {
-		c.JSON(400, gin.H{
+		c.JSON(200, gin.H{
 			"success": false,
 			"message": "系统已经初始化完成",
 		})
@@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) {
 	var req SetupRequest
 	err := c.ShouldBindJSON(&req)
 	if err != nil {
-		c.JSON(400, gin.H{
+		c.JSON(200, gin.H{
 			"success": false,
 			"message": "请求参数有误",
 		})
@@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) {
 	if !rootExists {
 		// Validate username length: max 12 characters to align with model.User validation
 		if len(req.Username) > 12 {
-			c.JSON(400, gin.H{
+			c.JSON(200, gin.H{
 				"success": false,
 				"message": "用户名长度不能超过12个字符",
 			})
@@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) {
 		}
 		// Validate password
 		if req.Password != req.ConfirmPassword {
-			c.JSON(400, gin.H{
+			c.JSON(200, gin.H{
 				"success": false,
 				"message": "两次输入的密码不一致",
 			})
@@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) {
 		}
 
 		if len(req.Password) < 8 {
-			c.JSON(400, gin.H{
+			c.JSON(200, gin.H{
 				"success": false,
 				"message": "密码长度至少为8个字符",
 			})
@@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) {
 		// Create root user
 		hashedPassword, err := common.Password2Hash(req.Password)
 		if err != nil {
-			c.JSON(500, gin.H{
+			c.JSON(200, gin.H{
 				"success": false,
 				"message": "系统错误: " + err.Error(),
 			})
@@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) {
 		}
 		err = model.DB.Create(&rootUser).Error
 		if err != nil {
-			c.JSON(500, gin.H{
+			c.JSON(200, gin.H{
 				"success": false,
 				"message": "创建管理员账号失败: " + err.Error(),
 			})
@@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) {
 	// Save operation modes to database for persistence
 	err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
 	if err != nil {
-		c.JSON(500, gin.H{
+		c.JSON(200, gin.H{
 			"success": false,
 			"message": "保存自用模式设置失败: " + err.Error(),
 		})
@@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) {
 
 	err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
 	if err != nil {
-		c.JSON(500, gin.H{
+		c.JSON(200, gin.H{
 			"success": false,
 			"message": "保存演示站点模式设置失败: " + err.Error(),
 		})
@@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) {
 	}
 	err = model.DB.Create(&setup).Error
 	if err != nil {
-		c.JSON(500, gin.H{
+		c.JSON(200, gin.H{
 			"success": false,
 			"message": "系统初始化失败: " + err.Error(),
 		})

+ 1 - 1
controller/topup_stripe.go

@@ -217,7 +217,7 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
 
 	params := &stripe.CheckoutSessionParams{
 		ClientReferenceID: stripe.String(referenceId),
-		SuccessURL:        stripe.String(system_setting.ServerAddress + "/log"),
+		SuccessURL:        stripe.String(system_setting.ServerAddress + "/console/log"),
 		CancelURL:         stripe.String(system_setting.ServerAddress + "/topup"),
 		LineItems: []*stripe.CheckoutSessionLineItemParams{
 			{

+ 42 - 0
dto/openai_response.go

@@ -6,6 +6,10 @@ import (
 	"one-api/types"
 )
 
+const (
+	ResponsesOutputTypeImageGenerationCall = "image_generation_call"
+)
+
 type SimpleResponse struct {
 	Usage `json:"usage"`
 	Error any `json:"error"`
@@ -273,6 +277,42 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
 	return GetOpenAIError(o.Error)
 }
 
+func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool {
+	if len(o.Output) == 0 {
+		return false
+	}
+	for _, output := range o.Output {
+		if output.Type == ResponsesOutputTypeImageGenerationCall {
+			return true
+		}
+	}
+	return false
+}
+
+func (o *OpenAIResponsesResponse) GetQuality() string {
+	if len(o.Output) == 0 {
+		return ""
+	}
+	for _, output := range o.Output {
+		if output.Type == ResponsesOutputTypeImageGenerationCall {
+			return output.Quality
+		}
+	}
+	return ""
+}
+
+func (o *OpenAIResponsesResponse) GetSize() string {
+	if len(o.Output) == 0 {
+		return ""
+	}
+	for _, output := range o.Output {
+		if output.Type == ResponsesOutputTypeImageGenerationCall {
+			return output.Size
+		}
+	}
+	return ""
+}
+
 type IncompleteDetails struct {
 	Reasoning string `json:"reasoning"`
 }
@@ -283,6 +323,8 @@ type ResponsesOutput struct {
 	Status  string                   `json:"status"`
 	Role    string                   `json:"role"`
 	Content []ResponsesOutputContent `json:"content"`
+	Quality string                   `json:"quality"`
+	Size    string                   `json:"size"`
 }
 
 type ResponsesOutputContent struct {

+ 24 - 11
relay/channel/openai/relay_responses.go

@@ -33,6 +33,12 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 		return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
 	}
 
+	if responsesResponse.HasImageGenerationCall() {
+		c.Set("image_generation_call", true)
+		c.Set("image_generation_call_quality", responsesResponse.GetQuality())
+		c.Set("image_generation_call_size", responsesResponse.GetSize())
+	}
+
 	// 写入新的 response body
 	service.IOCopyBytesGracefully(c, resp, responseBody)
 
@@ -80,18 +86,25 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
 			sendResponsesStreamData(c, streamResponse, data)
 			switch streamResponse.Type {
 			case "response.completed":
-				if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
-					if streamResponse.Response.Usage.InputTokens != 0 {
-						usage.PromptTokens = streamResponse.Response.Usage.InputTokens
-					}
-					if streamResponse.Response.Usage.OutputTokens != 0 {
-						usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
-					}
-					if streamResponse.Response.Usage.TotalTokens != 0 {
-						usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+				if streamResponse.Response != nil {
+					if streamResponse.Response.Usage != nil {
+						if streamResponse.Response.Usage.InputTokens != 0 {
+							usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+						}
+						if streamResponse.Response.Usage.OutputTokens != 0 {
+							usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
+						}
+						if streamResponse.Response.Usage.TotalTokens != 0 {
+							usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+						}
+						if streamResponse.Response.Usage.InputTokensDetails != nil {
+							usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
+						}
 					}
-					if streamResponse.Response.Usage.InputTokensDetails != nil {
-						usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
+					if streamResponse.Response.HasImageGenerationCall() {
+						c.Set("image_generation_call", true)
+						c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
+						c.Set("image_generation_call_size", streamResponse.Response.GetSize())
 					}
 				}
 			case "response.output_text.delta":

+ 26 - 4
relay/channel/task/jimeng/adaptor.go

@@ -36,6 +36,7 @@ type requestPayload struct {
 	Prompt           string   `json:"prompt,omitempty"`
 	Seed             int64    `json:"seed"`
 	AspectRatio      string   `json:"aspect_ratio"`
+	Frames           int      `json:"frames,omitempty"`
 }
 
 type responsePayload struct {
@@ -311,10 +312,15 @@ func hmacSHA256(key []byte, data []byte) []byte {
 
 func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
 	r := requestPayload{
-		ReqKey:      "jimeng_vgfm_i2v_l20",
-		Prompt:      req.Prompt,
-		AspectRatio: "16:9", // Default aspect ratio
-		Seed:        -1,     // Default to random
+		ReqKey: req.Model,
+		Prompt: req.Prompt,
+	}
+
+	switch req.Duration {
+	case 10:
+		r.Frames = 241 // 24*10+1 = 241
+	default:
+		r.Frames = 121 // 24*5+1 = 121
 	}
 
 	// Handle one-of image_urls or binary_data_base64
@@ -334,6 +340,22 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	if err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
+
+	// 即梦视频3.0 ReqKey转换
+	// https://www.volcengine.com/docs/85621/1792707
+	if strings.Contains(r.ReqKey, "jimeng_v30") {
+		if len(r.ImageUrls) > 1 {
+			// 多张图片:首尾帧生成
+			r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
+		} else if len(r.ImageUrls) == 1 {
+			// 单张图片:图生视频
+			r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
+		} else {
+			// 无图片:文生视频
+			r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
+		}
+	}
+
 	return &r, nil
 }
 

+ 2 - 0
relay/channel/volcengine/adaptor.go

@@ -41,6 +41,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
 	switch info.RelayMode {
+	case constant.RelayModeImagesGenerations:
+		return request, nil
 	case constant.RelayModeImagesEdits:
 
 		var requestBody bytes.Buffer

+ 1 - 0
relay/channel/volcengine/constants.go

@@ -8,6 +8,7 @@ var ModelList = []string{
 	"Doubao-lite-32k",
 	"Doubao-lite-4k",
 	"Doubao-embedding",
+	"doubao-seedream-4-0-250828",
 }
 
 var ChannelName = "volcengine"

+ 13 - 0
relay/compatible_handler.go

@@ -276,6 +276,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 				fileSearchTool.CallCount, dFileSearchQuota.String())
 		}
 	}
+	var dImageGenerationCallQuota decimal.Decimal
+	var imageGenerationCallPrice float64
+	if ctx.GetBool("image_generation_call") {
+		imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
+		dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+		extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
+	}
 
 	var quotaCalculateDecimal decimal.Decimal
 
@@ -331,6 +338,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
 	// 添加 audio input 独立计费
 	quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
+	// 添加 image generation call 计费
+	quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
 
 	quota := int(quotaCalculateDecimal.Round(0).IntPart())
 	totalTokens := promptTokens + completionTokens
@@ -429,6 +438,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 		other["audio_input_token_count"] = audioTokens
 		other["audio_input_price"] = audioInputPrice
 	}
+	if !dImageGenerationCallQuota.IsZero() {
+		other["image_generation_call"] = true
+		other["image_generation_call_price"] = imageGenerationCallPrice
+	}
 	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
 		ChannelId:        relayInfo.ChannelId,
 		PromptTokens:     promptTokens,

+ 40 - 0
setting/operation_setting/tools.go

@@ -10,6 +10,18 @@ const (
 	FileSearchPrice = 2.5
 )
 
+const (
+	GPTImage1Low1024x1024    = 0.011
+	GPTImage1Low1024x1536    = 0.016
+	GPTImage1Low1536x1024    = 0.016
+	GPTImage1Medium1024x1024 = 0.042
+	GPTImage1Medium1024x1536 = 0.063
+	GPTImage1Medium1536x1024 = 0.063
+	GPTImage1High1024x1024   = 0.167
+	GPTImage1High1024x1536   = 0.25
+	GPTImage1High1536x1024   = 0.25
+)
+
 const (
 	// Gemini Audio Input Price
 	Gemini25FlashPreviewInputAudioPrice     = 1.00
@@ -65,3 +77,31 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
 	}
 	return 0
 }
+
+func GetGPTImage1PriceOnceCall(quality string, size string) float64 {
+	prices := map[string]map[string]float64{
+		"low": {
+			"1024x1024": GPTImage1Low1024x1024,
+			"1024x1536": GPTImage1Low1024x1536,
+			"1536x1024": GPTImage1Low1536x1024,
+		},
+		"medium": {
+			"1024x1024": GPTImage1Medium1024x1024,
+			"1024x1536": GPTImage1Medium1024x1536,
+			"1536x1024": GPTImage1Medium1536x1024,
+		},
+		"high": {
+			"1024x1024": GPTImage1High1024x1024,
+			"1024x1536": GPTImage1High1024x1536,
+			"1536x1024": GPTImage1High1536x1024,
+		},
+	}
+
+	if qualityMap, exists := prices[quality]; exists {
+		if price, exists := qualityMap[size]; exists {
+			return price
+		}
+	}
+
+	return GPTImage1High1024x1024
+}

+ 21 - 2
web/src/helpers/render.jsx

@@ -1027,6 +1027,8 @@ export function renderModelPrice(
   audioInputSeperatePrice = false,
   audioInputTokens = 0,
   audioInputPrice = 0,
+  imageGenerationCall = false,
+  imageGenerationCallPrice = 0,
 ) {
   const { ratio: effectiveGroupRatio, label: ratioLabel } = getEffectiveRatio(
     groupRatio,
@@ -1069,7 +1071,8 @@ export function renderModelPrice(
       (audioInputTokens / 1000000) * audioInputPrice * groupRatio +
       (completionTokens / 1000000) * completionRatioPrice * groupRatio +
       (webSearchCallCount / 1000) * webSearchPrice * groupRatio +
-      (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio;
+      (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio +
+      (imageGenerationCallPrice * groupRatio);
 
     return (
       <>
@@ -1131,7 +1134,13 @@ export function renderModelPrice(
               })}
             </p>
           )}
-          <p></p>
+          {imageGenerationCall && imageGenerationCallPrice > 0 && (
+            <p>
+              {i18next.t('图片生成调用:${{price}} / 1次', {
+                price: imageGenerationCallPrice,
+              })}
+            </p>
+          )}
           <p>
             {(() => {
               // 构建输入部分描述
@@ -1211,6 +1220,16 @@ export function renderModelPrice(
                       },
                     )
                   : '',
+                imageGenerationCall && imageGenerationCallPrice > 0
+                  ? i18next.t(
+                      ' + 图片生成调用 ${{price}} / 1次 * {{ratioType}} {{ratio}}',
+                      {
+                        price: imageGenerationCallPrice,
+                        ratio: groupRatio,
+                        ratioType: ratioLabel,
+                      },
+                    )
+                  : '',
               ].join('');
 
               return i18next.t(

+ 2 - 0
web/src/hooks/usage-logs/useUsageLogsData.jsx

@@ -447,6 +447,8 @@ export const useLogsData = () => {
             other?.audio_input_seperate_price || false,
             other?.audio_input_token_count || 0,
             other?.audio_input_price || 0,
+            other?.image_generation_call || false,
+            other?.image_generation_call_price || 0,
           );
         }
         expandDataLocal.push({