Просмотр исходного кода

Merge pull request #1775 from QuantumNous/alpha

Alpha
Calcium-Ion 5 месяцев назад
Родитель
Сommit
e813da59cc

+ 1 - 1
controller/channel-test.go

@@ -235,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
-			err := service.RelayErrorHandler(httpResp, true)
+			err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
 			return testResult{
 				context:     c,
 				localErr:    err,

+ 9 - 10
controller/relay.go

@@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 
 	// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
 
-	preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+	newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
 	if newAPIError != nil {
 		return
 	}
 
 	defer func() {
 		// Only return quota if downstream failed and quota was actually pre-consumed
-		if newAPIError != nil && preConsumedQuota != 0 {
-			service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
+		if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
+			service.ReturnPreConsumedQuota(c, relayInfo)
 		}
 	}()
 
@@ -277,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
 
 func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
 	logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
-
-	gopool.Go(func() {
-		// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
-		// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
-		if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+	// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
+	// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
+	if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+		gopool.Go(func() {
 			service.DisableChannel(channelError, err.Error())
-		}
-	})
+		})
+	}
 
 	if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
 		// 保存错误日志到mysql中

+ 25 - 0
dto/openai_image.go

@@ -59,6 +59,31 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error {
 	return nil
 }
 
+// 序列化时需要重新把字段平铺
+func (r ImageRequest) MarshalJSON() ([]byte, error) {
+	// 将已定义字段转为 map
+	type Alias ImageRequest
+	alias := Alias(r)
+	base, err := common.Marshal(alias)
+	if err != nil {
+		return nil, err
+	}
+
+	var baseMap map[string]json.RawMessage
+	if err := common.Unmarshal(base, &baseMap); err != nil {
+		return nil, err
+	}
+
+	// 合并 ExtraFields
+	for k, v := range r.Extra {
+		if _, exists := baseMap[k]; !exists {
+			baseMap[k] = v
+		}
+	}
+
+	return json.Marshal(baseMap)
+}
+
 func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
 	fields := make(map[string]struct{})
 	for i := 0; i < t.NumField(); i++ {

+ 1 - 1
relay/audio_handler.go

@@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError

+ 1 - 2
relay/channel/api_request.go

@@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
 	}
 
 	resp, err := client.Do(req)
-
 	if err != nil {
-		return nil, err
+		return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
 	}
 	if resp == nil {
 		return nil, errors.New("resp is nil")

+ 1 - 1
relay/claude_handler.go

@@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		httpResp = resp.(*http.Response)
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError

+ 18 - 2
relay/compatible_handler.go

@@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 		httpResp = resp.(*http.Response)
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
-			newApiErr := service.RelayErrorHandler(httpResp, false)
+			newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			service.ResetStatusCode(newApiErr, statusCodeMappingStr)
 			return newApiErr
@@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	imageTokens := usage.PromptTokensDetails.ImageTokens
 	audioTokens := usage.PromptTokensDetails.AudioTokens
 	completionTokens := usage.CompletionTokens
+	cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
+
 	modelName := relayInfo.OriginModelName
 
 	tokenName := ctx.GetString("token_name")
@@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	modelRatio := relayInfo.PriceData.ModelRatio
 	groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
 	modelPrice := relayInfo.PriceData.ModelPrice
+	cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
 
 	// Convert values to decimal for precise calculation
 	dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	dImageTokens := decimal.NewFromInt(int64(imageTokens))
 	dAudioTokens := decimal.NewFromInt(int64(audioTokens))
 	dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
+	dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
 	dCompletionRatio := decimal.NewFromFloat(completionRatio)
 	dCacheRatio := decimal.NewFromFloat(cacheRatio)
 	dImageRatio := decimal.NewFromFloat(imageRatio)
 	dModelRatio := decimal.NewFromFloat(modelRatio)
 	dGroupRatio := decimal.NewFromFloat(groupRatio)
 	dModelPrice := decimal.NewFromFloat(modelPrice)
+	dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
 	dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
 
 	ratio := dModelRatio.Mul(dGroupRatio)
@@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 			baseTokens = baseTokens.Sub(dCacheTokens)
 			cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
 		}
+		var dCachedCreationTokensWithRatio decimal.Decimal
+		if !dCachedCreationTokens.IsZero() {
+			baseTokens = baseTokens.Sub(dCachedCreationTokens)
+			dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
+		}
 
 		// 减去 image tokens
 		var imageTokensWithRatio decimal.Decimal
@@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 				extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
 			}
 		}
-		promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
+		promptQuota := baseTokens.Add(cachedTokensWithRatio).
+			Add(imageTokensWithRatio).
+			Add(dCachedCreationTokensWithRatio)
 
 		completionQuota := dCompletionTokens.Mul(dCompletionRatio)
 
@@ -395,6 +407,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 		other["image_ratio"] = imageRatio
 		other["image_output"] = imageTokens
 	}
+	if cachedCreationTokens != 0 {
+		other["cache_creation_tokens"] = cachedCreationTokens
+		other["cache_creation_ratio"] = cachedCreationRatio
+	}
 	if !dWebSearchQuota.IsZero() {
 		if relayInfo.ResponsesUsageInfo != nil {
 			if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {

+ 1 - 1
relay/embedding_handler.go

@@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError

+ 2 - 2
relay/gemini_handler.go

@@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		httpResp = resp.(*http.Response)
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
@@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
 		}

+ 2 - 2
relay/image_handler.go

@@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 		httpResp = resp.(*http.Response)
 		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError
@@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 	var logContent string
 
 	if len(request.Size) > 0 {
-		logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
+		logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
 	}
 
 	postConsumeQuota(c, info, usage.(*dto.Usage), logContent)

+ 1 - 1
relay/rerank_handler.go

@@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError

+ 1 - 1
relay/responses_handler.go

@@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		httpResp = resp.(*http.Response)
 
 		if httpResp.StatusCode != http.StatusOK {
-			newAPIError = service.RelayErrorHandler(httpResp, false)
+			newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
 			// reset status code 重置状态码
 			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 			return newAPIError

+ 4 - 2
service/error.go

@@ -1,12 +1,14 @@
 package service
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"io"
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/types"
 	"strconv"
 	"strings"
@@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
 	return claudeErr
 }
 
-func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
+func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
 	newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
 
 	responseBody, err := io.ReadAll(resp.Body)
@@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
 			newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
 		} else {
 			if common.DebugEnabled {
-				println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
+				logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
 			}
 			newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
 		}

+ 11 - 11
service/pre_consume_quota.go

@@ -13,13 +13,13 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
-	if preConsumedQuota != 0 {
-		logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
+func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
+	if relayInfo.FinalPreConsumedQuota != 0 {
+		logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
 		gopool.Go(func() {
 			relayInfoCopy := *relayInfo
 
-			err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
+			err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false)
 			if err != nil {
 				common.SysLog("error return pre-consumed quota: " + err.Error())
 			}
@@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
 
 // PreConsumeQuota checks if the user has enough quota to pre-consume.
 // It returns the pre-consumed quota if successful, or an error if not.
-func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
+func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
 	userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
 	if err != nil {
-		return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
+		return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
 	}
 	if userQuota <= 0 {
-		return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
 	}
 	if userQuota-preConsumedQuota < 0 {
-		return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
 	}
 
 	trustQuota := common.GetTrustQuota()
@@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 	if preConsumedQuota > 0 {
 		err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
 		if err != nil {
-			return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+			return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
 		}
 		err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
 		if err != nil {
-			return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
+			return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
 		}
 		logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
 	}
 	relayInfo.FinalPreConsumedQuota = preConsumedQuota
-	return preConsumedQuota, nil
+	return nil
 }

+ 32 - 2
types/error.go

@@ -185,6 +185,14 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
 type NewAPIErrorOptions func(*NewAPIError)
 
 func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
+	var newErr *NewAPIError
+	// 保留深层传递的 new err
+	if errors.As(err, &newErr) {
+		for _, op := range ops {
+			op(newErr)
+		}
+		return newErr
+	}
 	e := &NewAPIError{
 		Err:        err,
 		RelayError: nil,
@@ -199,8 +207,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI
 }
 
 func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
-	if errorCode == ErrorCodeDoRequestFailed {
-		err = errors.New("upstream error: do request failed")
+	var newErr *NewAPIError
+	// 保留深层传递的 new err
+	if errors.As(err, &newErr) {
+		if newErr.RelayError == nil {
+			openaiError := OpenAIError{
+				Message: newErr.Error(),
+				Type:    string(errorCode),
+				Code:    errorCode,
+			}
+			newErr.RelayError = openaiError
+		}
+		for _, op := range ops {
+			op(newErr)
+		}
+		return newErr
 	}
 	openaiError := OpenAIError{
 		Message: err.Error(),
@@ -305,6 +326,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
 	}
 }
 
+func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
+	return func(e *NewAPIError) {
+		if common.DebugEnabled {
+			fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err)
+		}
+		e.Err = errors.New(replaceStr)
+	}
+}
+
 func IsRecordErrorLog(e *NewAPIError) bool {
 	if e == nil {
 		return false