Преглед изворни кода

refactor: Introduce pre-consume quota and unify relay handlers

This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic.

Key changes:
- **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests.

- **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels.

- **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package.

- **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
CaIon пре 6 месеци
родитељ
комит
e2037ad756
100 измењених фајлова са 2977 додато и 2520 уклоњено
  1. 2 2
      common/limiter/limiter.go
  2. 0 99
      common/logger.go
  3. 2 0
      constant/context_key.go
  4. 3 2
      controller/channel-billing.go
  5. 6 5
      controller/channel-test.go
  6. 91 91
      controller/console_migrate.go
  7. 3 2
      controller/github.go
  8. 19 18
      controller/midjourney.go
  9. 6 5
      controller/oidc.go
  10. 453 453
      controller/ratio_sync.go
  11. 156 166
      controller/relay.go
  12. 20 19
      controller/task.go
  13. 10 10
      controller/task_video.go
  14. 2 1
      controller/token.go
  15. 2 1
      controller/topup.go
  16. 7 6
      controller/twofa.go
  17. 5 4
      controller/user.go
  18. 18 0
      dto/audio.go
  19. 127 1
      dto/claude.go
  20. 26 2
      dto/embedding.go
  21. 62 3
      dto/gemini.go
  22. 44 2
      dto/openai_image.go
  23. 271 24
      dto/openai_request.go
  24. 11 0
      dto/request_common.go
  25. 27 0
      dto/rerank.go
  26. 115 0
      logger/logger.go
  27. 19 18
      main.go
  28. 3 3
      middleware/recover.go
  29. 3 2
      middleware/turnstile-check.go
  30. 3 2
      middleware/utils.go
  31. 5 4
      model/ability.go
  32. 14 13
      model/channel.go
  33. 3 2
      model/channel_cache.go
  34. 6 6
      model/log.go
  35. 16 15
      model/main.go
  36. 3 2
      model/option.go
  37. 2 1
      model/pricing.go
  38. 2 1
      model/redemption.go
  39. 10 9
      model/token.go
  40. 2 1
      model/topup.go
  41. 5 4
      model/twofa.go
  42. 5 4
      model/usedata.go
  43. 17 16
      model/user.go
  44. 3 2
      model/user_cache.go
  45. 5 4
      model/utils.go
  46. 13 80
      relay/audio_handler.go
  47. 6 6
      relay/channel/ali/image.go
  48. 2 2
      relay/channel/ali/rerank.go
  49. 7 5
      relay/channel/ali/text.go
  50. 2 1
      relay/channel/api_request.go
  51. 6 5
      relay/channel/baidu/relay-baidu.go
  52. 8 7
      relay/channel/claude/relay-claude.go
  53. 8 8
      relay/channel/cloudflare/relay_cloudflare.go
  54. 5 4
      relay/channel/cohere/relay-cohere.go
  55. 7 6
      relay/channel/coze/relay-coze.go
  56. 13 12
      relay/channel/dify/relay-dify.go
  57. 1 1
      relay/channel/gemini/adaptor.go
  58. 7 6
      relay/channel/gemini/relay-gemini-native.go
  59. 9 8
      relay/channel/gemini/relay-gemini.go
  60. 2 2
      relay/channel/jimeng/image.go
  61. 2 2
      relay/channel/jimeng/sign.go
  62. 3 2
      relay/channel/mokaai/relay-mokaai.go
  63. 2 2
      relay/channel/ollama/relay-ollama.go
  64. 9 8
      relay/channel/openai/helper.go
  65. 21 20
      relay/channel/openai/relay-openai.go
  66. 4 3
      relay/channel/openai/relay_responses.go
  67. 8 7
      relay/channel/palm/relay-palm.go
  68. 3 3
      relay/channel/siliconflow/relay-siliconflow.go
  69. 2 1
      relay/channel/task/suno/adaptor.go
  70. 7 6
      relay/channel/tencent/relay-tencent.go
  71. 6 5
      relay/channel/xai/text.go
  72. 6 5
      relay/channel/xunfei/relay-xunfei.go
  73. 8 6
      relay/channel/zhipu/relay-zhipu.go
  74. 5 4
      relay/chat_handler.go
  75. 19 72
      relay/claude_handler.go
  76. 177 144
      relay/common/relay_info.go
  77. 2 1
      relay/common_handler/rerank.go
  78. 13 56
      relay/embedding_handler.go
  79. 44 159
      relay/gemini_handler.go
  80. 3 2
      relay/helper/common.go
  81. 9 8
      relay/helper/model_mapped.go
  82. 33 57
      relay/helper/price.go
  83. 12 11
      relay/helper/stream_scanner.go
  84. 301 0
      relay/helper/valid_request.go
  85. 27 167
      relay/image_handler.go
  86. 44 256
      relay/relay-text.go
  87. 2 1
      relay/relay_task.go
  88. 15 44
      relay/rerank_handler.go
  89. 17 77
      relay/responses_handler.go
  90. 0 7
      relay/websocket.go
  91. 2 1
      router/main.go
  92. 75 19
      router/relay-router.go
  93. 3 3
      service/cf_worker.go
  94. 4 3
      service/error.go
  95. 5 3
      service/http.go
  96. 5 5
      service/image.go
  97. 3 2
      service/midjourney.go
  98. 72 0
      service/pre_consume_quota.go
  99. 43 47
      service/quota.go
  100. 251 123
      service/token_counter.go

+ 2 - 2
common/limiter/limiter.go

@@ -5,7 +5,7 @@ import (
 	_ "embed"
 	_ "embed"
 	"fmt"
 	"fmt"
 	"github.com/go-redis/redis/v8"
 	"github.com/go-redis/redis/v8"
-	"one-api/common"
+	"one-api/logger"
 	"sync"
 	"sync"
 )
 )
 
 
@@ -27,7 +27,7 @@ func New(ctx context.Context, r *redis.Client) *RedisLimiter {
 		// 预加载脚本
 		// 预加载脚本
 		limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
 		limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
 		if err != nil {
 		if err != nil {
-			common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
+			logger.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
 		}
 		}
 		instance = &RedisLimiter{
 		instance = &RedisLimiter{
 			client:         r,
 			client:         r,

+ 0 - 99
common/logger.go

@@ -1,52 +1,12 @@
 package common
 package common
 
 
 import (
 import (
-	"context"
-	"encoding/json"
 	"fmt"
 	"fmt"
-	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
-	"io"
-	"log"
 	"os"
 	"os"
-	"path/filepath"
-	"sync"
 	"time"
 	"time"
 )
 )
 
 
-const (
-	loggerINFO  = "INFO"
-	loggerWarn  = "WARN"
-	loggerError = "ERR"
-)
-
-const maxLogCount = 1000000
-
-var logCount int
-var setupLogLock sync.Mutex
-var setupLogWorking bool
-
-func SetupLogger() {
-	if *LogDir != "" {
-		ok := setupLogLock.TryLock()
-		if !ok {
-			log.Println("setup log is already working")
-			return
-		}
-		defer func() {
-			setupLogLock.Unlock()
-			setupLogWorking = false
-		}()
-		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
-		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
-		if err != nil {
-			log.Fatal("failed to open log file")
-		}
-		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
-		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
-	}
-}
-
 func SysLog(s string) {
 func SysLog(s string) {
 	t := time.Now()
 	t := time.Now()
 	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 	_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
@@ -57,67 +17,8 @@ func SysError(s string) {
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 }
 }
 
 
-func LogInfo(ctx context.Context, msg string) {
-	logHelper(ctx, loggerINFO, msg)
-}
-
-func LogWarn(ctx context.Context, msg string) {
-	logHelper(ctx, loggerWarn, msg)
-}
-
-func LogError(ctx context.Context, msg string) {
-	logHelper(ctx, loggerError, msg)
-}
-
-func logHelper(ctx context.Context, level string, msg string) {
-	writer := gin.DefaultErrorWriter
-	if level == loggerINFO {
-		writer = gin.DefaultWriter
-	}
-	id := ctx.Value(RequestIdKey)
-	if id == nil {
-		id = "SYSTEM"
-	}
-	now := time.Now()
-	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
-	logCount++ // we don't need accurate count, so no lock here
-	if logCount > maxLogCount && !setupLogWorking {
-		logCount = 0
-		setupLogWorking = true
-		gopool.Go(func() {
-			SetupLogger()
-		})
-	}
-}
-
 func FatalLog(v ...any) {
 func FatalLog(v ...any) {
 	t := time.Now()
 	t := time.Now()
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
 	os.Exit(1)
 	os.Exit(1)
 }
 }
-
-func LogQuota(quota int) string {
-	if DisplayInCurrencyEnabled {
-		return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
-	} else {
-		return fmt.Sprintf("%d 点额度", quota)
-	}
-}
-
-func FormatQuota(quota int) string {
-	if DisplayInCurrencyEnabled {
-		return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
-	} else {
-		return fmt.Sprintf("%d", quota)
-	}
-}
-
-// LogJson 仅供测试使用 only for test
-func LogJson(ctx context.Context, msg string, obj any) {
-	jsonStr, err := json.Marshal(obj)
-	if err != nil {
-		LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
-		return
-	}
-	LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
-}

+ 2 - 0
constant/context_key.go

@@ -3,6 +3,8 @@ package constant
 type ContextKey string
 type ContextKey string
 
 
 const (
 const (
+	ContextKeyPromptTokens ContextKey = "prompt_tokens"
+
 	ContextKeyOriginalModel    ContextKey = "original_model"
 	ContextKeyOriginalModel    ContextKey = "original_model"
 	ContextKeyRequestStartTime ContextKey = "request_start_time"
 	ContextKeyRequestStartTime ContextKey = "request_start_time"
 
 

+ 3 - 2
controller/channel-billing.go

@@ -8,6 +8,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"one-api/service"
 	"one-api/service"
 	"one-api/setting"
 	"one-api/setting"
@@ -485,8 +486,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
 func AutomaticallyUpdateChannels(frequency int) {
 func AutomaticallyUpdateChannels(frequency int) {
 	for {
 	for {
 		time.Sleep(time.Duration(frequency) * time.Minute)
 		time.Sleep(time.Duration(frequency) * time.Minute)
-		common.SysLog("updating all channels")
+		logger.SysLog("updating all channels")
 		_ = updateAllChannelsBalance()
 		_ = updateAllChannelsBalance()
-		common.SysLog("channels update done")
+		logger.SysLog("channels update done")
 	}
 	}
 }
 }

+ 6 - 5
controller/channel-test.go

@@ -13,6 +13,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/middleware"
 	"one-api/middleware"
 	"one-api/model"
 	"one-api/model"
 	"one-api/relay"
 	"one-api/relay"
@@ -159,7 +160,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	// 创建一个用于日志的 info 副本,移除 ApiKey
 	// 创建一个用于日志的 info 副本,移除 ApiKey
 	logInfo := *info
 	logInfo := *info
 	logInfo.ApiKey = ""
 	logInfo.ApiKey = ""
-	common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
+	logger.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
 
 
 	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
 	priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
 	if err != nil {
 	if err != nil {
@@ -279,7 +280,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		Group:            info.UsingGroup,
 		Group:            info.UsingGroup,
 		Other:            other,
 		Other:            other,
 	})
 	})
-	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
+	logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
 	return testResult{
 	return testResult{
 		context:     c,
 		context:     c,
 		localErr:    nil,
 		localErr:    nil,
@@ -461,13 +462,13 @@ func TestAllChannels(c *gin.Context) {
 
 
 func AutomaticallyTestChannels(frequency int) {
 func AutomaticallyTestChannels(frequency int) {
 	if frequency <= 0 {
 	if frequency <= 0 {
-		common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
+		logger.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
 		return
 		return
 	}
 	}
 	for {
 	for {
 		time.Sleep(time.Duration(frequency) * time.Minute)
 		time.Sleep(time.Duration(frequency) * time.Minute)
-		common.SysLog("testing all channels")
+		logger.SysLog("testing all channels")
 		_ = testAllChannels(false)
 		_ = testAllChannels(false)
-		common.SysLog("channel test finished")
+		logger.SysLog("channel test finished")
 	}
 	}
 }
 }

+ 91 - 91
controller/console_migrate.go

@@ -3,101 +3,101 @@
 package controller
 package controller
 
 
 import (
 import (
-    "encoding/json"
-    "net/http"
-    "one-api/common"
-    "one-api/model"
-    "github.com/gin-gonic/gin"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"net/http"
+	"one-api/logger"
+	"one-api/model"
 )
 )
 
 
 // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
 // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
 func MigrateConsoleSetting(c *gin.Context) {
 func MigrateConsoleSetting(c *gin.Context) {
-    // 读取全部 option
-    opts, err := model.AllOption()
-    if err != nil {
-        c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
-        return
-    }
-    // 建立 map
-    valMap := map[string]string{}
-    for _, o := range opts {
-        valMap[o.Key] = o.Value
-    }
+	// 读取全部 option
+	opts, err := model.AllOption()
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+		return
+	}
+	// 建立 map
+	valMap := map[string]string{}
+	for _, o := range opts {
+		valMap[o.Key] = o.Value
+	}
 
 
-    // 处理 APIInfo
-    if v := valMap["ApiInfo"]; v != "" {
-        var arr []map[string]interface{}
-        if err := json.Unmarshal([]byte(v), &arr); err == nil {
-            if len(arr) > 50 {
-                arr = arr[:50]
-            }
-            bytes, _ := json.Marshal(arr)
-            model.UpdateOption("console_setting.api_info", string(bytes))
-        }
-        model.UpdateOption("ApiInfo", "")
-    }
-    // Announcements 直接搬
-    if v := valMap["Announcements"]; v != "" {
-        model.UpdateOption("console_setting.announcements", v)
-        model.UpdateOption("Announcements", "")
-    }
-    // FAQ 转换
-    if v := valMap["FAQ"]; v != "" {
-        var arr []map[string]interface{}
-        if err := json.Unmarshal([]byte(v), &arr); err == nil {
-            out := []map[string]interface{}{}
-            for _, item := range arr {
-                q, _ := item["question"].(string)
-                if q == "" {
-                    q, _ = item["title"].(string)
-                }
-                a, _ := item["answer"].(string)
-                if a == "" {
-                    a, _ = item["content"].(string)
-                }
-                if q != "" && a != "" {
-                    out = append(out, map[string]interface{}{"question": q, "answer": a})
-                }
-            }
-            if len(out) > 50 {
-                out = out[:50]
-            }
-            bytes, _ := json.Marshal(out)
-            model.UpdateOption("console_setting.faq", string(bytes))
-        }
-        model.UpdateOption("FAQ", "")
-    }
-    // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
-    url := valMap["UptimeKumaUrl"]
-    slug := valMap["UptimeKumaSlug"]
-    if url != "" && slug != "" {
-        // 仅当同时存在 URL 与 Slug 时才进行迁移
-        groups := []map[string]interface{}{
-            {
-                "id":           1,
-                "categoryName": "old",
-                "url":          url,
-                "slug":         slug,
-                "description":  "",
-            },
-        }
-        bytes, _ := json.Marshal(groups)
-        model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
-    }
-    // 清空旧键内容
-    if url != "" {
-        model.UpdateOption("UptimeKumaUrl", "")
-    }
-    if slug != "" {
-        model.UpdateOption("UptimeKumaSlug", "")
-    }
+	// 处理 APIInfo
+	if v := valMap["ApiInfo"]; v != "" {
+		var arr []map[string]interface{}
+		if err := json.Unmarshal([]byte(v), &arr); err == nil {
+			if len(arr) > 50 {
+				arr = arr[:50]
+			}
+			bytes, _ := json.Marshal(arr)
+			model.UpdateOption("console_setting.api_info", string(bytes))
+		}
+		model.UpdateOption("ApiInfo", "")
+	}
+	// Announcements 直接搬
+	if v := valMap["Announcements"]; v != "" {
+		model.UpdateOption("console_setting.announcements", v)
+		model.UpdateOption("Announcements", "")
+	}
+	// FAQ 转换
+	if v := valMap["FAQ"]; v != "" {
+		var arr []map[string]interface{}
+		if err := json.Unmarshal([]byte(v), &arr); err == nil {
+			out := []map[string]interface{}{}
+			for _, item := range arr {
+				q, _ := item["question"].(string)
+				if q == "" {
+					q, _ = item["title"].(string)
+				}
+				a, _ := item["answer"].(string)
+				if a == "" {
+					a, _ = item["content"].(string)
+				}
+				if q != "" && a != "" {
+					out = append(out, map[string]interface{}{"question": q, "answer": a})
+				}
+			}
+			if len(out) > 50 {
+				out = out[:50]
+			}
+			bytes, _ := json.Marshal(out)
+			model.UpdateOption("console_setting.faq", string(bytes))
+		}
+		model.UpdateOption("FAQ", "")
+	}
+	// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
+	url := valMap["UptimeKumaUrl"]
+	slug := valMap["UptimeKumaSlug"]
+	if url != "" && slug != "" {
+		// 仅当同时存在 URL 与 Slug 时才进行迁移
+		groups := []map[string]interface{}{
+			{
+				"id":           1,
+				"categoryName": "old",
+				"url":          url,
+				"slug":         slug,
+				"description":  "",
+			},
+		}
+		bytes, _ := json.Marshal(groups)
+		model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
+	}
+	// 清空旧键内容
+	if url != "" {
+		model.UpdateOption("UptimeKumaUrl", "")
+	}
+	if slug != "" {
+		model.UpdateOption("UptimeKumaSlug", "")
+	}
 
 
-    // 删除旧键记录
-    oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
-    model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
+	// 删除旧键记录
+	oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
+	model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
 
 
-    // 重新加载 OptionMap
-    model.InitOptionMap()
-    common.SysLog("console setting migrated")
-    c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
-} 
+	// 重新加载 OptionMap
+	model.InitOptionMap()
+	logger.SysLog("console setting migrated")
+	c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
+}

+ 3 - 2
controller/github.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"strconv"
 	"strconv"
 	"time"
 	"time"
@@ -47,7 +48,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
 	}
 	}
 	res, err := client.Do(req)
 	res, err := client.Do(req)
 	if err != nil {
 	if err != nil {
-		common.SysLog(err.Error())
+		logger.SysLog(err.Error())
 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
 	}
 	}
 	defer res.Body.Close()
 	defer res.Body.Close()
@@ -63,7 +64,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
 	res2, err := client.Do(req)
 	res2, err := client.Do(req)
 	if err != nil {
 	if err != nil {
-		common.SysLog(err.Error())
+		logger.SysLog(err.Error())
 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
 		return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
 	}
 	}
 	defer res2.Body.Close()
 	defer res2.Body.Close()

+ 19 - 18
controller/midjourney.go

@@ -9,6 +9,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"one-api/service"
 	"one-api/service"
 	"one-api/setting"
 	"one-api/setting"
@@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() {
 			continue
 			continue
 		}
 		}
 
 
-		common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
+		logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
 		taskChannelM := make(map[int][]string)
 		taskChannelM := make(map[int][]string)
 		taskM := make(map[string]*model.Midjourney)
 		taskM := make(map[string]*model.Midjourney)
 		nullTaskIds := make([]int, 0)
 		nullTaskIds := make([]int, 0)
@@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() {
 				"progress": "100%",
 				"progress": "100%",
 			})
 			})
 			if err != nil {
 			if err != nil {
-				common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
+				logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
 			} else {
 			} else {
-				common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
+				logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
 			}
 			}
 		}
 		}
 		if len(taskChannelM) == 0 {
 		if len(taskChannelM) == 0 {
@@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() {
 		}
 		}
 
 
 		for channelId, taskIds := range taskChannelM {
 		for channelId, taskIds := range taskChannelM {
-			common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+			logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
 			if len(taskIds) == 0 {
 			if len(taskIds) == 0 {
 				continue
 				continue
 			}
 			}
 			midjourneyChannel, err := model.CacheGetChannel(channelId)
 			midjourneyChannel, err := model.CacheGetChannel(channelId)
 			if err != nil {
 			if err != nil {
-				common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
+				logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
 				err := model.MjBulkUpdate(taskIds, map[string]any{
 				err := model.MjBulkUpdate(taskIds, map[string]any{
 					"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
 					"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
 					"status":      "FAILURE",
 					"status":      "FAILURE",
 					"progress":    "100%",
 					"progress":    "100%",
 				})
 				})
 				if err != nil {
 				if err != nil {
-					common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
+					logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
 				}
 				}
 				continue
 				continue
 			}
 			}
@@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() {
 			})
 			})
 			req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
 			req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
 			if err != nil {
 			if err != nil {
-				common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
+				logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
 				continue
 				continue
 			}
 			}
 			// 设置超时时间
 			// 设置超时时间
@@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() {
 			req.Header.Set("mj-api-secret", midjourneyChannel.Key)
 			req.Header.Set("mj-api-secret", midjourneyChannel.Key)
 			resp, err := service.GetHttpClient().Do(req)
 			resp, err := service.GetHttpClient().Do(req)
 			if err != nil {
 			if err != nil {
-				common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
+				logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
 				continue
 				continue
 			}
 			}
 			if resp.StatusCode != http.StatusOK {
 			if resp.StatusCode != http.StatusOK {
-				common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+				logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
 				continue
 				continue
 			}
 			}
 			responseBody, err := io.ReadAll(resp.Body)
 			responseBody, err := io.ReadAll(resp.Body)
 			if err != nil {
 			if err != nil {
-				common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
+				logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
 				continue
 				continue
 			}
 			}
 			var responseItems []dto.MidjourneyDto
 			var responseItems []dto.MidjourneyDto
 			err = json.Unmarshal(responseBody, &responseItems)
 			err = json.Unmarshal(responseBody, &responseItems)
 			if err != nil {
 			if err != nil {
-				common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+				logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
 				continue
 				continue
 			}
 			}
 			resp.Body.Close()
 			resp.Body.Close()
@@ -147,12 +148,12 @@ func UpdateMidjourneyTaskBulk() {
 				}
 				}
 				// 映射 VideoUrl
 				// 映射 VideoUrl
 				task.VideoUrl = responseItem.VideoUrl
 				task.VideoUrl = responseItem.VideoUrl
-				
+
 				// 映射 VideoUrls - 将数组序列化为 JSON 字符串
 				// 映射 VideoUrls - 将数组序列化为 JSON 字符串
 				if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
 				if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
 					videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
 					videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
 					if err != nil {
 					if err != nil {
-						common.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
+						logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
 						task.VideoUrls = "[]" // 失败时设置为空数组
 						task.VideoUrls = "[]" // 失败时设置为空数组
 					} else {
 					} else {
 						task.VideoUrls = string(videoUrlsStr)
 						task.VideoUrls = string(videoUrlsStr)
@@ -160,10 +161,10 @@ func UpdateMidjourneyTaskBulk() {
 				} else {
 				} else {
 					task.VideoUrls = "" // 空值时清空字段
 					task.VideoUrls = "" // 空值时清空字段
 				}
 				}
-				
+
 				shouldReturnQuota := false
 				shouldReturnQuota := false
 				if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
 				if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
-					common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
+					logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
 					task.Progress = "100%"
 					task.Progress = "100%"
 					if task.Quota != 0 {
 					if task.Quota != 0 {
 						shouldReturnQuota = true
 						shouldReturnQuota = true
@@ -171,14 +172,14 @@ func UpdateMidjourneyTaskBulk() {
 				}
 				}
 				err = task.Update()
 				err = task.Update()
 				if err != nil {
 				if err != nil {
-					common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
+					logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
 				} else {
 				} else {
 					if shouldReturnQuota {
 					if shouldReturnQuota {
 						err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
 						err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
 						if err != nil {
 						if err != nil {
-							common.LogError(ctx, "fail to increase user quota: "+err.Error())
+							logger.LogError(ctx, "fail to increase user quota: "+err.Error())
 						}
 						}
-						logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
+						logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
 						model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 						model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 					}
 					}
 				}
 				}

+ 6 - 5
controller/oidc.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"one-api/setting"
 	"one-api/setting"
 	"one-api/setting/system_setting"
 	"one-api/setting/system_setting"
@@ -58,7 +59,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
 	}
 	}
 	res, err := client.Do(req)
 	res, err := client.Do(req)
 	if err != nil {
 	if err != nil {
-		common.SysLog(err.Error())
+		logger.SysLog(err.Error())
 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
 	}
 	}
 	defer res.Body.Close()
 	defer res.Body.Close()
@@ -69,7 +70,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
 	}
 	}
 
 
 	if oidcResponse.AccessToken == "" {
 	if oidcResponse.AccessToken == "" {
-		common.SysError("OIDC 获取 Token 失败,请检查设置!")
+		logger.SysError("OIDC 获取 Token 失败,请检查设置!")
 		return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
 		return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
 	}
 	}
 
 
@@ -80,12 +81,12 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
 	req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
 	req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
 	res2, err := client.Do(req)
 	res2, err := client.Do(req)
 	if err != nil {
 	if err != nil {
-		common.SysLog(err.Error())
+		logger.SysLog(err.Error())
 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
 		return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
 	}
 	}
 	defer res2.Body.Close()
 	defer res2.Body.Close()
 	if res2.StatusCode != http.StatusOK {
 	if res2.StatusCode != http.StatusOK {
-		common.SysError("OIDC 获取用户信息失败!请检查设置!")
+		logger.SysError("OIDC 获取用户信息失败!请检查设置!")
 		return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
 		return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
 	}
 	}
 
 
@@ -95,7 +96,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 	if oidcUser.OpenID == "" || oidcUser.Email == "" {
 	if oidcUser.OpenID == "" || oidcUser.Email == "" {
-		common.SysError("OIDC 获取用户信息为空!请检查设置!")
+		logger.SysError("OIDC 获取用户信息为空!请检查设置!")
 		return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
 		return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
 	}
 	}
 	return &oidcUser, nil
 	return &oidcUser, nil

+ 453 - 453
controller/ratio_sync.go

@@ -1,474 +1,474 @@
 package controller
 package controller
 
 
 import (
 import (
-    "context"
-    "encoding/json"
-    "fmt"
-    "net/http"
-    "strings"
-    "sync"
-    "time"
-
-    "one-api/common"
-    "one-api/dto"
-    "one-api/model"
-    "one-api/setting/ratio_setting"
-
-    "github.com/gin-gonic/gin"
+	"context"
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"one-api/logger"
+	"strings"
+	"sync"
+	"time"
+
+	"one-api/dto"
+	"one-api/model"
+	"one-api/setting/ratio_setting"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 const (
 const (
-    defaultTimeoutSeconds  = 10
-    defaultEndpoint        = "/api/ratio_config"
-    maxConcurrentFetches   = 8
+	defaultTimeoutSeconds = 10
+	defaultEndpoint       = "/api/ratio_config"
+	maxConcurrentFetches  = 8
 )
 )
 
 
 var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
 var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
 
 
 type upstreamResult struct {
 type upstreamResult struct {
-    Name string                 `json:"name"`
-    Data map[string]any         `json:"data,omitempty"`
-    Err  string                 `json:"err,omitempty"`
+	Name string         `json:"name"`
+	Data map[string]any `json:"data,omitempty"`
+	Err  string         `json:"err,omitempty"`
 }
 }
 
 
 func FetchUpstreamRatios(c *gin.Context) {
 func FetchUpstreamRatios(c *gin.Context) {
-    var req dto.UpstreamRequest
-    if err := c.ShouldBindJSON(&req); err != nil {
-        c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
-        return
-    }
-
-    if req.Timeout <= 0 {
-        req.Timeout = defaultTimeoutSeconds
-    }
-
-    var upstreams []dto.UpstreamDTO
-
-    if len(req.Upstreams) > 0 {
-        for _, u := range req.Upstreams {
-            if strings.HasPrefix(u.BaseURL, "http") {
-                if u.Endpoint == "" {
-                    u.Endpoint = defaultEndpoint
-                }
-                u.BaseURL = strings.TrimRight(u.BaseURL, "/")
-                upstreams = append(upstreams, u)
-            }
-        }
-    } else if len(req.ChannelIDs) > 0 {
-        intIds := make([]int, 0, len(req.ChannelIDs))
-        for _, id64 := range req.ChannelIDs {
-            intIds = append(intIds, int(id64))
-        }
-        dbChannels, err := model.GetChannelsByIds(intIds)
-        if err != nil {
-            common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
-            c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
-            return
-        }
-        for _, ch := range dbChannels {
-            if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
-                upstreams = append(upstreams, dto.UpstreamDTO{
-                    ID:       ch.Id,
-                    Name:     ch.Name,
-                    BaseURL:  strings.TrimRight(base, "/"),
-                    Endpoint: "",
-                })
-            }
-        }
-    }
-
-    if len(upstreams) == 0 {
-        c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
-        return
-    }
-
-    var wg sync.WaitGroup
-    ch := make(chan upstreamResult, len(upstreams))
-
-    sem := make(chan struct{}, maxConcurrentFetches)
-
-    client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
-
-    for _, chn := range upstreams {
-        wg.Add(1)
-        go func(chItem dto.UpstreamDTO) {
-            defer wg.Done()
-
-            sem <- struct{}{}
-            defer func() { <-sem }()
-
-            endpoint := chItem.Endpoint
-            if endpoint == "" {
-                endpoint = defaultEndpoint
-            } else if !strings.HasPrefix(endpoint, "/") {
-                endpoint = "/" + endpoint
-            }
-            fullURL := chItem.BaseURL + endpoint
-
-            uniqueName := chItem.Name
-            if chItem.ID != 0 {
-                uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
-            }
-
-            ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
-            defer cancel()
-
-            httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
-            if err != nil {
-                common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
-                ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
-                return
-            }
-
-            resp, err := client.Do(httpReq)
-            if err != nil {
-                common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
-                ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
-                return
-            }
-            defer resp.Body.Close()
-            if resp.StatusCode != http.StatusOK {
-                common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
-                ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
-                return
-            }
-            // 兼容两种上游接口格式:
-            //  type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
-            //  type2: /api/pricing      -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
-            var body struct {
-                Success bool            `json:"success"`
-                Data    json.RawMessage `json:"data"`
-                Message string          `json:"message"`
-            }
-
-            if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
-                common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
-                ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
-                return
-            }
-
-            if !body.Success {
-                ch <- upstreamResult{Name: uniqueName, Err: body.Message}
-                return
-            }
-
-            // 尝试按 type1 解析
-            var type1Data map[string]any
-            if err := json.Unmarshal(body.Data, &type1Data); err == nil {
-                // 如果包含至少一个 ratioTypes 字段,则认为是 type1
-                isType1 := false
-                for _, rt := range ratioTypes {
-                    if _, ok := type1Data[rt]; ok {
-                        isType1 = true
-                        break
-                    }
-                }
-                if isType1 {
-                    ch <- upstreamResult{Name: uniqueName, Data: type1Data}
-                    return
-                }
-            }
-
-            // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
-            var pricingItems []struct {
-                ModelName       string  `json:"model_name"`
-                QuotaType       int     `json:"quota_type"`
-                ModelRatio      float64 `json:"model_ratio"`
-                ModelPrice      float64 `json:"model_price"`
-                CompletionRatio float64 `json:"completion_ratio"`
-            }
-            if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
-                common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
-                ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
-                return
-            }
-
-            modelRatioMap := make(map[string]float64)
-            completionRatioMap := make(map[string]float64)
-            modelPriceMap := make(map[string]float64)
-
-            for _, item := range pricingItems {
-                if item.QuotaType == 1 {
-                    modelPriceMap[item.ModelName] = item.ModelPrice
-                } else {
-                    modelRatioMap[item.ModelName] = item.ModelRatio
-                    // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
-                    completionRatioMap[item.ModelName] = item.CompletionRatio
-                }
-            }
-
-            converted := make(map[string]any)
-
-            if len(modelRatioMap) > 0 {
-                ratioAny := make(map[string]any, len(modelRatioMap))
-                for k, v := range modelRatioMap {
-                    ratioAny[k] = v
-                }
-                converted["model_ratio"] = ratioAny
-            }
-
-            if len(completionRatioMap) > 0 {
-                compAny := make(map[string]any, len(completionRatioMap))
-                for k, v := range completionRatioMap {
-                    compAny[k] = v
-                }
-                converted["completion_ratio"] = compAny
-            }
-
-            if len(modelPriceMap) > 0 {
-                priceAny := make(map[string]any, len(modelPriceMap))
-                for k, v := range modelPriceMap {
-                    priceAny[k] = v
-                }
-                converted["model_price"] = priceAny
-            }
-
-            ch <- upstreamResult{Name: uniqueName, Data: converted}
-        }(chn)
-    }
-
-    wg.Wait()
-    close(ch)
-
-    localData := ratio_setting.GetExposedData()
-
-    var testResults []dto.TestResult
-    var successfulChannels []struct {
-        name string
-        data map[string]any
-    }
-
-    for r := range ch {
-        if r.Err != "" {
-            testResults = append(testResults, dto.TestResult{
-                Name:   r.Name,
-                Status: "error",
-                Error:  r.Err,
-            })
-        } else {
-            testResults = append(testResults, dto.TestResult{
-                Name:   r.Name,
-                Status: "success",
-            })
-            successfulChannels = append(successfulChannels, struct {
-                name string
-                data map[string]any
-            }{name: r.Name, data: r.Data})
-        }
-    }
-
-    differences := buildDifferences(localData, successfulChannels)
-
-    c.JSON(http.StatusOK, gin.H{
-        "success": true,
-        "data": gin.H{
-            "differences":  differences,
-            "test_results": testResults,
-        },
-    })
+	var req dto.UpstreamRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+		return
+	}
+
+	if req.Timeout <= 0 {
+		req.Timeout = defaultTimeoutSeconds
+	}
+
+	var upstreams []dto.UpstreamDTO
+
+	if len(req.Upstreams) > 0 {
+		for _, u := range req.Upstreams {
+			if strings.HasPrefix(u.BaseURL, "http") {
+				if u.Endpoint == "" {
+					u.Endpoint = defaultEndpoint
+				}
+				u.BaseURL = strings.TrimRight(u.BaseURL, "/")
+				upstreams = append(upstreams, u)
+			}
+		}
+	} else if len(req.ChannelIDs) > 0 {
+		intIds := make([]int, 0, len(req.ChannelIDs))
+		for _, id64 := range req.ChannelIDs {
+			intIds = append(intIds, int(id64))
+		}
+		dbChannels, err := model.GetChannelsByIds(intIds)
+		if err != nil {
+			logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
+			c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
+			return
+		}
+		for _, ch := range dbChannels {
+			if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
+				upstreams = append(upstreams, dto.UpstreamDTO{
+					ID:       ch.Id,
+					Name:     ch.Name,
+					BaseURL:  strings.TrimRight(base, "/"),
+					Endpoint: "",
+				})
+			}
+		}
+	}
+
+	if len(upstreams) == 0 {
+		c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
+		return
+	}
+
+	var wg sync.WaitGroup
+	ch := make(chan upstreamResult, len(upstreams))
+
+	sem := make(chan struct{}, maxConcurrentFetches)
+
+	client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
+
+	for _, chn := range upstreams {
+		wg.Add(1)
+		go func(chItem dto.UpstreamDTO) {
+			defer wg.Done()
+
+			sem <- struct{}{}
+			defer func() { <-sem }()
+
+			endpoint := chItem.Endpoint
+			if endpoint == "" {
+				endpoint = defaultEndpoint
+			} else if !strings.HasPrefix(endpoint, "/") {
+				endpoint = "/" + endpoint
+			}
+			fullURL := chItem.BaseURL + endpoint
+
+			uniqueName := chItem.Name
+			if chItem.ID != 0 {
+				uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
+			}
+
+			ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
+			defer cancel()
+
+			httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
+			if err != nil {
+				logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
+				ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+				return
+			}
+
+			resp, err := client.Do(httpReq)
+			if err != nil {
+				logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
+				ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+				return
+			}
+			defer resp.Body.Close()
+			if resp.StatusCode != http.StatusOK {
+				logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
+				ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
+				return
+			}
+			// 兼容两种上游接口格式:
+			//  type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
+			//  type2: /api/pricing      -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
+			var body struct {
+				Success bool            `json:"success"`
+				Data    json.RawMessage `json:"data"`
+				Message string          `json:"message"`
+			}
+
+			if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+				logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
+				ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+				return
+			}
+
+			if !body.Success {
+				ch <- upstreamResult{Name: uniqueName, Err: body.Message}
+				return
+			}
+
+			// 尝试按 type1 解析
+			var type1Data map[string]any
+			if err := json.Unmarshal(body.Data, &type1Data); err == nil {
+				// 如果包含至少一个 ratioTypes 字段,则认为是 type1
+				isType1 := false
+				for _, rt := range ratioTypes {
+					if _, ok := type1Data[rt]; ok {
+						isType1 = true
+						break
+					}
+				}
+				if isType1 {
+					ch <- upstreamResult{Name: uniqueName, Data: type1Data}
+					return
+				}
+			}
+
+			// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
+			var pricingItems []struct {
+				ModelName       string  `json:"model_name"`
+				QuotaType       int     `json:"quota_type"`
+				ModelRatio      float64 `json:"model_ratio"`
+				ModelPrice      float64 `json:"model_price"`
+				CompletionRatio float64 `json:"completion_ratio"`
+			}
+			if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
+				logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
+				ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
+				return
+			}
+
+			modelRatioMap := make(map[string]float64)
+			completionRatioMap := make(map[string]float64)
+			modelPriceMap := make(map[string]float64)
+
+			for _, item := range pricingItems {
+				if item.QuotaType == 1 {
+					modelPriceMap[item.ModelName] = item.ModelPrice
+				} else {
+					modelRatioMap[item.ModelName] = item.ModelRatio
+					// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
+					completionRatioMap[item.ModelName] = item.CompletionRatio
+				}
+			}
+
+			converted := make(map[string]any)
+
+			if len(modelRatioMap) > 0 {
+				ratioAny := make(map[string]any, len(modelRatioMap))
+				for k, v := range modelRatioMap {
+					ratioAny[k] = v
+				}
+				converted["model_ratio"] = ratioAny
+			}
+
+			if len(completionRatioMap) > 0 {
+				compAny := make(map[string]any, len(completionRatioMap))
+				for k, v := range completionRatioMap {
+					compAny[k] = v
+				}
+				converted["completion_ratio"] = compAny
+			}
+
+			if len(modelPriceMap) > 0 {
+				priceAny := make(map[string]any, len(modelPriceMap))
+				for k, v := range modelPriceMap {
+					priceAny[k] = v
+				}
+				converted["model_price"] = priceAny
+			}
+
+			ch <- upstreamResult{Name: uniqueName, Data: converted}
+		}(chn)
+	}
+
+	wg.Wait()
+	close(ch)
+
+	localData := ratio_setting.GetExposedData()
+
+	var testResults []dto.TestResult
+	var successfulChannels []struct {
+		name string
+		data map[string]any
+	}
+
+	for r := range ch {
+		if r.Err != "" {
+			testResults = append(testResults, dto.TestResult{
+				Name:   r.Name,
+				Status: "error",
+				Error:  r.Err,
+			})
+		} else {
+			testResults = append(testResults, dto.TestResult{
+				Name:   r.Name,
+				Status: "success",
+			})
+			successfulChannels = append(successfulChannels, struct {
+				name string
+				data map[string]any
+			}{name: r.Name, data: r.Data})
+		}
+	}
+
+	differences := buildDifferences(localData, successfulChannels)
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"data": gin.H{
+			"differences":  differences,
+			"test_results": testResults,
+		},
+	})
 }
 }
 
 
 func buildDifferences(localData map[string]any, successfulChannels []struct {
 func buildDifferences(localData map[string]any, successfulChannels []struct {
-    name string
-    data map[string]any
+	name string
+	data map[string]any
 }) map[string]map[string]dto.DifferenceItem {
 }) map[string]map[string]dto.DifferenceItem {
-    differences := make(map[string]map[string]dto.DifferenceItem)
-
-    allModels := make(map[string]struct{})
-    
-    for _, ratioType := range ratioTypes {
-        if localRatioAny, ok := localData[ratioType]; ok {
-            if localRatio, ok := localRatioAny.(map[string]float64); ok {
-                for modelName := range localRatio {
-                    allModels[modelName] = struct{}{}
-                }
-            }
-        }
-    }
-    
-    for _, channel := range successfulChannels {
-        for _, ratioType := range ratioTypes {
-            if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
-                for modelName := range upstreamRatio {
-                    allModels[modelName] = struct{}{}
-                }
-            }
-        }
-    }
-
-    confidenceMap := make(map[string]map[string]bool)
-    
-    // 预处理阶段:检查pricing接口的可信度
-    for _, channel := range successfulChannels {
-        confidenceMap[channel.name] = make(map[string]bool)
-        
-        modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
-        completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
-        
-        if hasModelRatio && hasCompletionRatio {
-            // 遍历所有模型,检查是否满足不可信条件
-            for modelName := range allModels {
-                // 默认为可信
-                confidenceMap[channel.name][modelName] = true
-                
-                // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
-                if modelRatioVal, ok := modelRatios[modelName]; ok {
-                    if completionRatioVal, ok := completionRatios[modelName]; ok {
-                        // 转换为float64进行比较
-                        if modelRatioFloat, ok := modelRatioVal.(float64); ok {
-                            if completionRatioFloat, ok := completionRatioVal.(float64); ok {
-                                if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
-                                    confidenceMap[channel.name][modelName] = false
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-        } else {
-            // 如果不是从pricing接口获取的数据,则全部标记为可信
-            for modelName := range allModels {
-                confidenceMap[channel.name][modelName] = true
-            }
-        }
-    }
-
-    for modelName := range allModels {
-        for _, ratioType := range ratioTypes {
-            var localValue interface{} = nil
-            if localRatioAny, ok := localData[ratioType]; ok {
-                if localRatio, ok := localRatioAny.(map[string]float64); ok {
-                    if val, exists := localRatio[modelName]; exists {
-                        localValue = val
-                    }
-                }
-            }
-
-            upstreamValues := make(map[string]interface{})
-            confidenceValues := make(map[string]bool)
-            hasUpstreamValue := false
-            hasDifference := false
-
-            for _, channel := range successfulChannels {
-                var upstreamValue interface{} = nil
-                
-                if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
-                    if val, exists := upstreamRatio[modelName]; exists {
-                        upstreamValue = val
-                        hasUpstreamValue = true
-                        
-                        if localValue != nil && localValue != val {
-                            hasDifference = true
-                        } else if localValue == val {
-                            upstreamValue = "same"
-                        }
-                    }
-                }
-                if upstreamValue == nil && localValue == nil {
-                    upstreamValue = "same"
-                }
-                
-                if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
-                    hasDifference = true
-                }
-                
-                upstreamValues[channel.name] = upstreamValue
-                
-                confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
-            }
-
-            shouldInclude := false
-            
-            if localValue != nil {
-                if hasDifference {
-                    shouldInclude = true
-                }
-            } else {
-                if hasUpstreamValue {
-                    shouldInclude = true
-                }
-            }
-
-            if shouldInclude {
-                if differences[modelName] == nil {
-                    differences[modelName] = make(map[string]dto.DifferenceItem)
-                }
-                differences[modelName][ratioType] = dto.DifferenceItem{
-                    Current:   localValue,
-                    Upstreams: upstreamValues,
-                    Confidence: confidenceValues,
-                }
-            }
-        }
-    }
-
-    channelHasDiff := make(map[string]bool)
-    for _, ratioMap := range differences {
-        for _, item := range ratioMap {
-            for chName, val := range item.Upstreams {
-                if val != nil && val != "same" {
-                    channelHasDiff[chName] = true
-                }
-            }
-        }
-    }
-
-    for modelName, ratioMap := range differences {
-        for ratioType, item := range ratioMap {
-            for chName := range item.Upstreams {
-                if !channelHasDiff[chName] {
-                    delete(item.Upstreams, chName)
-                    delete(item.Confidence, chName)
-                }
-            }
-
-            allSame := true
-            for _, v := range item.Upstreams {
-                if v != "same" {
-                    allSame = false
-                    break
-                }
-            }
-            if len(item.Upstreams) == 0 || allSame {
-                delete(ratioMap, ratioType)
-            } else {
-                differences[modelName][ratioType] = item
-            }
-        }
-
-        if len(ratioMap) == 0 {
-            delete(differences, modelName)
-        }
-    }
-
-    return differences
+	differences := make(map[string]map[string]dto.DifferenceItem)
+
+	allModels := make(map[string]struct{})
+
+	for _, ratioType := range ratioTypes {
+		if localRatioAny, ok := localData[ratioType]; ok {
+			if localRatio, ok := localRatioAny.(map[string]float64); ok {
+				for modelName := range localRatio {
+					allModels[modelName] = struct{}{}
+				}
+			}
+		}
+	}
+
+	for _, channel := range successfulChannels {
+		for _, ratioType := range ratioTypes {
+			if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+				for modelName := range upstreamRatio {
+					allModels[modelName] = struct{}{}
+				}
+			}
+		}
+	}
+
+	confidenceMap := make(map[string]map[string]bool)
+
+	// 预处理阶段:检查pricing接口的可信度
+	for _, channel := range successfulChannels {
+		confidenceMap[channel.name] = make(map[string]bool)
+
+		modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
+		completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
+
+		if hasModelRatio && hasCompletionRatio {
+			// 遍历所有模型,检查是否满足不可信条件
+			for modelName := range allModels {
+				// 默认为可信
+				confidenceMap[channel.name][modelName] = true
+
+				// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
+				if modelRatioVal, ok := modelRatios[modelName]; ok {
+					if completionRatioVal, ok := completionRatios[modelName]; ok {
+						// 转换为float64进行比较
+						if modelRatioFloat, ok := modelRatioVal.(float64); ok {
+							if completionRatioFloat, ok := completionRatioVal.(float64); ok {
+								if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
+									confidenceMap[channel.name][modelName] = false
+								}
+							}
+						}
+					}
+				}
+			}
+		} else {
+			// 如果不是从pricing接口获取的数据,则全部标记为可信
+			for modelName := range allModels {
+				confidenceMap[channel.name][modelName] = true
+			}
+		}
+	}
+
+	for modelName := range allModels {
+		for _, ratioType := range ratioTypes {
+			var localValue interface{} = nil
+			if localRatioAny, ok := localData[ratioType]; ok {
+				if localRatio, ok := localRatioAny.(map[string]float64); ok {
+					if val, exists := localRatio[modelName]; exists {
+						localValue = val
+					}
+				}
+			}
+
+			upstreamValues := make(map[string]interface{})
+			confidenceValues := make(map[string]bool)
+			hasUpstreamValue := false
+			hasDifference := false
+
+			for _, channel := range successfulChannels {
+				var upstreamValue interface{} = nil
+
+				if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+					if val, exists := upstreamRatio[modelName]; exists {
+						upstreamValue = val
+						hasUpstreamValue = true
+
+						if localValue != nil && localValue != val {
+							hasDifference = true
+						} else if localValue == val {
+							upstreamValue = "same"
+						}
+					}
+				}
+				if upstreamValue == nil && localValue == nil {
+					upstreamValue = "same"
+				}
+
+				if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
+					hasDifference = true
+				}
+
+				upstreamValues[channel.name] = upstreamValue
+
+				confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
+			}
+
+			shouldInclude := false
+
+			if localValue != nil {
+				if hasDifference {
+					shouldInclude = true
+				}
+			} else {
+				if hasUpstreamValue {
+					shouldInclude = true
+				}
+			}
+
+			if shouldInclude {
+				if differences[modelName] == nil {
+					differences[modelName] = make(map[string]dto.DifferenceItem)
+				}
+				differences[modelName][ratioType] = dto.DifferenceItem{
+					Current:    localValue,
+					Upstreams:  upstreamValues,
+					Confidence: confidenceValues,
+				}
+			}
+		}
+	}
+
+	channelHasDiff := make(map[string]bool)
+	for _, ratioMap := range differences {
+		for _, item := range ratioMap {
+			for chName, val := range item.Upstreams {
+				if val != nil && val != "same" {
+					channelHasDiff[chName] = true
+				}
+			}
+		}
+	}
+
+	for modelName, ratioMap := range differences {
+		for ratioType, item := range ratioMap {
+			for chName := range item.Upstreams {
+				if !channelHasDiff[chName] {
+					delete(item.Upstreams, chName)
+					delete(item.Confidence, chName)
+				}
+			}
+
+			allSame := true
+			for _, v := range item.Upstreams {
+				if v != "same" {
+					allSame = false
+					break
+				}
+			}
+			if len(item.Upstreams) == 0 || allSame {
+				delete(ratioMap, ratioType)
+			} else {
+				differences[modelName][ratioType] = item
+			}
+		}
+
+		if len(ratioMap) == 0 {
+			delete(differences, modelName)
+		}
+	}
+
+	return differences
 }
 }
 
 
 func GetSyncableChannels(c *gin.Context) {
 func GetSyncableChannels(c *gin.Context) {
-    channels, err := model.GetAllChannels(0, 0, true, false)
-    if err != nil {
-        c.JSON(http.StatusOK, gin.H{
-            "success": false,
-            "message": err.Error(),
-        })
-        return
-    }
-
-    var syncableChannels []dto.SyncableChannel
-    for _, channel := range channels {
-        if channel.GetBaseURL() != "" {
-            syncableChannels = append(syncableChannels, dto.SyncableChannel{
-                ID:      channel.Id,
-                Name:    channel.Name,
-                BaseURL: channel.GetBaseURL(),
-                Status:  channel.Status,
-            })
-        }
-    }
-
-    c.JSON(http.StatusOK, gin.H{
-        "success": true,
-        "message": "",
-        "data":    syncableChannels,
-    })
-} 
+	channels, err := model.GetAllChannels(0, 0, true, false)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	var syncableChannels []dto.SyncableChannel
+	for _, channel := range channels {
+		if channel.GetBaseURL() != "" {
+			syncableChannels = append(syncableChannels, dto.SyncableChannel{
+				ID:      channel.Id,
+				Name:    channel.Name,
+				BaseURL: channel.GetBaseURL(),
+				Status:  channel.Status,
+			})
+		}
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    syncableChannels,
+	})
+}

+ 156 - 166
controller/relay.go

@@ -2,21 +2,22 @@ package controller
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
-	constant2 "one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/middleware"
 	"one-api/middleware"
 	"one-api/model"
 	"one-api/model"
 	"one-api/relay"
 	"one-api/relay"
+	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
+	"one-api/setting"
 	"one-api/types"
 	"one-api/types"
 	"strings"
 	"strings"
 
 
@@ -24,186 +25,196 @@ import (
 	"github.com/gorilla/websocket"
 	"github.com/gorilla/websocket"
 )
 )
 
 
-func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
+func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
 	var err *types.NewAPIError
 	var err *types.NewAPIError
-	switch relayMode {
+	switch info.RelayMode {
 	case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
 	case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
-		err = relay.ImageHelper(c)
+		err = relay.ImageHelper(c, info)
 	case relayconstant.RelayModeAudioSpeech:
 	case relayconstant.RelayModeAudioSpeech:
 		fallthrough
 		fallthrough
 	case relayconstant.RelayModeAudioTranslation:
 	case relayconstant.RelayModeAudioTranslation:
 		fallthrough
 		fallthrough
 	case relayconstant.RelayModeAudioTranscription:
 	case relayconstant.RelayModeAudioTranscription:
-		err = relay.AudioHelper(c)
+		err = relay.AudioHelper(c, info)
 	case relayconstant.RelayModeRerank:
 	case relayconstant.RelayModeRerank:
-		err = relay.RerankHelper(c, relayMode)
+		err = relay.RerankHelper(c, info)
 	case relayconstant.RelayModeEmbeddings:
 	case relayconstant.RelayModeEmbeddings:
-		err = relay.EmbeddingHelper(c)
+		err = relay.EmbeddingHelper(c, info)
 	case relayconstant.RelayModeResponses:
 	case relayconstant.RelayModeResponses:
-		err = relay.ResponsesHelper(c)
-	case relayconstant.RelayModeGemini:
-		if strings.Contains(c.Request.URL.Path, "embed") {
-			err = relay.GeminiEmbeddingHandler(c)
-		} else {
-			err = relay.GeminiHelper(c)
-		}
+		err = relay.ResponsesHelper(c, info)
 	default:
 	default:
-		err = relay.TextHelper(c)
+		err = relay.TextHelper(c, info)
 	}
 	}
+	return err
+}
 
 
-	if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
-		// 保存错误日志到mysql中
-		userId := c.GetInt("id")
-		tokenName := c.GetString("token_name")
-		modelName := c.GetString("original_model")
-		tokenId := c.GetInt("token_id")
-		userGroup := c.GetString("group")
-		channelId := c.GetInt("channel_id")
-		other := make(map[string]interface{})
-		other["error_type"] = err.GetErrorType()
-		other["error_code"] = err.GetErrorCode()
-		other["status_code"] = err.StatusCode
-		other["channel_id"] = channelId
-		other["channel_name"] = c.GetString("channel_name")
-		other["channel_type"] = c.GetInt("channel_type")
-		adminInfo := make(map[string]interface{})
-		adminInfo["use_channel"] = c.GetStringSlice("use_channel")
-		isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
-		if isMultiKey {
-			adminInfo["is_multi_key"] = true
-			adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
-		}
-		other["admin_info"] = adminInfo
-		model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
+func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
+	var err *types.NewAPIError
+	if strings.Contains(c.Request.URL.Path, "embed") {
+		err = relay.GeminiEmbeddingHandler(c, info)
+	} else {
+		err = relay.GeminiHelper(c, info)
 	}
 	}
-
 	return err
 	return err
 }
 }
 
 
-func Relay(c *gin.Context) {
-	relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+func Relay(c *gin.Context, relayFormat types.RelayFormat) {
+
 	requestId := c.GetString(common.RequestIdKey)
 	requestId := c.GetString(common.RequestIdKey)
 	group := c.GetString("group")
 	group := c.GetString("group")
 	originalModel := c.GetString("original_model")
 	originalModel := c.GetString("original_model")
-	var newAPIError *types.NewAPIError
-
-	for i := 0; i <= common.RetryTimes; i++ {
-		channel, err := getChannel(c, group, originalModel, i)
-		if err != nil {
-			common.LogError(c, err.Error())
-			newAPIError = err
-			break
-		}
 
 
-		newAPIError = relayRequest(c, relayMode, channel)
+	var (
+		newAPIError *types.NewAPIError
+		ws          *websocket.Conn
+	)
 
 
-		if newAPIError == nil {
-			return // 成功处理请求,直接返回
+	if relayFormat == types.RelayFormatOpenAIRealtime {
+		var err error
+		ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
+		if err != nil {
+			helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
+			return
 		}
 		}
+		defer ws.Close()
+	}
 
 
-		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
-
-		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
-			break
+	defer func() {
+		if newAPIError != nil {
+			newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+			switch relayFormat {
+			case types.RelayFormatOpenAIRealtime:
+				helper.WssError(c, ws, newAPIError.ToOpenAIError())
+			case types.RelayFormatClaude:
+				c.JSON(newAPIError.StatusCode, gin.H{
+					"type":  "error",
+					"error": newAPIError.ToClaudeError(),
+				})
+			default:
+				c.JSON(newAPIError.StatusCode, gin.H{
+					"error": newAPIError.ToOpenAIError(),
+				})
+			}
 		}
 		}
-	}
-	useChannel := c.GetStringSlice("use_channel")
-	if len(useChannel) > 1 {
-		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
-		common.LogInfo(c, retryLogStr)
-	}
+	}()
 
 
-	if newAPIError != nil {
-		//if newAPIError.StatusCode == http.StatusTooManyRequests {
-		//	common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
-		//	newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
-		//}
-		newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
-		c.JSON(newAPIError.StatusCode, gin.H{
-			"error": newAPIError.ToOpenAIError(),
-		})
+	request, err := helper.GetAndValidateRequest(c, relayFormat)
+	if err != nil {
+		newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
+		return
 	}
 	}
-}
-
-var upgrader = websocket.Upgrader{
-	Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
-	CheckOrigin: func(r *http.Request) bool {
-		return true // 允许跨域
-	},
-}
-
-func WssRelay(c *gin.Context) {
-	// 将 HTTP 连接升级为 WebSocket 连接
-
-	ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
-	defer ws.Close()
 
 
+	//includeUsage := true
+	//// 判断用户是否需要返回使用情况
+	//if textRequest.StreamOptions != nil {
+	//	includeUsage = textRequest.StreamOptions.IncludeUsage
+	//}
+	//
+	//// 如果不支持StreamOptions,将StreamOptions设置为nil
+	//if !relayInfo.SupportStreamOptions || !textRequest.Stream {
+	//	textRequest.StreamOptions = nil
+	//} else {
+	//	// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
+	//	if constant.ForceStreamOption {
+	//		textRequest.StreamOptions = &dto.StreamOptions{
+	//			IncludeUsage: true,
+	//		}
+	//	}
+	//}
+	//
+	//relayInfo.ShouldIncludeUsage = includeUsage
+
+	relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
 	if err != nil {
 	if err != nil {
-		helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
+		newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
 		return
 		return
 	}
 	}
 
 
-	relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
-	requestId := c.GetString(common.RequestIdKey)
-	group := c.GetString("group")
-	//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
-	originalModel := c.GetString("original_model")
-	var newAPIError *types.NewAPIError
+	meta := request.GetTokenCountMeta()
 
 
-	for i := 0; i <= common.RetryTimes; i++ {
-		channel, err := getChannel(c, group, originalModel, i)
+	if setting.ShouldCheckPromptSensitive() {
+		words, err := service.CheckSensitiveText(meta.CombineText)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, err.Error())
-			newAPIError = err
-			break
-		}
-
-		newAPIError = wssRequest(c, ws, relayMode, channel)
-
-		if newAPIError == nil {
-			return // 成功处理请求,直接返回
+			logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
+			newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+			return
 		}
 		}
+	}
 
 
-		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
-
-		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
-			break
-		}
+	tokens, err := service.CountRequestToken(c, meta, relayInfo)
+	if err != nil {
+		newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
+		return
 	}
 	}
-	useChannel := c.GetStringSlice("use_channel")
-	if len(useChannel) > 1 {
-		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
-		common.LogInfo(c, retryLogStr)
+
+	priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
+	if err != nil {
+		newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
+		return
 	}
 	}
 
 
-	if newAPIError != nil {
-		//if newAPIError.StatusCode == http.StatusTooManyRequests {
-		//	newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
-		//}
-		newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
-		helper.WssError(c, ws, newAPIError.ToOpenAIError())
+	preConsumedQuota, newApiErr := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+	if newApiErr != nil {
+		return
 	}
 	}
-}
 
 
-func RelayClaude(c *gin.Context) {
-	//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
-	requestId := c.GetString(common.RequestIdKey)
-	group := c.GetString("group")
-	originalModel := c.GetString("original_model")
-	var newAPIError *types.NewAPIError
+	defer func() {
+		if newApiErr != nil {
+			service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
+		}
+	}()
 
 
 	for i := 0; i <= common.RetryTimes; i++ {
 	for i := 0; i <= common.RetryTimes; i++ {
 		channel, err := getChannel(c, group, originalModel, i)
 		channel, err := getChannel(c, group, originalModel, i)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, err.Error())
+			logger.LogError(c, err.Error())
 			newAPIError = err
 			newAPIError = err
 			break
 			break
 		}
 		}
 
 
-		newAPIError = claudeRequest(c, channel)
+		addUsedChannel(c, channel.Id)
+		requestBody, _ := common.GetRequestBody(c)
+		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+
+		switch relayFormat {
+		case types.RelayFormatOpenAIRealtime:
+			newAPIError = relay.WssHelper(c, ws)
+		case types.RelayFormatClaude:
+			newAPIError = relay.ClaudeHelper(c, relayInfo)
+		case types.RelayFormatGemini:
+			newAPIError = geminiRelayHandler(c, relayInfo)
+		default:
+			newAPIError = relayHandler(c, relayInfo)
+		}
 
 
 		if newAPIError == nil {
 		if newAPIError == nil {
-			return // 成功处理请求,直接返回
+			return
+		} else {
+			if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) {
+				// 保存错误日志到mysql中
+				userId := c.GetInt("id")
+				tokenName := c.GetString("token_name")
+				modelName := c.GetString("original_model")
+				tokenId := c.GetInt("token_id")
+				userGroup := c.GetString("group")
+				channelId := c.GetInt("channel_id")
+				other := make(map[string]interface{})
+				other["error_type"] = newAPIError.GetErrorType()
+				other["error_code"] = newAPIError.GetErrorCode()
+				other["status_code"] = newAPIError.StatusCode
+				other["channel_id"] = channelId
+				other["channel_name"] = c.GetString("channel_name")
+				other["channel_type"] = c.GetInt("channel_type")
+				adminInfo := make(map[string]interface{})
+				adminInfo["use_channel"] = c.GetStringSlice("use_channel")
+				isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
+				if isMultiKey {
+					adminInfo["is_multi_key"] = true
+					adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
+				}
+				other["admin_info"] = adminInfo
+				model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
+			}
 		}
 		}
 
 
 		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
@@ -212,40 +223,19 @@ func RelayClaude(c *gin.Context) {
 			break
 			break
 		}
 		}
 	}
 	}
+
 	useChannel := c.GetStringSlice("use_channel")
 	useChannel := c.GetStringSlice("use_channel")
 	if len(useChannel) > 1 {
 	if len(useChannel) > 1 {
 		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
 		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
-		common.LogInfo(c, retryLogStr)
+		logger.LogInfo(c, retryLogStr)
 	}
 	}
-
-	if newAPIError != nil {
-		newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
-		c.JSON(newAPIError.StatusCode, gin.H{
-			"type":  "error",
-			"error": newAPIError.ToClaudeError(),
-		})
-	}
-}
-
-func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
-	addUsedChannel(c, channel.Id)
-	requestBody, _ := common.GetRequestBody(c)
-	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
-	return relayHandler(c, relayMode)
-}
-
-func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
-	addUsedChannel(c, channel.Id)
-	requestBody, _ := common.GetRequestBody(c)
-	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
-	return relay.WssHelper(c, ws)
 }
 }
 
 
-func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
-	addUsedChannel(c, channel.Id)
-	requestBody, _ := common.GetRequestBody(c)
-	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
-	return relay.ClaudeHelper(c)
+var upgrader = websocket.Upgrader{
+	Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
+	CheckOrigin: func(r *http.Request) bool {
+		return true // 允许跨域
+	},
 }
 }
 
 
 func addUsedChannel(c *gin.Context, channelId int) {
 func addUsedChannel(c *gin.Context, channelId int) {
@@ -270,10 +260,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
 	}
 	}
 	channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
 	channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
 	if err != nil {
 	if err != nil {
-		return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+		return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
 	}
 	}
 	if channel == nil {
 	if channel == nil {
-		return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+		return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
 	}
 	}
 	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 	newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 	if newAPIError != nil {
 	if newAPIError != nil {
@@ -327,7 +317,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
 func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
 func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
 	// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
 	// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
 	// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
 	// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
-	common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
+	logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
 	if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
 	if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
 		service.DisableChannel(channelError, err.Error())
 		service.DisableChannel(channelError, err.Error())
 	}
 	}
@@ -362,7 +352,7 @@ func RelayMidjourney(c *gin.Context) {
 			"code":        err.Code,
 			"code":        err.Code,
 		})
 		})
 		channelId := c.GetInt("channel_id")
 		channelId := c.GetInt("channel_id")
-		common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
+		logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
 	}
 	}
 }
 }
 
 
@@ -404,7 +394,7 @@ func RelayTask(c *gin.Context) {
 	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
 	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
 		channel, newAPIError := getChannel(c, group, originalModel, i)
 		channel, newAPIError := getChannel(c, group, originalModel, i)
 		if newAPIError != nil {
 		if newAPIError != nil {
-			common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
+			logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
 			taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
 			taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
 			break
 			break
 		}
 		}
@@ -412,7 +402,7 @@ func RelayTask(c *gin.Context) {
 		useChannel := c.GetStringSlice("use_channel")
 		useChannel := c.GetStringSlice("use_channel")
 		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
 		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
 		c.Set("use_channel", useChannel)
 		c.Set("use_channel", useChannel)
-		common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+		logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
 		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
 
 
 		requestBody, _ := common.GetRequestBody(c)
 		requestBody, _ := common.GetRequestBody(c)
@@ -422,7 +412,7 @@ func RelayTask(c *gin.Context) {
 	useChannel := c.GetStringSlice("use_channel")
 	useChannel := c.GetStringSlice("use_channel")
 	if len(useChannel) > 1 {
 	if len(useChannel) > 1 {
 		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
 		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
-		common.LogInfo(c, retryLogStr)
+		logger.LogInfo(c, retryLogStr)
 	}
 	}
 	if taskErr != nil {
 	if taskErr != nil {
 		if taskErr.StatusCode == http.StatusTooManyRequests {
 		if taskErr.StatusCode == http.StatusTooManyRequests {

+ 20 - 19
controller/task.go

@@ -10,6 +10,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"one-api/relay"
 	"one-api/relay"
 	"sort"
 	"sort"
@@ -25,7 +26,7 @@ func UpdateTaskBulk() {
 	//imageModel := "midjourney"
 	//imageModel := "midjourney"
 	for {
 	for {
 		time.Sleep(time.Duration(15) * time.Second)
 		time.Sleep(time.Duration(15) * time.Second)
-		common.SysLog("任务进度轮询开始")
+		logger.SysLog("任务进度轮询开始")
 		ctx := context.TODO()
 		ctx := context.TODO()
 		allTasks := model.GetAllUnFinishSyncTasks(500)
 		allTasks := model.GetAllUnFinishSyncTasks(500)
 		platformTask := make(map[constant.TaskPlatform][]*model.Task)
 		platformTask := make(map[constant.TaskPlatform][]*model.Task)
@@ -54,9 +55,9 @@ func UpdateTaskBulk() {
 					"progress": "100%",
 					"progress": "100%",
 				})
 				})
 				if err != nil {
 				if err != nil {
-					common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+					logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
 				} else {
 				} else {
-					common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+					logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
 				}
 				}
 			}
 			}
 			if len(taskChannelM) == 0 {
 			if len(taskChannelM) == 0 {
@@ -65,7 +66,7 @@ func UpdateTaskBulk() {
 
 
 			UpdateTaskByPlatform(platform, taskChannelM, taskM)
 			UpdateTaskByPlatform(platform, taskChannelM, taskM)
 		}
 		}
-		common.SysLog("任务进度轮询完成")
+		logger.SysLog("任务进度轮询完成")
 	}
 	}
 }
 }
 
 
@@ -77,7 +78,7 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
 		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
 		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
 	default:
 	default:
 		if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
 		if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
-			common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
+			logger.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
 		}
 		}
 	}
 	}
 }
 }
@@ -86,27 +87,27 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
 	for channelId, taskIds := range taskChannelM {
 	for channelId, taskIds := range taskChannelM {
 		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
 		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
 		if err != nil {
 		if err != nil {
-			common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
+			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
 		}
 		}
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
 func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
 func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
-	common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+	logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
 	if len(taskIds) == 0 {
 	if len(taskIds) == 0 {
 		return nil
 		return nil
 	}
 	}
 	channel, err := model.CacheGetChannel(channelId)
 	channel, err := model.CacheGetChannel(channelId)
 	if err != nil {
 	if err != nil {
-		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+		logger.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
 		err = model.TaskBulkUpdate(taskIds, map[string]any{
 		err = model.TaskBulkUpdate(taskIds, map[string]any{
 			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
 			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
 			"status":      "FAILURE",
 			"status":      "FAILURE",
 			"progress":    "100%",
 			"progress":    "100%",
 		})
 		})
 		if err != nil {
 		if err != nil {
-			common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
+			logger.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
 		}
 		}
 		return err
 		return err
 	}
 	}
@@ -118,27 +119,27 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
 		"ids": taskIds,
 		"ids": taskIds,
 	})
 	})
 	if err != nil {
 	if err != nil {
-		common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
+		logger.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
 		return err
 		return err
 	}
 	}
 	if resp.StatusCode != http.StatusOK {
 	if resp.StatusCode != http.StatusOK {
-		common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+		logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
 		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
 		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
 	}
 	}
 	defer resp.Body.Close()
 	defer resp.Body.Close()
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {
-		common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
+		logger.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
 		return err
 		return err
 	}
 	}
 	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
 	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
 	err = json.Unmarshal(responseBody, &responseItems)
 	err = json.Unmarshal(responseBody, &responseItems)
 	if err != nil {
 	if err != nil {
-		common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+		logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
 		return err
 		return err
 	}
 	}
 	if !responseItems.IsSuccess() {
 	if !responseItems.IsSuccess() {
-		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
+		logger.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
 		return err
 		return err
 	}
 	}
 
 
@@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
 		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
 		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
 		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
 		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
 		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
 		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
-			common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+			logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
 			task.Progress = "100%"
 			task.Progress = "100%"
 			//err = model.CacheUpdateUserQuota(task.UserId) ?
 			//err = model.CacheUpdateUserQuota(task.UserId) ?
 			if err != nil {
 			if err != nil {
-				common.LogError(ctx, "error update user quota cache: "+err.Error())
+				logger.LogError(ctx, "error update user quota cache: "+err.Error())
 			} else {
 			} else {
 				quota := task.Quota
 				quota := task.Quota
 				if quota != 0 {
 				if quota != 0 {
 					err = model.IncreaseUserQuota(task.UserId, quota, false)
 					err = model.IncreaseUserQuota(task.UserId, quota, false)
 					if err != nil {
 					if err != nil {
-						common.LogError(ctx, "fail to increase user quota: "+err.Error())
+						logger.LogError(ctx, "fail to increase user quota: "+err.Error())
 					}
 					}
-					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
+					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
 					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 				}
 				}
 			}
 			}
@@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
 
 
 		err = task.Update()
 		err = task.Update()
 		if err != nil {
 		if err != nil {
-			common.SysError("UpdateMidjourneyTask task error: " + err.Error())
+			logger.SysError("UpdateMidjourneyTask task error: " + err.Error())
 		}
 		}
 	}
 	}
 	return nil
 	return nil

+ 10 - 10
controller/task_video.go

@@ -5,9 +5,9 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"one-api/relay"
 	"one-api/relay"
 	"one-api/relay/channel"
 	"one-api/relay/channel"
@@ -18,14 +18,14 @@ import (
 func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
 func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
 	for channelId, taskIds := range taskChannelM {
 	for channelId, taskIds := range taskChannelM {
 		if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
 		if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
-			common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+			logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
 		}
 		}
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
 func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
 func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
-	common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+	logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
 	if len(taskIds) == 0 {
 	if len(taskIds) == 0 {
 		return nil
 		return nil
 	}
 	}
@@ -37,7 +37,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
 			"progress":    "100%",
 			"progress":    "100%",
 		})
 		})
 		if errUpdate != nil {
 		if errUpdate != nil {
-			common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+			logger.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
 		}
 		}
 		return fmt.Errorf("CacheGetChannel failed: %w", err)
 		return fmt.Errorf("CacheGetChannel failed: %w", err)
 	}
 	}
@@ -47,7 +47,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
 	}
 	}
 	for _, taskId := range taskIds {
 	for _, taskId := range taskIds {
 		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
 		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
-			common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+			logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
 		}
 		}
 	}
 	}
 	return nil
 	return nil
@@ -61,7 +61,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 
 
 	task := taskM[taskId]
 	task := taskM[taskId]
 	if task == nil {
 	if task == nil {
-		common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+		logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
 		return fmt.Errorf("task %s not found", taskId)
 		return fmt.Errorf("task %s not found", taskId)
 	}
 	}
 	resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
 	resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
@@ -124,13 +124,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 			task.FinishTime = now
 			task.FinishTime = now
 		}
 		}
 		task.FailReason = taskResult.Reason
 		task.FailReason = taskResult.Reason
-		common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+		logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
 		quota := task.Quota
 		quota := task.Quota
 		if quota != 0 {
 		if quota != 0 {
 			if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
 			if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
-				common.LogError(ctx, "Failed to increase user quota: "+err.Error())
+				logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
 			}
 			}
-			logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
+			logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
 			model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 			model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
 		}
 		}
 	default:
 	default:
@@ -140,7 +140,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
 		task.Progress = taskResult.Progress
 		task.Progress = taskResult.Progress
 	}
 	}
 	if err := task.Update(); err != nil {
 	if err := task.Update(); err != nil {
-		common.SysError("UpdateVideoTask task error: " + err.Error())
+		logger.SysError("UpdateVideoTask task error: " + err.Error())
 	}
 	}
 
 
 	return nil
 	return nil

+ 2 - 1
controller/token.go

@@ -3,6 +3,7 @@ package controller
 import (
 import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"strconv"
 	"strconv"
 
 
@@ -102,7 +103,7 @@ func AddToken(c *gin.Context) {
 			"success": false,
 			"success": false,
 			"message": "生成令牌失败",
 			"message": "生成令牌失败",
 		})
 		})
-		common.SysError("failed to generate token key: " + err.Error())
+		logger.SysError("failed to generate token key: " + err.Error())
 		return
 		return
 	}
 	}
 	cleanToken := model.Token{
 	cleanToken := model.Token{

+ 2 - 1
controller/topup.go

@@ -5,6 +5,7 @@ import (
 	"log"
 	"log"
 	"net/url"
 	"net/url"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"one-api/service"
 	"one-api/service"
 	"one-api/setting"
 	"one-api/setting"
@@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) {
 				return
 				return
 			}
 			}
 			log.Printf("易支付回调更新用户成功 %v", topUp)
 			log.Printf("易支付回调更新用户成功 %v", topUp)
-			model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
+			model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
 		}
 		}
 	} else {
 	} else {
 		log.Printf("易支付异常回调: %v", verifyInfo)
 		log.Printf("易支付异常回调: %v", verifyInfo)

+ 7 - 6
controller/twofa.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"strconv"
 	"strconv"
 
 
@@ -70,7 +71,7 @@ func Setup2FA(c *gin.Context) {
 			"success": false,
 			"success": false,
 			"message": "生成2FA密钥失败",
 			"message": "生成2FA密钥失败",
 		})
 		})
-		common.SysError("生成TOTP密钥失败: " + err.Error())
+		logger.SysError("生成TOTP密钥失败: " + err.Error())
 		return
 		return
 	}
 	}
 
 
@@ -81,7 +82,7 @@ func Setup2FA(c *gin.Context) {
 			"success": false,
 			"success": false,
 			"message": "生成备用码失败",
 			"message": "生成备用码失败",
 		})
 		})
-		common.SysError("生成备用码失败: " + err.Error())
+		logger.SysError("生成备用码失败: " + err.Error())
 		return
 		return
 	}
 	}
 
 
@@ -115,7 +116,7 @@ func Setup2FA(c *gin.Context) {
 			"success": false,
 			"success": false,
 			"message": "保存备用码失败",
 			"message": "保存备用码失败",
 		})
 		})
-		common.SysError("保存备用码失败: " + err.Error())
+		logger.SysError("保存备用码失败: " + err.Error())
 		return
 		return
 	}
 	}
 
 
@@ -294,7 +295,7 @@ func Get2FAStatus(c *gin.Context) {
 			// 获取剩余备用码数量
 			// 获取剩余备用码数量
 			backupCount, err := model.GetUnusedBackupCodeCount(userId)
 			backupCount, err := model.GetUnusedBackupCodeCount(userId)
 			if err != nil {
 			if err != nil {
-				common.SysError("获取备用码数量失败: " + err.Error())
+				logger.SysError("获取备用码数量失败: " + err.Error())
 			} else {
 			} else {
 				status["backup_codes_remaining"] = backupCount
 				status["backup_codes_remaining"] = backupCount
 			}
 			}
@@ -368,7 +369,7 @@ func RegenerateBackupCodes(c *gin.Context) {
 			"success": false,
 			"success": false,
 			"message": "生成备用码失败",
 			"message": "生成备用码失败",
 		})
 		})
-		common.SysError("生成备用码失败: " + err.Error())
+		logger.SysError("生成备用码失败: " + err.Error())
 		return
 		return
 	}
 	}
 
 
@@ -378,7 +379,7 @@ func RegenerateBackupCodes(c *gin.Context) {
 			"success": false,
 			"success": false,
 			"message": "保存备用码失败",
 			"message": "保存备用码失败",
 		})
 		})
-		common.SysError("保存备用码失败: " + err.Error())
+		logger.SysError("保存备用码失败: " + err.Error())
 		return
 		return
 	}
 	}
 
 

+ 5 - 4
controller/user.go

@@ -7,6 +7,7 @@ import (
 	"net/url"
 	"net/url"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	"one-api/setting"
 	"one-api/setting"
 	"strconv"
 	"strconv"
@@ -192,7 +193,7 @@ func Register(c *gin.Context) {
 			"success": false,
 			"success": false,
 			"message": "数据库错误,请稍后重试",
 			"message": "数据库错误,请稍后重试",
 		})
 		})
-		common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
+		logger.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
 		return
 		return
 	}
 	}
 	if exist {
 	if exist {
@@ -235,7 +236,7 @@ func Register(c *gin.Context) {
 				"success": false,
 				"success": false,
 				"message": "生成默认令牌失败",
 				"message": "生成默认令牌失败",
 			})
 			})
-			common.SysError("failed to generate token key: " + err.Error())
+			logger.SysError("failed to generate token key: " + err.Error())
 			return
 			return
 		}
 		}
 		// 生成默认令牌
 		// 生成默认令牌
@@ -342,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) {
 			"success": false,
 			"success": false,
 			"message": "生成失败",
 			"message": "生成失败",
 		})
 		})
-		common.SysError("failed to generate key: " + err.Error())
+		logger.SysError("failed to generate key: " + err.Error())
 		return
 		return
 	}
 	}
 	user.SetAccessToken(key)
 	user.SetAccessToken(key)
@@ -517,7 +518,7 @@ func UpdateUser(c *gin.Context) {
 		return
 		return
 	}
 	}
 	if originUser.Quota != updatedUser.Quota {
 	if originUser.Quota != updatedUser.Quota {
-		model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
+		model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
 	}
 	}
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"success": true,

+ 18 - 0
dto/audio.go

@@ -1,5 +1,11 @@
 package dto
 package dto
 
 
+import (
+	"one-api/types"
+
+	"github.com/gin-gonic/gin"
+)
+
 type AudioRequest struct {
 type AudioRequest struct {
 	Model          string  `json:"model"`
 	Model          string  `json:"model"`
 	Input          string  `json:"input"`
 	Input          string  `json:"input"`
@@ -8,6 +14,18 @@ type AudioRequest struct {
 	ResponseFormat string  `json:"response_format,omitempty"`
 	ResponseFormat string  `json:"response_format,omitempty"`
 }
 }
 
 
+func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	meta := &types.TokenCountMeta{
+		CombineText: r.Input,
+		TokenType:   types.TokenTypeTextNumber,
+	}
+	return meta
+}
+
+func (r *AudioRequest) IsStream(c *gin.Context) bool {
+	return false
+}
+
 type AudioResponse struct {
 type AudioResponse struct {
 	Text string `json:"text"`
 	Text string `json:"text"`
 }
 }

+ 127 - 1
dto/claude.go

@@ -5,6 +5,9 @@ import (
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
 	"one-api/types"
 	"one-api/types"
+	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 type ClaudeMetadata struct {
 type ClaudeMetadata struct {
@@ -81,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
 }
 }
 
 
 func (c *ClaudeMediaMessage) GetJsonRowString() string {
 func (c *ClaudeMediaMessage) GetJsonRowString() string {
-	jsonContent, _ := json.Marshal(c)
+	jsonContent, _ := common.Marshal(c)
 	return string(jsonContent)
 	return string(jsonContent)
 }
 }
 
 
@@ -199,6 +202,129 @@ type ClaudeRequest struct {
 	Thinking   *Thinking `json:"thinking,omitempty"`
 	Thinking   *Thinking `json:"thinking,omitempty"`
 }
 }
 
 
+func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var tokenCountMeta = types.TokenCountMeta{
+		TokenType: types.TokenTypeTextNumber,
+		MaxTokens: int(c.MaxTokens),
+	}
+
+	var texts = make([]string, 0)
+	var fileMeta = make([]*types.FileMeta, 0)
+
+	// system
+	if c.System != nil {
+		if c.IsStringSystem() {
+			sys := c.GetStringSystem()
+			if sys != "" {
+				texts = append(texts, sys)
+			}
+		} else {
+			systemMedia := c.ParseSystem()
+			for _, media := range systemMedia {
+				switch media.Type {
+				case "text":
+					texts = append(texts, media.GetText())
+				case "image":
+					if media.Source != nil {
+						data := media.Source.Url
+						if data == "" {
+							data = common.Interface2String(media.Source.Data)
+						}
+						if data != "" {
+							fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
+						}
+					}
+				}
+			}
+		}
+	}
+
+	// messages
+	for _, message := range c.Messages {
+		tokenCountMeta.MessagesCount++
+		texts = append(texts, message.Role)
+		if message.IsStringContent() {
+			content := message.GetStringContent()
+			if content != "" {
+				texts = append(texts, content)
+			}
+			continue
+		}
+
+		content, _ := message.ParseContent()
+		for _, media := range content {
+			switch media.Type {
+			case "text":
+				texts = append(texts, media.GetText())
+			case "image":
+				if media.Source != nil {
+					data := media.Source.Url
+					if data == "" {
+						data = common.Interface2String(media.Source.Data)
+					}
+					if data != "" {
+						fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
+					}
+				}
+			case "tool_use":
+				if media.Name != "" {
+					texts = append(texts, media.Name)
+				}
+				if media.Input != nil {
+					b, _ := common.Marshal(media.Input)
+					texts = append(texts, string(b))
+				}
+			case "tool_result":
+				if media.Content != nil {
+					b, _ := common.Marshal(media.Content)
+					texts = append(texts, string(b))
+				}
+			}
+		}
+	}
+
+	// tools
+	if c.Tools != nil {
+		tools := c.GetTools()
+		normalTools, webSearchTools := ProcessTools(tools)
+		if normalTools != nil {
+			for _, t := range normalTools {
+				tokenCountMeta.ToolsCount++
+				if t.Name != "" {
+					texts = append(texts, t.Name)
+				}
+				if t.Description != "" {
+					texts = append(texts, t.Description)
+				}
+				if t.InputSchema != nil {
+					b, _ := common.Marshal(t.InputSchema)
+					texts = append(texts, string(b))
+				}
+			}
+		}
+		if webSearchTools != nil {
+			for _, t := range webSearchTools {
+				tokenCountMeta.ToolsCount++
+				if t.Name != "" {
+					texts = append(texts, t.Name)
+				}
+				if t.UserLocation != nil {
+					b, _ := common.Marshal(t.UserLocation)
+					texts = append(texts, string(b))
+				}
+			}
+		}
+	}
+
+	tokenCountMeta.CombineText = strings.Join(texts, "\n")
+	tokenCountMeta.Files = fileMeta
+	return &tokenCountMeta
+}
+
+func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool {
+	return claudeRequest.Stream
+}
+
 func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
 func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
 	for _, message := range c.Messages {
 	for _, message := range c.Messages {
 		content, _ := message.ParseContent()
 		content, _ := message.ParseContent()

+ 26 - 2
dto/embedding.go

@@ -1,5 +1,12 @@
 package dto
 package dto
 
 
+import (
+	"one-api/types"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
+
 type EmbeddingOptions struct {
 type EmbeddingOptions struct {
 	Seed             int      `json:"seed,omitempty"`
 	Seed             int      `json:"seed,omitempty"`
 	Temperature      *float64 `json:"temperature,omitempty"`
 	Temperature      *float64 `json:"temperature,omitempty"`
@@ -24,9 +31,26 @@ type EmbeddingRequest struct {
 	PresencePenalty  float64  `json:"presence_penalty,omitempty"`
 	PresencePenalty  float64  `json:"presence_penalty,omitempty"`
 }
 }
 
 
-func (r EmbeddingRequest) ParseInput() []string {
+func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var texts = make([]string, 0)
+
+	inputs := r.ParseInput()
+	for _, input := range inputs {
+		texts = append(texts, input)
+	}
+
+	return &types.TokenCountMeta{
+		CombineText: strings.Join(texts, "\n"),
+	}
+}
+
+func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
+	return false
+}
+
+func (r *EmbeddingRequest) ParseInput() []string {
 	if r.Input == nil {
 	if r.Input == nil {
-		return nil
+		return make([]string, 0)
 	}
 	}
 	var input []string
 	var input []string
 	switch r.Input.(type) {
 	switch r.Input.(type) {

+ 62 - 3
dto/gemini.go

@@ -2,7 +2,10 @@ package dto
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"github.com/gin-gonic/gin"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
+	"one-api/types"
 	"strings"
 	"strings"
 )
 )
 
 
@@ -14,19 +17,75 @@ type GeminiChatRequest struct {
 	SystemInstructions *GeminiChatContent         `json:"systemInstruction,omitempty"`
 	SystemInstructions *GeminiChatContent         `json:"systemInstruction,omitempty"`
 }
 }
 
 
+func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var files []*types.FileMeta = make([]*types.FileMeta, 0)
+
+	var maxTokens int
+
+	if r.GenerationConfig.MaxOutputTokens > 0 {
+		maxTokens = int(r.GenerationConfig.MaxOutputTokens)
+	}
+
+	var inputTexts []string
+	for _, content := range r.Contents {
+		for _, part := range content.Parts {
+			if part.Text != "" {
+				inputTexts = append(inputTexts, part.Text)
+			}
+			if part.InlineData != nil && part.InlineData.Data != "" {
+				if strings.HasPrefix(part.InlineData.MimeType, "image/") {
+					files = append(files, &types.FileMeta{
+						FileType: types.FileTypeImage,
+						Data:     part.InlineData.Data,
+					})
+				} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
+					files = append(files, &types.FileMeta{
+						FileType: types.FileTypeAudio,
+						Data:     part.InlineData.Data,
+					})
+				} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
+					files = append(files, &types.FileMeta{
+						FileType: types.FileTypeVideo,
+						Data:     part.InlineData.Data,
+					})
+				} else {
+					files = append(files, &types.FileMeta{
+						FileType: types.FileTypeFile,
+						Data:     part.InlineData.Data,
+					})
+				}
+			}
+		}
+	}
+
+	inputText := strings.Join(inputTexts, "\n")
+	return &types.TokenCountMeta{
+		CombineText: inputText,
+		Files:       files,
+		MaxTokens:   maxTokens,
+	}
+}
+
+func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
+	if c.Query("alt") == "sse" {
+		return true
+	}
+	return false
+}
+
 func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
 func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
 	var tools []GeminiChatTool
 	var tools []GeminiChatTool
 	if strings.HasSuffix(string(r.Tools), "[") {
 	if strings.HasSuffix(string(r.Tools), "[") {
 		// is array
 		// is array
 		if err := common.Unmarshal(r.Tools, &tools); err != nil {
 		if err := common.Unmarshal(r.Tools, &tools); err != nil {
-			common.LogError(nil, "error_unmarshalling_tools: "+err.Error())
+			logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
 			return nil
 			return nil
 		}
 		}
 	} else if strings.HasPrefix(string(r.Tools), "{") {
 	} else if strings.HasPrefix(string(r.Tools), "{") {
 		// is object
 		// is object
 		singleTool := GeminiChatTool{}
 		singleTool := GeminiChatTool{}
 		if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
 		if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
-			common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
+			logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
 			return nil
 			return nil
 		}
 		}
 		tools = []GeminiChatTool{singleTool}
 		tools = []GeminiChatTool{singleTool}
@@ -43,7 +102,7 @@ func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
 	// Marshal the tools to JSON
 	// Marshal the tools to JSON
 	data, err := common.Marshal(tools)
 	data, err := common.Marshal(tools)
 	if err != nil {
 	if err != nil {
-		common.LogError(nil, "error_marshalling_tools: "+err.Error())
+		logger.LogError(nil, "error_marshalling_tools: "+err.Error())
 		return
 		return
 	}
 	}
 	r.Tools = data
 	r.Tools = data

+ 44 - 2
dto/dalle.go → dto/openai_image.go

@@ -1,11 +1,17 @@
 package dto
 package dto
 
 
-import "encoding/json"
+import (
+	"encoding/json"
+	"one-api/types"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
 
 
 type ImageRequest struct {
 type ImageRequest struct {
 	Model          string          `json:"model"`
 	Model          string          `json:"model"`
 	Prompt         string          `json:"prompt" binding:"required"`
 	Prompt         string          `json:"prompt" binding:"required"`
-	N              int             `json:"n,omitempty"`
+	N              uint            `json:"n,omitempty"`
 	Size           string          `json:"size,omitempty"`
 	Size           string          `json:"size,omitempty"`
 	Quality        string          `json:"quality,omitempty"`
 	Quality        string          `json:"quality,omitempty"`
 	ResponseFormat string          `json:"response_format,omitempty"`
 	ResponseFormat string          `json:"response_format,omitempty"`
@@ -18,6 +24,42 @@ type ImageRequest struct {
 	Watermark      *bool           `json:"watermark,omitempty"`
 	Watermark      *bool           `json:"watermark,omitempty"`
 }
 }
 
 
+func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var sizeRatio = 1.0
+	var qualityRatio = 1.0
+
+	if strings.HasPrefix(i.Model, "dall-e") {
+		// Size
+		if i.Size == "256x256" {
+			sizeRatio = 0.4
+		} else if i.Size == "512x512" {
+			sizeRatio = 0.45
+		} else if i.Size == "1024x1024" {
+			sizeRatio = 1
+		} else if i.Size == "1024x1792" || i.Size == "1792x1024" {
+			sizeRatio = 2
+		}
+
+		if i.Model == "dall-e-3" && i.Quality == "hd" {
+			qualityRatio = 2.0
+			if i.Size == "1024x1792" || i.Size == "1792x1024" {
+				qualityRatio = 1.5
+			}
+		}
+	}
+
+	// not support token count for dalle
+	return &types.TokenCountMeta{
+		CombineText:     i.Prompt,
+		MaxTokens:       1584,
+		ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
+	}
+}
+
+func (i *ImageRequest) IsStream(c *gin.Context) bool {
+	return false
+}
+
 type ImageResponse struct {
 type ImageResponse struct {
 	Data    []ImageData `json:"data"`
 	Data    []ImageData `json:"data"`
 	Created int64       `json:"created"`
 	Created int64       `json:"created"`

+ 271 - 24
dto/openai_request.go

@@ -2,8 +2,12 @@ package dto
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"fmt"
 	"one-api/common"
 	"one-api/common"
+	"one-api/types"
 	"strings"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 type ResponseFormat struct {
 type ResponseFormat struct {
@@ -67,6 +71,116 @@ type GeneralOpenAIRequest struct {
 	Extra map[string]json.RawMessage `json:"-"`
 	Extra map[string]json.RawMessage `json:"-"`
 }
 }
 
 
+func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var tokenCountMeta types.TokenCountMeta
+	var texts = make([]string, 0)
+	var fileMeta = make([]*types.FileMeta, 0)
+
+	if r.Prompt != nil {
+		switch v := r.Prompt.(type) {
+		case string:
+			texts = append(texts, v)
+		case []any:
+			for _, item := range v {
+				if str, ok := item.(string); ok {
+					texts = append(texts, str)
+				}
+			}
+		default:
+			texts = append(texts, fmt.Sprintf("%v", r.Prompt))
+		}
+	}
+
+	if r.Input != nil {
+		inputs := r.ParseInput()
+		texts = append(texts, inputs...)
+	}
+
+	if r.MaxCompletionTokens > r.MaxTokens {
+		tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
+	} else {
+		tokenCountMeta.MaxTokens = int(r.MaxTokens)
+	}
+
+	for _, message := range r.Messages {
+		tokenCountMeta.MessagesCount++
+		texts = append(texts, message.Role)
+		if message.Content != nil {
+			if message.Name != nil {
+				tokenCountMeta.NameCount++
+				texts = append(texts, *message.Name)
+			}
+			arrayContent := message.ParseContent()
+			for _, m := range arrayContent {
+				if m.Type == ContentTypeImageURL {
+					imageUrl := m.GetImageMedia()
+					if imageUrl != nil {
+						meta := &types.FileMeta{
+							FileType: types.FileTypeImage,
+						}
+						meta.Data = imageUrl.Url
+						meta.Detail = imageUrl.Detail
+						fileMeta = append(fileMeta, meta)
+					}
+				} else if m.Type == ContentTypeInputAudio {
+					inputAudio := m.GetInputAudio()
+					if inputAudio != nil {
+						meta := &types.FileMeta{
+							FileType: types.FileTypeAudio,
+						}
+						meta.Data = inputAudio.Data
+						fileMeta = append(fileMeta, meta)
+					}
+				} else if m.Type == ContentTypeFile {
+					file := m.GetFile()
+					if file != nil {
+						meta := &types.FileMeta{
+							FileType: types.FileTypeFile,
+						}
+						meta.Data = file.FileData
+						fileMeta = append(fileMeta, meta)
+					}
+				} else if m.Type == ContentTypeVideoUrl {
+					videoUrl := m.GetVideoUrl()
+					if videoUrl != nil {
+						meta := &types.FileMeta{
+							FileType: types.FileTypeVideo,
+						}
+						meta.Data = videoUrl.Url
+						fileMeta = append(fileMeta, meta)
+					}
+				} else {
+					texts = append(texts, m.Text)
+				}
+			}
+		}
+	}
+
+	if r.Tools != nil {
+		openaiTools := r.Tools
+		for _, tool := range openaiTools {
+			tokenCountMeta.ToolsCount++
+			texts = append(texts, tool.Function.Name)
+			if tool.Function.Description != "" {
+				texts = append(texts, tool.Function.Description)
+			}
+			if tool.Function.Parameters != nil {
+				texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
+			}
+		}
+		//toolTokens := CountTokenInput(countStr, request.Model)
+		//tkm += 8
+		//tkm += toolTokens
+	}
+	tokenCountMeta.CombineText = strings.Join(texts, "\n")
+	tokenCountMeta.Files = fileMeta
+	return &tokenCountMeta
+}
+
+func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
+	return r.Stream
+}
+
 func (r *GeneralOpenAIRequest) ToMap() map[string]any {
 func (r *GeneralOpenAIRequest) ToMap() map[string]any {
 	result := make(map[string]any)
 	result := make(map[string]any)
 	data, _ := common.Marshal(r)
 	data, _ := common.Marshal(r)
@@ -202,10 +316,25 @@ func (m *MediaContent) GetFile() *MessageFile {
 	return nil
 	return nil
 }
 }
 
 
+func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
+	if m.VideoUrl != nil {
+		if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
+			return m.VideoUrl.(*MessageVideoUrl)
+		}
+		if itemMap, ok := m.VideoUrl.(map[string]any); ok {
+			out := &MessageVideoUrl{
+				Url: common.Interface2String(itemMap["url"]),
+			}
+			return out
+		}
+	}
+	return nil
+}
+
 type MessageImageUrl struct {
 type MessageImageUrl struct {
-	Url      string `json:"url"`
-	Detail   string `json:"detail"`
-	MimeType string
+	Url    string `json:"url"`
+	Detail string `json:"detail"`
+	//MimeType string
 }
 }
 
 
 func (m *MessageImageUrl) IsRemoteImage() bool {
 func (m *MessageImageUrl) IsRemoteImage() bool {
@@ -233,6 +362,7 @@ const (
 	ContentTypeInputAudio = "input_audio"
 	ContentTypeInputAudio = "input_audio"
 	ContentTypeFile       = "file"
 	ContentTypeFile       = "file"
 	ContentTypeVideoUrl   = "video_url" // 阿里百炼视频识别
 	ContentTypeVideoUrl   = "video_url" // 阿里百炼视频识别
+	//ContentTypeAudioUrl   = "audio_url"
 )
 )
 
 
 func (m *Message) GetPrefix() bool {
 func (m *Message) GetPrefix() bool {
@@ -623,7 +753,7 @@ type WebSearchOptions struct {
 // https://platform.openai.com/docs/api-reference/responses/create
 // https://platform.openai.com/docs/api-reference/responses/create
 type OpenAIResponsesRequest struct {
 type OpenAIResponsesRequest struct {
 	Model              string           `json:"model"`
 	Model              string           `json:"model"`
-	Input              json.RawMessage  `json:"input,omitempty"`
+	Input              any              `json:"input,omitempty"`
 	Include            json.RawMessage  `json:"include,omitempty"`
 	Include            json.RawMessage  `json:"include,omitempty"`
 	Instructions       json.RawMessage  `json:"instructions,omitempty"`
 	Instructions       json.RawMessage  `json:"instructions,omitempty"`
 	MaxOutputTokens    uint             `json:"max_output_tokens,omitempty"`
 	MaxOutputTokens    uint             `json:"max_output_tokens,omitempty"`
@@ -645,28 +775,145 @@ type OpenAIResponsesRequest struct {
 	Prompt             json.RawMessage  `json:"prompt,omitempty"`
 	Prompt             json.RawMessage  `json:"prompt,omitempty"`
 }
 }
 
 
+func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var fileMeta = make([]*types.FileMeta, 0)
+	var texts = make([]string, 0)
+
+	if r.Input != nil {
+		inputs := r.ParseInput()
+		for _, input := range inputs {
+			if input.Type == "input_image" {
+				fileMeta = append(fileMeta, &types.FileMeta{
+					FileType: types.FileTypeImage,
+					Data:     input.ImageUrl,
+					Detail:   input.Detail,
+				})
+			} else if input.Type == "input_file" {
+				fileMeta = append(fileMeta, &types.FileMeta{
+					FileType: types.FileTypeFile,
+					Data:     input.FileUrl,
+				})
+			} else {
+				texts = append(texts, input.Text)
+			}
+		}
+	}
+
+	if len(r.Instructions) > 0 {
+		texts = append(texts, string(r.Instructions))
+	}
+
+	if len(r.Metadata) > 0 {
+		texts = append(texts, string(r.Metadata))
+	}
+
+	if len(r.Text) > 0 {
+		texts = append(texts, string(r.Text))
+	}
+
+	if len(r.ToolChoice) > 0 {
+		texts = append(texts, string(r.ToolChoice))
+	}
+
+	if len(r.Prompt) > 0 {
+		texts = append(texts, string(r.Prompt))
+	}
+
+	if len(r.Tools) > 0 {
+		toolStr, _ := common.Marshal(r.Tools)
+		texts = append(texts, string(toolStr))
+	}
+
+	return &types.TokenCountMeta{
+		CombineText: strings.Join(texts, "\n"),
+		Files:       fileMeta,
+		MaxTokens:   int(r.MaxOutputTokens),
+	}
+}
+
+func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
+	return r.Stream
+}
+
 type Reasoning struct {
 type Reasoning struct {
 	Effort  string `json:"effort,omitempty"`
 	Effort  string `json:"effort,omitempty"`
 	Summary string `json:"summary,omitempty"`
 	Summary string `json:"summary,omitempty"`
 }
 }
 
 
-//type ResponsesToolsCall struct {
-//	Type string `json:"type"`
-//	// Web Search
-//	UserLocation      json.RawMessage `json:"user_location,omitempty"`
-//	SearchContextSize string          `json:"search_context_size,omitempty"`
-//	// File Search
-//	VectorStoreIds []string        `json:"vector_store_ids,omitempty"`
-//	MaxNumResults  uint            `json:"max_num_results,omitempty"`
-//	Filters        json.RawMessage `json:"filters,omitempty"`
-//	// Computer Use
-//	DisplayWidth  uint   `json:"display_width,omitempty"`
-//	DisplayHeight uint   `json:"display_height,omitempty"`
-//	Environment   string `json:"environment,omitempty"`
-//	// Function
-//	Name        string          `json:"name,omitempty"`
-//	Description string          `json:"description,omitempty"`
-//	Parameters  json.RawMessage `json:"parameters,omitempty"`
-//	Function    json.RawMessage `json:"function,omitempty"`
-//	Container   json.RawMessage `json:"container,omitempty"`
-//}
+type MediaInput struct {
+	Type     string `json:"type"`
+	Text     string `json:"text,omitempty"`
+	FileUrl  string `json:"file_url,omitempty"`
+	ImageUrl string `json:"image_url,omitempty"`
+	Detail   string `json:"detail,omitempty"` // 仅 input_image 有效
+}
+
+// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
+// Reference implementation mirrors Message.ParseContent:
+//   - input can be a string, treated as an input_text item
+//   - input can be an array of objects with a `type` field
+//     supported types: input_text, input_image, input_file
+func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
+	if r.Input == nil {
+		return nil
+	}
+
+	var inputs []MediaInput
+
+	// Try string first
+	if str, ok := r.Input.(string); ok {
+		inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
+		return inputs
+	}
+
+	// Try array of parts
+	if array, ok := r.Input.([]any); ok {
+		for _, itemAny := range array {
+			// Already parsed MediaInput
+			if media, ok := itemAny.(MediaInput); ok {
+				inputs = append(inputs, media)
+				continue
+			}
+			// Generic map
+			item, ok := itemAny.(map[string]any)
+			if !ok {
+				continue
+			}
+			typeVal, ok := item["type"].(string)
+			if !ok {
+				continue
+			}
+			switch typeVal {
+			case "input_text":
+				text, _ := item["text"].(string)
+				inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
+			case "input_image":
+				// image_url may be string or object with url field
+				var imageUrl string
+				switch v := item["image_url"].(type) {
+				case string:
+					imageUrl = v
+				case map[string]any:
+					if url, ok := v["url"].(string); ok {
+						imageUrl = url
+					}
+				}
+				inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
+			case "input_file":
+				// file_url may be string or object with url field
+				var fileUrl string
+				switch v := item["file_url"].(type) {
+				case string:
+					fileUrl = v
+				case map[string]any:
+					if url, ok := v["url"].(string); ok {
+						fileUrl = url
+					}
+				}
+				inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
+			}
+		}
+	}
+
+	return inputs
+}

+ 11 - 0
dto/request_common.go

@@ -0,0 +1,11 @@
+package dto
+
+import (
+	"github.com/gin-gonic/gin"
+	"one-api/types"
+)
+
+type Request interface {
+	GetTokenCountMeta() *types.TokenCountMeta
+	IsStream(c *gin.Context) bool
+}

+ 27 - 0
dto/rerank.go

@@ -1,5 +1,12 @@
 package dto
 package dto
 
 
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"one-api/types"
+	"strings"
+)
+
 type RerankRequest struct {
 type RerankRequest struct {
 	Documents       []any  `json:"documents"`
 	Documents       []any  `json:"documents"`
 	Query           string `json:"query"`
 	Query           string `json:"query"`
@@ -10,6 +17,26 @@ type RerankRequest struct {
 	OverLapTokens   int    `json:"overlap_tokens,omitempty"`
 	OverLapTokens   int    `json:"overlap_tokens,omitempty"`
 }
 }
 
 
+func (r *RerankRequest) IsStream(c *gin.Context) bool {
+	return false
+}
+
+func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
+	var texts = make([]string, 0)
+
+	for _, document := range r.Documents {
+		texts = append(texts, fmt.Sprintf("%v", document))
+	}
+
+	if r.Query != "" {
+		texts = append(texts, r.Query)
+	}
+
+	return &types.TokenCountMeta{
+		CombineText: strings.Join(texts, "\n"),
+	}
+}
+
 func (r *RerankRequest) GetReturnDocuments() bool {
 func (r *RerankRequest) GetReturnDocuments() bool {
 	if r.ReturnDocuments == nil {
 	if r.ReturnDocuments == nil {
 		return false
 		return false

+ 115 - 0
logger/logger.go

@@ -0,0 +1,115 @@
+package logger
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
+	"io"
+	"log"
+	"one-api/common"
+	"os"
+	"path/filepath"
+	"sync"
+	"time"
+)
+
+const (
+	loggerINFO  = "INFO"
+	loggerWarn  = "WARN"
+	loggerError = "ERR"
+	loggerDebug = "DEBUG"
+)
+
+const maxLogCount = 1000000
+
+var logCount int
+var setupLogLock sync.Mutex
+var setupLogWorking bool
+
+func SetupLogger() {
+	if *common.LogDir != "" {
+		ok := setupLogLock.TryLock()
+		if !ok {
+			log.Println("setup log is already working")
+			return
+		}
+		defer func() {
+			setupLogLock.Unlock()
+			setupLogWorking = false
+		}()
+		logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
+		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+		if err != nil {
+			log.Fatal("failed to open log file")
+		}
+		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
+		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
+	}
+}
+
+func LogInfo(ctx context.Context, msg string) {
+	logHelper(ctx, loggerINFO, msg)
+}
+
+func LogWarn(ctx context.Context, msg string) {
+	logHelper(ctx, loggerWarn, msg)
+}
+
+func LogError(ctx context.Context, msg string) {
+	logHelper(ctx, loggerError, msg)
+}
+
+func LogDebug(ctx context.Context, msg string) {
+	if common.DebugEnabled {
+		logHelper(ctx, loggerDebug, msg)
+	}
+}
+
+func logHelper(ctx context.Context, level string, msg string) {
+	writer := gin.DefaultErrorWriter
+	if level == loggerINFO {
+		writer = gin.DefaultWriter
+	}
+	id := ctx.Value(common.RequestIdKey)
+	if id == nil {
+		id = "SYSTEM"
+	}
+	now := time.Now()
+	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
+	logCount++ // we don't need accurate count, so no lock here
+	if logCount > maxLogCount && !setupLogWorking {
+		logCount = 0
+		setupLogWorking = true
+		gopool.Go(func() {
+			SetupLogger()
+		})
+	}
+}
+
+func LogQuota(quota int) string {
+	if common.DisplayInCurrencyEnabled {
+		return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit)
+	} else {
+		return fmt.Sprintf("%d 点额度", quota)
+	}
+}
+
+func FormatQuota(quota int) string {
+	if common.DisplayInCurrencyEnabled {
+		return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit)
+	} else {
+		return fmt.Sprintf("%d", quota)
+	}
+}
+
+// LogJson 仅供测试使用 only for test
+func LogJson(ctx context.Context, msg string, obj any) {
+	jsonStr, err := json.Marshal(obj)
+	if err != nil {
+		LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
+		return
+	}
+	LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
+}

+ 19 - 18
main.go

@@ -8,6 +8,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/controller"
 	"one-api/controller"
+	"one-api/logger"
 	"one-api/middleware"
 	"one-api/middleware"
 	"one-api/model"
 	"one-api/model"
 	"one-api/router"
 	"one-api/router"
@@ -35,22 +36,22 @@ func main() {
 
 
 	err := InitResources()
 	err := InitResources()
 	if err != nil {
 	if err != nil {
-		common.FatalLog("failed to initialize resources: " + err.Error())
+		logger.FatalLog("failed to initialize resources: " + err.Error())
 		return
 		return
 	}
 	}
 
 
-	common.SysLog("New API " + common.Version + " started")
+	logger.SysLog("New API " + common.Version + " started")
 	if os.Getenv("GIN_MODE") != "debug" {
 	if os.Getenv("GIN_MODE") != "debug" {
 		gin.SetMode(gin.ReleaseMode)
 		gin.SetMode(gin.ReleaseMode)
 	}
 	}
 	if common.DebugEnabled {
 	if common.DebugEnabled {
-		common.SysLog("running in debug mode")
+		logger.SysLog("running in debug mode")
 	}
 	}
 
 
 	defer func() {
 	defer func() {
 		err := model.CloseDB()
 		err := model.CloseDB()
 		if err != nil {
 		if err != nil {
-			common.FatalLog("failed to close database: " + err.Error())
+			logger.FatalLog("failed to close database: " + err.Error())
 		}
 		}
 	}()
 	}()
 
 
@@ -59,18 +60,18 @@ func main() {
 		common.MemoryCacheEnabled = true
 		common.MemoryCacheEnabled = true
 	}
 	}
 	if common.MemoryCacheEnabled {
 	if common.MemoryCacheEnabled {
-		common.SysLog("memory cache enabled")
-		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
+		logger.SysLog("memory cache enabled")
+		logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
 
 
 		// Add panic recovery and retry for InitChannelCache
 		// Add panic recovery and retry for InitChannelCache
 		func() {
 		func() {
 			defer func() {
 			defer func() {
 				if r := recover(); r != nil {
 				if r := recover(); r != nil {
-					common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
+					logger.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
 					// Retry once
 					// Retry once
 					_, _, fixErr := model.FixAbility()
 					_, _, fixErr := model.FixAbility()
 					if fixErr != nil {
 					if fixErr != nil {
-						common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
+						logger.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
 					}
 					}
 				}
 				}
 			}()
 			}()
@@ -89,14 +90,14 @@ func main() {
 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
 	if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
 		if err != nil {
 		if err != nil {
-			common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
+			logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
 		}
 		}
 		go controller.AutomaticallyUpdateChannels(frequency)
 		go controller.AutomaticallyUpdateChannels(frequency)
 	}
 	}
 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
 	if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
 		frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
 		if err != nil {
 		if err != nil {
-			common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
+			logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
 		}
 		}
 		go controller.AutomaticallyTestChannels(frequency)
 		go controller.AutomaticallyTestChannels(frequency)
 	}
 	}
@@ -110,7 +111,7 @@ func main() {
 	}
 	}
 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 		common.BatchUpdateEnabled = true
 		common.BatchUpdateEnabled = true
-		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+		logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
 		model.InitBatchUpdater()
 		model.InitBatchUpdater()
 	}
 	}
 
 
@@ -119,13 +120,13 @@ func main() {
 			log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
 			log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
 		})
 		})
 		go common.Monitor()
 		go common.Monitor()
-		common.SysLog("pprof enabled")
+		logger.SysLog("pprof enabled")
 	}
 	}
 
 
 	// Initialize HTTP server
 	// Initialize HTTP server
 	server := gin.New()
 	server := gin.New()
 	server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
 	server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
-		common.SysError(fmt.Sprintf("panic detected: %v", err))
+		logger.SysError(fmt.Sprintf("panic detected: %v", err))
 		c.JSON(http.StatusInternalServerError, gin.H{
 		c.JSON(http.StatusInternalServerError, gin.H{
 			"error": gin.H{
 			"error": gin.H{
 				"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
 				"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
@@ -155,7 +156,7 @@ func main() {
 	}
 	}
 	err = server.Run(":" + port)
 	err = server.Run(":" + port)
 	if err != nil {
 	if err != nil {
-		common.FatalLog("failed to start HTTP server: " + err.Error())
+		logger.FatalLog("failed to start HTTP server: " + err.Error())
 	}
 	}
 }
 }
 
 
@@ -164,14 +165,14 @@ func InitResources() error {
 	// This is a placeholder function for future resource initialization
 	// This is a placeholder function for future resource initialization
 	err := godotenv.Load(".env")
 	err := godotenv.Load(".env")
 	if err != nil {
 	if err != nil {
-		common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
-		common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
+		logger.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
+		logger.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
 	}
 	}
 
 
 	// 加载环境变量
 	// 加载环境变量
 	common.InitEnv()
 	common.InitEnv()
 
 
-	common.SetupLogger()
+	logger.SetupLogger()
 
 
 	// Initialize model settings
 	// Initialize model settings
 	ratio_setting.InitRatioSettings()
 	ratio_setting.InitRatioSettings()
@@ -183,7 +184,7 @@ func InitResources() error {
 	// Initialize SQL Database
 	// Initialize SQL Database
 	err = model.InitDB()
 	err = model.InitDB()
 	if err != nil {
 	if err != nil {
-		common.FatalLog("failed to initialize database: " + err.Error())
+		logger.FatalLog("failed to initialize database: " + err.Error())
 		return err
 		return err
 	}
 	}
 
 

+ 3 - 3
middleware/recover.go

@@ -4,7 +4,7 @@ import (
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"net/http"
 	"net/http"
-	"one-api/common"
+	"one-api/logger"
 	"runtime/debug"
 	"runtime/debug"
 )
 )
 
 
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
 	return func(c *gin.Context) {
 	return func(c *gin.Context) {
 		defer func() {
 		defer func() {
 			if err := recover(); err != nil {
 			if err := recover(); err != nil {
-				common.SysError(fmt.Sprintf("panic detected: %v", err))
-				common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
+				logger.SysError(fmt.Sprintf("panic detected: %v", err))
+				logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
 				c.JSON(http.StatusInternalServerError, gin.H{
 				c.JSON(http.StatusInternalServerError, gin.H{
 					"error": gin.H{
 					"error": gin.H{
 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
 						"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),

+ 3 - 2
middleware/turnstile-check.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 )
 )
 
 
 type turnstileCheckResponse struct {
 type turnstileCheckResponse struct {
@@ -37,7 +38,7 @@ func TurnstileCheck() gin.HandlerFunc {
 				"remoteip": {c.ClientIP()},
 				"remoteip": {c.ClientIP()},
 			})
 			})
 			if err != nil {
 			if err != nil {
-				common.SysError(err.Error())
+				logger.SysError(err.Error())
 				c.JSON(http.StatusOK, gin.H{
 				c.JSON(http.StatusOK, gin.H{
 					"success": false,
 					"success": false,
 					"message": err.Error(),
 					"message": err.Error(),
@@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
 			var res turnstileCheckResponse
 			var res turnstileCheckResponse
 			err = json.NewDecoder(rawRes.Body).Decode(&res)
 			err = json.NewDecoder(rawRes.Body).Decode(&res)
 			if err != nil {
 			if err != nil {
-				common.SysError(err.Error())
+				logger.SysError(err.Error())
 				c.JSON(http.StatusOK, gin.H{
 				c.JSON(http.StatusOK, gin.H{
 					"success": false,
 					"success": false,
 					"message": err.Error(),
 					"message": err.Error(),

+ 3 - 2
middleware/utils.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 )
 )
 
 
 func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
 func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
@@ -15,7 +16,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
 		},
 		},
 	})
 	})
 	c.Abort()
 	c.Abort()
-	common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
+	logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
 }
 }
 
 
 func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
 func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
@@ -25,5 +26,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri
 		"code":        code,
 		"code":        code,
 	})
 	})
 	c.Abort()
 	c.Abort()
-	common.LogError(c.Request.Context(), description)
+	logger.LogError(c.Request.Context(), description)
 }
 }

+ 5 - 4
model/ability.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 
 
@@ -294,13 +295,13 @@ func FixAbility() (int, int, error) {
 	if common.UsingSQLite {
 	if common.UsingSQLite {
 		err := DB.Exec("DELETE FROM abilities").Error
 		err := DB.Exec("DELETE FROM abilities").Error
 		if err != nil {
 		if err != nil {
-			common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+			logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
 			return 0, 0, err
 			return 0, 0, err
 		}
 		}
 	} else {
 	} else {
 		err := DB.Exec("TRUNCATE TABLE abilities").Error
 		err := DB.Exec("TRUNCATE TABLE abilities").Error
 		if err != nil {
 		if err != nil {
-			common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
+			logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
 			return 0, 0, err
 			return 0, 0, err
 		}
 		}
 	}
 	}
@@ -320,7 +321,7 @@ func FixAbility() (int, int, error) {
 		// Delete all abilities of this channel
 		// Delete all abilities of this channel
 		err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
 		err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
 		if err != nil {
 		if err != nil {
-			common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+			logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
 			failCount += len(chunk)
 			failCount += len(chunk)
 			continue
 			continue
 		}
 		}
@@ -328,7 +329,7 @@ func FixAbility() (int, int, error) {
 		for _, channel := range chunk {
 		for _, channel := range chunk {
 			err = channel.AddAbilities(nil)
 			err = channel.AddAbilities(nil)
 			if err != nil {
 			if err != nil {
-				common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
+				logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
 				failCount++
 				failCount++
 			} else {
 			} else {
 				successCount++
 				successCount++

+ 14 - 13
model/channel.go

@@ -9,6 +9,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/types"
 	"one-api/types"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -209,7 +210,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
 	if channel.OtherInfo != "" {
 	if channel.OtherInfo != "" {
 		err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
 		err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to unmarshal other info: " + err.Error())
+			logger.SysError("failed to unmarshal other info: " + err.Error())
 		}
 		}
 	}
 	}
 	return otherInfo
 	return otherInfo
@@ -218,7 +219,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
 func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
 func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
 	otherInfoBytes, err := json.Marshal(otherInfo)
 	otherInfoBytes, err := json.Marshal(otherInfo)
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to marshal other info: " + err.Error())
+		logger.SysError("failed to marshal other info: " + err.Error())
 		return
 		return
 	}
 	}
 	channel.OtherInfo = string(otherInfoBytes)
 	channel.OtherInfo = string(otherInfoBytes)
@@ -488,7 +489,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
 		ResponseTime: int(responseTime),
 		ResponseTime: int(responseTime),
 	}).Error
 	}).Error
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to update response time: " + err.Error())
+		logger.SysError("failed to update response time: " + err.Error())
 	}
 	}
 }
 }
 
 
@@ -498,7 +499,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
 		Balance:            balance,
 		Balance:            balance,
 	}).Error
 	}).Error
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to update balance: " + err.Error())
+		logger.SysError("failed to update balance: " + err.Error())
 	}
 	}
 }
 }
 
 
@@ -614,7 +615,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
 		if shouldUpdateAbilities {
 		if shouldUpdateAbilities {
 			err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
 			err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
 			if err != nil {
 			if err != nil {
-				common.SysError("failed to update ability status: " + err.Error())
+				logger.SysError("failed to update ability status: " + err.Error())
 			}
 			}
 		}
 		}
 	}()
 	}()
@@ -642,7 +643,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
 		}
 		}
 		err = channel.Save()
 		err = channel.Save()
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to update channel status: " + err.Error())
+			logger.SysError("failed to update channel status: " + err.Error())
 			return false
 			return false
 		}
 		}
 	}
 	}
@@ -704,7 +705,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
 			for _, channel := range channels {
 			for _, channel := range channels {
 				err = channel.UpdateAbilities(nil)
 				err = channel.UpdateAbilities(nil)
 				if err != nil {
 				if err != nil {
-					common.SysError("failed to update abilities: " + err.Error())
+					logger.SysError("failed to update abilities: " + err.Error())
 				}
 				}
 			}
 			}
 		}
 		}
@@ -728,7 +729,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
 func updateChannelUsedQuota(id int, quota int) {
 func updateChannelUsedQuota(id int, quota int) {
 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to update channel used quota: " + err.Error())
+		logger.SysError("failed to update channel used quota: " + err.Error())
 	}
 	}
 }
 }
 
 
@@ -821,7 +822,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
 	if channel.Setting != nil && *channel.Setting != "" {
 	if channel.Setting != nil && *channel.Setting != "" {
 		err := common.Unmarshal([]byte(*channel.Setting), &setting)
 		err := common.Unmarshal([]byte(*channel.Setting), &setting)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to unmarshal setting: " + err.Error())
+			logger.SysError("failed to unmarshal setting: " + err.Error())
 			channel.Setting = nil // 清空设置以避免后续错误
 			channel.Setting = nil // 清空设置以避免后续错误
 			_ = channel.Save()    // 保存修改
 			_ = channel.Save()    // 保存修改
 		}
 		}
@@ -832,7 +833,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
 func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
 func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
 	settingBytes, err := common.Marshal(setting)
 	settingBytes, err := common.Marshal(setting)
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to marshal setting: " + err.Error())
+		logger.SysError("failed to marshal setting: " + err.Error())
 		return
 		return
 	}
 	}
 	channel.Setting = common.GetPointer[string](string(settingBytes))
 	channel.Setting = common.GetPointer[string](string(settingBytes))
@@ -843,7 +844,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
 	if channel.OtherSettings != "" {
 	if channel.OtherSettings != "" {
 		err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
 		err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to unmarshal setting: " + err.Error())
+			logger.SysError("failed to unmarshal setting: " + err.Error())
 			channel.OtherSettings = "{}" // 清空设置以避免后续错误
 			channel.OtherSettings = "{}" // 清空设置以避免后续错误
 			_ = channel.Save()           // 保存修改
 			_ = channel.Save()           // 保存修改
 		}
 		}
@@ -854,7 +855,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
 func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
 func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
 	settingBytes, err := common.Marshal(setting)
 	settingBytes, err := common.Marshal(setting)
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to marshal setting: " + err.Error())
+		logger.SysError("failed to marshal setting: " + err.Error())
 		return
 		return
 	}
 	}
 	channel.OtherSettings = string(settingBytes)
 	channel.OtherSettings = string(settingBytes)
@@ -865,7 +866,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
 	if channel.ParamOverride != nil && *channel.ParamOverride != "" {
 	if channel.ParamOverride != nil && *channel.ParamOverride != "" {
 		err := common.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
 		err := common.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to unmarshal param override: " + err.Error())
+			logger.SysError("failed to unmarshal param override: " + err.Error())
 		}
 		}
 	}
 	}
 	return paramOverride
 	return paramOverride

+ 3 - 2
model/channel_cache.go

@@ -6,6 +6,7 @@ import (
 	"math/rand"
 	"math/rand"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
+	"one-api/logger"
 	"one-api/setting"
 	"one-api/setting"
 	"one-api/setting/ratio_setting"
 	"one-api/setting/ratio_setting"
 	"sort"
 	"sort"
@@ -84,13 +85,13 @@ func InitChannelCache() {
 	}
 	}
 	channelsIDM = newChannelId2channel
 	channelsIDM = newChannelId2channel
 	channelSyncLock.Unlock()
 	channelSyncLock.Unlock()
-	common.SysLog("channels synced from database")
+	logger.SysLog("channels synced from database")
 }
 }
 
 
 func SyncChannelCache(frequency int) {
 func SyncChannelCache(frequency int) {
 	for {
 	for {
 		time.Sleep(time.Duration(frequency) * time.Second)
 		time.Sleep(time.Duration(frequency) * time.Second)
-		common.SysLog("syncing channels from database")
+		logger.SysLog("syncing channels from database")
 		InitChannelCache()
 		InitChannelCache()
 	}
 	}
 }
 }

+ 6 - 6
model/log.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -87,13 +88,13 @@ func RecordLog(userId int, logType int, content string) {
 	}
 	}
 	err := LOG_DB.Create(log).Error
 	err := LOG_DB.Create(log).Error
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to record log: " + err.Error())
+		logger.SysError("failed to record log: " + err.Error())
 	}
 	}
 }
 }
 
 
 func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
 func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
 	isStream bool, group string, other map[string]interface{}) {
 	isStream bool, group string, other map[string]interface{}) {
-	common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
+	logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
 	username := c.GetString("username")
 	username := c.GetString("username")
 	otherStr := common.MapToJsonStr(other)
 	otherStr := common.MapToJsonStr(other)
 	// 判断是否需要记录 IP
 	// 判断是否需要记录 IP
@@ -129,7 +130,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
 	}
 	}
 	err := LOG_DB.Create(log).Error
 	err := LOG_DB.Create(log).Error
 	if err != nil {
 	if err != nil {
-		common.LogError(c, "failed to record log: "+err.Error())
+		logger.LogError(c, "failed to record log: "+err.Error())
 	}
 	}
 }
 }
 
 
@@ -142,7 +143,6 @@ type RecordConsumeLogParams struct {
 	Quota            int                    `json:"quota"`
 	Quota            int                    `json:"quota"`
 	Content          string                 `json:"content"`
 	Content          string                 `json:"content"`
 	TokenId          int                    `json:"token_id"`
 	TokenId          int                    `json:"token_id"`
-	UserQuota        int                    `json:"user_quota"`
 	UseTimeSeconds   int                    `json:"use_time_seconds"`
 	UseTimeSeconds   int                    `json:"use_time_seconds"`
 	IsStream         bool                   `json:"is_stream"`
 	IsStream         bool                   `json:"is_stream"`
 	Group            string                 `json:"group"`
 	Group            string                 `json:"group"`
@@ -150,7 +150,7 @@ type RecordConsumeLogParams struct {
 }
 }
 
 
 func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
 func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
-	common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
+	logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
 	if !common.LogConsumeEnabled {
 	if !common.LogConsumeEnabled {
 		return
 		return
 	}
 	}
@@ -189,7 +189,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
 	}
 	}
 	err := LOG_DB.Create(log).Error
 	err := LOG_DB.Create(log).Error
 	if err != nil {
 	if err != nil {
-		common.LogError(c, "failed to record log: "+err.Error())
+		logger.LogError(c, "failed to record log: "+err.Error())
 	}
 	}
 	if common.DataExportEnabled {
 	if common.DataExportEnabled {
 		gopool.Go(func() {
 		gopool.Go(func() {

+ 16 - 15
model/main.go

@@ -5,6 +5,7 @@ import (
 	"log"
 	"log"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
+	"one-api/logger"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -84,7 +85,7 @@ func createRootAccountIfNeed() error {
 	var user User
 	var user User
 	//if user.Status != common.UserStatusEnabled {
 	//if user.Status != common.UserStatusEnabled {
 	if err := DB.First(&user).Error; err != nil {
 	if err := DB.First(&user).Error; err != nil {
-		common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
+		logger.SysLog("no user exists, create a root user for you: username is root, password is 123456")
 		hashedPassword, err := common.Password2Hash("123456")
 		hashedPassword, err := common.Password2Hash("123456")
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -108,7 +109,7 @@ func CheckSetup() {
 	if setup == nil {
 	if setup == nil {
 		// No setup record exists, check if we have a root user
 		// No setup record exists, check if we have a root user
 		if RootUserExists() {
 		if RootUserExists() {
-			common.SysLog("system is not initialized, but root user exists")
+			logger.SysLog("system is not initialized, but root user exists")
 			// Create setup record
 			// Create setup record
 			newSetup := Setup{
 			newSetup := Setup{
 				Version:       common.Version,
 				Version:       common.Version,
@@ -116,16 +117,16 @@ func CheckSetup() {
 			}
 			}
 			err := DB.Create(&newSetup).Error
 			err := DB.Create(&newSetup).Error
 			if err != nil {
 			if err != nil {
-				common.SysLog("failed to create setup record: " + err.Error())
+				logger.SysLog("failed to create setup record: " + err.Error())
 			}
 			}
 			constant.Setup = true
 			constant.Setup = true
 		} else {
 		} else {
-			common.SysLog("system is not initialized and no root user exists")
+			logger.SysLog("system is not initialized and no root user exists")
 			constant.Setup = false
 			constant.Setup = false
 		}
 		}
 	} else {
 	} else {
 		// Setup record exists, system is initialized
 		// Setup record exists, system is initialized
-		common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
+		logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
 		constant.Setup = true
 		constant.Setup = true
 	}
 	}
 }
 }
@@ -138,7 +139,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
 	if dsn != "" {
 	if dsn != "" {
 		if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
 		if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
 			// Use PostgreSQL
 			// Use PostgreSQL
-			common.SysLog("using PostgreSQL as database")
+			logger.SysLog("using PostgreSQL as database")
 			if !isLog {
 			if !isLog {
 				common.UsingPostgreSQL = true
 				common.UsingPostgreSQL = true
 			} else {
 			} else {
@@ -152,7 +153,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
 			})
 			})
 		}
 		}
 		if strings.HasPrefix(dsn, "local") {
 		if strings.HasPrefix(dsn, "local") {
-			common.SysLog("SQL_DSN not set, using SQLite as database")
+			logger.SysLog("SQL_DSN not set, using SQLite as database")
 			if !isLog {
 			if !isLog {
 				common.UsingSQLite = true
 				common.UsingSQLite = true
 			} else {
 			} else {
@@ -163,7 +164,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
 			})
 			})
 		}
 		}
 		// Use MySQL
 		// Use MySQL
-		common.SysLog("using MySQL as database")
+		logger.SysLog("using MySQL as database")
 		// check parseTime
 		// check parseTime
 		if !strings.Contains(dsn, "parseTime") {
 		if !strings.Contains(dsn, "parseTime") {
 			if strings.Contains(dsn, "?") {
 			if strings.Contains(dsn, "?") {
@@ -182,7 +183,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
 		})
 		})
 	}
 	}
 	// Use SQLite
 	// Use SQLite
-	common.SysLog("SQL_DSN not set, using SQLite as database")
+	logger.SysLog("SQL_DSN not set, using SQLite as database")
 	common.UsingSQLite = true
 	common.UsingSQLite = true
 	return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
 	return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
 		PrepareStmt: true, // precompile SQL
 		PrepareStmt: true, // precompile SQL
@@ -216,11 +217,11 @@ func InitDB() (err error) {
 		if common.UsingMySQL {
 		if common.UsingMySQL {
 			//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
 			//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
 		}
 		}
-		common.SysLog("database migration started")
+		logger.SysLog("database migration started")
 		err = migrateDB()
 		err = migrateDB()
 		return err
 		return err
 	} else {
 	} else {
-		common.FatalLog(err)
+		logger.FatalLog(err)
 	}
 	}
 	return err
 	return err
 }
 }
@@ -253,11 +254,11 @@ func InitLogDB() (err error) {
 		if !common.IsMasterNode {
 		if !common.IsMasterNode {
 			return nil
 			return nil
 		}
 		}
-		common.SysLog("database migration started")
+		logger.SysLog("database migration started")
 		err = migrateLOGDB()
 		err = migrateLOGDB()
 		return err
 		return err
 	} else {
 	} else {
-		common.FatalLog(err)
+		logger.FatalLog(err)
 	}
 	}
 	return err
 	return err
 }
 }
@@ -354,7 +355,7 @@ func migrateDBFast() error {
 			return err
 			return err
 		}
 		}
 	}
 	}
-	common.SysLog("database migrated")
+	logger.SysLog("database migrated")
 	return nil
 	return nil
 }
 }
 
 
@@ -503,6 +504,6 @@ func PingDB() error {
 	}
 	}
 
 
 	lastPingTime = time.Now()
 	lastPingTime = time.Now()
-	common.SysLog("Database pinged successfully")
+	logger.SysLog("Database pinged successfully")
 	return nil
 	return nil
 }
 }

+ 3 - 2
model/option.go

@@ -2,6 +2,7 @@ package model
 
 
 import (
 import (
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"one-api/setting"
 	"one-api/setting"
 	"one-api/setting/config"
 	"one-api/setting/config"
 	"one-api/setting/operation_setting"
 	"one-api/setting/operation_setting"
@@ -150,7 +151,7 @@ func loadOptionsFromDatabase() {
 	for _, option := range options {
 	for _, option := range options {
 		err := updateOptionMap(option.Key, option.Value)
 		err := updateOptionMap(option.Key, option.Value)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to update option map: " + err.Error())
+			logger.SysError("failed to update option map: " + err.Error())
 		}
 		}
 	}
 	}
 }
 }
@@ -158,7 +159,7 @@ func loadOptionsFromDatabase() {
 func SyncOptions(frequency int) {
 func SyncOptions(frequency int) {
 	for {
 	for {
 		time.Sleep(time.Duration(frequency) * time.Second)
 		time.Sleep(time.Duration(frequency) * time.Second)
-		common.SysLog("syncing options from database")
+		logger.SysLog("syncing options from database")
 		loadOptionsFromDatabase()
 		loadOptionsFromDatabase()
 	}
 	}
 }
 }

+ 2 - 1
model/pricing.go

@@ -3,6 +3,7 @@ package model
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
+	"one-api/logger"
 	"strings"
 	"strings"
 
 
 	"one-api/common"
 	"one-api/common"
@@ -92,7 +93,7 @@ func updatePricing() {
 	//modelRatios := common.GetModelRatios()
 	//modelRatios := common.GetModelRatios()
 	enableAbilities, err := GetAllEnableAbilityWithChannels()
 	enableAbilities, err := GetAllEnableAbilityWithChannels()
 	if err != nil {
 	if err != nil {
-		common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
+		logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
 		return
 		return
 	}
 	}
 	// 预加载模型元数据与供应商一次,避免循环查询
 	// 预加载模型元数据与供应商一次,避免循环查询

+ 2 - 1
model/redemption.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"strconv"
 	"strconv"
 
 
 	"gorm.io/gorm"
 	"gorm.io/gorm"
@@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) {
 	if err != nil {
 	if err != nil {
 		return 0, errors.New("兑换失败," + err.Error())
 		return 0, errors.New("兑换失败," + err.Error())
 	}
 	}
-	RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
+	RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
 	return redemption.Quota, nil
 	return redemption.Quota, nil
 }
 }
 
 

+ 10 - 9
model/token.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"strings"
 	"strings"
 
 
 	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/bytedance/gopkg/util/gopool"
@@ -91,7 +92,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
 				token.Status = common.TokenStatusExpired
 				token.Status = common.TokenStatusExpired
 				err := token.SelectUpdate()
 				err := token.SelectUpdate()
 				if err != nil {
 				if err != nil {
-					common.SysError("failed to update token status" + err.Error())
+					logger.SysError("failed to update token status" + err.Error())
 				}
 				}
 			}
 			}
 			return token, errors.New("该令牌已过期")
 			return token, errors.New("该令牌已过期")
@@ -102,7 +103,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
 				token.Status = common.TokenStatusExhausted
 				token.Status = common.TokenStatusExhausted
 				err := token.SelectUpdate()
 				err := token.SelectUpdate()
 				if err != nil {
 				if err != nil {
-					common.SysError("failed to update token status" + err.Error())
+					logger.SysError("failed to update token status" + err.Error())
 				}
 				}
 			}
 			}
 			keyPrefix := key[:3]
 			keyPrefix := key[:3]
@@ -134,7 +135,7 @@ func GetTokenById(id int) (*Token, error) {
 	if shouldUpdateRedis(true, err) {
 	if shouldUpdateRedis(true, err) {
 		gopool.Go(func() {
 		gopool.Go(func() {
 			if err := cacheSetToken(token); err != nil {
 			if err := cacheSetToken(token); err != nil {
-				common.SysError("failed to update user status cache: " + err.Error())
+				logger.SysError("failed to update user status cache: " + err.Error())
 			}
 			}
 		})
 		})
 	}
 	}
@@ -147,7 +148,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
 		if shouldUpdateRedis(fromDB, err) && token != nil {
 		if shouldUpdateRedis(fromDB, err) && token != nil {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				if err := cacheSetToken(*token); err != nil {
 				if err := cacheSetToken(*token); err != nil {
-					common.SysError("failed to update user status cache: " + err.Error())
+					logger.SysError("failed to update user status cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}
@@ -178,7 +179,7 @@ func (token *Token) Update() (err error) {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				err := cacheSetToken(*token)
 				err := cacheSetToken(*token)
 				if err != nil {
 				if err != nil {
-					common.SysError("failed to update token cache: " + err.Error())
+					logger.SysError("failed to update token cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}
@@ -194,7 +195,7 @@ func (token *Token) SelectUpdate() (err error) {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				err := cacheSetToken(*token)
 				err := cacheSetToken(*token)
 				if err != nil {
 				if err != nil {
-					common.SysError("failed to update token cache: " + err.Error())
+					logger.SysError("failed to update token cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}
@@ -209,7 +210,7 @@ func (token *Token) Delete() (err error) {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				err := cacheDeleteToken(token.Key)
 				err := cacheDeleteToken(token.Key)
 				if err != nil {
 				if err != nil {
-					common.SysError("failed to delete token cache: " + err.Error())
+					logger.SysError("failed to delete token cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}
@@ -269,7 +270,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
 		gopool.Go(func() {
 		gopool.Go(func() {
 			err := cacheIncrTokenQuota(key, int64(quota))
 			err := cacheIncrTokenQuota(key, int64(quota))
 			if err != nil {
 			if err != nil {
-				common.SysError("failed to increase token quota: " + err.Error())
+				logger.SysError("failed to increase token quota: " + err.Error())
 			}
 			}
 		})
 		})
 	}
 	}
@@ -299,7 +300,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) {
 		gopool.Go(func() {
 		gopool.Go(func() {
 			err := cacheDecrTokenQuota(key, int64(quota))
 			err := cacheDecrTokenQuota(key, int64(quota))
 			if err != nil {
 			if err != nil {
-				common.SysError("failed to decrease token quota: " + err.Error())
+				logger.SysError("failed to decrease token quota: " + err.Error())
 			}
 			}
 		})
 		})
 	}
 	}

+ 2 - 1
model/topup.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 
 
 	"gorm.io/gorm"
 	"gorm.io/gorm"
 )
 )
@@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) {
 		return errors.New("充值失败," + err.Error())
 		return errors.New("充值失败," + err.Error())
 	}
 	}
 
 
-	RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount))
+	RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount))
 
 
 	return nil
 	return nil
 }
 }

+ 5 - 4
model/twofa.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"time"
 	"time"
 
 
 	"gorm.io/gorm"
 	"gorm.io/gorm"
@@ -243,7 +244,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
 	if !common.ValidateTOTPCode(t.Secret, code) {
 	if !common.ValidateTOTPCode(t.Secret, code) {
 		// 增加失败次数
 		// 增加失败次数
 		if err := t.IncrementFailedAttempts(); err != nil {
 		if err := t.IncrementFailedAttempts(); err != nil {
-			common.SysError("更新2FA失败次数失败: " + err.Error())
+			logger.SysError("更新2FA失败次数失败: " + err.Error())
 		}
 		}
 		return false, nil
 		return false, nil
 	}
 	}
@@ -255,7 +256,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
 	t.LastUsedAt = &now
 	t.LastUsedAt = &now
 
 
 	if err := t.Update(); err != nil {
 	if err := t.Update(); err != nil {
-		common.SysError("更新2FA使用记录失败: " + err.Error())
+		logger.SysError("更新2FA使用记录失败: " + err.Error())
 	}
 	}
 
 
 	return true, nil
 	return true, nil
@@ -277,7 +278,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
 	if !valid {
 	if !valid {
 		// 增加失败次数
 		// 增加失败次数
 		if err := t.IncrementFailedAttempts(); err != nil {
 		if err := t.IncrementFailedAttempts(); err != nil {
-			common.SysError("更新2FA失败次数失败: " + err.Error())
+			logger.SysError("更新2FA失败次数失败: " + err.Error())
 		}
 		}
 		return false, nil
 		return false, nil
 	}
 	}
@@ -289,7 +290,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
 	t.LastUsedAt = &now
 	t.LastUsedAt = &now
 
 
 	if err := t.Update(); err != nil {
 	if err := t.Update(); err != nil {
-		common.SysError("更新2FA使用记录失败: " + err.Error())
+		logger.SysError("更新2FA使用记录失败: " + err.Error())
 	}
 	}
 
 
 	return true, nil
 	return true, nil

+ 5 - 4
model/usedata.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"fmt"
 	"gorm.io/gorm"
 	"gorm.io/gorm"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"sync"
 	"sync"
 	"time"
 	"time"
 )
 )
@@ -24,12 +25,12 @@ func UpdateQuotaData() {
 	// recover
 	// recover
 	defer func() {
 	defer func() {
 		if r := recover(); r != nil {
 		if r := recover(); r != nil {
-			common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
+			logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
 		}
 		}
 	}()
 	}()
 	for {
 	for {
 		if common.DataExportEnabled {
 		if common.DataExportEnabled {
-			common.SysLog("正在更新数据看板数据...")
+			logger.SysLog("正在更新数据看板数据...")
 			SaveQuotaDataCache()
 			SaveQuotaDataCache()
 		}
 		}
 		time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
 		time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
@@ -91,7 +92,7 @@ func SaveQuotaDataCache() {
 		}
 		}
 	}
 	}
 	CacheQuotaData = make(map[string]*QuotaData)
 	CacheQuotaData = make(map[string]*QuotaData)
-	common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
+	logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
 }
 }
 
 
 func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
 func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
@@ -102,7 +103,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int,
 		"token_used": gorm.Expr("token_used + ?", tokenUsed),
 		"token_used": gorm.Expr("token_used + ?", tokenUsed),
 	}).Error
 	}).Error
 	if err != nil {
 	if err != nil {
-		common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
+		logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
 	}
 	}
 }
 }
 
 

+ 17 - 16
model/user.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 
 
@@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting {
 	if user.Setting != "" {
 	if user.Setting != "" {
 		err := json.Unmarshal([]byte(user.Setting), &setting)
 		err := json.Unmarshal([]byte(user.Setting), &setting)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to unmarshal setting: " + err.Error())
+			logger.SysError("failed to unmarshal setting: " + err.Error())
 		}
 		}
 	}
 	}
 	return setting
 	return setting
@@ -84,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting {
 func (user *User) SetSetting(setting dto.UserSetting) {
 func (user *User) SetSetting(setting dto.UserSetting) {
 	settingBytes, err := json.Marshal(setting)
 	settingBytes, err := json.Marshal(setting)
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to marshal setting: " + err.Error())
+		logger.SysError("failed to marshal setting: " + err.Error())
 		return
 		return
 	}
 	}
 	user.Setting = string(settingBytes)
 	user.Setting = string(settingBytes)
@@ -274,7 +275,7 @@ func inviteUser(inviterId int) (err error) {
 func (user *User) TransferAffQuotaToQuota(quota int) error {
 func (user *User) TransferAffQuotaToQuota(quota int) error {
 	// 检查quota是否小于最小额度
 	// 检查quota是否小于最小额度
 	if float64(quota) < common.QuotaPerUnit {
 	if float64(quota) < common.QuotaPerUnit {
-		return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit)))
+		return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit)))
 	}
 	}
 
 
 	// 开始数据库事务
 	// 开始数据库事务
@@ -324,16 +325,16 @@ func (user *User) Insert(inviterId int) error {
 		return result.Error
 		return result.Error
 	}
 	}
 	if common.QuotaForNewUser > 0 {
 	if common.QuotaForNewUser > 0 {
-		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
+		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
 	}
 	}
 	if inviterId != 0 {
 	if inviterId != 0 {
 		if common.QuotaForInvitee > 0 {
 		if common.QuotaForInvitee > 0 {
 			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
 			_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
-			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
+			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
 		}
 		}
 		if common.QuotaForInviter > 0 {
 		if common.QuotaForInviter > 0 {
 			//_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
 			//_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
-			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
+			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
 			_ = inviteUser(inviterId)
 			_ = inviteUser(inviterId)
 		}
 		}
 	}
 	}
@@ -517,7 +518,7 @@ func IsAdmin(userId int) bool {
 	var user User
 	var user User
 	err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
 	err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
 	if err != nil {
 	if err != nil {
-		common.SysError("no such user " + err.Error())
+		logger.SysError("no such user " + err.Error())
 		return false
 		return false
 	}
 	}
 	return user.Role >= common.RoleAdminUser
 	return user.Role >= common.RoleAdminUser
@@ -572,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
 		if shouldUpdateRedis(fromDB, err) {
 		if shouldUpdateRedis(fromDB, err) {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				if err := updateUserQuotaCache(id, quota); err != nil {
 				if err := updateUserQuotaCache(id, quota); err != nil {
-					common.SysError("failed to update user quota cache: " + err.Error())
+					logger.SysError("failed to update user quota cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}
@@ -610,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
 		if shouldUpdateRedis(fromDB, err) {
 		if shouldUpdateRedis(fromDB, err) {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				if err := updateUserGroupCache(id, group); err != nil {
 				if err := updateUserGroupCache(id, group); err != nil {
-					common.SysError("failed to update user group cache: " + err.Error())
+					logger.SysError("failed to update user group cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}
@@ -639,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
 		if shouldUpdateRedis(fromDB, err) {
 		if shouldUpdateRedis(fromDB, err) {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				if err := updateUserSettingCache(id, setting); err != nil {
 				if err := updateUserSettingCache(id, setting); err != nil {
-					common.SysError("failed to update user setting cache: " + err.Error())
+					logger.SysError("failed to update user setting cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}
@@ -669,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) {
 	gopool.Go(func() {
 	gopool.Go(func() {
 		err := cacheIncrUserQuota(id, int64(quota))
 		err := cacheIncrUserQuota(id, int64(quota))
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to increase user quota: " + err.Error())
+			logger.SysError("failed to increase user quota: " + err.Error())
 		}
 		}
 	})
 	})
 	if !db && common.BatchUpdateEnabled {
 	if !db && common.BatchUpdateEnabled {
@@ -694,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 	gopool.Go(func() {
 	gopool.Go(func() {
 		err := cacheDecrUserQuota(id, int64(quota))
 		err := cacheDecrUserQuota(id, int64(quota))
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to decrease user quota: " + err.Error())
+			logger.SysError("failed to decrease user quota: " + err.Error())
 		}
 		}
 	})
 	})
 	if common.BatchUpdateEnabled {
 	if common.BatchUpdateEnabled {
@@ -750,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 		},
 		},
 	).Error
 	).Error
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to update user used quota and request count: " + err.Error())
+		logger.SysError("failed to update user used quota and request count: " + err.Error())
 		return
 		return
 	}
 	}
 
 
@@ -767,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) {
 		},
 		},
 	).Error
 	).Error
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to update user used quota: " + err.Error())
+		logger.SysError("failed to update user used quota: " + err.Error())
 	}
 	}
 }
 }
 
 
 func updateUserRequestCount(id int, count int) {
 func updateUserRequestCount(id int, count int) {
 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
 	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
 	if err != nil {
 	if err != nil {
-		common.SysError("failed to update user request count: " + err.Error())
+		logger.SysError("failed to update user request count: " + err.Error())
 	}
 	}
 }
 }
 
 
@@ -785,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
 		if shouldUpdateRedis(fromDB, err) {
 		if shouldUpdateRedis(fromDB, err) {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				if err := updateUserNameCache(id, username); err != nil {
 				if err := updateUserNameCache(id, username); err != nil {
-					common.SysError("failed to update user name cache: " + err.Error())
+					logger.SysError("failed to update user name cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}

+ 3 - 2
model/user_cache.go

@@ -5,6 +5,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"time"
 	"time"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
@@ -37,7 +38,7 @@ func (user *UserBase) GetSetting() dto.UserSetting {
 	if user.Setting != "" {
 	if user.Setting != "" {
 		err := common.Unmarshal([]byte(user.Setting), &setting)
 		err := common.Unmarshal([]byte(user.Setting), &setting)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to unmarshal setting: " + err.Error())
+			logger.SysError("failed to unmarshal setting: " + err.Error())
 		}
 		}
 	}
 	}
 	return setting
 	return setting
@@ -78,7 +79,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) {
 		if shouldUpdateRedis(fromDB, err) && user != nil {
 		if shouldUpdateRedis(fromDB, err) && user != nil {
 			gopool.Go(func() {
 			gopool.Go(func() {
 				if err := updateUserCache(*user); err != nil {
 				if err := updateUserCache(*user); err != nil {
-					common.SysError("failed to update user status cache: " + err.Error())
+					logger.SysError("failed to update user status cache: " + err.Error())
 				}
 				}
 			})
 			})
 		}
 		}

+ 5 - 4
model/utils.go

@@ -3,6 +3,7 @@ package model
 import (
 import (
 	"errors"
 	"errors"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
@@ -65,7 +66,7 @@ func batchUpdate() {
 		return
 		return
 	}
 	}
 
 
-	common.SysLog("batch update started")
+	logger.SysLog("batch update started")
 	for i := 0; i < BatchUpdateTypeCount; i++ {
 	for i := 0; i < BatchUpdateTypeCount; i++ {
 		batchUpdateLocks[i].Lock()
 		batchUpdateLocks[i].Lock()
 		store := batchUpdateStores[i]
 		store := batchUpdateStores[i]
@@ -77,12 +78,12 @@ func batchUpdate() {
 			case BatchUpdateTypeUserQuota:
 			case BatchUpdateTypeUserQuota:
 				err := increaseUserQuota(key, value)
 				err := increaseUserQuota(key, value)
 				if err != nil {
 				if err != nil {
-					common.SysError("failed to batch update user quota: " + err.Error())
+					logger.SysError("failed to batch update user quota: " + err.Error())
 				}
 				}
 			case BatchUpdateTypeTokenQuota:
 			case BatchUpdateTypeTokenQuota:
 				err := increaseTokenQuota(key, value)
 				err := increaseTokenQuota(key, value)
 				if err != nil {
 				if err != nil {
-					common.SysError("failed to batch update token quota: " + err.Error())
+					logger.SysError("failed to batch update token quota: " + err.Error())
 				}
 				}
 			case BatchUpdateTypeUsedQuota:
 			case BatchUpdateTypeUsedQuota:
 				updateUserUsedQuota(key, value)
 				updateUserUsedQuota(key, value)
@@ -93,7 +94,7 @@ func batchUpdate() {
 			}
 			}
 		}
 		}
 	}
 	}
-	common.SysLog("batch update finished")
+	logger.SysLog("batch update finished")
 }
 }
 
 
 func RecordExist(err error) (bool, error) {
 func RecordExist(err error) (bool, error) {

+ 13 - 80
relay/audio_handler.go

@@ -4,107 +4,40 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
-	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
-	relayconstant "one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
-	"one-api/setting"
 	"one-api/types"
 	"one-api/types"
-	"strings"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
-func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
-	audioRequest := &dto.AudioRequest{}
-	err := common.UnmarshalBodyReusable(c, audioRequest)
-	if err != nil {
-		return nil, err
-	}
-	switch info.RelayMode {
-	case relayconstant.RelayModeAudioSpeech:
-		if audioRequest.Model == "" {
-			return nil, errors.New("model is required")
-		}
-		if setting.ShouldCheckPromptSensitive() {
-			words, err := service.CheckSensitiveInput(audioRequest.Input)
-			if err != nil {
-				common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
-				return nil, err
-			}
-		}
-	default:
-		err = c.Request.ParseForm()
-		if err != nil {
-			return nil, err
-		}
-		formData := c.Request.PostForm
-		if audioRequest.Model == "" {
-			audioRequest.Model = formData.Get("model")
-		}
+func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+	info.InitChannelMeta(c)
 
 
-		if audioRequest.Model == "" {
-			return nil, errors.New("model is required")
-		}
-		audioRequest.ResponseFormat = formData.Get("response_format")
-		if audioRequest.ResponseFormat == "" {
-			audioRequest.ResponseFormat = "json"
-		}
+	audioRequest, ok := info.Request.(*dto.AudioRequest)
+	if !ok {
+		return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
 	}
 	}
-	return audioRequest, nil
-}
-
-func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
-	relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
-	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
-
-	if err != nil {
-		common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
-	}
-
-	promptTokens := 0
-	preConsumedTokens := common.PreConsumedQuota
-	if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
-		promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
-		preConsumedTokens = promptTokens
-		relayInfo.PromptTokens = promptTokens
-	}
-
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-
-	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-	if openaiErr != nil {
-		return openaiErr
-	}
-	defer func() {
-		if openaiErr != nil {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-		}
-	}()
 
 
-	err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
+	err := helper.ModelMappedHelper(c, info, audioRequest)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 
 
-	ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
+	ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
+	resp, err := adaptor.DoRequest(c, info, ioReader)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeDoRequestFailed)
 		return types.NewError(err, types.ErrorCodeDoRequestFailed)
 	}
 	}
@@ -121,14 +54,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 	}
 	}
 
 
-	usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+	usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
 	if newAPIError != nil {
 	if newAPIError != nil {
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		return newAPIError
 		return newAPIError
 	}
 	}
 
 
-	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	postConsumeQuota(c, info, usage.(*dto.Usage), "")
 
 
 	return nil
 	return nil
 }
 }

+ 6 - 6
relay/channel/ali/image.go

@@ -6,8 +6,8 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
-	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"one-api/service"
 	"one-api/types"
 	"one-api/types"
@@ -43,7 +43,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
 	client := &http.Client{}
 	client := &http.Client{}
 	resp, err := client.Do(req)
 	resp, err := client.Do(req)
 	if err != nil {
 	if err != nil {
-		common.SysError("updateTask client.Do err: " + err.Error())
+		logger.SysError("updateTask client.Do err: " + err.Error())
 		return &aliResponse, err, nil
 		return &aliResponse, err, nil
 	}
 	}
 	defer resp.Body.Close()
 	defer resp.Body.Close()
@@ -53,7 +53,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
 	var response AliResponse
 	var response AliResponse
 	err = json.Unmarshal(responseBody, &response)
 	err = json.Unmarshal(responseBody, &response)
 	if err != nil {
 	if err != nil {
-		common.SysError("updateTask NewDecoder err: " + err.Error())
+		logger.SysError("updateTask NewDecoder err: " + err.Error())
 		return &aliResponse, err, nil
 		return &aliResponse, err, nil
 	}
 	}
 
 
@@ -109,7 +109,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
 		if responseFormat == "b64_json" {
 		if responseFormat == "b64_json" {
 			_, b64, err := service.GetImageFromUrl(data.Url)
 			_, b64, err := service.GetImageFromUrl(data.Url)
 			if err != nil {
 			if err != nil {
-				common.LogError(c, "get_image_data_failed: "+err.Error())
+				logger.LogError(c, "get_image_data_failed: "+err.Error())
 				continue
 				continue
 			}
 			}
 			b64Json = b64
 			b64Json = b64
@@ -134,14 +134,14 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &aliTaskResponse)
 	err = json.Unmarshal(responseBody, &aliTaskResponse)
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
 		return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
 	}
 	}
 
 
 	if aliTaskResponse.Message != "" {
 	if aliTaskResponse.Message != "" {
-		common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
+		logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
 		return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
 		return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
 	}
 	}
 
 

+ 2 - 2
relay/channel/ali/rerank.go

@@ -4,9 +4,9 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
-	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"one-api/types"
 	"one-api/types"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
@@ -36,7 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 
 
 	var aliResponse AliRerankResponse
 	var aliResponse AliRerankResponse
 	err = json.Unmarshal(responseBody, &aliResponse)
 	err = json.Unmarshal(responseBody, &aliResponse)

+ 7 - 5
relay/channel/ali/text.go

@@ -7,7 +7,9 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
+	"one-api/service"
 	"strings"
 	"strings"
 
 
 	"one-api/types"
 	"one-api/types"
@@ -46,7 +48,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
 		return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
 		return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
 	}
 	}
 
 
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 
 
 	model := c.GetString("model")
 	model := c.GetString("model")
 	if model == "" {
 	if model == "" {
@@ -148,7 +150,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
 			var aliResponse AliResponse
 			var aliResponse AliResponse
 			err := json.Unmarshal([]byte(data), &aliResponse)
 			err := json.Unmarshal([]byte(data), &aliResponse)
 			if err != nil {
 			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
+				logger.SysError("error unmarshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
 			if aliResponse.Usage.OutputTokens != 0 {
 			if aliResponse.Usage.OutputTokens != 0 {
@@ -161,7 +163,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
 			lastResponseText = aliResponse.Output.Text
 			lastResponseText = aliResponse.Output.Text
 			jsonResponse, err := json.Marshal(response)
 			jsonResponse, err := json.Marshal(response)
 			if err != nil {
 			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
+				logger.SysError("error marshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -171,7 +173,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
 			return false
 			return false
 		}
 		}
 	})
 	})
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	return nil, &usage
 	return nil, &usage
 }
 }
 
 
@@ -181,7 +183,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &aliResponse)
 	err = json.Unmarshal(responseBody, &aliResponse)
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
 		return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil

+ 2 - 1
relay/channel/api_request.go

@@ -7,6 +7,7 @@ import (
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	common2 "one-api/common"
 	common2 "one-api/common"
+	"one-api/logger"
 	"one-api/relay/common"
 	"one-api/relay/common"
 	"one-api/relay/constant"
 	"one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
@@ -181,7 +182,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
 
 
 		err := helper.PingData(c)
 		err := helper.PingData(c)
 		if err != nil {
 		if err != nil {
-			common2.LogError(c, "SSE ping error: "+err.Error())
+			logger.LogError(c, "SSE ping error: "+err.Error())
 			done <- err
 			done <- err
 			return
 			return
 		}
 		}

+ 6 - 5
relay/channel/baidu/relay-baidu.go

@@ -9,6 +9,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -118,7 +119,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
 		var baiduResponse BaiduChatStreamResponse
 		var baiduResponse BaiduChatStreamResponse
 		err := common.Unmarshal([]byte(data), &baiduResponse)
 		err := common.Unmarshal([]byte(data), &baiduResponse)
 		if err != nil {
 		if err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
+			logger.SysError("error unmarshalling stream response: " + err.Error())
 			return true
 			return true
 		}
 		}
 		if baiduResponse.Usage.TotalTokens != 0 {
 		if baiduResponse.Usage.TotalTokens != 0 {
@@ -129,11 +130,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
 		response := streamResponseBaidu2OpenAI(&baiduResponse)
 		response := streamResponseBaidu2OpenAI(&baiduResponse)
 		err = helper.ObjectData(c, response)
 		err = helper.ObjectData(c, response)
 		if err != nil {
 		if err != nil {
-			common.SysError("error sending stream response: " + err.Error())
+			logger.SysError("error sending stream response: " + err.Error())
 		}
 		}
 		return true
 		return true
 	})
 	})
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	return nil, usage
 	return nil, usage
 }
 }
 
 
@@ -143,7 +144,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
@@ -168,7 +169,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil

+ 8 - 7
relay/channel/claude/relay-claude.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/relay/channel/openrouter"
 	"one-api/relay/channel/openrouter"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
@@ -375,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 					for _, toolCall := range message.ParseToolCalls() {
 					for _, toolCall := range message.ParseToolCalls() {
 						inputObj := make(map[string]any)
 						inputObj := make(map[string]any)
 						if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
 						if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
-							common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
+							logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
 							continue
 							continue
 						}
 						}
 						claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
 						claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
@@ -609,7 +610,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 	var claudeResponse dto.ClaudeResponse
 	var claudeResponse dto.ClaudeResponse
 	err := common.UnmarshalJsonStr(data, &claudeResponse)
 	err := common.UnmarshalJsonStr(data, &claudeResponse)
 	if err != nil {
 	if err != nil {
-		common.SysError("error unmarshalling stream response: " + err.Error())
+		logger.SysError("error unmarshalling stream response: " + err.Error())
 		return types.NewError(err, types.ErrorCodeBadResponseBody)
 		return types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	}
 	if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
 	if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
@@ -637,7 +638,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 
 
 		err = helper.ObjectData(c, response)
 		err = helper.ObjectData(c, response)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, "send_stream_response_failed: "+err.Error())
+			logger.LogError(c, "send_stream_response_failed: "+err.Error())
 		}
 		}
 	}
 	}
 	return nil
 	return nil
@@ -653,7 +654,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
 		}
 		}
 		if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
 		if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
 			if common.DebugEnabled {
 			if common.DebugEnabled {
-				common.SysError("claude response usage is not complete, maybe upstream error")
+				logger.SysError("claude response usage is not complete, maybe upstream error")
 			}
 			}
 			claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
 			claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
 		}
 		}
@@ -667,7 +668,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
 			response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
 			response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
 			err := helper.ObjectData(c, response)
 			err := helper.ObjectData(c, response)
 			if err != nil {
 			if err != nil {
-				common.SysError("send final response failed: " + err.Error())
+				logger.SysError("send final response failed: " + err.Error())
 			}
 			}
 		}
 		}
 		helper.Done(c)
 		helper.Done(c)
@@ -736,12 +737,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
 		c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
 	}
 	}
 
 
-	common.IOCopyBytesGracefully(c, nil, responseData)
+	service.IOCopyBytesGracefully(c, nil, responseData)
 	return nil
 	return nil
 }
 }
 
 
 func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
 func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	claudeInfo := &ClaudeResponseInfo{
 	claudeInfo := &ClaudeResponseInfo{
 		ResponseId:   helper.GetResponseID(c),
 		ResponseId:   helper.GetResponseID(c),

+ 8 - 8
relay/channel/cloudflare/relay_cloudflare.go

@@ -5,8 +5,8 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
-	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
 		var response dto.ChatCompletionsStreamResponse
 		var response dto.ChatCompletionsStreamResponse
 		err := json.Unmarshal([]byte(data), &response)
 		err := json.Unmarshal([]byte(data), &response)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
+			logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
 			continue
 			continue
 		}
 		}
 		for _, choice := range response.Choices {
 		for _, choice := range response.Choices {
@@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
 			info.FirstResponseTime = time.Now()
 			info.FirstResponseTime = time.Now()
 		}
 		}
 		if err != nil {
 		if err != nil {
-			common.LogError(c, "error_rendering_stream_response: "+err.Error())
+			logger.LogError(c, "error_rendering_stream_response: "+err.Error())
 		}
 		}
 	}
 	}
 
 
 	if err := scanner.Err(); err != nil {
 	if err := scanner.Err(); err != nil {
-		common.LogError(c, "error_scanning_stream_response: "+err.Error())
+		logger.LogError(c, "error_scanning_stream_response: "+err.Error())
 	}
 	}
 	usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	if info.ShouldIncludeUsage {
 	if info.ShouldIncludeUsage {
 		response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
 		response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
 		err := helper.ObjectData(c, response)
 		err := helper.ObjectData(c, response)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
+			logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
 		}
 		}
 	}
 	}
 	helper.Done(c)
 	helper.Done(c)
 
 
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 
 
 	return nil, usage
 	return nil, usage
 }
 }
@@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	var response dto.TextResponse
 	var response dto.TextResponse
 	err = json.Unmarshal(responseBody, &response)
 	err = json.Unmarshal(responseBody, &response)
 	if err != nil {
 	if err != nil {
@@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &cfResp)
 	err = json.Unmarshal(responseBody, &cfResp)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
 		return types.NewError(err, types.ErrorCodeBadResponseBody), nil

+ 5 - 4
relay/channel/cohere/relay-cohere.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -118,7 +119,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 			var cohereResp CohereResponse
 			var cohereResp CohereResponse
 			err := json.Unmarshal([]byte(data), &cohereResp)
 			err := json.Unmarshal([]byte(data), &cohereResp)
 			if err != nil {
 			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
+				logger.SysError("error unmarshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
 			var openaiResp dto.ChatCompletionsStreamResponse
 			var openaiResp dto.ChatCompletionsStreamResponse
@@ -153,7 +154,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 			}
 			}
 			jsonStr, err := json.Marshal(openaiResp)
 			jsonStr, err := json.Marshal(openaiResp)
 			if err != nil {
 			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
+				logger.SysError("error marshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
@@ -175,7 +176,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	if err != nil {
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	var cohereResp CohereResponseResult
 	var cohereResp CohereResponseResult
 	err = json.Unmarshal(responseBody, &cohereResp)
 	err = json.Unmarshal(responseBody, &cohereResp)
 	if err != nil {
 	if err != nil {
@@ -216,7 +217,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	if err != nil {
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	var cohereResp CohereRerankResponseResult
 	var cohereResp CohereRerankResponseResult
 	err = json.Unmarshal(responseBody, &cohereResp)
 	err = json.Unmarshal(responseBody, &cohereResp)
 	if err != nil {
 	if err != nil {

+ 7 - 6
relay/channel/coze/relay-coze.go

@@ -9,6 +9,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -49,7 +50,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
 	if err != nil {
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	// convert coze response to openai response
 	// convert coze response to openai response
 	var response dto.TextResponse
 	var response dto.TextResponse
 	var cozeResponse CozeChatDetailResponse
 	var cozeResponse CozeChatDetailResponse
@@ -154,7 +155,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
 		var chatData CozeChatResponseData
 		var chatData CozeChatResponseData
 		err := json.Unmarshal([]byte(data), &chatData)
 		err := json.Unmarshal([]byte(data), &chatData)
 		if err != nil {
 		if err != nil {
-			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			logger.SysError("error_unmarshalling_stream_response: " + err.Error())
 			return
 			return
 		}
 		}
 
 
@@ -171,14 +172,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
 		var messageData CozeChatV3MessageDetail
 		var messageData CozeChatV3MessageDetail
 		err := json.Unmarshal([]byte(data), &messageData)
 		err := json.Unmarshal([]byte(data), &messageData)
 		if err != nil {
 		if err != nil {
-			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			logger.SysError("error_unmarshalling_stream_response: " + err.Error())
 			return
 			return
 		}
 		}
 
 
 		var content string
 		var content string
 		err = json.Unmarshal(messageData.Content, &content)
 		err = json.Unmarshal(messageData.Content, &content)
 		if err != nil {
 		if err != nil {
-			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			logger.SysError("error_unmarshalling_stream_response: " + err.Error())
 			return
 			return
 		}
 		}
 
 
@@ -203,11 +204,11 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
 		var errorData CozeError
 		var errorData CozeError
 		err := json.Unmarshal([]byte(data), &errorData)
 		err := json.Unmarshal([]byte(data), &errorData)
 		if err != nil {
 		if err != nil {
-			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			logger.SysError("error_unmarshalling_stream_response: " + err.Error())
 			return
 			return
 		}
 		}
 
 
-		common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
+		logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
 	}
 	}
 }
 }
 
 

+ 13 - 12
relay/channel/dify/relay-dify.go

@@ -11,6 +11,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -36,14 +37,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
 		// Decode base64 string
 		// Decode base64 string
 		decodedData, err := base64.StdEncoding.DecodeString(base64Data)
 		decodedData, err := base64.StdEncoding.DecodeString(base64Data)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to decode base64: " + err.Error())
+			logger.SysError("failed to decode base64: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 
 
 		// Create temporary file
 		// Create temporary file
 		tempFile, err := os.CreateTemp("", "dify-upload-*")
 		tempFile, err := os.CreateTemp("", "dify-upload-*")
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to create temp file: " + err.Error())
+			logger.SysError("failed to create temp file: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 		defer tempFile.Close()
 		defer tempFile.Close()
@@ -51,7 +52,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
 
 
 		// Write decoded data to temp file
 		// Write decoded data to temp file
 		if _, err := tempFile.Write(decodedData); err != nil {
 		if _, err := tempFile.Write(decodedData); err != nil {
-			common.SysError("failed to write to temp file: " + err.Error())
+			logger.SysError("failed to write to temp file: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 
 
@@ -61,7 +62,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
 
 
 		// Add user field
 		// Add user field
 		if err := writer.WriteField("user", user); err != nil {
 		if err := writer.WriteField("user", user); err != nil {
-			common.SysError("failed to add user field: " + err.Error())
+			logger.SysError("failed to add user field: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 
 
@@ -74,13 +75,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
 		// Create form file
 		// Create form file
 		part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
 		part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to create form file: " + err.Error())
+			logger.SysError("failed to create form file: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 
 
 		// Copy file content to form
 		// Copy file content to form
 		if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
 		if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
-			common.SysError("failed to copy file content: " + err.Error())
+			logger.SysError("failed to copy file content: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 		writer.Close()
 		writer.Close()
@@ -88,7 +89,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
 		// Create HTTP request
 		// Create HTTP request
 		req, err := http.NewRequest("POST", uploadUrl, body)
 		req, err := http.NewRequest("POST", uploadUrl, body)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to create request: " + err.Error())
+			logger.SysError("failed to create request: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 
 
@@ -99,7 +100,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
 		client := service.GetHttpClient()
 		client := service.GetHttpClient()
 		resp, err := client.Do(req)
 		resp, err := client.Do(req)
 		if err != nil {
 		if err != nil {
-			common.SysError("failed to send request: " + err.Error())
+			logger.SysError("failed to send request: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 		defer resp.Body.Close()
 		defer resp.Body.Close()
@@ -109,7 +110,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
 			Id string `json:"id"`
 			Id string `json:"id"`
 		}
 		}
 		if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
 		if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
-			common.SysError("failed to decode response: " + err.Error())
+			logger.SysError("failed to decode response: " + err.Error())
 			return nil
 			return nil
 		}
 		}
 
 
@@ -219,7 +220,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 		var difyResponse DifyChunkChatCompletionResponse
 		var difyResponse DifyChunkChatCompletionResponse
 		err := json.Unmarshal([]byte(data), &difyResponse)
 		err := json.Unmarshal([]byte(data), &difyResponse)
 		if err != nil {
 		if err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
+			logger.SysError("error unmarshalling stream response: " + err.Error())
 			return true
 			return true
 		}
 		}
 		var openaiResponse dto.ChatCompletionsStreamResponse
 		var openaiResponse dto.ChatCompletionsStreamResponse
@@ -239,7 +240,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 		}
 		}
 		err = helper.ObjectData(c, openaiResponse)
 		err = helper.ObjectData(c, openaiResponse)
 		if err != nil {
 		if err != nil {
-			common.SysError(err.Error())
+			logger.SysError(err.Error())
 		}
 		}
 		return true
 		return true
 	})
 	})
@@ -258,7 +259,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
 	if err != nil {
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &difyResponse)
 	err = json.Unmarshal(responseBody, &difyResponse)
 	if err != nil {
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)

+ 1 - 1
relay/channel/gemini/adaptor.go

@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 			},
 			},
 		},
 		},
 		Parameters: dto.GeminiImageParameters{
 		Parameters: dto.GeminiImageParameters{
-			SampleCount:      request.N,
+			SampleCount:      int(request.N),
 			AspectRatio:      aspectRatio,
 			AspectRatio:      aspectRatio,
 			PersonGeneration: "allow_adult", // default allow adult
 			PersonGeneration: "allow_adult", // default allow adult
 		},
 		},

+ 7 - 6
relay/channel/gemini/relay-gemini-native.go

@@ -5,6 +5,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -17,7 +18,7 @@ import (
 )
 )
 
 
 func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	// 读取响应体
 	// 读取响应体
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
@@ -53,13 +54,13 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
 		}
 		}
 	}
 	}
 
 
-	common.IOCopyBytesGracefully(c, resp, responseBody)
+	service.IOCopyBytesGracefully(c, resp, responseBody)
 
 
 	return &usage, nil
 	return &usage, nil
 }
 }
 
 
 func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
 func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {
@@ -89,7 +90,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
 		}
 		}
 	}
 	}
 
 
-	common.IOCopyBytesGracefully(c, resp, responseBody)
+	service.IOCopyBytesGracefully(c, resp, responseBody)
 
 
 	return usage, nil
 	return usage, nil
 }
 }
@@ -106,7 +107,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
 		var geminiResponse dto.GeminiChatResponse
 		var geminiResponse dto.GeminiChatResponse
 		err := common.UnmarshalJsonStr(data, &geminiResponse)
 		err := common.UnmarshalJsonStr(data, &geminiResponse)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, "error unmarshalling stream response: "+err.Error())
+			logger.LogError(c, "error unmarshalling stream response: "+err.Error())
 			return false
 			return false
 		}
 		}
 
 
@@ -140,7 +141,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
 		// 直接发送 GeminiChatResponse 响应
 		// 直接发送 GeminiChatResponse 响应
 		err = helper.StringData(c, data)
 		err = helper.StringData(c, data)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, err.Error())
+			logger.LogError(c, err.Error())
 		}
 		}
 		info.SendResponseCount++
 		info.SendResponseCount++
 		return true
 		return true

+ 9 - 8
relay/channel/gemini/relay-gemini.go

@@ -9,6 +9,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
@@ -901,7 +902,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 		var geminiResponse dto.GeminiChatResponse
 		var geminiResponse dto.GeminiChatResponse
 		err := common.UnmarshalJsonStr(data, &geminiResponse)
 		err := common.UnmarshalJsonStr(data, &geminiResponse)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, "error unmarshalling stream response: "+err.Error())
+			logger.LogError(c, "error unmarshalling stream response: "+err.Error())
 			return false
 			return false
 		}
 		}
 
 
@@ -945,7 +946,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 				finishReason = constant.FinishReasonToolCalls
 				finishReason = constant.FinishReasonToolCalls
 				err = handleStream(c, info, emptyResponse)
 				err = handleStream(c, info, emptyResponse)
 				if err != nil {
 				if err != nil {
-					common.LogError(c, err.Error())
+					logger.LogError(c, err.Error())
 				}
 				}
 
 
 				response.ClearToolCalls()
 				response.ClearToolCalls()
@@ -957,7 +958,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 
 
 		err = handleStream(c, info, response)
 		err = handleStream(c, info, response)
 		if err != nil {
 		if err != nil {
-			common.LogError(c, err.Error())
+			logger.LogError(c, err.Error())
 		}
 		}
 		if isStop {
 		if isStop {
 			_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
 			_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
@@ -993,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 	response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
 	response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
 	err := handleFinalStream(c, info, response)
 	err := handleFinalStream(c, info, response)
 	if err != nil {
 	if err != nil {
-		common.SysError("send final response failed: " + err.Error())
+		logger.SysError("send final response failed: " + err.Error())
 	}
 	}
 	//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
 	//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
 	//	helper.Done(c)
 	//	helper.Done(c)
@@ -1007,7 +1008,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	if common.DebugEnabled {
 	if common.DebugEnabled {
 		println(string(responseBody))
 		println(string(responseBody))
 	}
 	}
@@ -1057,13 +1058,13 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 		break
 		break
 	}
 	}
 
 
-	common.IOCopyBytesGracefully(c, resp, responseBody)
+	service.IOCopyBytesGracefully(c, resp, responseBody)
 
 
 	return &usage, nil
 	return &usage, nil
 }
 }
 
 
 func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	responseBody, readErr := io.ReadAll(resp.Body)
 	responseBody, readErr := io.ReadAll(resp.Body)
 	if readErr != nil {
 	if readErr != nil {
@@ -1107,7 +1108,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 		return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	}
 
 
-	common.IOCopyBytesGracefully(c, resp, jsonResponse)
+	service.IOCopyBytesGracefully(c, resp, jsonResponse)
 	return usage, nil
 	return usage, nil
 }
 }
 
 

+ 2 - 2
relay/channel/jimeng/image.go

@@ -5,9 +5,9 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
-	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"one-api/types"
 	"one-api/types"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
@@ -54,7 +54,7 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 
 
 	err = json.Unmarshal(responseBody, &jimengResponse)
 	err = json.Unmarshal(responseBody, &jimengResponse)
 	if err != nil {
 	if err != nil {

+ 2 - 2
relay/channel/jimeng/sign.go

@@ -12,7 +12,7 @@ import (
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
-	"one-api/common"
+	"one-api/logger"
 	"sort"
 	"sort"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
+	logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
 	payloadHash := sha256.Sum256(body)
 	payloadHash := sha256.Sum256(body)
 	hexPayloadHash := hex.EncodeToString(payloadHash[:])
 	hexPayloadHash := hex.EncodeToString(payloadHash[:])
 	c.Set(HexPayloadHashKey, hexPayloadHash)
 	c.Set(HexPayloadHashKey, hexPayloadHash)

+ 3 - 2
relay/channel/mokaai/relay-mokaai.go

@@ -7,6 +7,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"one-api/types"
 	"one-api/types"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
@@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
 	if err != nil {
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	err = json.Unmarshal(responseBody, &baiduResponse)
 	if err != nil {
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
@@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
 	}
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
 	c.Writer.WriteHeader(resp.StatusCode)
-	common.IOCopyBytesGracefully(c, resp, jsonResponse)
+	service.IOCopyBytesGracefully(c, resp, jsonResponse)
 	return &fullTextResponse.Usage, nil
 	return &fullTextResponse.Usage, nil
 }
 }

+ 2 - 2
relay/channel/ollama/relay-ollama.go

@@ -94,7 +94,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
 	err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -123,7 +123,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	}
-	common.IOCopyBytesGracefully(c, resp, doResponseBody)
+	service.IOCopyBytesGracefully(c, resp, doResponseBody)
 	return usage, nil
 	return usage, nil
 }
 }
 
 

+ 9 - 8
relay/channel/openai/helper.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
@@ -50,7 +51,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
 func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
 func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
 	var streamResponse dto.ChatCompletionsStreamResponse
 	var streamResponse dto.ChatCompletionsStreamResponse
 	if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
 	if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
-		common.LogError(c, "failed to unmarshal stream response: "+err.Error())
+		logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
 		return err
 		return err
 	}
 	}
 
 
@@ -63,7 +64,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
 
 
 	geminiResponseStr, err := common.Marshal(geminiResponse)
 	geminiResponseStr, err := common.Marshal(geminiResponse)
 	if err != nil {
 	if err != nil {
-		common.LogError(c, "failed to marshal gemini response: "+err.Error())
+		logger.LogError(c, "failed to marshal gemini response: "+err.Error())
 		return err
 		return err
 	}
 	}
 
 
@@ -110,14 +111,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
 	var streamResponses []dto.ChatCompletionsStreamResponse
 	var streamResponses []dto.ChatCompletionsStreamResponse
 	if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
 	if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
 		// 一次性解析失败,逐个解析
 		// 一次性解析失败,逐个解析
-		common.SysError("error unmarshalling stream response: " + err.Error())
+		logger.SysError("error unmarshalling stream response: " + err.Error())
 		for _, item := range streamItems {
 		for _, item := range streamItems {
 			var streamResponse dto.ChatCompletionsStreamResponse
 			var streamResponse dto.ChatCompletionsStreamResponse
 			if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
 			if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
 				return err
 				return err
 			}
 			}
 			if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
 			if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
-				common.SysError("error processing stream response: " + err.Error())
+				logger.SysError("error processing stream response: " + err.Error())
 			}
 			}
 		}
 		}
 		return nil
 		return nil
@@ -146,7 +147,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui
 	var streamResponses []dto.CompletionsStreamResponse
 	var streamResponses []dto.CompletionsStreamResponse
 	if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
 	if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
 		// 一次性解析失败,逐个解析
 		// 一次性解析失败,逐个解析
-		common.SysError("error unmarshalling stream response: " + err.Error())
+		logger.SysError("error unmarshalling stream response: " + err.Error())
 		for _, item := range streamItems {
 		for _, item := range streamItems {
 			var streamResponse dto.CompletionsStreamResponse
 			var streamResponse dto.CompletionsStreamResponse
 			if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
 			if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
@@ -213,7 +214,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
 		info.ClaudeConvertInfo.Done = true
 		info.ClaudeConvertInfo.Done = true
 		var streamResponse dto.ChatCompletionsStreamResponse
 		var streamResponse dto.ChatCompletionsStreamResponse
 		if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
 		if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
+			logger.SysError("error unmarshalling stream response: " + err.Error())
 			return
 			return
 		}
 		}
 
 
@@ -227,7 +228,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
 	case relaycommon.RelayFormatGemini:
 	case relaycommon.RelayFormatGemini:
 		var streamResponse dto.ChatCompletionsStreamResponse
 		var streamResponse dto.ChatCompletionsStreamResponse
 		if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
 		if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
+			logger.SysError("error unmarshalling stream response: " + err.Error())
 			return
 			return
 		}
 		}
 
 
@@ -245,7 +246,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
 
 
 		geminiResponseStr, err := common.Marshal(geminiResponse)
 		geminiResponseStr, err := common.Marshal(geminiResponse)
 		if err != nil {
 		if err != nil {
-			common.SysError("error marshalling gemini response: " + err.Error())
+			logger.SysError("error marshalling gemini response: " + err.Error())
 			return
 			return
 		}
 		}
 
 

+ 21 - 20
relay/channel/openai/relay-openai.go

@@ -10,6 +10,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -108,11 +109,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
 
 
 func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 	if resp == nil || resp.Body == nil {
 	if resp == nil || resp.Body == nil {
-		common.LogError(c, "invalid response or response body")
+		logger.LogError(c, "invalid response or response body")
 		return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
 	}
 	}
 
 
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	model := info.UpstreamModelName
 	model := info.UpstreamModelName
 	var responseId string
 	var responseId string
@@ -129,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 		if lastStreamData != "" {
 		if lastStreamData != "" {
 			err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
 			err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
 			if err != nil {
 			if err != nil {
-				common.SysError("error handling stream format: " + err.Error())
+				logger.SysError("error handling stream format: " + err.Error())
 			}
 			}
 		}
 		}
 		if len(data) > 0 {
 		if len(data) > 0 {
@@ -143,7 +144,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 	shouldSendLastResp := true
 	shouldSendLastResp := true
 	if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
 	if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
 		&containStreamUsage, info, &shouldSendLastResp); err != nil {
 		&containStreamUsage, info, &shouldSendLastResp); err != nil {
-		common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
+		logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
 	}
 	}
 
 
 	if info.RelayFormat == relaycommon.RelayFormatOpenAI {
 	if info.RelayFormat == relaycommon.RelayFormatOpenAI {
@@ -154,7 +155,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 
 
 	// 处理token计算
 	// 处理token计算
 	if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
 	if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
-		common.LogError(c, "error processing tokens: "+err.Error())
+		logger.LogError(c, "error processing tokens: "+err.Error())
 	}
 	}
 
 
 	if !containStreamUsage {
 	if !containStreamUsage {
@@ -173,7 +174,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 }
 }
 
 
 func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	var simpleResponse dto.OpenAITextResponse
 	var simpleResponse dto.OpenAITextResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
@@ -235,7 +236,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 		responseBody = geminiRespStr
 		responseBody = geminiRespStr
 	}
 	}
 
 
-	common.IOCopyBytesGracefully(c, resp, responseBody)
+	service.IOCopyBytesGracefully(c, resp, responseBody)
 
 
 	return &simpleResponse.Usage, nil
 	return &simpleResponse.Usage, nil
 }
 }
@@ -247,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	// if the upstream returns a specific status code, once the upstream has already written the header,
 	// if the upstream returns a specific status code, once the upstream has already written the header,
 	// the subsequent failure of the response body should be regarded as a non-recoverable error,
 	// the subsequent failure of the response body should be regarded as a non-recoverable error,
 	// and can be terminated directly.
 	// and can be terminated directly.
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 	usage := &dto.Usage{}
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
 	usage.PromptTokens = info.PromptTokens
 	usage.TotalTokens = info.PromptTokens
 	usage.TotalTokens = info.PromptTokens
@@ -258,13 +259,13 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	c.Writer.WriteHeaderNow()
 	c.Writer.WriteHeaderNow()
 	_, err := io.Copy(c.Writer, resp.Body)
 	_, err := io.Copy(c.Writer, resp.Body)
 	if err != nil {
 	if err != nil {
-		common.LogError(c, err.Error())
+		logger.LogError(c, err.Error())
 	}
 	}
 	return usage
 	return usage
 }
 }
 
 
 func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
 func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	// count tokens by audio file duration
 	// count tokens by audio file duration
 	audioTokens, err := countAudioTokens(c)
 	audioTokens, err := countAudioTokens(c)
@@ -276,7 +277,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
 		return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
 	}
 	}
 	// 写入新的 response body
 	// 写入新的 response body
-	common.IOCopyBytesGracefully(c, resp, responseBody)
+	service.IOCopyBytesGracefully(c, resp, responseBody)
 
 
 	usage := &dto.Usage{}
 	usage := &dto.Usage{}
 	usage.PromptTokens = audioTokens
 	usage.PromptTokens = audioTokens
@@ -386,7 +387,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
 					errChan <- fmt.Errorf("error counting text token: %v", err)
 					errChan <- fmt.Errorf("error counting text token: %v", err)
 					return
 					return
 				}
 				}
-				common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+				logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
 				localUsage.TotalTokens += textToken + audioToken
 				localUsage.TotalTokens += textToken + audioToken
 				localUsage.InputTokens += textToken + audioToken
 				localUsage.InputTokens += textToken + audioToken
 				localUsage.InputTokenDetails.TextTokens += textToken
 				localUsage.InputTokenDetails.TextTokens += textToken
@@ -459,7 +460,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
 							errChan <- fmt.Errorf("error counting text token: %v", err)
 							errChan <- fmt.Errorf("error counting text token: %v", err)
 							return
 							return
 						}
 						}
-						common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+						logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
 						localUsage.TotalTokens += textToken + audioToken
 						localUsage.TotalTokens += textToken + audioToken
 						info.IsFirstRequest = false
 						info.IsFirstRequest = false
 						localUsage.InputTokens += textToken + audioToken
 						localUsage.InputTokens += textToken + audioToken
@@ -474,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
 						localUsage = &dto.RealtimeUsage{}
 						localUsage = &dto.RealtimeUsage{}
 						// print now usage
 						// print now usage
 					}
 					}
-					common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
-					common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
-					common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+					logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
+					logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+					logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
 
 
 				} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
 				} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
 					realtimeSession := realtimeEvent.Session
 					realtimeSession := realtimeEvent.Session
@@ -491,7 +492,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
 						errChan <- fmt.Errorf("error counting text token: %v", err)
 						errChan <- fmt.Errorf("error counting text token: %v", err)
 						return
 						return
 					}
 					}
-					common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+					logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
 					localUsage.TotalTokens += textToken + audioToken
 					localUsage.TotalTokens += textToken + audioToken
 					localUsage.OutputTokens += textToken + audioToken
 					localUsage.OutputTokens += textToken + audioToken
 					localUsage.OutputTokenDetails.TextTokens += textToken
 					localUsage.OutputTokenDetails.TextTokens += textToken
@@ -517,7 +518,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
 	case <-targetClosed:
 	case <-targetClosed:
 	case err := <-errChan:
 	case err := <-errChan:
 		//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
 		//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
-		common.LogError(c, "realtime error: "+err.Error())
+		logger.LogError(c, "realtime error: "+err.Error())
 	case <-c.Done():
 	case <-c.Done():
 	}
 	}
 
 
@@ -553,7 +554,7 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
 }
 }
 
 
 func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {
@@ -567,7 +568,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 	}
 	}
 
 
 	// 写入新的 response body
 	// 写入新的 response body
-	common.IOCopyBytesGracefully(c, resp, responseBody)
+	service.IOCopyBytesGracefully(c, resp, responseBody)
 
 
 	// Once we've written to the client, we should not return errors anymore
 	// Once we've written to the client, we should not return errors anymore
 	// because the upstream has already consumed resources and returned content
 	// because the upstream has already consumed resources and returned content

+ 4 - 3
relay/channel/openai/relay_responses.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -16,7 +17,7 @@ import (
 )
 )
 
 
 func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	// read response body
 	// read response body
 	var responsesResponse dto.OpenAIResponsesResponse
 	var responsesResponse dto.OpenAIResponsesResponse
@@ -33,7 +34,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 	}
 	}
 
 
 	// 写入新的 response body
 	// 写入新的 response body
-	common.IOCopyBytesGracefully(c, resp, responseBody)
+	service.IOCopyBytesGracefully(c, resp, responseBody)
 
 
 	// compute usage
 	// compute usage
 	usage := dto.Usage{}
 	usage := dto.Usage{}
@@ -54,7 +55,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 
 
 func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 	if resp == nil || resp.Body == nil {
 	if resp == nil || resp.Body == nil {
-		common.LogError(c, "invalid response or response body")
+		logger.LogError(c, "invalid response or response body")
 		return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
 		return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
 	}
 	}
 
 

+ 8 - 7
relay/channel/palm/relay-palm.go

@@ -7,6 +7,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -58,15 +59,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
 	go func() {
 	go func() {
 		responseBody, err := io.ReadAll(resp.Body)
 		responseBody, err := io.ReadAll(resp.Body)
 		if err != nil {
 		if err != nil {
-			common.SysError("error reading stream response: " + err.Error())
+			logger.SysError("error reading stream response: " + err.Error())
 			stopChan <- true
 			stopChan <- true
 			return
 			return
 		}
 		}
-		common.CloseResponseBodyGracefully(resp)
+		service.CloseResponseBodyGracefully(resp)
 		var palmResponse PaLMChatResponse
 		var palmResponse PaLMChatResponse
 		err = json.Unmarshal(responseBody, &palmResponse)
 		err = json.Unmarshal(responseBody, &palmResponse)
 		if err != nil {
 		if err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
+			logger.SysError("error unmarshalling stream response: " + err.Error())
 			stopChan <- true
 			stopChan <- true
 			return
 			return
 		}
 		}
@@ -78,7 +79,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
 		}
 		}
 		jsonResponse, err := json.Marshal(fullTextResponse)
 		jsonResponse, err := json.Marshal(fullTextResponse)
 		if err != nil {
 		if err != nil {
-			common.SysError("error marshalling stream response: " + err.Error())
+			logger.SysError("error marshalling stream response: " + err.Error())
 			stopChan <- true
 			stopChan <- true
 			return
 			return
 		}
 		}
@@ -96,7 +97,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
 			return false
 			return false
 		}
 		}
 	})
 	})
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	return nil, responseText
 	return nil, responseText
 }
 }
 
 
@@ -105,7 +106,7 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	var palmResponse PaLMChatResponse
 	var palmResponse PaLMChatResponse
 	err = json.Unmarshal(responseBody, &palmResponse)
 	err = json.Unmarshal(responseBody, &palmResponse)
 	if err != nil {
 	if err != nil {
@@ -133,6 +134,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
 	}
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
 	c.Writer.WriteHeader(resp.StatusCode)
-	common.IOCopyBytesGracefully(c, resp, jsonResponse)
+	service.IOCopyBytesGracefully(c, resp, jsonResponse)
 	return &usage, nil
 	return &usage, nil
 }
 }

+ 3 - 3
relay/channel/siliconflow/relay-siliconflow.go

@@ -4,9 +4,9 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
-	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"one-api/types"
 	"one-api/types"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
@@ -17,7 +17,7 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	var siliconflowResp SFRerankResponse
 	var siliconflowResp SFRerankResponse
 	err = json.Unmarshal(responseBody, &siliconflowResp)
 	err = json.Unmarshal(responseBody, &siliconflowResp)
 	if err != nil {
 	if err != nil {
@@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
 	}
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
 	c.Writer.WriteHeader(resp.StatusCode)
-	common.IOCopyBytesGracefully(c, resp, jsonResponse)
+	service.IOCopyBytesGracefully(c, resp, jsonResponse)
 	return usage, nil
 	return usage, nil
 }
 }

+ 2 - 1
relay/channel/task/suno/adaptor.go

@@ -11,6 +11,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/relay/channel"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"one-api/service"
@@ -139,7 +140,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
 
 
 	req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
 	req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
 	if err != nil {
 	if err != nil {
-		common.SysError(fmt.Sprintf("Get Task error: %v", err))
+		logger.SysError(fmt.Sprintf("Get Task error: %v", err))
 		return nil, err
 		return nil, err
 	}
 	}
 	defer req.Body.Close()
 	defer req.Body.Close()

+ 7 - 6
relay/channel/tencent/relay-tencent.go

@@ -13,6 +13,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
@@ -106,7 +107,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
 		var tencentResponse TencentChatResponse
 		var tencentResponse TencentChatResponse
 		err := json.Unmarshal([]byte(data), &tencentResponse)
 		err := json.Unmarshal([]byte(data), &tencentResponse)
 		if err != nil {
 		if err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
+			logger.SysError("error unmarshalling stream response: " + err.Error())
 			continue
 			continue
 		}
 		}
 
 
@@ -117,17 +118,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
 
 
 		err = helper.ObjectData(c, response)
 		err = helper.ObjectData(c, response)
 		if err != nil {
 		if err != nil {
-			common.SysError(err.Error())
+			logger.SysError(err.Error())
 		}
 		}
 	}
 	}
 
 
 	if err := scanner.Err(); err != nil {
 	if err := scanner.Err(); err != nil {
-		common.SysError("error reading stream: " + err.Error())
+		logger.SysError("error reading stream: " + err.Error())
 	}
 	}
 
 
 	helper.Done(c)
 	helper.Done(c)
 
 
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 
 
 	return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
 	return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
 }
 }
@@ -138,7 +139,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &tencentSb)
 	err = json.Unmarshal(responseBody, &tencentSb)
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -156,7 +157,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
 	}
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
 	c.Writer.WriteHeader(resp.StatusCode)
-	common.IOCopyBytesGracefully(c, resp, jsonResponse)
+	service.IOCopyBytesGracefully(c, resp, jsonResponse)
 	return &fullTextResponse.Usage, nil
 	return &fullTextResponse.Usage, nil
 }
 }
 
 

+ 6 - 5
relay/channel/xai/text.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
@@ -47,7 +48,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 		var xAIResp *dto.ChatCompletionsStreamResponse
 		var xAIResp *dto.ChatCompletionsStreamResponse
 		err := json.Unmarshal([]byte(data), &xAIResp)
 		err := json.Unmarshal([]byte(data), &xAIResp)
 		if err != nil {
 		if err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
+			logger.SysError("error unmarshalling stream response: " + err.Error())
 			return true
 			return true
 		}
 		}
 
 
@@ -63,7 +64,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 		_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
 		_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
 		err = helper.ObjectData(c, openaiResponse)
 		err = helper.ObjectData(c, openaiResponse)
 		if err != nil {
 		if err != nil {
-			common.SysError(err.Error())
+			logger.SysError(err.Error())
 		}
 		}
 		return true
 		return true
 	})
 	})
@@ -74,12 +75,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 	}
 	}
 
 
 	helper.Done(c)
 	helper.Done(c)
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	return usage, nil
 	return usage, nil
 }
 }
 
 
 func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	defer common.CloseResponseBodyGracefully(resp)
+	defer service.CloseResponseBodyGracefully(resp)
 
 
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {
@@ -101,7 +102,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	}
 
 
-	common.IOCopyBytesGracefully(c, resp, encodeJson)
+	service.IOCopyBytesGracefully(c, resp, encodeJson)
 
 
 	return xaiResponse.Usage, nil
 	return xaiResponse.Usage, nil
 }
 }

+ 6 - 5
relay/channel/xunfei/relay-xunfei.go

@@ -11,6 +11,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/types"
 	"one-api/types"
 	"strings"
 	"strings"
@@ -143,7 +144,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
 			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
 			jsonResponse, err := json.Marshal(response)
 			jsonResponse, err := json.Marshal(response)
 			if err != nil {
 			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
+				logger.SysError("error marshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -218,20 +219,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
 		for {
 		for {
 			_, msg, err := conn.ReadMessage()
 			_, msg, err := conn.ReadMessage()
 			if err != nil {
 			if err != nil {
-				common.SysError("error reading stream response: " + err.Error())
+				logger.SysError("error reading stream response: " + err.Error())
 				break
 				break
 			}
 			}
 			var response XunfeiChatResponse
 			var response XunfeiChatResponse
 			err = json.Unmarshal(msg, &response)
 			err = json.Unmarshal(msg, &response)
 			if err != nil {
 			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
+				logger.SysError("error unmarshalling stream response: " + err.Error())
 				break
 				break
 			}
 			}
 			dataChan <- response
 			dataChan <- response
 			if response.Payload.Choices.Status == 2 {
 			if response.Payload.Choices.Status == 2 {
 				err := conn.Close()
 				err := conn.Close()
 				if err != nil {
 				if err != nil {
-					common.SysError("error closing websocket connection: " + err.Error())
+					logger.SysError("error closing websocket connection: " + err.Error())
 				}
 				}
 				break
 				break
 			}
 			}
@@ -282,6 +283,6 @@ func getAPIVersion(c *gin.Context, modelName string) string {
 		return apiVersion
 		return apiVersion
 	}
 	}
 	apiVersion = "v1.1"
 	apiVersion = "v1.1"
-	common.SysLog("api_version not found, using default: " + apiVersion)
+	logger.SysLog("api_version not found, using default: " + apiVersion)
 	return apiVersion
 	return apiVersion
 }
 }

+ 8 - 6
relay/channel/zhipu/relay-zhipu.go

@@ -8,8 +8,10 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
+	"one-api/service"
 	"one-api/types"
 	"one-api/types"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -38,7 +40,7 @@ func getZhipuToken(apikey string) string {
 
 
 	split := strings.Split(apikey, ".")
 	split := strings.Split(apikey, ".")
 	if len(split) != 2 {
 	if len(split) != 2 {
-		common.SysError("invalid zhipu key: " + apikey)
+		logger.SysError("invalid zhipu key: " + apikey)
 		return ""
 		return ""
 	}
 	}
 
 
@@ -186,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
 			response := streamResponseZhipu2OpenAI(data)
 			response := streamResponseZhipu2OpenAI(data)
 			jsonResponse, err := json.Marshal(response)
 			jsonResponse, err := json.Marshal(response)
 			if err != nil {
 			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
+				logger.SysError("error marshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -195,13 +197,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
 			var zhipuResponse ZhipuStreamMetaResponse
 			var zhipuResponse ZhipuStreamMetaResponse
 			err := json.Unmarshal([]byte(data), &zhipuResponse)
 			err := json.Unmarshal([]byte(data), &zhipuResponse)
 			if err != nil {
 			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
+				logger.SysError("error unmarshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
 			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
 			jsonResponse, err := json.Marshal(response)
 			jsonResponse, err := json.Marshal(response)
 			if err != nil {
 			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
+				logger.SysError("error marshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
 			usage = zhipuUsage
 			usage = zhipuUsage
@@ -212,7 +214,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
 			return false
 			return false
 		}
 		}
 	})
 	})
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	return usage, nil
 	return usage, nil
 }
 }
 
 
@@ -222,7 +224,7 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	err = json.Unmarshal(responseBody, &zhipuResponse)
 	err = json.Unmarshal(responseBody, &zhipuResponse)
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)

+ 5 - 4
relay/relay-mj.go → relay/chat_handler.go

@@ -10,6 +10,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
@@ -214,7 +215,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
 		if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
 		if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
 			err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
 			err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
 			if err != nil {
 			if err != nil {
-				common.SysError("error consuming token remain quota: " + err.Error())
+				logger.SysError("error consuming token remain quota: " + err.Error())
 			}
 			}
 
 
 			tokenName := c.GetString("token_name")
 			tokenName := c.GetString("token_name")
@@ -300,7 +301,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
 	if err != nil {
 	if err != nil {
 		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
 		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
 	}
 	}
-	common.IOCopyBytesGracefully(c, nil, respBody)
+	service.IOCopyBytesGracefully(c, nil, respBody)
 	return nil
 	return nil
 }
 }
 
 
@@ -521,7 +522,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
 		if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
 			err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
 			err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
 			if err != nil {
 			if err != nil {
-				common.SysError("error consuming token remain quota: " + err.Error())
+				logger.SysError("error consuming token remain quota: " + err.Error())
 			}
 			}
 			tokenName := c.GetString("token_name")
 			tokenName := c.GetString("token_name")
 			logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
 			logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
@@ -572,7 +573,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
 		//无实例账号自动禁用渠道(No available account instance)
 		//无实例账号自动禁用渠道(No available account instance)
 		channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
 		channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
 		if err != nil {
 		if err != nil {
-			common.SysError("get_channel_null: " + err.Error())
+			logger.SysError("get_channel_null: " + err.Error())
 		}
 		}
 		if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
 		if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
 			model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
 			model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")

+ 19 - 72
relay/claude_handler.go

@@ -2,7 +2,6 @@ package relay
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
@@ -18,68 +17,26 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
-func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
-	textRequest = &dto.ClaudeRequest{}
-	err = c.ShouldBindJSON(textRequest)
-	if err != nil {
-		return nil, err
-	}
-	if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
-		return nil, errors.New("field messages is required")
-	}
-	if textRequest.Model == "" {
-		return nil, errors.New("field model is required")
-	}
-	return textRequest, nil
-}
-
-func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
 
 
-	relayInfo := relaycommon.GenRelayInfoClaude(c)
+	info.InitChannelMeta(c)
 
 
-	// get & validate textRequest 获取并验证文本请求
-	textRequest, err := getAndValidateClaudeRequest(c)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
-	}
+	textRequest, ok := info.Request.(*dto.ClaudeRequest)
 
 
-	if textRequest.Stream {
-		relayInfo.IsStream = true
+	if !ok {
+		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
 	}
 	}
 
 
-	err = helper.ModelMappedHelper(c, relayInfo, textRequest)
+	err := helper.ModelMappedHelper(c, info, textRequest)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
-	// count messages token error 计算promptTokens错误
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
-	}
-
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-
-	// pre-consume quota 预消耗配额
-	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-
-	if newAPIError != nil {
-		return newAPIError
-	}
-	defer func() {
-		if newAPIError != nil {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-		}
-	}()
-
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 
 
 	if textRequest.MaxTokens == 0 {
 	if textRequest.MaxTokens == 0 {
 		textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
 		textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
@@ -104,18 +61,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 			textRequest.Temperature = common.GetPointer[float64](1.0)
 			textRequest.Temperature = common.GetPointer[float64](1.0)
 		}
 		}
 		textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
 		textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
-		relayInfo.UpstreamModelName = textRequest.Model
+		info.UpstreamModelName = textRequest.Model
 	}
 	}
 
 
 	var requestBody io.Reader
 	var requestBody io.Reader
-	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		body, err := common.GetRequestBody(c)
 		if err != nil {
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		}
 		requestBody = bytes.NewBuffer(body)
 		requestBody = bytes.NewBuffer(body)
 	} else {
 	} else {
-		convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
+		convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, textRequest)
 		if err != nil {
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		}
@@ -125,10 +82,10 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 
 
 		// apply param override
 		// apply param override
-		if len(relayInfo.ParamOverride) > 0 {
+		if len(info.ParamOverride) > 0 {
 			reqMap := make(map[string]interface{})
 			reqMap := make(map[string]interface{})
 			_ = common.Unmarshal(jsonData, &reqMap)
 			_ = common.Unmarshal(jsonData, &reqMap)
-			for key, value := range relayInfo.ParamOverride {
+			for key, value := range info.ParamOverride {
 				reqMap[key] = value
 				reqMap[key] = value
 			}
 			}
 			jsonData, err = common.Marshal(reqMap)
 			jsonData, err = common.Marshal(reqMap)
@@ -145,14 +102,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	var httpResp *http.Response
 	var httpResp *http.Response
-	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
 	}
 
 
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
-		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
 			newAPIError = service.RelayErrorHandler(httpResp, false)
 			newAPIError = service.RelayErrorHandler(httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
@@ -161,24 +118,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 	}
 	}
 
 
-	usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+	usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
 	//log.Printf("usage: %v", usage)
 	//log.Printf("usage: %v", usage)
 	if newAPIError != nil {
 	if newAPIError != nil {
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		return newAPIError
 		return newAPIError
 	}
 	}
-	service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
-	return nil
-}
 
 
-func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
-	var promptTokens int
-	var err error
-	switch info.RelayMode {
-	default:
-		promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
-	}
-	info.PromptTokens = promptTokens
-	return promptTokens, err
+	service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
+	return nil
 }
 }

+ 177 - 144
relay/common/relay_info.go

@@ -1,10 +1,12 @@
 package common
 package common
 
 
 import (
 import (
+	"errors"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
+	"one-api/types"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
@@ -33,17 +35,6 @@ type ClaudeConvertInfo struct {
 	Done             bool
 	Done             bool
 }
 }
 
 
-const (
-	RelayFormatOpenAI          = "openai"
-	RelayFormatClaude          = "claude"
-	RelayFormatGemini          = "gemini"
-	RelayFormatOpenAIResponses = "openai_responses"
-	RelayFormatOpenAIAudio     = "openai_audio"
-	RelayFormatOpenAIImage     = "openai_image"
-	RelayFormatRerank          = "rerank"
-	RelayFormatEmbedding       = "embedding"
-)
-
 type RerankerInfo struct {
 type RerankerInfo struct {
 	Documents       []any
 	Documents       []any
 	ReturnDocuments bool
 	ReturnDocuments bool
@@ -59,61 +50,103 @@ type ResponsesUsageInfo struct {
 	BuiltInTools map[string]*BuildInToolInfo
 	BuiltInTools map[string]*BuildInToolInfo
 }
 }
 
 
-type RelayInfo struct {
+type ChannelMeta struct {
 	ChannelType          int
 	ChannelType          int
 	ChannelId            int
 	ChannelId            int
-	ChannelIsMultiKey    bool // 是否多密钥
-	ChannelMultiKeyIndex int  // 多密钥索引
-	TokenId              int
-	TokenKey             string
-	UserId               int
-	UsingGroup           string // 使用的分组
-	UserGroup            string // 用户所在分组
-	TokenUnlimited       bool
-	StartTime            time.Time
-	FirstResponseTime    time.Time
-	isFirstResponse      bool
+	ChannelIsMultiKey    bool
+	ChannelMultiKeyIndex int
+	ChannelBaseUrl       string
+	ApiType              int
+	ApiVersion           string
+	ApiKey               string
+	Organization         string
+	ChannelCreateTime    int64
+	ParamOverride        map[string]interface{}
+	ChannelSetting       dto.ChannelSettings
+	ChannelOtherSettings dto.ChannelOtherSettings
+	UpstreamModelName    string
+	IsModelMapped        bool
+}
+
+type RelayInfo struct {
+	TokenId           int
+	TokenKey          string
+	UserId            int
+	UsingGroup        string // 使用的分组
+	UserGroup         string // 用户所在分组
+	TokenUnlimited    bool
+	StartTime         time.Time
+	FirstResponseTime time.Time
+	isFirstResponse   bool
 	//SendLastReasoningResponse bool
 	//SendLastReasoningResponse bool
-	ApiType                int
 	IsStream               bool
 	IsStream               bool
 	IsGeminiBatchEmbedding bool
 	IsGeminiBatchEmbedding bool
 	IsPlayground           bool
 	IsPlayground           bool
 	UsePrice               bool
 	UsePrice               bool
 	RelayMode              int
 	RelayMode              int
-	UpstreamModelName      string
 	OriginModelName        string
 	OriginModelName        string
 	//RecodeModelName      string
 	//RecodeModelName      string
-	RequestURLPath       string
-	ApiVersion           string
-	PromptTokens         int
-	ApiKey               string
-	Organization         string
-	BaseUrl              string
-	SupportStreamOptions bool
-	ShouldIncludeUsage   bool
-	DisablePing          bool // 是否禁止向下游发送自定义 Ping
-	IsModelMapped        bool
-	ClientWs             *websocket.Conn
-	TargetWs             *websocket.Conn
-	InputAudioFormat     string
-	OutputAudioFormat    string
-	RealtimeTools        []dto.RealTimeTool
-	IsFirstRequest       bool
-	AudioUsage           bool
-	ReasoningEffort      string
-	ChannelSetting       dto.ChannelSettings
-	ChannelOtherSettings dto.ChannelOtherSettings
-	ParamOverride        map[string]interface{}
-	UserSetting          dto.UserSetting
-	UserEmail            string
-	UserQuota            int
-	RelayFormat          string
-	SendResponseCount    int
-	ChannelCreateTime    int64
+	RequestURLPath        string
+	PromptTokens          int
+	SupportStreamOptions  bool
+	ShouldIncludeUsage    bool
+	DisablePing           bool // 是否禁止向下游发送自定义 Ping
+	ClientWs              *websocket.Conn
+	TargetWs              *websocket.Conn
+	InputAudioFormat      string
+	OutputAudioFormat     string
+	RealtimeTools         []dto.RealTimeTool
+	IsFirstRequest        bool
+	AudioUsage            bool
+	ReasoningEffort       string
+	UserSetting           dto.UserSetting
+	UserEmail             string
+	UserQuota             int
+	RelayFormat           types.RelayFormat
+	SendResponseCount     int
+	FinalPreConsumedQuota int // 最终预消耗的配额
+
+	PriceData types.PriceData
+
+	Request dto.Request
+
 	ThinkingContentInfo
 	ThinkingContentInfo
 	*ClaudeConvertInfo
 	*ClaudeConvertInfo
 	*RerankerInfo
 	*RerankerInfo
 	*ResponsesUsageInfo
 	*ResponsesUsageInfo
+	*ChannelMeta
+}
+
+func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
+	channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
+	paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
+	apiType, _ := common.ChannelType2APIType(channelType)
+	channelMeta := &ChannelMeta{
+		ChannelType:          channelType,
+		ChannelId:            common.GetContextKeyInt(c, constant.ContextKeyChannelId),
+		ChannelIsMultiKey:    common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
+		ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
+		ChannelBaseUrl:       common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
+		ApiType:              apiType,
+		ApiVersion:           c.GetString("api_version"),
+		ApiKey:               common.GetContextKeyString(c, constant.ContextKeyChannelKey),
+		Organization:         c.GetString("channel_organization"),
+		ChannelCreateTime:    c.GetInt64("channel_create_time"),
+		ParamOverride:        paramOverride,
+		UpstreamModelName:    common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+		IsModelMapped:        false,
+	}
+
+	channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
+	if ok {
+		channelMeta.ChannelSetting = channelSetting
+	}
+
+	channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
+	if ok {
+		channelMeta.ChannelOtherSettings = channelOtherSettings
+	}
+	info.ChannelMeta = channelMeta
 }
 }
 
 
 // 定义支持流式选项的通道类型
 // 定义支持流式选项的通道类型
@@ -132,7 +165,8 @@ var streamSupportedChannels = map[int]bool{
 }
 }
 
 
 func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
 func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
-	info := GenRelayInfo(c)
+	info := genBaseRelayInfo(c, nil)
+	info.RelayFormat = types.RelayFormatOpenAIRealtime
 	info.ClientWs = ws
 	info.ClientWs = ws
 	info.InputAudioFormat = "pcm16"
 	info.InputAudioFormat = "pcm16"
 	info.OutputAudioFormat = "pcm16"
 	info.OutputAudioFormat = "pcm16"
@@ -140,9 +174,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
 	return info
 	return info
 }
 }
 
 
-func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
-	info := GenRelayInfo(c)
-	info.RelayFormat = RelayFormatClaude
+func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
+	info := genBaseRelayInfo(c, request)
+	info.RelayFormat = types.RelayFormatClaude
 	info.ShouldIncludeUsage = false
 	info.ShouldIncludeUsage = false
 	info.ClaudeConvertInfo = &ClaudeConvertInfo{
 	info.ClaudeConvertInfo = &ClaudeConvertInfo{
 		LastMessagesType: LastMessageTypeNone,
 		LastMessagesType: LastMessageTypeNone,
@@ -150,41 +184,41 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
 	return info
 	return info
 }
 }
 
 
-func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
-	info := GenRelayInfo(c)
+func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
+	info := genBaseRelayInfo(c, request)
 	info.RelayMode = relayconstant.RelayModeRerank
 	info.RelayMode = relayconstant.RelayModeRerank
-	info.RelayFormat = RelayFormatRerank
+	info.RelayFormat = types.RelayFormatRerank
 	info.RerankerInfo = &RerankerInfo{
 	info.RerankerInfo = &RerankerInfo{
-		Documents:       req.Documents,
-		ReturnDocuments: req.GetReturnDocuments(),
+		Documents:       request.Documents,
+		ReturnDocuments: request.GetReturnDocuments(),
 	}
 	}
 	return info
 	return info
 }
 }
 
 
-func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
-	info := GenRelayInfo(c)
-	info.RelayFormat = RelayFormatOpenAIAudio
+func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
+	info := genBaseRelayInfo(c, request)
+	info.RelayFormat = types.RelayFormatOpenAIAudio
 	return info
 	return info
 }
 }
 
 
-func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
-	info := GenRelayInfo(c)
-	info.RelayFormat = RelayFormatEmbedding
+func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
+	info := genBaseRelayInfo(c, request)
+	info.RelayFormat = types.RelayFormatEmbedding
 	return info
 	return info
 }
 }
 
 
-func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
-	info := GenRelayInfo(c)
+func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
+	info := genBaseRelayInfo(c, request)
 	info.RelayMode = relayconstant.RelayModeResponses
 	info.RelayMode = relayconstant.RelayModeResponses
-	info.RelayFormat = RelayFormatOpenAIResponses
+	info.RelayFormat = types.RelayFormatOpenAIResponses
 
 
 	info.SupportStreamOptions = false
 	info.SupportStreamOptions = false
 
 
 	info.ResponsesUsageInfo = &ResponsesUsageInfo{
 	info.ResponsesUsageInfo = &ResponsesUsageInfo{
 		BuiltInTools: make(map[string]*BuildInToolInfo),
 		BuiltInTools: make(map[string]*BuildInToolInfo),
 	}
 	}
-	if len(req.Tools) > 0 {
-		for _, tool := range req.Tools {
+	if len(request.Tools) > 0 {
+		for _, tool := range request.Tools {
 			toolType := common.Interface2String(tool["type"])
 			toolType := common.Interface2String(tool["type"])
 			info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
 			info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
 				ToolName:  toolType,
 				ToolName:  toolType,
@@ -200,104 +234,76 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
 			}
 			}
 		}
 		}
 	}
 	}
-	info.IsStream = req.Stream
 	return info
 	return info
 }
 }
 
 
-func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
-	info := GenRelayInfo(c)
-	info.RelayFormat = RelayFormatGemini
+func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
+	info := genBaseRelayInfo(c, request)
+	info.RelayFormat = types.RelayFormatGemini
 	info.ShouldIncludeUsage = false
 	info.ShouldIncludeUsage = false
+
 	return info
 	return info
 }
 }
 
 
-func GenRelayInfoImage(c *gin.Context) *RelayInfo {
-	info := GenRelayInfo(c)
-	info.RelayFormat = RelayFormatOpenAIImage
+func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
+	info := genBaseRelayInfo(c, request)
+	info.RelayFormat = types.RelayFormatOpenAIImage
 	return info
 	return info
 }
 }
 
 
-func GenRelayInfo(c *gin.Context) *RelayInfo {
-	channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
-	channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
-	paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
+func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
+	info := genBaseRelayInfo(c, request)
+	info.RelayFormat = types.RelayFormatOpenAI
+	return info
+}
+
+func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
+
+	//channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
+	//channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
+	//paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
 
 
-	tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
-	tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
-	userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
-	tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
 	startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
 	startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
 	if startTime.IsZero() {
 	if startTime.IsZero() {
 		startTime = time.Now()
 		startTime = time.Now()
 	}
 	}
-	// firstResponseTime = time.Now() - 1 second
 
 
-	apiType, _ := common.ChannelType2APIType(channelType)
+	// firstResponseTime = time.Now() - 1 second
 
 
 	info := &RelayInfo{
 	info := &RelayInfo{
-		UserQuota:         common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
-		UserEmail:         common.GetContextKeyString(c, constant.ContextKeyUserEmail),
-		isFirstResponse:   true,
-		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
-		BaseUrl:           common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
-		RequestURLPath:    c.Request.URL.String(),
-		ChannelType:       channelType,
-		ChannelId:         channelId,
-		TokenId:           tokenId,
-		TokenKey:          tokenKey,
-		UserId:            userId,
-		UsingGroup:        common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
-		UserGroup:         common.GetContextKeyString(c, constant.ContextKeyUserGroup),
-		TokenUnlimited:    tokenUnlimited,
+		Request: request,
+
+		UserId:     common.GetContextKeyInt(c, constant.ContextKeyUserId),
+		UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
+		UserGroup:  common.GetContextKeyString(c, constant.ContextKeyUserGroup),
+		UserQuota:  common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
+		UserEmail:  common.GetContextKeyString(c, constant.ContextKeyUserEmail),
+
+		OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+		PromptTokens:    common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
+
+		TokenId:        common.GetContextKeyInt(c, constant.ContextKeyTokenId),
+		TokenKey:       common.GetContextKeyString(c, constant.ContextKeyTokenKey),
+		TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
+
+		isFirstResponse: true,
+		RelayMode:       relayconstant.Path2RelayMode(c.Request.URL.Path),
+		RequestURLPath:  c.Request.URL.String(),
+		IsStream:        request.IsStream(c),
+
 		StartTime:         startTime,
 		StartTime:         startTime,
 		FirstResponseTime: startTime.Add(-time.Second),
 		FirstResponseTime: startTime.Add(-time.Second),
-		OriginModelName:   common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
-		UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
-		//RecodeModelName:   c.GetString("original_model"),
-		IsModelMapped: false,
-		ApiType:       apiType,
-		ApiVersion:    c.GetString("api_version"),
-		ApiKey:        common.GetContextKeyString(c, constant.ContextKeyChannelKey),
-		Organization:  c.GetString("channel_organization"),
-
-		ChannelCreateTime: c.GetInt64("channel_create_time"),
-		ParamOverride:     paramOverride,
-		RelayFormat:       RelayFormatOpenAI,
 		ThinkingContentInfo: ThinkingContentInfo{
 		ThinkingContentInfo: ThinkingContentInfo{
 			IsFirstThinkingContent:  true,
 			IsFirstThinkingContent:  true,
 			SendLastThinkingContent: false,
 			SendLastThinkingContent: false,
 		},
 		},
-
-		ChannelIsMultiKey:    common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
-		ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
 	}
 	}
+
 	if strings.HasPrefix(c.Request.URL.Path, "/pg") {
 	if strings.HasPrefix(c.Request.URL.Path, "/pg") {
 		info.IsPlayground = true
 		info.IsPlayground = true
 		info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
 		info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
 		info.RequestURLPath = "/v1" + info.RequestURLPath
 		info.RequestURLPath = "/v1" + info.RequestURLPath
 	}
 	}
-	if info.BaseUrl == "" {
-		info.BaseUrl = constant.ChannelBaseURLs[channelType]
-	}
-	if info.ChannelType == constant.ChannelTypeAzure {
-		info.ApiVersion = GetAPIVersion(c)
-	}
-	if info.ChannelType == constant.ChannelTypeVertexAi {
-		info.ApiVersion = c.GetString("region")
-	}
-	if streamSupportedChannels[info.ChannelType] {
-		info.SupportStreamOptions = true
-	}
-
-	channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
-	if ok {
-		info.ChannelSetting = channelSetting
-	}
-
-	channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
-	if ok {
-		info.ChannelOtherSettings = channelOtherSettings
-	}
 
 
 	userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
 	userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
 	if ok {
 	if ok {
@@ -307,12 +313,39 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	return info
 	return info
 }
 }
 
 
-func (info *RelayInfo) SetPromptTokens(promptTokens int) {
-	info.PromptTokens = promptTokens
+func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
+	switch relayFormat {
+	case types.RelayFormatOpenAI:
+		return GenRelayInfoOpenAI(c, request), nil
+	case types.RelayFormatOpenAIAudio:
+		return GenRelayInfoOpenAIAudio(c, request), nil
+	case types.RelayFormatOpenAIImage:
+		return GenRelayInfoImage(c, request), nil
+	case types.RelayFormatOpenAIRealtime:
+		return GenRelayInfoWs(c, ws), nil
+	case types.RelayFormatClaude:
+		return GenRelayInfoClaude(c, request), nil
+	case types.RelayFormatRerank:
+		if request, ok := request.(*dto.RerankRequest); ok {
+			return GenRelayInfoRerank(c, request), nil
+		}
+		return nil, errors.New("request is not a RerankRequest")
+	case types.RelayFormatGemini:
+		return GenRelayInfoGemini(c, request), nil
+	case types.RelayFormatEmbedding:
+		return GenRelayInfoEmbedding(c, request), nil
+	case types.RelayFormatOpenAIResponses:
+		if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
+			return GenRelayInfoResponses(c, request), nil
+		}
+		return nil, errors.New("request is not a OpenAIResponsesRequest")
+	default:
+		return nil, errors.New("invalid relay format")
+	}
 }
 }
 
 
-func (info *RelayInfo) SetIsStream(isStream bool) {
-	info.IsStream = isStream
+func (info *RelayInfo) SetPromptTokens(promptTokens int) {
+	info.PromptTokens = promptTokens
 }
 }
 
 
 func (info *RelayInfo) SetFirstResponseTime() {
 func (info *RelayInfo) SetFirstResponseTime() {

+ 2 - 1
relay/common_handler/rerank.go

@@ -8,6 +8,7 @@ import (
 	"one-api/dto"
 	"one-api/dto"
 	"one-api/relay/channel/xinference"
 	"one-api/relay/channel/xinference"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/service"
 	"one-api/types"
 	"one-api/types"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
@@ -18,7 +19,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	if err != nil {
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	service.CloseResponseBodyGracefully(resp)
 	if common.DebugEnabled {
 	if common.DebugEnabled {
 		println("reranker response body: ", string(responseBody))
 		println("reranker response body: ", string(responseBody))
 	}
 	}

+ 13 - 56
relay/embedding_handler.go

@@ -8,7 +8,6 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
-	relayconstant "one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
 	"one-api/types"
 	"one-api/types"
@@ -16,69 +15,27 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
-func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
-	token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
-	return token
-}
-
-func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
-	if embeddingRequest.Input == nil {
-		return fmt.Errorf("input is empty")
-	}
-	if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
-		embeddingRequest.Model = "omni-moderation-latest"
-	}
-	if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
-		embeddingRequest.Model = c.Param("model")
-	}
-	return nil
-}
+func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
 
 
-func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
-	relayInfo := relaycommon.GenRelayInfoEmbedding(c)
+	info.InitChannelMeta(c)
 
 
-	var embeddingRequest *dto.EmbeddingRequest
-	err := common.UnmarshalBodyReusable(c, &embeddingRequest)
-	if err != nil {
-		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	embeddingRequest, ok := info.Request.(*dto.EmbeddingRequest)
+	if !ok {
+		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ClaudeRequest, got %T", info.Request))
 	}
 	}
 
 
-	err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
-	}
-
-	err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
+	err := helper.ModelMappedHelper(c, info, embeddingRequest)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	promptToken := getEmbeddingPromptToken(*embeddingRequest)
-	relayInfo.PromptTokens = promptToken
-
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-	// pre-consume quota 预消耗配额
-	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-	if newAPIError != nil {
-		return newAPIError
-	}
-	defer func() {
-		if newAPIError != nil {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-		}
-	}()
-
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 
 
-	convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
+	convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *embeddingRequest)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
 	}
@@ -88,7 +45,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	}
 	}
 	requestBody := bytes.NewBuffer(jsonData)
 	requestBody := bytes.NewBuffer(jsonData)
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	statusCodeMappingStr := c.GetString("status_code_mapping")
-	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
 	}
@@ -104,12 +61,12 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 	}
 	}
 
 
-	usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+	usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
 	if newAPIError != nil {
 	if newAPIError != nil {
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		return newAPIError
 		return newAPIError
 	}
 	}
-	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	postConsumeQuota(c, info, usage.(*dto.Usage), "")
 	return nil
 	return nil
 }
 }

+ 44 - 159
relay/gemini_handler.go

@@ -2,17 +2,16 @@ package relay
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/relay/channel/gemini"
 	"one-api/relay/channel/gemini"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
-	"one-api/setting"
 	"one-api/setting/model_setting"
 	"one-api/setting/model_setting"
 	"one-api/types"
 	"one-api/types"
 	"strings"
 	"strings"
@@ -20,64 +19,6 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
-func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
-	request := &dto.GeminiChatRequest{}
-	err := common.UnmarshalBodyReusable(c, request)
-	if err != nil {
-		return nil, err
-	}
-	if len(request.Contents) == 0 {
-		return nil, errors.New("contents is required")
-	}
-	return request, nil
-}
-
-// 流模式
-// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
-func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
-	if c.Query("alt") == "sse" {
-		relayInfo.IsStream = true
-	}
-
-	// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
-	// 	relayInfo.IsStream = true
-	// }
-}
-
-func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
-	var inputTexts []string
-	for _, content := range textRequest.Contents {
-		for _, part := range content.Parts {
-			if part.Text != "" {
-				inputTexts = append(inputTexts, part.Text)
-			}
-		}
-	}
-	if len(inputTexts) == 0 {
-		return nil, nil
-	}
-
-	sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
-	return sensitiveWords, err
-}
-
-func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
-	// 计算输入 token 数量
-	var inputTexts []string
-	for _, content := range req.Contents {
-		for _, part := range content.Parts {
-			if part.Text != "" {
-				inputTexts = append(inputTexts, part.Text)
-			}
-		}
-	}
-
-	inputText := strings.Join(inputTexts, "\n")
-	inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
-	info.PromptTokens = inputTokens
-	return inputTokens
-}
-
 func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
 func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
 	if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
 	if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
 		configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
 		configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
@@ -109,97 +50,61 @@ func trimModelThinking(modelName string) string {
 	return modelName
 	return modelName
 }
 }
 
 
-func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
-	req, err := getAndValidateGeminiRequest(c)
-	if err != nil {
-		common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
-	}
+func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+	info.InitChannelMeta(c)
 
 
-	relayInfo := relaycommon.GenRelayInfoGemini(c)
-
-	// 检查 Gemini 流式模式
-	checkGeminiStreamMode(c, relayInfo)
-
-	if setting.ShouldCheckPromptSensitive() {
-		sensitiveWords, err := checkGeminiInputSensitive(req)
-		if err != nil {
-			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
-			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
-		}
+	request, ok := info.Request.(*dto.GeminiChatRequest)
+	if !ok {
+		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.GeminiChatRequest, got %T", info.Request))
 	}
 	}
 
 
 	// model mapped 模型映射
 	// model mapped 模型映射
-	err = helper.ModelMappedHelper(c, relayInfo, req)
+	err := helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	if value, exists := c.Get("prompt_tokens"); exists {
-		promptTokens := value.(int)
-		relayInfo.SetPromptTokens(promptTokens)
-	} else {
-		promptTokens := getGeminiInputTokens(req, relayInfo)
-		c.Set("prompt_tokens", promptTokens)
-	}
-
 	if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
 	if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
-		if isNoThinkingRequest(req) {
+		if isNoThinkingRequest(request) {
 			// check is thinking
 			// check is thinking
-			if !strings.Contains(relayInfo.OriginModelName, "-nothinking") {
+			if !strings.Contains(info.OriginModelName, "-nothinking") {
 				// try to get no thinking model price
 				// try to get no thinking model price
-				noThinkingModelName := relayInfo.OriginModelName + "-nothinking"
+				noThinkingModelName := info.OriginModelName + "-nothinking"
 				containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
 				containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
 				if containPrice {
 				if containPrice {
-					relayInfo.OriginModelName = noThinkingModelName
-					relayInfo.UpstreamModelName = noThinkingModelName
+					info.OriginModelName = noThinkingModelName
+					info.UpstreamModelName = noThinkingModelName
 				}
 				}
 			}
 			}
 		}
 		}
-		if req.GenerationConfig.ThinkingConfig == nil {
-			gemini.ThinkingAdaptor(req, relayInfo)
+		if request.GenerationConfig.ThinkingConfig == nil {
+			gemini.ThinkingAdaptor(request, info)
 		}
 		}
 	}
 	}
 
 
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-
-	// pre consume quota
-	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-	if newAPIError != nil {
-		return newAPIError
-	}
-	defer func() {
-		if newAPIError != nil {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-		}
-	}()
-
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 
 
 	// Clean up empty system instruction
 	// Clean up empty system instruction
-	if req.SystemInstructions != nil {
+	if request.SystemInstructions != nil {
 		hasContent := false
 		hasContent := false
-		for _, part := range req.SystemInstructions.Parts {
+		for _, part := range request.SystemInstructions.Parts {
 			if part.Text != "" {
 			if part.Text != "" {
 				hasContent = true
 				hasContent = true
 				break
 				break
 			}
 			}
 		}
 		}
 		if !hasContent {
 		if !hasContent {
-			req.SystemInstructions = nil
+			request.SystemInstructions = nil
 		}
 		}
 	}
 	}
 
 
 	var requestBody io.Reader
 	var requestBody io.Reader
-	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		body, err := common.GetRequestBody(c)
 		if err != nil {
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
@@ -207,7 +112,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		requestBody = bytes.NewReader(body)
 		requestBody = bytes.NewReader(body)
 	} else {
 	} else {
 		// 使用 ConvertGeminiRequest 转换请求格式
 		// 使用 ConvertGeminiRequest 转换请求格式
-		convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
+		convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)
 		if err != nil {
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		}
@@ -217,10 +122,10 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 
 
 		// apply param override
 		// apply param override
-		if len(relayInfo.ParamOverride) > 0 {
+		if len(info.ParamOverride) > 0 {
 			reqMap := make(map[string]interface{})
 			reqMap := make(map[string]interface{})
 			_ = common.Unmarshal(jsonData, &reqMap)
 			_ = common.Unmarshal(jsonData, &reqMap)
-			for key, value := range relayInfo.ParamOverride {
+			for key, value := range info.ParamOverride {
 				reqMap[key] = value
 				reqMap[key] = value
 			}
 			}
 			jsonData, err = common.Marshal(reqMap)
 			jsonData, err = common.Marshal(reqMap)
@@ -229,15 +134,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 			}
 			}
 		}
 		}
 
 
-		if common.DebugEnabled {
-			println("Gemini request body: %s", string(jsonData))
-		}
+		logger.LogDebug(c, "Gemini request body: "+string(jsonData))
+
 		requestBody = bytes.NewReader(jsonData)
 		requestBody = bytes.NewReader(jsonData)
 	}
 	}
 
 
-	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 	if err != nil {
-		common.LogError(c, "Do gemini request failed: "+err.Error())
+		logger.LogError(c, "Do gemini request failed: "+err.Error())
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
 	}
 
 
@@ -246,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	var httpResp *http.Response
 	var httpResp *http.Response
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
-		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
 			newAPIError = service.RelayErrorHandler(httpResp, false)
 			newAPIError = service.RelayErrorHandler(httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
@@ -255,23 +159,22 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 	}
 	}
 
 
-	usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
+	usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
 	if openaiErr != nil {
 	if openaiErr != nil {
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 		return openaiErr
 	}
 	}
 
 
-	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	postConsumeQuota(c, info, usage.(*dto.Usage), "")
 	return nil
 	return nil
 }
 }
 
 
-func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
-	relayInfo := relaycommon.GenRelayInfoGemini(c)
+func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+	info.InitChannelMeta(c)
 
 
 	isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
 	isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
-	relayInfo.IsGeminiBatchEmbedding = isBatch
+	info.IsGeminiBatchEmbedding = isBatch
 
 
-	var promptTokens int
 	var req any
 	var req any
 	var err error
 	var err error
 	var inputTexts []string
 	var inputTexts []string
@@ -303,35 +206,17 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
 			}
 			}
 		}
 		}
 	}
 	}
-	promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
-	relayInfo.SetPromptTokens(promptTokens)
-	c.Set("prompt_tokens", promptTokens)
 
 
-	err = helper.ModelMappedHelper(c, relayInfo, req)
+	err = helper.ModelMappedHelper(c, info, req)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-
-	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-	if newAPIError != nil {
-		return newAPIError
-	}
-	defer func() {
-		if newAPIError != nil {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-		}
-	}()
-
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 
 
 	var requestBody io.Reader
 	var requestBody io.Reader
 	jsonData, err := common.Marshal(req)
 	jsonData, err := common.Marshal(req)
@@ -340,10 +225,10 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
 	}
 	}
 
 
 	// apply param override
 	// apply param override
-	if len(relayInfo.ParamOverride) > 0 {
+	if len(info.ParamOverride) > 0 {
 		reqMap := make(map[string]interface{})
 		reqMap := make(map[string]interface{})
 		_ = common.Unmarshal(jsonData, &reqMap)
 		_ = common.Unmarshal(jsonData, &reqMap)
-		for key, value := range relayInfo.ParamOverride {
+		for key, value := range info.ParamOverride {
 			reqMap[key] = value
 			reqMap[key] = value
 		}
 		}
 		jsonData, err = common.Marshal(reqMap)
 		jsonData, err = common.Marshal(reqMap)
@@ -353,9 +238,9 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
 	}
 	}
 	requestBody = bytes.NewReader(jsonData)
 	requestBody = bytes.NewReader(jsonData)
 
 
-	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 	if err != nil {
-		common.LogError(c, "Do gemini request failed: "+err.Error())
+		logger.LogError(c, "Do gemini request failed: "+err.Error())
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
 	}
 
 
@@ -370,12 +255,12 @@ func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 	}
 	}
 
 
-	usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
+	usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
 	if openaiErr != nil {
 	if openaiErr != nil {
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 		return openaiErr
 	}
 	}
 
 
-	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	postConsumeQuota(c, info, usage.(*dto.Usage), "")
 	return nil
 	return nil
 }
 }

+ 3 - 2
relay/helper/common.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/types"
 	"one-api/types"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
@@ -100,7 +101,7 @@ func Done(c *gin.Context) {
 
 
 func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
 func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
 	if ws == nil {
 	if ws == nil {
-		common.LogError(c, "websocket connection is nil")
+		logger.LogError(c, "websocket connection is nil")
 		return errors.New("websocket connection is nil")
 		return errors.New("websocket connection is nil")
 	}
 	}
 	//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
 	//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
@@ -113,7 +114,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
 		return fmt.Errorf("error marshalling object: %w", err)
 		return fmt.Errorf("error marshalling object: %w", err)
 	}
 	}
 	if ws == nil {
 	if ws == nil {
-		common.LogError(c, "websocket connection is nil")
+		logger.LogError(c, "websocket connection is nil")
 		return errors.New("websocket connection is nil")
 		return errors.New("websocket connection is nil")
 	}
 	}
 	//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
 	//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))

+ 9 - 8
relay/helper/model_mapped.go

@@ -4,9 +4,10 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	common2 "one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	common2 "one-api/logger"
 	"one-api/relay/common"
 	"one-api/relay/common"
+	"one-api/types"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
@@ -54,29 +55,29 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
 	}
 	}
 	if request != nil {
 	if request != nil {
 		switch info.RelayFormat {
 		switch info.RelayFormat {
-		case common.RelayFormatGemini:
+		case types.RelayFormatGemini:
 			// Gemini 模型映射
 			// Gemini 模型映射
-		case common.RelayFormatClaude:
+		case types.RelayFormatClaude:
 			if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
 			if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
 				claudeRequest.Model = info.UpstreamModelName
 				claudeRequest.Model = info.UpstreamModelName
 			}
 			}
-		case common.RelayFormatOpenAIResponses:
+		case types.RelayFormatOpenAIResponses:
 			if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
 			if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
 				openAIResponsesRequest.Model = info.UpstreamModelName
 				openAIResponsesRequest.Model = info.UpstreamModelName
 			}
 			}
-		case common.RelayFormatOpenAIAudio:
+		case types.RelayFormatOpenAIAudio:
 			if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
 			if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
 				openAIAudioRequest.Model = info.UpstreamModelName
 				openAIAudioRequest.Model = info.UpstreamModelName
 			}
 			}
-		case common.RelayFormatOpenAIImage:
+		case types.RelayFormatOpenAIImage:
 			if imageRequest, ok := request.(*dto.ImageRequest); ok {
 			if imageRequest, ok := request.(*dto.ImageRequest); ok {
 				imageRequest.Model = info.UpstreamModelName
 				imageRequest.Model = info.UpstreamModelName
 			}
 			}
-		case common.RelayFormatRerank:
+		case types.RelayFormatRerank:
 			if rerankRequest, ok := request.(*dto.RerankRequest); ok {
 			if rerankRequest, ok := request.(*dto.RerankRequest); ok {
 				rerankRequest.Model = info.UpstreamModelName
 				rerankRequest.Model = info.UpstreamModelName
 			}
 			}
-		case common.RelayFormatEmbedding:
+		case types.RelayFormatEmbedding:
 			if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
 			if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
 				embeddingRequest.Model = info.UpstreamModelName
 				embeddingRequest.Model = info.UpstreamModelName
 			}
 			}

+ 33 - 57
relay/helper/price.go

@@ -5,35 +5,14 @@ import (
 	"one-api/common"
 	"one-api/common"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/setting/ratio_setting"
 	"one-api/setting/ratio_setting"
+	"one-api/types"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
-type GroupRatioInfo struct {
-	GroupRatio        float64
-	GroupSpecialRatio float64
-	HasSpecialRatio   bool
-}
-
-type PriceData struct {
-	ModelPrice             float64
-	ModelRatio             float64
-	CompletionRatio        float64
-	CacheRatio             float64
-	CacheCreationRatio     float64
-	ImageRatio             float64
-	UsePrice               bool
-	ShouldPreConsumedQuota int
-	GroupRatioInfo         GroupRatioInfo
-}
-
-func (p PriceData) ToSetting() string {
-	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
-}
-
 // HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
 // HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
-func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
-	groupRatioInfo := GroupRatioInfo{
+func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo {
+	groupRatioInfo := types.GroupRatioInfo{
 		GroupRatio:        1.0, // default ratio
 		GroupRatio:        1.0, // default ratio
 		GroupSpecialRatio: -1,
 		GroupSpecialRatio: -1,
 	}
 	}
@@ -62,7 +41,7 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR
 	return groupRatioInfo
 	return groupRatioInfo
 }
 }
 
 
-func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
+func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) {
 	modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
 	modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
 
 
 	groupRatioInfo := HandleGroupRatio(c, info)
 	groupRatioInfo := HandleGroupRatio(c, info)
@@ -75,8 +54,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 	var cacheCreationRatio float64
 	var cacheCreationRatio float64
 	if !usePrice {
 	if !usePrice {
 		preConsumedTokens := common.PreConsumedQuota
 		preConsumedTokens := common.PreConsumedQuota
-		if maxTokens != 0 {
-			preConsumedTokens = promptTokens + maxTokens
+		if meta.MaxTokens != 0 {
+			preConsumedTokens = promptTokens + meta.MaxTokens
 		}
 		}
 		var success bool
 		var success bool
 		var matchName string
 		var matchName string
@@ -87,7 +66,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 				acceptUnsetRatio = true
 				acceptUnsetRatio = true
 			}
 			}
 			if !acceptUnsetRatio {
 			if !acceptUnsetRatio {
-				return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
+				return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
 			}
 			}
 		}
 		}
 		completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
 		completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
@@ -97,10 +76,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 		ratio := modelRatio * groupRatioInfo.GroupRatio
 		ratio := modelRatio * groupRatioInfo.GroupRatio
 		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
 		preConsumedQuota = int(float64(preConsumedTokens) * ratio)
 	} else {
 	} else {
+		if meta.ImagePriceRatio != 0 {
+			modelPrice = modelPrice * meta.ImagePriceRatio
+		}
 		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
 		preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
 	}
 	}
 
 
-	priceData := PriceData{
+	priceData := types.PriceData{
 		ModelPrice:             modelPrice,
 		ModelPrice:             modelPrice,
 		ModelRatio:             modelRatio,
 		ModelRatio:             modelRatio,
 		CompletionRatio:        completionRatio,
 		CompletionRatio:        completionRatio,
@@ -115,38 +97,32 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 	if common.DebugEnabled {
 	if common.DebugEnabled {
 		println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
 		println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
 	}
 	}
-
+	info.PriceData = priceData
 	return priceData, nil
 	return priceData, nil
 }
 }
 
 
-type PerCallPriceData struct {
-	ModelPrice     float64
-	Quota          int
-	GroupRatioInfo GroupRatioInfo
-}
-
 // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
 // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
-func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData {
-	groupRatioInfo := HandleGroupRatio(c, info)
-
-	modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
-	// 如果没有配置价格,则使用默认价格
-	if !success {
-		defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
-		if !ok {
-			modelPrice = 0.1
-		} else {
-			modelPrice = defaultPrice
-		}
-	}
-	quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
-	priceData := PerCallPriceData{
-		ModelPrice:     modelPrice,
-		Quota:          quota,
-		GroupRatioInfo: groupRatioInfo,
-	}
-	return priceData
-}
+//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
+//	groupRatioInfo := HandleGroupRatio(c, info)
+//
+//	modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
+//	// 如果没有配置价格,则使用默认价格
+//	if !success {
+//		defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
+//		if !ok {
+//			modelPrice = 0.1
+//		} else {
+//			modelPrice = defaultPrice
+//		}
+//	}
+//	quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
+//	priceData := types.PerCallPriceData{
+//		ModelPrice:     modelPrice,
+//		Quota:          quota,
+//		GroupRatioInfo: groupRatioInfo,
+//	}
+//	return priceData
+//}
 
 
 func ContainPriceOrRatio(modelName string) bool {
 func ContainPriceOrRatio(modelName string) bool {
 	_, ok := ratio_setting.GetModelPrice(modelName, false)
 	_, ok := ratio_setting.GetModelPrice(modelName, false)

+ 12 - 11
relay/helper/stream_scanner.go

@@ -8,6 +8,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/setting/operation_setting"
 	"one-api/setting/operation_setting"
 	"strings"
 	"strings"
@@ -87,7 +88,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 		select {
 		select {
 		case <-done:
 		case <-done:
 		case <-time.After(5 * time.Second):
 		case <-time.After(5 * time.Second):
-			common.LogError(c, "timeout waiting for goroutines to exit")
+			logger.LogError(c, "timeout waiting for goroutines to exit")
 		}
 		}
 
 
 		close(stopChan)
 		close(stopChan)
@@ -109,7 +110,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 			defer func() {
 			defer func() {
 				wg.Done()
 				wg.Done()
 				if r := recover(); r != nil {
 				if r := recover(); r != nil {
-					common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
+					logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
 					common.SafeSendBool(stopChan, true)
 					common.SafeSendBool(stopChan, true)
 				}
 				}
 				if common.DebugEnabled {
 				if common.DebugEnabled {
@@ -136,14 +137,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 					select {
 					select {
 					case err := <-done:
 					case err := <-done:
 						if err != nil {
 						if err != nil {
-							common.LogError(c, "ping data error: "+err.Error())
+							logger.LogError(c, "ping data error: "+err.Error())
 							return
 							return
 						}
 						}
 						if common.DebugEnabled {
 						if common.DebugEnabled {
 							println("ping data sent")
 							println("ping data sent")
 						}
 						}
 					case <-time.After(10 * time.Second):
 					case <-time.After(10 * time.Second):
-						common.LogError(c, "ping data send timeout")
+						logger.LogError(c, "ping data send timeout")
 						return
 						return
 					case <-ctx.Done():
 					case <-ctx.Done():
 						return
 						return
@@ -158,7 +159,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 					// 监听客户端断开连接
 					// 监听客户端断开连接
 					return
 					return
 				case <-pingTimeout.C:
 				case <-pingTimeout.C:
-					common.LogError(c, "ping goroutine max duration reached")
+					logger.LogError(c, "ping goroutine max duration reached")
 					return
 					return
 				}
 				}
 			}
 			}
@@ -171,7 +172,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 		defer func() {
 		defer func() {
 			wg.Done()
 			wg.Done()
 			if r := recover(); r != nil {
 			if r := recover(); r != nil {
-				common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
+				logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
 			}
 			}
 			common.SafeSendBool(stopChan, true)
 			common.SafeSendBool(stopChan, true)
 			if common.DebugEnabled {
 			if common.DebugEnabled {
@@ -223,7 +224,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 						return
 						return
 					}
 					}
 				case <-time.After(10 * time.Second):
 				case <-time.After(10 * time.Second):
-					common.LogError(c, "data handler timeout")
+					logger.LogError(c, "data handler timeout")
 					return
 					return
 				case <-ctx.Done():
 				case <-ctx.Done():
 					return
 					return
@@ -241,7 +242,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 
 
 		if err := scanner.Err(); err != nil {
 		if err := scanner.Err(); err != nil {
 			if err != io.EOF {
 			if err != io.EOF {
-				common.LogError(c, "scanner error: "+err.Error())
+				logger.LogError(c, "scanner error: "+err.Error())
 			}
 			}
 		}
 		}
 	})
 	})
@@ -250,12 +251,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 	select {
 	select {
 	case <-ticker.C:
 	case <-ticker.C:
 		// 超时处理逻辑
 		// 超时处理逻辑
-		common.LogError(c, "streaming timeout")
+		logger.LogError(c, "streaming timeout")
 	case <-stopChan:
 	case <-stopChan:
 		// 正常结束
 		// 正常结束
-		common.LogInfo(c, "streaming finished")
+		logger.LogInfo(c, "streaming finished")
 	case <-c.Request.Context().Done():
 	case <-c.Request.Context().Done():
 		// 客户端断开连接
 		// 客户端断开连接
-		common.LogInfo(c, "client disconnected")
+		logger.LogInfo(c, "client disconnected")
 	}
 	}
 }
 }

+ 301 - 0
relay/helper/valid_request.go

@@ -0,0 +1,301 @@
+package helper
+
+import (
+	"errors"
+	"fmt"
+	"math"
+	"one-api/common"
+	"one-api/dto"
+	"one-api/logger"
+	relayconstant "one-api/relay/constant"
+	"one-api/types"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) {
+	relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+
+	switch format {
+	case types.RelayFormatOpenAI:
+		request, err = GetAndValidateTextRequest(c, relayMode)
+	case types.RelayFormatGemini:
+		request, err = GetAndValidateGeminiRequest(c)
+	case types.RelayFormatClaude:
+		request, err = GetAndValidateClaudeRequest(c)
+	case types.RelayFormatOpenAIResponses:
+		request, err = GetAndValidateResponsesRequest(c)
+
+	case types.RelayFormatOpenAIImage:
+		request, err = GetAndValidOpenAIImageRequest(c, relayMode)
+	case types.RelayFormatEmbedding:
+		request, err = GetAndValidateEmbeddingRequest(c, relayMode)
+	case types.RelayFormatRerank:
+		request, err = GetAndValidateRerankRequest(c)
+	case types.RelayFormatOpenAIAudio:
+		request, err = GetAndValidAudioRequest(c, relayMode)
+	case types.RelayFormatOpenAIRealtime:
+	// nothing to do, no request body
+	default:
+		return nil, fmt.Errorf("unsupported relay format: %s", format)
+	}
+	return request, err
+}
+
+func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) {
+	audioRequest := &dto.AudioRequest{}
+	err := common.UnmarshalBodyReusable(c, audioRequest)
+	if err != nil {
+		return nil, err
+	}
+	switch relayMode {
+	case relayconstant.RelayModeAudioSpeech:
+		if audioRequest.Model == "" {
+			return nil, errors.New("model is required")
+		}
+	default:
+		err = c.Request.ParseForm()
+		if err != nil {
+			return nil, err
+		}
+		formData := c.Request.PostForm
+		if audioRequest.Model == "" {
+			audioRequest.Model = formData.Get("model")
+		}
+
+		if audioRequest.Model == "" {
+			return nil, errors.New("model is required")
+		}
+		audioRequest.ResponseFormat = formData.Get("response_format")
+		if audioRequest.ResponseFormat == "" {
+			audioRequest.ResponseFormat = "json"
+		}
+	}
+	return audioRequest, nil
+}
+
+func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) {
+	var rerankRequest *dto.RerankRequest
+	err := common.UnmarshalBodyReusable(c, &rerankRequest)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+		return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	if rerankRequest.Query == "" {
+		return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+	if len(rerankRequest.Documents) == 0 {
+		return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+	return rerankRequest, nil
+}
+
+func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) {
+	var embeddingRequest *dto.EmbeddingRequest
+	err := common.UnmarshalBodyReusable(c, &embeddingRequest)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+		return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	}
+
+	if embeddingRequest.Input == nil {
+		return nil, fmt.Errorf("input is empty")
+	}
+	if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
+		embeddingRequest.Model = "omni-moderation-latest"
+	}
+	if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
+		embeddingRequest.Model = c.Param("model")
+	}
+	return embeddingRequest, nil
+}
+
+func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
+	request := &dto.OpenAIResponsesRequest{}
+	err := common.UnmarshalBodyReusable(c, request)
+	if err != nil {
+		return nil, err
+	}
+	if request.Model == "" {
+		return nil, errors.New("model is required")
+	}
+	if request.Input == nil {
+		return nil, errors.New("input is required")
+	}
+	return request, nil
+}
+
+func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) {
+	imageRequest := &dto.ImageRequest{}
+
+	switch relayMode {
+	case relayconstant.RelayModeImagesEdits:
+		_, err := c.MultipartForm()
+		if err != nil {
+			return nil, err
+		}
+		formData := c.Request.PostForm
+		imageRequest.Prompt = formData.Get("prompt")
+		imageRequest.Model = formData.Get("model")
+		imageRequest.N = uint(common.String2Int(formData.Get("n")))
+		imageRequest.Quality = formData.Get("quality")
+		imageRequest.Size = formData.Get("size")
+
+		if imageRequest.Model == "gpt-image-1" {
+			if imageRequest.Quality == "" {
+				imageRequest.Quality = "standard"
+			}
+		}
+		if imageRequest.N == 0 {
+			imageRequest.N = 1
+		}
+
+		watermark := formData.Has("watermark")
+		if watermark {
+			imageRequest.Watermark = &watermark
+		}
+	default:
+		err := common.UnmarshalBodyReusable(c, imageRequest)
+		if err != nil {
+			return nil, err
+		}
+
+		if imageRequest.Model == "" {
+			imageRequest.Model = "dall-e-3"
+		}
+
+		if strings.Contains(imageRequest.Size, "×") {
+			return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
+		}
+
+		// Not "256x256", "512x512", or "1024x1024"
+		if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
+			if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
+				return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
+			}
+			if imageRequest.Size == "" {
+				imageRequest.Size = "1024x1024"
+			}
+		} else if imageRequest.Model == "dall-e-3" {
+			if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
+				return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
+			}
+			if imageRequest.Quality == "" {
+				imageRequest.Quality = "standard"
+			}
+			if imageRequest.Size == "" {
+				imageRequest.Size = "1024x1024"
+			}
+		} else if imageRequest.Model == "gpt-image-1" {
+			if imageRequest.Quality == "" {
+				imageRequest.Quality = "auto"
+			}
+		}
+
+		if imageRequest.Prompt == "" {
+			return nil, errors.New("prompt is required")
+		}
+
+		if imageRequest.N == 0 {
+			imageRequest.N = 1
+		}
+	}
+
+	return imageRequest, nil
+}
+
+func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
+	textRequest = &dto.ClaudeRequest{}
+	err = c.ShouldBindJSON(textRequest)
+	if err != nil {
+		return nil, err
+	}
+	if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+		return nil, errors.New("field messages is required")
+	}
+	if textRequest.Model == "" {
+		return nil, errors.New("field model is required")
+	}
+
+	//if textRequest.Stream {
+	//	relayInfo.IsStream = true
+	//}
+
+	return textRequest, nil
+}
+
+func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) {
+	textRequest := &dto.GeneralOpenAIRequest{}
+	err := common.UnmarshalBodyReusable(c, textRequest)
+	if err != nil {
+		return nil, err
+	}
+
+	if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
+		textRequest.Model = "text-moderation-latest"
+	}
+	if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
+		textRequest.Model = c.Param("model")
+	}
+
+	if textRequest.MaxTokens > math.MaxInt32/2 {
+		return nil, errors.New("max_tokens is invalid")
+	}
+	if textRequest.Model == "" {
+		return nil, errors.New("model is required")
+	}
+	if textRequest.WebSearchOptions != nil {
+		if textRequest.WebSearchOptions.SearchContextSize != "" {
+			validSizes := map[string]bool{
+				"high":   true,
+				"medium": true,
+				"low":    true,
+			}
+			if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
+				return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
+			}
+		} else {
+			textRequest.WebSearchOptions.SearchContextSize = "medium"
+		}
+	}
+	switch relayMode {
+	case relayconstant.RelayModeCompletions:
+		if textRequest.Prompt == "" {
+			return nil, errors.New("field prompt is required")
+		}
+	case relayconstant.RelayModeChatCompletions:
+		if len(textRequest.Messages) == 0 {
+			return nil, errors.New("field messages is required")
+		}
+	case relayconstant.RelayModeEmbeddings:
+	case relayconstant.RelayModeModerations:
+		if textRequest.Input == nil || textRequest.Input == "" {
+			return nil, errors.New("field input is required")
+		}
+	case relayconstant.RelayModeEdits:
+		if textRequest.Instruction == "" {
+			return nil, errors.New("field instruction is required")
+		}
+	}
+	return textRequest, nil
+}
+
+func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
+
+	request := &dto.GeminiChatRequest{}
+	err := common.UnmarshalBodyReusable(c, request)
+	if err != nil {
+		return nil, err
+	}
+	if len(request.Contents) == 0 {
+		return nil, errors.New("contents is required")
+	}
+
+	//if c.Query("alt") == "sse" {
+	//	relayInfo.IsStream = true
+	//}
+
+	return request, nil
+}

+ 27 - 167
relay/image_handler.go

@@ -3,19 +3,15 @@ package relay
 import (
 import (
 	"bytes"
 	"bytes"
 	"encoding/json"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
-	"one-api/model"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
-	"one-api/setting"
 	"one-api/setting/model_setting"
 	"one-api/setting/model_setting"
 	"one-api/types"
 	"one-api/types"
 	"strings"
 	"strings"
@@ -23,183 +19,41 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
-func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
-	imageRequest := &dto.ImageRequest{}
+func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
 
 
-	switch info.RelayMode {
-	case relayconstant.RelayModeImagesEdits:
-		_, err := c.MultipartForm()
-		if err != nil {
-			return nil, err
-		}
-		formData := c.Request.PostForm
-		imageRequest.Prompt = formData.Get("prompt")
-		imageRequest.Model = formData.Get("model")
-		imageRequest.N = common.String2Int(formData.Get("n"))
-		imageRequest.Quality = formData.Get("quality")
-		imageRequest.Size = formData.Get("size")
-
-		if imageRequest.Model == "gpt-image-1" {
-			if imageRequest.Quality == "" {
-				imageRequest.Quality = "standard"
-			}
-		}
-		if imageRequest.N == 0 {
-			imageRequest.N = 1
-		}
-
-		if info.ApiType == constant.APITypeVolcEngine {
-			watermark := formData.Has("watermark")
-			imageRequest.Watermark = &watermark
-		}
-	default:
-		err := common.UnmarshalBodyReusable(c, imageRequest)
-		if err != nil {
-			return nil, err
-		}
-
-		if imageRequest.Model == "" {
-			imageRequest.Model = "dall-e-3"
-		}
-
-		if strings.Contains(imageRequest.Size, "×") {
-			return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
-		}
-
-		// Not "256x256", "512x512", or "1024x1024"
-		if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
-			if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
-				return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
-			}
-			if imageRequest.Size == "" {
-				imageRequest.Size = "1024x1024"
-			}
-		} else if imageRequest.Model == "dall-e-3" {
-			if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
-				return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
-			}
-			if imageRequest.Quality == "" {
-				imageRequest.Quality = "standard"
-			}
-			if imageRequest.Size == "" {
-				imageRequest.Size = "1024x1024"
-			}
-		} else if imageRequest.Model == "gpt-image-1" {
-			if imageRequest.Quality == "" {
-				imageRequest.Quality = "auto"
-			}
-		}
-
-		if imageRequest.Prompt == "" {
-			return nil, errors.New("prompt is required")
-		}
-
-		if imageRequest.N == 0 {
-			imageRequest.N = 1
-		}
-	}
-
-	if setting.ShouldCheckPromptSensitive() {
-		words, err := service.CheckSensitiveInput(imageRequest.Prompt)
-		if err != nil {
-			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
-			return nil, err
-		}
-	}
-	return imageRequest, nil
-}
+	info.InitChannelMeta(c)
 
 
-func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
-	relayInfo := relaycommon.GenRelayInfoImage(c)
+	imageRequest, ok := info.Request.(*dto.ImageRequest)
 
 
-	imageRequest, err := getAndValidImageRequest(c, relayInfo)
-	if err != nil {
-		common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	if !ok {
+		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.ImageRequest, got %T", info.Request))
 	}
 	}
 
 
-	err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
+	err := helper.ModelMappedHelper(c, info, imageRequest)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-	var preConsumedQuota int
-	var quota int
-	var userQuota int
-	if !priceData.UsePrice {
-		// modelRatio 16 = modelPrice $0.04
-		// per 1 modelRatio = $0.04 / 16
-		// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
-		preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-		if newAPIError != nil {
-			return newAPIError
-		}
-		defer func() {
-			if newAPIError != nil {
-				returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-			}
-		}()
-
-	} else {
-		sizeRatio := 1.0
-		qualityRatio := 1.0
-
-		if strings.HasPrefix(imageRequest.Model, "dall-e") {
-			// Size
-			if imageRequest.Size == "256x256" {
-				sizeRatio = 0.4
-			} else if imageRequest.Size == "512x512" {
-				sizeRatio = 0.45
-			} else if imageRequest.Size == "1024x1024" {
-				sizeRatio = 1
-			} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
-				sizeRatio = 2
-			}
-
-			if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
-				qualityRatio = 2.0
-				if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
-					qualityRatio = 1.5
-				}
-			}
-		}
-
-		// reset model price
-		priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
-		quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
-		userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
-		if err != nil {
-			return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
-		}
-		if userQuota-quota < 0 {
-			return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry())
-		}
-	}
-
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 
 
 	var requestBody io.Reader
 	var requestBody io.Reader
 
 
-	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		body, err := common.GetRequestBody(c)
 		if err != nil {
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		}
 		requestBody = bytes.NewBuffer(body)
 		requestBody = bytes.NewBuffer(body)
 	} else {
 	} else {
-		convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
+		convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest)
 		if err != nil {
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		}
-		if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
+		if info.RelayMode == relayconstant.RelayModeImagesEdits {
 			requestBody = convertedRequest.(io.Reader)
 			requestBody = convertedRequest.(io.Reader)
 		} else {
 		} else {
 			jsonData, err := json.Marshal(convertedRequest)
 			jsonData, err := json.Marshal(convertedRequest)
@@ -208,10 +62,10 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 			}
 			}
 
 
 			// apply param override
 			// apply param override
-			if len(relayInfo.ParamOverride) > 0 {
+			if len(info.ParamOverride) > 0 {
 				reqMap := make(map[string]interface{})
 				reqMap := make(map[string]interface{})
 				_ = common.Unmarshal(jsonData, &reqMap)
 				_ = common.Unmarshal(jsonData, &reqMap)
-				for key, value := range relayInfo.ParamOverride {
+				for key, value := range info.ParamOverride {
 					reqMap[key] = value
 					reqMap[key] = value
 				}
 				}
 				jsonData, err = common.Marshal(reqMap)
 				jsonData, err = common.Marshal(reqMap)
@@ -229,14 +83,14 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 	statusCodeMappingStr := c.GetString("status_code_mapping")
 
 
-	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
 	}
 	var httpResp *http.Response
 	var httpResp *http.Response
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
-		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
 			newAPIError = service.RelayErrorHandler(httpResp, false)
 			newAPIError = service.RelayErrorHandler(httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
@@ -245,7 +99,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 	}
 	}
 
 
-	usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+	usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
 	if newAPIError != nil {
 	if newAPIError != nil {
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
@@ -253,17 +107,23 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	}
 	}
 
 
 	if usage.(*dto.Usage).TotalTokens == 0 {
 	if usage.(*dto.Usage).TotalTokens == 0 {
-		usage.(*dto.Usage).TotalTokens = imageRequest.N
+		usage.(*dto.Usage).TotalTokens = int(imageRequest.N)
 	}
 	}
 	if usage.(*dto.Usage).PromptTokens == 0 {
 	if usage.(*dto.Usage).PromptTokens == 0 {
-		usage.(*dto.Usage).PromptTokens = imageRequest.N
+		usage.(*dto.Usage).PromptTokens = int(imageRequest.N)
 	}
 	}
+
 	quality := "standard"
 	quality := "standard"
 	if imageRequest.Quality == "hd" {
 	if imageRequest.Quality == "hd" {
 		quality = "hd"
 		quality = "hd"
 	}
 	}
 
 
-	logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
-	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
+	var logContent string
+
+	if len(imageRequest.Size) > 0 {
+		logContent = fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
+	}
+
+	postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
 	return nil
 	return nil
 }
 }

+ 44 - 256
relay/relay-text.go

@@ -2,172 +2,56 @@ package relay
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"math"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
-	relayconstant "one-api/relay/constant"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
-	"one-api/setting"
 	"one-api/setting/model_setting"
 	"one-api/setting/model_setting"
 	"one-api/setting/operation_setting"
 	"one-api/setting/operation_setting"
 	"one-api/types"
 	"one-api/types"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
-	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/shopspring/decimal"
 	"github.com/shopspring/decimal"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
-func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
-	textRequest := &dto.GeneralOpenAIRequest{}
-	err := common.UnmarshalBodyReusable(c, textRequest)
-	if err != nil {
-		return nil, err
-	}
-	if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
-		textRequest.Model = "text-moderation-latest"
-	}
-	if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
-		textRequest.Model = c.Param("model")
-	}
+func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
 
 
-	if textRequest.MaxTokens > math.MaxInt32/2 {
-		return nil, errors.New("max_tokens is invalid")
-	}
-	if textRequest.Model == "" {
-		return nil, errors.New("model is required")
-	}
-	if textRequest.WebSearchOptions != nil {
-		if textRequest.WebSearchOptions.SearchContextSize != "" {
-			validSizes := map[string]bool{
-				"high":   true,
-				"medium": true,
-				"low":    true,
-			}
-			if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
-				return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
-			}
-		} else {
-			textRequest.WebSearchOptions.SearchContextSize = "medium"
-		}
-	}
-	switch relayInfo.RelayMode {
-	case relayconstant.RelayModeCompletions:
-		if textRequest.Prompt == "" {
-			return nil, errors.New("field prompt is required")
-		}
-	case relayconstant.RelayModeChatCompletions:
-		if len(textRequest.Messages) == 0 {
-			return nil, errors.New("field messages is required")
-		}
-	case relayconstant.RelayModeEmbeddings:
-	case relayconstant.RelayModeModerations:
-		if textRequest.Input == nil || textRequest.Input == "" {
-			return nil, errors.New("field input is required")
-		}
-	case relayconstant.RelayModeEdits:
-		if textRequest.Instruction == "" {
-			return nil, errors.New("field instruction is required")
-		}
-	}
-	relayInfo.IsStream = textRequest.Stream
-	return textRequest, nil
-}
+	info.InitChannelMeta(c)
 
 
-func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+	textRequest, ok := info.Request.(*dto.GeneralOpenAIRequest)
 
 
-	relayInfo := relaycommon.GenRelayInfo(c)
-
-	// get & validate textRequest 获取并验证文本请求
-	textRequest, err := getAndValidateTextRequest(c, relayInfo)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	if !ok {
+		//return types.NewErrorWithStatusCode(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+		common.FatalLog("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request)
 	}
 	}
 
 
 	if textRequest.WebSearchOptions != nil {
 	if textRequest.WebSearchOptions != nil {
 		c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
 		c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
 	}
 	}
 
 
-	if setting.ShouldCheckPromptSensitive() {
-		words, err := checkRequestSensitive(textRequest, relayInfo)
-		if err != nil {
-			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
-			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
-		}
-	}
-
-	err = helper.ModelMappedHelper(c, relayInfo, textRequest)
+	err := helper.ModelMappedHelper(c, info, textRequest)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	// 获取 promptTokens,如果上下文中已经存在,则直接使用
-	var promptTokens int
-	if value, exists := c.Get("prompt_tokens"); exists {
-		promptTokens = value.(int)
-		relayInfo.PromptTokens = promptTokens
-	} else {
-		promptTokens, err = getPromptTokens(textRequest, relayInfo)
-		// count messages token error 计算promptTokens错误
-		if err != nil {
-			return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
-		}
-		c.Set("prompt_tokens", promptTokens)
-	}
-
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-
-	// pre-consume quota 预消耗配额
-	preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-	if newApiErr != nil {
-		return newApiErr
-	}
-	defer func() {
-		if newApiErr != nil {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-		}
-	}()
-	includeUsage := true
-	// 判断用户是否需要返回使用情况
-	if textRequest.StreamOptions != nil {
-		includeUsage = textRequest.StreamOptions.IncludeUsage
-	}
-
-	// 如果不支持StreamOptions,将StreamOptions设置为nil
-	if !relayInfo.SupportStreamOptions || !textRequest.Stream {
-		textRequest.StreamOptions = nil
-	} else {
-		// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
-		if constant.ForceStreamOption {
-			textRequest.StreamOptions = &dto.StreamOptions{
-				IncludeUsage: true,
-			}
-		}
-	}
-
-	relayInfo.ShouldIncludeUsage = includeUsage
-
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 	var requestBody io.Reader
 	var requestBody io.Reader
 
 
-	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		body, err := common.GetRequestBody(c)
 		if err != nil {
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
@@ -177,12 +61,12 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 		requestBody = bytes.NewBuffer(body)
 		requestBody = bytes.NewBuffer(body)
 	} else {
 	} else {
-		convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
+		convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, textRequest)
 		if err != nil {
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		}
 
 
-		if relayInfo.ChannelSetting.SystemPrompt != "" {
+		if info.ChannelSetting.SystemPrompt != "" {
 			// 如果有系统提示,则将其添加到请求中
 			// 如果有系统提示,则将其添加到请求中
 			request := convertedRequest.(*dto.GeneralOpenAIRequest)
 			request := convertedRequest.(*dto.GeneralOpenAIRequest)
 			containSystemPrompt := false
 			containSystemPrompt := false
@@ -196,22 +80,22 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 				// 如果没有系统提示,则添加系统提示
 				// 如果没有系统提示,则添加系统提示
 				systemMessage := dto.Message{
 				systemMessage := dto.Message{
 					Role:    request.GetSystemRoleName(),
 					Role:    request.GetSystemRoleName(),
-					Content: relayInfo.ChannelSetting.SystemPrompt,
+					Content: info.ChannelSetting.SystemPrompt,
 				}
 				}
 				request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
 				request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
-			} else if relayInfo.ChannelSetting.SystemPromptOverride {
+			} else if info.ChannelSetting.SystemPromptOverride {
 				common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
 				common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
 				// 如果有系统提示,且允许覆盖,则拼接到前面
 				// 如果有系统提示,且允许覆盖,则拼接到前面
 				for i, message := range request.Messages {
 				for i, message := range request.Messages {
 					if message.Role == request.GetSystemRoleName() {
 					if message.Role == request.GetSystemRoleName() {
 						if message.IsStringContent() {
 						if message.IsStringContent() {
-							request.Messages[i].SetStringContent(relayInfo.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
+							request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
 						} else {
 						} else {
 							contents := message.ParseContent()
 							contents := message.ParseContent()
 							contents = append([]dto.MediaContent{
 							contents = append([]dto.MediaContent{
 								{
 								{
 									Type: dto.ContentTypeText,
 									Type: dto.ContentTypeText,
-									Text: relayInfo.ChannelSetting.SystemPrompt,
+									Text: info.ChannelSetting.SystemPrompt,
 								},
 								},
 							}, contents...)
 							}, contents...)
 							request.Messages[i].Content = contents
 							request.Messages[i].Content = contents
@@ -228,10 +112,10 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 
 
 		// apply param override
 		// apply param override
-		if len(relayInfo.ParamOverride) > 0 {
+		if len(info.ParamOverride) > 0 {
 			reqMap := make(map[string]interface{})
 			reqMap := make(map[string]interface{})
 			_ = common.Unmarshal(jsonData, &reqMap)
 			_ = common.Unmarshal(jsonData, &reqMap)
-			for key, value := range relayInfo.ParamOverride {
+			for key, value := range info.ParamOverride {
 				reqMap[key] = value
 				reqMap[key] = value
 			}
 			}
 			jsonData, err = common.Marshal(reqMap)
 			jsonData, err = common.Marshal(reqMap)
@@ -240,14 +124,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 			}
 			}
 		}
 		}
 
 
-		if common.DebugEnabled {
-			println("requestBody: ", string(jsonData))
-		}
+		logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData)))
+
 		requestBody = bytes.NewBuffer(jsonData)
 		requestBody = bytes.NewBuffer(jsonData)
 	}
 	}
 
 
 	var httpResp *http.Response
 	var httpResp *http.Response
-	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
 	}
@@ -256,125 +139,31 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 
 
 	if resp != nil {
 	if resp != nil {
 		httpResp = resp.(*http.Response)
 		httpResp = resp.(*http.Response)
-		relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+		info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
 		if httpResp.StatusCode != http.StatusOK {
 		if httpResp.StatusCode != http.StatusOK {
-			newApiErr = service.RelayErrorHandler(httpResp, false)
+			newApiErr := service.RelayErrorHandler(httpResp, false)
 			// reset status code 重置状态码
 			// reset status code 重置状态码
 			service.ResetStatusCode(newApiErr, statusCodeMappingStr)
 			service.ResetStatusCode(newApiErr, statusCodeMappingStr)
 			return newApiErr
 			return newApiErr
 		}
 		}
 	}
 	}
 
 
-	usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo)
+	usage, newApiErr := adaptor.DoResponse(c, httpResp, info)
 	if newApiErr != nil {
 	if newApiErr != nil {
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(newApiErr, statusCodeMappingStr)
 		service.ResetStatusCode(newApiErr, statusCodeMappingStr)
 		return newApiErr
 		return newApiErr
 	}
 	}
 
 
-	if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
-		service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
+		service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
 	} else {
 	} else {
-		postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+		postConsumeQuota(c, info, usage.(*dto.Usage), "")
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
-func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
-	var promptTokens int
-	var err error
-	switch info.RelayMode {
-	case relayconstant.RelayModeChatCompletions:
-		promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
-	case relayconstant.RelayModeCompletions:
-		promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
-	case relayconstant.RelayModeModerations:
-		promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
-	case relayconstant.RelayModeEmbeddings:
-		promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
-	default:
-		err = errors.New("unknown relay mode")
-		promptTokens = 0
-	}
-	info.PromptTokens = promptTokens
-	return promptTokens, err
-}
-
-func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
-	var err error
-	var words []string
-	switch info.RelayMode {
-	case relayconstant.RelayModeChatCompletions:
-		words, err = service.CheckSensitiveMessages(textRequest.Messages)
-	case relayconstant.RelayModeCompletions:
-		words, err = service.CheckSensitiveInput(textRequest.Prompt)
-	case relayconstant.RelayModeModerations:
-		words, err = service.CheckSensitiveInput(textRequest.Input)
-	case relayconstant.RelayModeEmbeddings:
-		words, err = service.CheckSensitiveInput(textRequest.Input)
-	}
-	return words, err
-}
-
-// 预扣费并返回用户剩余配额
-func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
-	userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
-	if err != nil {
-		return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
-	}
-	if userQuota <= 0 {
-		return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
-	}
-	if userQuota-preConsumedQuota < 0 {
-		return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
-	}
-	relayInfo.UserQuota = userQuota
-	if userQuota > 100*preConsumedQuota {
-		// 用户额度充足,判断令牌额度是否充足
-		if !relayInfo.TokenUnlimited {
-			// 非无限令牌,判断令牌额度是否充足
-			tokenQuota := c.GetInt("token_quota")
-			if tokenQuota > 100*preConsumedQuota {
-				// 令牌额度充足,信任令牌
-				preConsumedQuota = 0
-				common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
-			}
-		} else {
-			// in this case, we do not pre-consume quota
-			// because the user has enough quota
-			preConsumedQuota = 0
-			common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
-		}
-	}
-
-	if preConsumedQuota > 0 {
-		err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
-		if err != nil {
-			return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
-		}
-		err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
-		if err != nil {
-			return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
-		}
-	}
-	return preConsumedQuota, userQuota, nil
-}
-
-func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
-	if preConsumedQuota != 0 {
-		gopool.Go(func() {
-			relayInfoCopy := *relayInfo
-
-			err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
-			if err != nil {
-				common.SysError("error return pre-consumed quota: " + err.Error())
-			}
-		})
-	}
-}
-
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
-	usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
 	if usage == nil {
 	if usage == nil {
 		usage = &dto.Usage{
 		usage = &dto.Usage{
 			PromptTokens:     relayInfo.PromptTokens,
 			PromptTokens:     relayInfo.PromptTokens,
@@ -392,12 +181,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	modelName := relayInfo.OriginModelName
 	modelName := relayInfo.OriginModelName
 
 
 	tokenName := ctx.GetString("token_name")
 	tokenName := ctx.GetString("token_name")
-	completionRatio := priceData.CompletionRatio
-	cacheRatio := priceData.CacheRatio
-	imageRatio := priceData.ImageRatio
-	modelRatio := priceData.ModelRatio
-	groupRatio := priceData.GroupRatioInfo.GroupRatio
-	modelPrice := priceData.ModelPrice
+	completionRatio := relayInfo.PriceData.CompletionRatio
+	cacheRatio := relayInfo.PriceData.CacheRatio
+	imageRatio := relayInfo.PriceData.ImageRatio
+	modelRatio := relayInfo.PriceData.ModelRatio
+	groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+	modelPrice := relayInfo.PriceData.ModelPrice
 
 
 	// Convert values to decimal for precise calculation
 	// Convert values to decimal for precise calculation
 	dPromptTokens := decimal.NewFromInt(int64(promptTokens))
 	dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -470,7 +259,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 
 
 	var audioInputQuota decimal.Decimal
 	var audioInputQuota decimal.Decimal
 	var audioInputPrice float64
 	var audioInputPrice float64
-	if !priceData.UsePrice {
+	if !relayInfo.PriceData.UsePrice {
 		baseTokens := dPromptTokens
 		baseTokens := dPromptTokens
 		// 减去 cached tokens
 		// 减去 cached tokens
 		var cachedTokensWithRatio decimal.Decimal
 		var cachedTokensWithRatio decimal.Decimal
@@ -518,7 +307,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	totalTokens := promptTokens + completionTokens
 	totalTokens := promptTokens + completionTokens
 
 
 	var logContent string
 	var logContent string
-	if !priceData.UsePrice {
+	if !relayInfo.PriceData.UsePrice {
 		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
 		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
 	} else {
 	} else {
 		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
 		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
@@ -530,8 +319,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		// we cannot just return, because we may have to return the pre-consumed quota
 		// we cannot just return, because we may have to return the pre-consumed quota
 		quota = 0
 		quota = 0
 		logContent += fmt.Sprintf("(可能是上游超时)")
 		logContent += fmt.Sprintf("(可能是上游超时)")
-		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
-			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+		logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
 	} else {
 	} else {
 		if !ratio.IsZero() && quota == 0 {
 		if !ratio.IsZero() && quota == 0 {
 			quota = 1
 			quota = 1
@@ -540,11 +329,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 	}
 
 
-	quotaDelta := quota - preConsumedQuota
+	quotaDelta := quota - relayInfo.FinalPreConsumedQuota
 	if quotaDelta != 0 {
 	if quotaDelta != 0 {
-		err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+		err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
 		if err != nil {
 		if err != nil {
-			common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+			logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
 		}
 		}
 	}
 	}
 
 
@@ -560,7 +349,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	if extraContent != "" {
 	if extraContent != "" {
 		logContent += ", " + extraContent
 		logContent += ", " + extraContent
 	}
 	}
-	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+	other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
 	if imageTokens != 0 {
 	if imageTokens != 0 {
 		other["image"] = true
 		other["image"] = true
 		other["image_ratio"] = imageRatio
 		other["image_ratio"] = imageRatio
@@ -604,7 +393,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		Quota:            quota,
 		Quota:            quota,
 		Content:          logContent,
 		Content:          logContent,
 		TokenId:          relayInfo.TokenId,
 		TokenId:          relayInfo.TokenId,
-		UserQuota:        userQuota,
 		UseTimeSeconds:   int(useTimeSeconds),
 		UseTimeSeconds:   int(useTimeSeconds),
 		IsStream:         relayInfo.IsStream,
 		IsStream:         relayInfo.IsStream,
 		Group:            relayInfo.UsingGroup,
 		Group:            relayInfo.UsingGroup,

+ 2 - 1
relay/relay_task.go

@@ -10,6 +10,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
@@ -127,7 +128,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
 
 
 			err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
 			err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
 			if err != nil {
 			if err != nil {
-				common.SysError("error consuming token remain quota: " + err.Error())
+				logger.SysError("error consuming token remain quota: " + err.Error())
 			}
 			}
 			if quota != 0 {
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				tokenName := c.GetString("token_name")

+ 15 - 44
relay/rerank_handler.go

@@ -25,62 +25,33 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
 	return token
 	return token
 }
 }
 
 
-func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) {
+func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
 
 
-	var rerankRequest *dto.RerankRequest
-	err := common.UnmarshalBodyReusable(c, &rerankRequest)
-	if err != nil {
-		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	rerankRequest, ok := info.Request.(*dto.RerankRequest)
+	if !ok {
+		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.RerankRequest, got %T", info.Request))
 	}
 	}
 
 
-	relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
-
-	if rerankRequest.Query == "" {
-		return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
-	}
-	if len(rerankRequest.Documents) == 0 {
-		return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
-	}
-
-	err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
+	err := helper.ModelMappedHelper(c, info, rerankRequest)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	promptToken := getRerankPromptToken(*rerankRequest)
-	relayInfo.PromptTokens = promptToken
-
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-	// pre-consume quota 预消耗配额
-	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-	if newAPIError != nil {
-		return newAPIError
-	}
-	defer func() {
-		if newAPIError != nil {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-		}
-	}()
-
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 
 
 	var requestBody io.Reader
 	var requestBody io.Reader
-	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
+	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
 		body, err := common.GetRequestBody(c)
 		body, err := common.GetRequestBody(c)
 		if err != nil {
 		if err != nil {
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 			return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
 		}
 		}
 		requestBody = bytes.NewBuffer(body)
 		requestBody = bytes.NewBuffer(body)
 	} else {
 	} else {
-		convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
+		convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankRequest)
 		if err != nil {
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		}
@@ -90,10 +61,10 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
 		}
 		}
 
 
 		// apply param override
 		// apply param override
-		if len(relayInfo.ParamOverride) > 0 {
+		if len(info.ParamOverride) > 0 {
 			reqMap := make(map[string]interface{})
 			reqMap := make(map[string]interface{})
 			_ = common.Unmarshal(jsonData, &reqMap)
 			_ = common.Unmarshal(jsonData, &reqMap)
-			for key, value := range relayInfo.ParamOverride {
+			for key, value := range info.ParamOverride {
 				reqMap[key] = value
 				reqMap[key] = value
 			}
 			}
 			jsonData, err = common.Marshal(reqMap)
 			jsonData, err = common.Marshal(reqMap)
@@ -108,7 +79,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
 		requestBody = bytes.NewBuffer(jsonData)
 		requestBody = bytes.NewBuffer(jsonData)
 	}
 	}
 
 
-	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
 	}
@@ -125,12 +96,12 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
 		}
 		}
 	}
 	}
 
 
-	usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+	usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
 	if newAPIError != nil {
 	if newAPIError != nil {
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		return newAPIError
 		return newAPIError
 	}
 	}
-	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	postConsumeQuota(c, info, usage.(*dto.Usage), "")
 	return nil
 	return nil
 }
 }

+ 17 - 77
relay/responses_handler.go

@@ -3,7 +3,6 @@ package relay
 import (
 import (
 	"bytes"
 	"bytes"
 	"encoding/json"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
@@ -12,7 +11,6 @@ import (
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
-	"one-api/setting"
 	"one-api/setting/model_setting"
 	"one-api/setting/model_setting"
 	"one-api/types"
 	"one-api/types"
 	"strings"
 	"strings"
@@ -20,82 +18,24 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
-func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
-	request := &dto.OpenAIResponsesRequest{}
-	err := common.UnmarshalBodyReusable(c, request)
-	if err != nil {
-		return nil, err
-	}
-	if request.Model == "" {
-		return nil, errors.New("model is required")
-	}
-	if len(request.Input) == 0 {
-		return nil, errors.New("input is required")
-	}
-	return request, nil
+func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+	info.InitChannelMeta(c)
 
 
-}
-
-func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
-	sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
-	return sensitiveWords, err
-}
-
-func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
-	inputTokens := service.CountTokenInput(req.Input, req.Model)
-	info.PromptTokens = inputTokens
-	return inputTokens
-}
-
-func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
-	req, err := getAndValidateResponsesRequest(c)
-	if err != nil {
-		common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
-		return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+	request, ok := info.Request.(*dto.OpenAIResponsesRequest)
+	if !ok {
+		common.FatalLog(fmt.Sprintf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request))
 	}
 	}
 
 
-	relayInfo := relaycommon.GenRelayInfoResponses(c, req)
-
-	if setting.ShouldCheckPromptSensitive() {
-		sensitiveWords, err := checkInputSensitive(req, relayInfo)
-		if err != nil {
-			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
-			return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
-		}
-	}
-
-	err = helper.ModelMappedHelper(c, relayInfo, req)
+	err := helper.ModelMappedHelper(c, info, request)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 	}
 	}
 
 
-	if value, exists := c.Get("prompt_tokens"); exists {
-		promptTokens := value.(int)
-		relayInfo.SetPromptTokens(promptTokens)
-	} else {
-		promptTokens := getInputTokens(req, relayInfo)
-		c.Set("prompt_tokens", promptTokens)
-	}
-
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
-	if err != nil {
-		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
-	}
-	// pre consume quota
-	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-	if newAPIError != nil {
-		return newAPIError
-	}
-	defer func() {
-		if newAPIError != nil {
-			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
-		}
-	}()
-	adaptor := GetAdaptor(relayInfo.ApiType)
+	adaptor := GetAdaptor(info.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
-		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+		return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
 	}
 	}
-	adaptor.Init(relayInfo)
+	adaptor.Init(info)
 	var requestBody io.Reader
 	var requestBody io.Reader
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
 	if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
 		body, err := common.GetRequestBody(c)
 		body, err := common.GetRequestBody(c)
@@ -104,7 +44,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 		requestBody = bytes.NewBuffer(body)
 		requestBody = bytes.NewBuffer(body)
 	} else {
 	} else {
-		convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
+		convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request)
 		if err != nil {
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		}
@@ -113,13 +53,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
 		}
 		// apply param override
 		// apply param override
-		if len(relayInfo.ParamOverride) > 0 {
+		if len(info.ParamOverride) > 0 {
 			reqMap := make(map[string]interface{})
 			reqMap := make(map[string]interface{})
 			err = json.Unmarshal(jsonData, &reqMap)
 			err = json.Unmarshal(jsonData, &reqMap)
 			if err != nil {
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}
 			}
-			for key, value := range relayInfo.ParamOverride {
+			for key, value := range info.ParamOverride {
 				reqMap[key] = value
 				reqMap[key] = value
 			}
 			}
 			jsonData, err = json.Marshal(reqMap)
 			jsonData, err = json.Marshal(reqMap)
@@ -135,7 +75,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	}
 	}
 
 
 	var httpResp *http.Response
 	var httpResp *http.Response
-	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
 	if err != nil {
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
 	}
 	}
@@ -153,17 +93,17 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 		}
 		}
 	}
 	}
 
 
-	usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+	usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
 	if newAPIError != nil {
 	if newAPIError != nil {
 		// reset status code 重置状态码
 		// reset status code 重置状态码
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		service.ResetStatusCode(newAPIError, statusCodeMappingStr)
 		return newAPIError
 		return newAPIError
 	}
 	}
 
 
-	if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
-		service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
+		service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
 	} else {
 	} else {
-		postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+		postConsumeQuota(c, info, usage.(*dto.Usage), "")
 	}
 	}
 	return nil
 	return nil
 }
 }

+ 0 - 7
relay/websocket.go

@@ -15,13 +15,6 @@ import (
 func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
 func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
 	relayInfo := relaycommon.GenRelayInfoWs(c, ws)
 	relayInfo := relaycommon.GenRelayInfoWs(c, ws)
 
 
-	// get & validate textRequest 获取并验证文本请求
-	//realtimeEvent, err := getAndValidateWssRequest(c, ws)
-	//if err != nil {
-	//	common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
-	//	return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
-	//}
-
 	err := helper.ModelMappedHelper(c, relayInfo, nil)
 	err := helper.ModelMappedHelper(c, relayInfo, nil)
 	if err != nil {
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
 		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())

+ 2 - 1
router/main.go

@@ -6,6 +6,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
+	"one-api/logger"
 	"os"
 	"os"
 	"strings"
 	"strings"
 )
 )
@@ -18,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 	frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
 	frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
 	if common.IsMasterNode && frontendBaseUrl != "" {
 	if common.IsMasterNode && frontendBaseUrl != "" {
 		frontendBaseUrl = ""
 		frontendBaseUrl = ""
-		common.SysLog("FRONTEND_BASE_URL is ignored on master node")
+		logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
 	}
 	}
 	if frontendBaseUrl == "" {
 	if frontendBaseUrl == "" {
 		SetWebRouter(router, buildFS, indexPage)
 		SetWebRouter(router, buildFS, indexPage)

+ 75 - 19
router/relay-router.go

@@ -1,11 +1,13 @@
 package router
 package router
 
 
 import (
 import (
-	"github.com/gin-gonic/gin"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/controller"
 	"one-api/controller"
 	"one-api/middleware"
 	"one-api/middleware"
 	"one-api/relay"
 	"one-api/relay"
+	"one-api/types"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 func SetRelayRouter(router *gin.Engine) {
 func SetRelayRouter(router *gin.Engine) {
@@ -62,28 +64,83 @@ func SetRelayRouter(router *gin.Engine) {
 	relayV1Router.Use(middleware.TokenAuth())
 	relayV1Router.Use(middleware.TokenAuth())
 	relayV1Router.Use(middleware.ModelRequestRateLimit())
 	relayV1Router.Use(middleware.ModelRequestRateLimit())
 	{
 	{
-		// WebSocket 路由
+		// WebSocket 路由(统一到 Relay)
 		wsRouter := relayV1Router.Group("")
 		wsRouter := relayV1Router.Group("")
 		wsRouter.Use(middleware.Distribute())
 		wsRouter.Use(middleware.Distribute())
-		wsRouter.GET("/realtime", controller.WssRelay)
+		wsRouter.GET("/realtime", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAIRealtime)
+		})
 	}
 	}
 	{
 	{
 		//http router
 		//http router
 		httpRouter := relayV1Router.Group("")
 		httpRouter := relayV1Router.Group("")
 		httpRouter.Use(middleware.Distribute())
 		httpRouter.Use(middleware.Distribute())
-		httpRouter.POST("/messages", controller.RelayClaude)
-		httpRouter.POST("/completions", controller.Relay)
-		httpRouter.POST("/chat/completions", controller.Relay)
-		httpRouter.POST("/edits", controller.Relay)
-		httpRouter.POST("/images/generations", controller.Relay)
-		httpRouter.POST("/images/edits", controller.Relay)
+
+		// claude related routes
+		httpRouter.POST("/messages", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatClaude)
+		})
+
+		// chat related routes
+		httpRouter.POST("/completions", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAI)
+		})
+		httpRouter.POST("/chat/completions", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAI)
+		})
+
+		// response related routes
+		httpRouter.POST("/responses", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAIResponses)
+		})
+
+		// image related routes
+		httpRouter.POST("/edits", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAIImage)
+		})
+		httpRouter.POST("/images/generations", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAIImage)
+		})
+		httpRouter.POST("/images/edits", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAIImage)
+		})
+
+		// embedding related routes
+		httpRouter.POST("/embeddings", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatEmbedding)
+		})
+
+		// audio related routes
+		httpRouter.POST("/audio/transcriptions", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAIAudio)
+		})
+		httpRouter.POST("/audio/translations", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAIAudio)
+		})
+		httpRouter.POST("/audio/speech", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAIAudio)
+		})
+
+		// rerank related routes
+		httpRouter.POST("/rerank", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatRerank)
+		})
+
+		// gemini relay routes
+		httpRouter.POST("/engines/:model/embeddings", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatGemini)
+		})
+		httpRouter.POST("/models/*path", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatGemini)
+		})
+
+		// other relay routes
+		httpRouter.POST("/moderations", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatOpenAI)
+		})
+
+		// not implemented
 		httpRouter.POST("/images/variations", controller.RelayNotImplemented)
 		httpRouter.POST("/images/variations", controller.RelayNotImplemented)
-		httpRouter.POST("/embeddings", controller.Relay)
-		httpRouter.POST("/engines/:model/embeddings", controller.Relay)
-		httpRouter.POST("/audio/transcriptions", controller.Relay)
-		httpRouter.POST("/audio/translations", controller.Relay)
-		httpRouter.POST("/audio/speech", controller.Relay)
-		httpRouter.POST("/responses", controller.Relay)
 		httpRouter.GET("/files", controller.RelayNotImplemented)
 		httpRouter.GET("/files", controller.RelayNotImplemented)
 		httpRouter.POST("/files", controller.RelayNotImplemented)
 		httpRouter.POST("/files", controller.RelayNotImplemented)
 		httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
 		httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
@@ -95,9 +152,6 @@ func SetRelayRouter(router *gin.Engine) {
 		httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
 		httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
 		httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
 		httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
 		httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
 		httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
-		httpRouter.POST("/moderations", controller.Relay)
-		httpRouter.POST("/rerank", controller.Relay)
-		httpRouter.POST("/models/*path", controller.Relay)
 	}
 	}
 
 
 	relayMjRouter := router.Group("/mj")
 	relayMjRouter := router.Group("/mj")
@@ -121,7 +175,9 @@ func SetRelayRouter(router *gin.Engine) {
 	relayGeminiRouter.Use(middleware.Distribute())
 	relayGeminiRouter.Use(middleware.Distribute())
 	{
 	{
 		// Gemini API 路径格式: /v1beta/models/{model_name}:{action}
 		// Gemini API 路径格式: /v1beta/models/{model_name}:{action}
-		relayGeminiRouter.POST("/models/*path", controller.Relay)
+		relayGeminiRouter.POST("/models/*path", func(c *gin.Context) {
+			controller.Relay(c, types.RelayFormatGemini)
+		})
 	}
 	}
 }
 }
 
 

+ 3 - 3
service/cf_worker.go

@@ -5,7 +5,7 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
-	"one-api/common"
+	"one-api/logger"
 	"one-api/setting"
 	"one-api/setting"
 	"strings"
 	"strings"
 )
 )
@@ -44,14 +44,14 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
 
 
 func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
 func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
 	if setting.EnableWorker() {
 	if setting.EnableWorker() {
-		common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
+		logger.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
 		req := &WorkerRequest{
 		req := &WorkerRequest{
 			URL: originUrl,
 			URL: originUrl,
 			Key: setting.WorkerValidKey,
 			Key: setting.WorkerValidKey,
 		}
 		}
 		return DoWorkerRequest(req)
 		return DoWorkerRequest(req)
 	} else {
 	} else {
-		common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
+		logger.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
 		return http.Get(originUrl)
 		return http.Get(originUrl)
 	}
 	}
 }
 }

+ 4 - 3
service/error.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/types"
 	"one-api/types"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -58,7 +59,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError
 	lowerText := strings.ToLower(text)
 	lowerText := strings.ToLower(text)
 	if !strings.HasPrefix(lowerText, "get file base64 from url") {
 	if !strings.HasPrefix(lowerText, "get file base64 from url") {
 		if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
 		if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
-			common.SysLog(fmt.Sprintf("error: %s", text))
+			logger.SysLog(fmt.Sprintf("error: %s", text))
 			text = "请求上游地址失败"
 			text = "请求上游地址失败"
 		}
 		}
 	}
 	}
@@ -85,7 +86,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
 	if err != nil {
 	if err != nil {
 		return
 		return
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	CloseResponseBodyGracefully(resp)
 	var errResponse dto.GeneralErrorResponse
 	var errResponse dto.GeneralErrorResponse
 
 
 	err = common.Unmarshal(responseBody, &errResponse)
 	err = common.Unmarshal(responseBody, &errResponse)
@@ -138,7 +139,7 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
 	text := err.Error()
 	text := err.Error()
 	lowerText := strings.ToLower(text)
 	lowerText := strings.ToLower(text)
 	if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
 	if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
-		common.SysLog(fmt.Sprintf("error: %s", text))
+		logger.SysLog(fmt.Sprintf("error: %s", text))
 		text = "请求上游地址失败"
 		text = "请求上游地址失败"
 	}
 	}
 	//避免暴露内部错误
 	//避免暴露内部错误

+ 5 - 3
common/http.go → service/http.go

@@ -1,10 +1,12 @@
-package common
+package service
 
 
 import (
 import (
 	"bytes"
 	"bytes"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
+	"one-api/common"
+	"one-api/logger"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
@@ -15,7 +17,7 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) {
 	}
 	}
 	err := httpResponse.Body.Close()
 	err := httpResponse.Body.Close()
 	if err != nil {
 	if err != nil {
-		SysError("failed to close response body: " + err.Error())
+		common.SysError("failed to close response body: " + err.Error())
 	}
 	}
 }
 }
 
 
@@ -52,6 +54,6 @@ func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
 
 
 	_, err := io.Copy(c.Writer, body)
 	_, err := io.Copy(c.Writer, body)
 	if err != nil {
 	if err != nil {
-		LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
+		logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
 	}
 	}
 }
 }

+ 5 - 5
service/image.go

@@ -8,8 +8,8 @@ import (
 	"image"
 	"image"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
-	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
+	"one-api/logger"
 	"strings"
 	"strings"
 
 
 	"golang.org/x/image/webp"
 	"golang.org/x/image/webp"
@@ -113,7 +113,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
 func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
 func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
 	response, err := DoDownloadRequest(imageUrl)
 	response, err := DoDownloadRequest(imageUrl)
 	if err != nil {
 	if err != nil {
-		common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
+		logger.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
 		return image.Config{}, "", err
 		return image.Config{}, "", err
 	}
 	}
 	defer response.Body.Close()
 	defer response.Body.Close()
@@ -131,7 +131,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
 
 
 	var readData []byte
 	var readData []byte
 	for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
 	for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
-		common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
+		logger.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
 
 
 		// 从response.Body读取更多的数据直到达到当前的限制
 		// 从response.Body读取更多的数据直到达到当前的限制
 		additionalData := make([]byte, limit-int64(len(readData)))
 		additionalData := make([]byte, limit-int64(len(readData)))
@@ -157,11 +157,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
 	config, format, err := image.DecodeConfig(reader)
 	config, format, err := image.DecodeConfig(reader)
 	if err != nil {
 	if err != nil {
 		err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
 		err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
-		common.SysLog(err.Error())
+		logger.SysLog(err.Error())
 		config, err = webp.DecodeConfig(reader)
 		config, err = webp.DecodeConfig(reader)
 		if err != nil {
 		if err != nil {
 			err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
 			err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
-			common.SysLog(err.Error())
+			logger.SysLog(err.Error())
 		}
 		}
 		format = "webp"
 		format = "webp"
 	}
 	}

+ 3 - 2
service/midjourney.go

@@ -9,6 +9,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
 	"one-api/setting"
 	"one-api/setting"
 	"strconv"
 	"strconv"
@@ -212,7 +213,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
 	defer cancel()
 	defer cancel()
 	resp, err := GetHttpClient().Do(req)
 	resp, err := GetHttpClient().Do(req)
 	if err != nil {
 	if err != nil {
-		common.SysError("do request failed: " + err.Error())
+		logger.SysError("do request failed: " + err.Error())
 		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
 		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
 	}
 	}
 	statusCode := resp.StatusCode
 	statusCode := resp.StatusCode
@@ -233,7 +234,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
 	if err != nil {
 	if err != nil {
 		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
 		return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
+	CloseResponseBodyGracefully(resp)
 	respStr := string(responseBody)
 	respStr := string(responseBody)
 	log.Printf("respStr: %s", respStr)
 	log.Printf("respStr: %s", respStr)
 	if respStr == "" {
 	if respStr == "" {

+ 72 - 0
service/pre_consume_quota.go

@@ -0,0 +1,72 @@
+package service
+
+import (
+	"errors"
+	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
+	"net/http"
+	"one-api/logger"
+	"one-api/model"
+	relaycommon "one-api/relay/common"
+	"one-api/types"
+)
+
+func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
+	if preConsumedQuota != 0 {
+		gopool.Go(func() {
+			relayInfoCopy := *relayInfo
+
+			err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
+			if err != nil {
+				logger.SysError("error return pre-consumed quota: " + err.Error())
+			}
+		})
+	}
+}
+
+// PreConsumeQuota checks if the user has enough quota to pre-consume.
+// It returns the pre-consumed quota if successful, or an error if not.
+func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
+	userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
+	if err != nil {
+		return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
+	}
+	if userQuota <= 0 {
+		return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+	}
+	if userQuota-preConsumedQuota < 0 {
+		return 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+	}
+	relayInfo.UserQuota = userQuota
+	if userQuota > 100*preConsumedQuota {
+		// 用户额度充足,判断令牌额度是否充足
+		if !relayInfo.TokenUnlimited {
+			// 非无限令牌,判断令牌额度是否充足
+			tokenQuota := c.GetInt("token_quota")
+			if tokenQuota > 100*preConsumedQuota {
+				// 令牌额度充足,信任令牌
+				preConsumedQuota = 0
+				logger.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
+			}
+		} else {
+			// in this case, we do not pre-consume quota
+			// because the user has enough quota
+			preConsumedQuota = 0
+			logger.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, logger.FormatQuota(userQuota)))
+		}
+	}
+
+	if preConsumedQuota > 0 {
+		err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+		if err != nil {
+			return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+		}
+		err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
+		if err != nil {
+			return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
+		}
+	}
+	relayInfo.FinalPreConsumedQuota = preConsumedQuota
+	return preConsumedQuota, nil
+}

+ 43 - 47
service/quota.go

@@ -8,11 +8,12 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	"one-api/model"
 	"one-api/model"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
-	"one-api/relay/helper"
 	"one-api/setting"
 	"one-api/setting"
 	"one-api/setting/ratio_setting"
 	"one-api/setting/ratio_setting"
+	"one-api/types"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
@@ -129,23 +130,23 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
 	quota := calculateAudioQuota(quotaInfo)
 	quota := calculateAudioQuota(quotaInfo)
 
 
 	if userQuota < quota {
 	if userQuota < quota {
-		return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
+		return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota))
 	}
 	}
 
 
 	if !token.UnlimitedQuota && token.RemainQuota < quota {
 	if !token.UnlimitedQuota && token.RemainQuota < quota {
-		return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+		return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
 	}
 	}
 
 
 	err = PostConsumeQuota(relayInfo, quota, 0, false)
 	err = PostConsumeQuota(relayInfo, quota, 0, false)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
+	logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
 	return nil
 	return nil
 }
 }
 
 
 func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
 func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
-	usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+	usage *dto.RealtimeUsage, extraContent string) {
 
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	textInputTokens := usage.InputTokenDetails.TextTokens
 	textInputTokens := usage.InputTokenDetails.TextTokens
@@ -159,10 +160,10 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
 	audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
 	audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
 	audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
 
 
-	modelRatio := priceData.ModelRatio
-	groupRatio := priceData.GroupRatioInfo.GroupRatio
-	modelPrice := priceData.ModelPrice
-	usePrice := priceData.UsePrice
+	modelRatio := relayInfo.PriceData.ModelRatio
+	groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+	modelPrice := relayInfo.PriceData.ModelPrice
+	usePrice := relayInfo.PriceData.UsePrice
 
 
 	quotaInfo := QuotaInfo{
 	quotaInfo := QuotaInfo{
 		InputDetails: TokenDetails{
 		InputDetails: TokenDetails{
@@ -196,8 +197,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 		// we cannot just return, because we may have to return the pre-consumed quota
 		// we cannot just return, because we may have to return the pre-consumed quota
 		quota = 0
 		quota = 0
 		logContent += fmt.Sprintf("(可能是上游超时)")
 		logContent += fmt.Sprintf("(可能是上游超时)")
-		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
-			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+		logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
 	} else {
 	} else {
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
@@ -208,7 +209,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 		logContent += ", " + extraContent
 		logContent += ", " + extraContent
 	}
 	}
 	other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
 	other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
-		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
 	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
 		ChannelId:        relayInfo.ChannelId,
 		ChannelId:        relayInfo.ChannelId,
 		PromptTokens:     usage.InputTokens,
 		PromptTokens:     usage.InputTokens,
@@ -218,7 +219,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 		Quota:            quota,
 		Quota:            quota,
 		Content:          logContent,
 		Content:          logContent,
 		TokenId:          relayInfo.TokenId,
 		TokenId:          relayInfo.TokenId,
-		UserQuota:        userQuota,
 		UseTimeSeconds:   int(useTimeSeconds),
 		UseTimeSeconds:   int(useTimeSeconds),
 		IsStream:         relayInfo.IsStream,
 		IsStream:         relayInfo.IsStream,
 		Group:            relayInfo.UsingGroup,
 		Group:            relayInfo.UsingGroup,
@@ -226,8 +226,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	})
 	})
 }
 }
 
 
-func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
-	usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
 
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens
 	promptTokens := usage.PromptTokens
@@ -235,20 +234,20 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	modelName := relayInfo.OriginModelName
 	modelName := relayInfo.OriginModelName
 
 
 	tokenName := ctx.GetString("token_name")
 	tokenName := ctx.GetString("token_name")
-	completionRatio := priceData.CompletionRatio
-	modelRatio := priceData.ModelRatio
-	groupRatio := priceData.GroupRatioInfo.GroupRatio
-	modelPrice := priceData.ModelPrice
-	cacheRatio := priceData.CacheRatio
+	completionRatio := relayInfo.PriceData.CompletionRatio
+	modelRatio := relayInfo.PriceData.ModelRatio
+	groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+	modelPrice := relayInfo.PriceData.ModelPrice
+	cacheRatio := relayInfo.PriceData.CacheRatio
 	cacheTokens := usage.PromptTokensDetails.CachedTokens
 	cacheTokens := usage.PromptTokensDetails.CachedTokens
 
 
-	cacheCreationRatio := priceData.CacheCreationRatio
+	cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
 	cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
 	cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
 
 
 	if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
 	if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
 		promptTokens -= cacheTokens
 		promptTokens -= cacheTokens
-		if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
-			maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
+		if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 {
+			maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
 			if promptTokens >= maybeCacheCreationTokens {
 			if promptTokens >= maybeCacheCreationTokens {
 				cacheCreationTokens = maybeCacheCreationTokens
 				cacheCreationTokens = maybeCacheCreationTokens
 			}
 			}
@@ -257,7 +256,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	}
 	}
 
 
 	calculateQuota := 0.0
 	calculateQuota := 0.0
-	if !priceData.UsePrice {
+	if !relayInfo.PriceData.UsePrice {
 		calculateQuota = float64(promptTokens)
 		calculateQuota = float64(promptTokens)
 		calculateQuota += float64(cacheTokens) * cacheRatio
 		calculateQuota += float64(cacheTokens) * cacheRatio
 		calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
 		calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
@@ -282,23 +281,23 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		// we cannot just return, because we may have to return the pre-consumed quota
 		// we cannot just return, because we may have to return the pre-consumed quota
 		quota = 0
 		quota = 0
 		logContent += fmt.Sprintf("(可能是上游出错)")
 		logContent += fmt.Sprintf("(可能是上游出错)")
-		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
-			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+		logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
 	} else {
 	} else {
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 	}
 
 
-	quotaDelta := quota - preConsumedQuota
+	quotaDelta := quota - relayInfo.FinalPreConsumedQuota
 	if quotaDelta != 0 {
 	if quotaDelta != 0 {
-		err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+		err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
 		if err != nil {
 		if err != nil {
-			common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+			logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
 		}
 		}
 	}
 	}
 
 
 	other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
 	other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
-		cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+		cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
 	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
 		ChannelId:        relayInfo.ChannelId,
 		ChannelId:        relayInfo.ChannelId,
 		PromptTokens:     promptTokens,
 		PromptTokens:     promptTokens,
@@ -308,7 +307,6 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		Quota:            quota,
 		Quota:            quota,
 		Content:          logContent,
 		Content:          logContent,
 		TokenId:          relayInfo.TokenId,
 		TokenId:          relayInfo.TokenId,
-		UserQuota:        userQuota,
 		UseTimeSeconds:   int(useTimeSeconds),
 		UseTimeSeconds:   int(useTimeSeconds),
 		IsStream:         relayInfo.IsStream,
 		IsStream:         relayInfo.IsStream,
 		Group:            relayInfo.UsingGroup,
 		Group:            relayInfo.UsingGroup,
@@ -317,7 +315,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 
 
 }
 }
 
 
-func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
+func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
 	if priceData.CacheCreationRatio == 1 {
 	if priceData.CacheCreationRatio == 1 {
 		return 0
 		return 0
 	}
 	}
@@ -338,8 +336,7 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData
 		(promptCacheCreatePrice - quotaPrice)))
 		(promptCacheCreatePrice - quotaPrice)))
 }
 }
 
 
-func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
-	usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
 
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	textInputTokens := usage.PromptTokensDetails.TextTokens
 	textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -353,10 +350,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
 	audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
 	audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
 	audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
 
 
-	modelRatio := priceData.ModelRatio
-	groupRatio := priceData.GroupRatioInfo.GroupRatio
-	modelPrice := priceData.ModelPrice
-	usePrice := priceData.UsePrice
+	modelRatio := relayInfo.PriceData.ModelRatio
+	groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+	modelPrice := relayInfo.PriceData.ModelPrice
+	usePrice := relayInfo.PriceData.UsePrice
 
 
 	quotaInfo := QuotaInfo{
 	quotaInfo := QuotaInfo{
 		InputDetails: TokenDetails{
 		InputDetails: TokenDetails{
@@ -390,18 +387,18 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		// we cannot just return, because we may have to return the pre-consumed quota
 		// we cannot just return, because we may have to return the pre-consumed quota
 		quota = 0
 		quota = 0
 		logContent += fmt.Sprintf("(可能是上游超时)")
 		logContent += fmt.Sprintf("(可能是上游超时)")
-		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
-			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
+		logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+			"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota))
 	} else {
 	} else {
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 		model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
 	}
 	}
 
 
-	quotaDelta := quota - preConsumedQuota
+	quotaDelta := quota - relayInfo.FinalPreConsumedQuota
 	if quotaDelta != 0 {
 	if quotaDelta != 0 {
-		err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+		err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
 		if err != nil {
 		if err != nil {
-			common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+			logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
 		}
 		}
 	}
 	}
 
 
@@ -410,7 +407,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		logContent += ", " + extraContent
 		logContent += ", " + extraContent
 	}
 	}
 	other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
 	other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
-		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+		completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
 	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
 	model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
 		ChannelId:        relayInfo.ChannelId,
 		ChannelId:        relayInfo.ChannelId,
 		PromptTokens:     usage.PromptTokens,
 		PromptTokens:     usage.PromptTokens,
@@ -420,7 +417,6 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		Quota:            quota,
 		Quota:            quota,
 		Content:          logContent,
 		Content:          logContent,
 		TokenId:          relayInfo.TokenId,
 		TokenId:          relayInfo.TokenId,
-		UserQuota:        userQuota,
 		UseTimeSeconds:   int(useTimeSeconds),
 		UseTimeSeconds:   int(useTimeSeconds),
 		IsStream:         relayInfo.IsStream,
 		IsStream:         relayInfo.IsStream,
 		Group:            relayInfo.UsingGroup,
 		Group:            relayInfo.UsingGroup,
@@ -443,7 +439,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
 		return err
 		return err
 	}
 	}
 	if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
 	if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
-		return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+		return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
 	}
 	}
 	err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
 	err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
 	if err != nil {
 	if err != nil {
@@ -501,7 +497,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
 			prompt := "您的额度即将用尽"
 			prompt := "您的额度即将用尽"
 			topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
 			topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
 			content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
 			content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
-			err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
+			err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
 			if err != nil {
 			if err != nil {
 				common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
 				common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
 			}
 			}

+ 251 - 123
service/token_counter.go

@@ -4,18 +4,22 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"github.com/tiktoken-go/tokenizer"
-	"github.com/tiktoken-go/tokenizer/codec"
 	"image"
 	"image"
 	"log"
 	"log"
 	"math"
 	"math"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"one-api/dto"
 	"one-api/dto"
+	"one-api/logger"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/types"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"unicode/utf8"
 	"unicode/utf8"
+
+	"github.com/gin-gonic/gin"
+	"github.com/tiktoken-go/tokenizer"
+	"github.com/tiktoken-go/tokenizer/codec"
 )
 )
 
 
 // tokenEncoderMap won't grow after initialization
 // tokenEncoderMap won't grow after initialization
@@ -28,9 +32,9 @@ var tokenEncoderMap = make(map[string]tokenizer.Codec)
 var tokenEncoderMutex sync.RWMutex
 var tokenEncoderMutex sync.RWMutex
 
 
 func InitTokenEncoders() {
 func InitTokenEncoders() {
-	common.SysLog("initializing token encoders")
+	logger.SysLog("initializing token encoders")
 	defaultTokenEncoder = codec.NewCl100kBase()
 	defaultTokenEncoder = codec.NewCl100kBase()
-	common.SysLog("token encoders initialized")
+	logger.SysLog("token encoders initialized")
 }
 }
 
 
 func getTokenEncoder(model string) tokenizer.Codec {
 func getTokenEncoder(model string) tokenizer.Codec {
@@ -72,52 +76,95 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
 	return tkm
 	return tkm
 }
 }
 
 
-func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
-	if imageUrl == nil {
+func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
+	if fileMeta == nil {
 		return 0, fmt.Errorf("image_url_is_nil")
 		return 0, fmt.Errorf("image_url_is_nil")
 	}
 	}
+
+	// Defaults for 4o/4.1/4.5 family unless overridden below
 	baseTokens := 85
 	baseTokens := 85
-	if model == "glm-4v" {
+	tileTokens := 170
+
+	// Model classification
+	lowerModel := strings.ToLower(model)
+
+	// Special cases from existing behavior
+	if strings.HasPrefix(lowerModel, "glm-4") {
 		return 1047, nil
 		return 1047, nil
 	}
 	}
-	if imageUrl.Detail == "low" {
+
+	// Patch-based models (32x32 patches, capped at 1536, with multiplier)
+	isPatchBased := false
+	multiplier := 1.0
+	switch {
+	case strings.Contains(lowerModel, "gpt-4.1-mini"):
+		isPatchBased = true
+		multiplier = 1.62
+	case strings.Contains(lowerModel, "gpt-4.1-nano"):
+		isPatchBased = true
+		multiplier = 2.46
+	case strings.HasPrefix(lowerModel, "o4-mini"):
+		isPatchBased = true
+		multiplier = 1.72
+	case strings.HasPrefix(lowerModel, "gpt-5-mini"):
+		isPatchBased = true
+		multiplier = 1.62
+	case strings.HasPrefix(lowerModel, "gpt-5-nano"):
+		isPatchBased = true
+		multiplier = 2.46
+	}
+
+	// Tile-based model tokens and bases per doc
+	if !isPatchBased {
+		if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
+			baseTokens = 2833
+			tileTokens = 5667
+		} else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
+			baseTokens = 70
+			tileTokens = 140
+		} else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
+			baseTokens = 75
+			tileTokens = 150
+		} else if strings.Contains(lowerModel, "computer-use-preview") {
+			baseTokens = 65
+			tileTokens = 129
+		} else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
+			baseTokens = 85
+			tileTokens = 170
+		}
+	}
+
+	// Respect existing feature flags/short-circuits
+	if fileMeta.Detail == "low" && !isPatchBased {
 		return baseTokens, nil
 		return baseTokens, nil
 	}
 	}
 	if !constant.GetMediaTokenNotStream && !stream {
 	if !constant.GetMediaTokenNotStream && !stream {
 		return 3 * baseTokens, nil
 		return 3 * baseTokens, nil
 	}
 	}
-
-	// 同步One API的图片计费逻辑
-	if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
-		imageUrl.Detail = "high"
-	}
-
-	tileTokens := 170
-	if strings.HasPrefix(model, "gpt-4o-mini") {
-		tileTokens = 5667
-		baseTokens = 2833
+	// Normalize detail
+	if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
+		fileMeta.Detail = "high"
 	}
 	}
-	// 是否统计图片token
+	// Whether to count image tokens at all
 	if !constant.GetMediaToken {
 	if !constant.GetMediaToken {
 		return 3 * baseTokens, nil
 		return 3 * baseTokens, nil
 	}
 	}
-	if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
-		return 3 * baseTokens, nil
-	}
+
+	// Decode image to get dimensions
 	var config image.Config
 	var config image.Config
 	var err error
 	var err error
 	var format string
 	var format string
 	var b64str string
 	var b64str string
-	if strings.HasPrefix(imageUrl.Url, "http") {
-		config, format, err = DecodeUrlImageData(imageUrl.Url)
+	if strings.HasPrefix(fileMeta.Data, "http") {
+		config, format, err = DecodeUrlImageData(fileMeta.Data)
 	} else {
 	} else {
-		common.SysLog(fmt.Sprintf("decoding image"))
-		config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
+		logger.SysLog(fmt.Sprintf("decoding image"))
+		config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
 	}
 	}
 	if err != nil {
 	if err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
-	imageUrl.MimeType = format
+	fileMeta.MimeType = format
 
 
 	if config.Width == 0 || config.Height == 0 {
 	if config.Width == 0 || config.Height == 0 {
 		// not an image
 		// not an image
@@ -125,60 +172,144 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
 			// file type
 			// file type
 			return 3 * baseTokens, nil
 			return 3 * baseTokens, nil
 		}
 		}
-		return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
-	}
-
-	shortSide := config.Width
-	otherSide := config.Height
-	log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
-	// 缩放倍数
-	scale := 1.0
-	if config.Height < shortSide {
-		shortSide = config.Height
-		otherSide = config.Width
-	}
-
-	// 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
-	if shortSide > 768 {
-		scale = float64(shortSide) / 768
-		shortSide = 768
-	}
-	// 将另一边按照相同的比例缩小,向上取整
-	otherSide = int(math.Ceil(float64(otherSide) / scale))
-	log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
-	// 计算图片的token数量(边的长度除以512,向上取整)
-	tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
-	log.Printf("tiles: %d", tiles)
+		return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data))
+	}
+
+	width := config.Width
+	height := config.Height
+	log.Printf("format: %s, width: %d, height: %d", format, width, height)
+
+	if isPatchBased {
+		// 32x32 patch-based calculation with 1536 cap and model multiplier
+		ceilDiv := func(a, b int) int { return (a + b - 1) / b }
+		rawPatchesW := ceilDiv(width, 32)
+		rawPatchesH := ceilDiv(height, 32)
+		rawPatches := rawPatchesW * rawPatchesH
+		if rawPatches > 1536 {
+			// scale down
+			area := float64(width * height)
+			r := math.Sqrt(float64(32*32*1536) / area)
+			wScaled := float64(width) * r
+			hScaled := float64(height) * r
+			// adjust to fit whole number of patches after scaling
+			adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
+			adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
+			adj := math.Min(adjW, adjH)
+			if !math.IsNaN(adj) && adj > 0 {
+				r = r * adj
+			}
+			wScaled = float64(width) * r
+			hScaled = float64(height) * r
+			patchesW := math.Ceil(wScaled / 32.0)
+			patchesH := math.Ceil(hScaled / 32.0)
+			imageTokens := int(patchesW * patchesH)
+			if imageTokens > 1536 {
+				imageTokens = 1536
+			}
+			return int(math.Round(float64(imageTokens) * multiplier)), nil
+		}
+		// below cap
+		imageTokens := rawPatches
+		return int(math.Round(float64(imageTokens) * multiplier)), nil
+	}
+
+	// Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
+	// Step 1: fit within 2048x2048 square
+	maxSide := math.Max(float64(width), float64(height))
+	fitScale := 1.0
+	if maxSide > 2048 {
+		fitScale = maxSide / 2048.0
+	}
+	fitW := int(math.Round(float64(width) / fitScale))
+	fitH := int(math.Round(float64(height) / fitScale))
+
+	// Step 2: scale so that shortest side is exactly 768
+	minSide := math.Min(float64(fitW), float64(fitH))
+	if minSide == 0 {
+		return baseTokens, nil
+	}
+	shortScale := 768.0 / minSide
+	finalW := int(math.Round(float64(fitW) * shortScale))
+	finalH := int(math.Round(float64(fitH) * shortScale))
+
+	// Count 512px tiles
+	tilesW := (finalW + 512 - 1) / 512
+	tilesH := (finalH + 512 - 1) / 512
+	tiles := tilesW * tilesH
+
+	if common.DebugEnabled {
+		log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
+	}
+
 	return tiles*tileTokens + baseTokens, nil
 	return tiles*tileTokens + baseTokens, nil
 }
 }
 
 
-func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
-	tkm := 0
-	msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
-	if err != nil {
-		return 0, err
-	}
-	tkm += msgTokens
-	if request.Tools != nil {
-		openaiTools := request.Tools
-		countStr := ""
-		for _, tool := range openaiTools {
-			countStr = tool.Function.Name
-			if tool.Function.Description != "" {
-				countStr += tool.Function.Description
-			}
-			if tool.Function.Parameters != nil {
-				countStr += fmt.Sprintf("%v", tool.Function.Parameters)
+func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
+	if meta == nil {
+		return 0, errors.New("token count meta is nil")
+	}
+	model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
+	tkm := CountTextToken(meta.CombineText, model)
+
+	if info.RelayFormat == types.RelayFormatOpenAI {
+		tkm += meta.ToolsCount * 8
+		tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
+		tkm += meta.NameCount * 3
+		tkm += 3
+	}
+
+	for _, file := range meta.Files {
+		switch file.FileType {
+		case types.FileTypeImage:
+			if info.RelayFormat == types.RelayFormatGemini {
+				tkm += 240
+			} else {
+				token, err := getImageToken(file, model, info.IsStream)
+				if err != nil {
+					return 0, fmt.Errorf("error counting image token: %v", err)
+				}
+				tkm += token
 			}
 			}
+		case types.FileTypeAudio:
+			tkm += 100
+		case types.FileTypeVideo:
+			tkm += 5000
+		case types.FileTypeFile:
+			tkm += 5000
 		}
 		}
-		toolTokens := CountTokenInput(countStr, request.Model)
-		tkm += 8
-		tkm += toolTokens
 	}
 	}
 
 
+	common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
 	return tkm, nil
 	return tkm, nil
 }
 }
 
 
+//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
+//	tkm := 0
+//	msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
+//	if err != nil {
+//		return 0, err
+//	}
+//	tkm += msgTokens
+//	if request.Tools != nil {
+//		openaiTools := request.Tools
+//		countStr := ""
+//		for _, tool := range openaiTools {
+//			countStr = tool.Function.Name
+//			if tool.Function.Description != "" {
+//				countStr += tool.Function.Description
+//			}
+//			if tool.Function.Parameters != nil {
+//				countStr += fmt.Sprintf("%v", tool.Function.Parameters)
+//			}
+//		}
+//		toolTokens := CountTokenInput(countStr, request.Model)
+//		tkm += 8
+//		tkm += toolTokens
+//	}
+//
+//	return tkm, nil
+//}
+
 func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
 func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
 	tkm := 0
 	tkm := 0
 
 
@@ -338,58 +469,55 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
 	return textToken, audioToken, nil
 	return textToken, audioToken, nil
 }
 }
 
 
-func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
-	//recover when panic
-	tokenEncoder := getTokenEncoder(model)
-	// Reference:
-	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
-	// https://github.com/pkoukk/tiktoken-go/issues/6
-	//
-	// Every message follows <|start|>{role/name}\n{content}<|end|>\n
-	var tokensPerMessage int
-	var tokensPerName int
-	if model == "gpt-3.5-turbo-0301" {
-		tokensPerMessage = 4
-		tokensPerName = -1 // If there's a name, the role is omitted
-	} else {
-		tokensPerMessage = 3
-		tokensPerName = 1
-	}
-	tokenNum := 0
-	for _, message := range messages {
-		tokenNum += tokensPerMessage
-		tokenNum += getTokenNum(tokenEncoder, message.Role)
-		if message.Content != nil {
-			if message.Name != nil {
-				tokenNum += tokensPerName
-				tokenNum += getTokenNum(tokenEncoder, *message.Name)
-			}
-			arrayContent := message.ParseContent()
-			for _, m := range arrayContent {
-				if m.Type == dto.ContentTypeImageURL {
-					imageUrl := m.GetImageMedia()
-					imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
-					if err != nil {
-						return 0, err
-					}
-					tokenNum += imageTokenNum
-					log.Printf("image token num: %d", imageTokenNum)
-				} else if m.Type == dto.ContentTypeInputAudio {
-					// TODO: 音频token数量计算
-					tokenNum += 100
-				} else if m.Type == dto.ContentTypeFile {
-					tokenNum += 5000
-				} else if m.Type == dto.ContentTypeVideoUrl {
-					tokenNum += 5000
-				} else {
-					tokenNum += getTokenNum(tokenEncoder, m.Text)
-				}
-			}
-		}
-	}
-	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
-	return tokenNum, nil
-}
+//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
+//	//recover when panic
+//	tokenEncoder := getTokenEncoder(model)
+//	// Reference:
+//	// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+//	// https://github.com/pkoukk/tiktoken-go/issues/6
+//	//
+//	// Every message follows <|start|>{role/name}\n{content}<|end|>\n
+//	var tokensPerMessage int
+//	var tokensPerName int
+//
+//	tokensPerMessage = 3
+//	tokensPerName = 1
+//
+//	tokenNum := 0
+//	for _, message := range messages {
+//		tokenNum += tokensPerMessage
+//		tokenNum += getTokenNum(tokenEncoder, message.Role)
+//		if message.Content != nil {
+//			if message.Name != nil {
+//				tokenNum += tokensPerName
+//				tokenNum += getTokenNum(tokenEncoder, *message.Name)
+//			}
+//			arrayContent := message.ParseContent()
+//			for _, m := range arrayContent {
+//				if m.Type == dto.ContentTypeImageURL {
+//					imageUrl := m.GetImageMedia()
+//					imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
+//					if err != nil {
+//						return 0, err
+//					}
+//					tokenNum += imageTokenNum
+//					log.Printf("image token num: %d", imageTokenNum)
+//				} else if m.Type == dto.ContentTypeInputAudio {
+//					// TODO: 音频token数量计算
+//					tokenNum += 100
+//				} else if m.Type == dto.ContentTypeFile {
+//					tokenNum += 5000
+//				} else if m.Type == dto.ContentTypeVideoUrl {
+//					tokenNum += 5000
+//				} else {
+//					tokenNum += getTokenNum(tokenEncoder, m.Text)
+//				}
+//			}
+//		}
+//	}
+//	tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
+//	return tokenNum, nil
+//}
 
 
 func CountTokenInput(input any, model string) int {
 func CountTokenInput(input any, model string) int {
 	switch v := input.(type) {
 	switch v := input.(type) {

Неке датотеке нису приказане због велике количине промена