package service import ( "fmt" "strings" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/setting/model_setting" "github.com/QuantumNous/new-api/types" "github.com/shopspring/decimal" "github.com/gin-gonic/gin" ) const ( ViolationFeeCodePrefix = "violation_fee." CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE" ContentViolatesUsageMarker = "Content violates usage guidelines" ) func IsViolationFeeCode(code types.ErrorCode) bool { return strings.HasPrefix(string(code), ViolationFeeCodePrefix) } func HasCSAMViolationMarker(err *types.NewAPIError) bool { if err == nil { return false } if strings.Contains(err.Error(), CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) { return true } msg := err.ToOpenAIError().Message return strings.Contains(msg, CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) } func WrapAsViolationFeeGrokCSAM(err *types.NewAPIError) *types.NewAPIError { if err == nil { return nil } oai := err.ToOpenAIError() oai.Type = string(types.ErrorCodeViolationFeeGrokCSAM) oai.Code = string(types.ErrorCodeViolationFeeGrokCSAM) return types.WithOpenAIError(oai, err.StatusCode, types.ErrOptionWithSkipRetry()) } // NormalizeViolationFeeError ensures: // - if the CSAM marker is present, error.code is set to a stable violation-fee code and skip-retry is enabled. // - if error.code already has the violation-fee prefix, skip-retry is enabled. // // It must be called before retry decision logic. func NormalizeViolationFeeError(err *types.NewAPIError) *types.NewAPIError { if err == nil { return nil } if HasCSAMViolationMarker(err) { return WrapAsViolationFeeGrokCSAM(err) } if IsViolationFeeCode(err.GetErrorCode()) { oai := err.ToOpenAIError() return types.WithOpenAIError(oai, err.StatusCode, types.ErrOptionWithSkipRetry()) } return err } func shouldChargeViolationFee(err *types.NewAPIError) bool { if err == nil { return false } if err.GetErrorCode() == types.ErrorCodeViolationFeeGrokCSAM { return true } // In case some callers didn't normalize, keep a safety net. return HasCSAMViolationMarker(err) } func calcViolationFeeQuota(amount, groupRatio float64) int { if amount <= 0 { return 0 } if groupRatio <= 0 { return 0 } quota := decimal.NewFromFloat(amount). Mul(decimal.NewFromFloat(common.QuotaPerUnit)). Mul(decimal.NewFromFloat(groupRatio)). Round(0). IntPart() if quota <= 0 { return 0 } return int(quota) } // ChargeViolationFeeIfNeeded charges an additional fee after the normal flow finishes (including refund). // It uses Grok fee settings as the fee policy. func ChargeViolationFeeIfNeeded(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, apiErr *types.NewAPIError) bool { if ctx == nil || relayInfo == nil || apiErr == nil { return false } //if relayInfo.IsPlayground { // return false //} if !shouldChargeViolationFee(apiErr) { return false } settings := model_setting.GetGrokSettings() if settings == nil || !settings.ViolationDeductionEnabled { return false } groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio feeQuota := calcViolationFeeQuota(settings.ViolationDeductionAmount, groupRatio) if feeQuota <= 0 { return false } if err := PostConsumeQuota(relayInfo, feeQuota, 0, true); err != nil { logger.LogError(ctx, fmt.Sprintf("failed to charge violation fee: %s", err.Error())) return false } model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, feeQuota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, feeQuota) useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() tokenName := ctx.GetString("token_name") oai := apiErr.ToOpenAIError() other := map[string]any{ "violation_fee": true, "violation_fee_code": string(types.ErrorCodeViolationFeeGrokCSAM), "fee_quota": feeQuota, "base_amount": settings.ViolationDeductionAmount, "group_ratio": groupRatio, "status_code": apiErr.StatusCode, "upstream_error_type": oai.Type, "upstream_error_code": fmt.Sprintf("%v", oai.Code), "violation_fee_marker": CSAMViolationMarker, } model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, ModelName: relayInfo.OriginModelName, TokenName: tokenName, Quota: feeQuota, Content: "Violation fee charged", TokenId: relayInfo.TokenId, UseTimeSeconds: int(useTimeSeconds), IsStream: relayInfo.IsStream, Group: relayInfo.UsingGroup, Other: other, }) return true }