Просмотр исходного кода

Merge branch 'main' into feat-copy-models

JoeyLearnsToCode 8 месяцев назад
Родитель
Сommit
ffc22b8dac
100 измененных файлов с 3893 добавлено и 756 удалено
  1. 14 6
      .github/workflows/docker-image-alpha.yml
  2. 1 6
      .github/workflows/docker-image-arm64.yml
  3. 9 4
      .github/workflows/linux-release.yml
  4. 9 4
      .github/workflows/macos-release.yml
  5. 9 4
      .github/workflows/windows-release.yml
  6. 1 2
      Dockerfile
  7. 4 0
      README.en.md
  8. 7 1
      README.md
  9. 2 0
      common/constants.go
  10. 7 0
      common/database.go
  11. 8 4
      common/redis.go
  12. 45 2
      common/utils.go
  13. 4 6
      constant/cache_key.go
  14. 1 0
      constant/task.go
  15. 1 0
      constant/user_setting.go
  16. 38 0
      controller/channel-billing.go
  17. 16 8
      controller/channel-test.go
  18. 97 16
      controller/channel.go
  19. 103 0
      controller/console_migrate.go
  20. 11 3
      controller/group.go
  21. 34 20
      controller/midjourney.go
  22. 72 42
      controller/misc.go
  23. 20 2
      controller/model.go
  24. 39 2
      controller/option.go
  25. 13 3
      controller/playground.go
  26. 13 6
      controller/pricing.go
  27. 24 0
      controller/ratio_config.go
  28. 322 0
      controller/ratio_sync.go
  29. 39 3
      controller/redemption.go
  30. 5 3
      controller/relay.go
  31. 8 0
      controller/setup.go
  32. 32 14
      controller/task.go
  33. 140 0
      controller/task_video.go
  34. 12 4
      controller/token.go
  35. 7 9
      controller/topup.go
  36. 154 0
      controller/uptime_kuma.go
  37. 8 0
      controller/user.go
  38. 87 48
      dto/claude.go
  39. 2 0
      dto/dalle.go
  40. 280 50
      dto/openai_request.go
  41. 49 0
      dto/ratio_sync.go
  42. 47 0
      dto/video.go
  43. 3 3
      go.mod
  44. 4 4
      go.sum
  45. 21 6
      main.go
  46. 1 1
      makefile
  47. 15 2
      middleware/auth.go
  48. 58 5
      middleware/distributor.go
  49. 41 0
      middleware/stats.go
  50. 55 25
      model/ability.go
  51. 45 3
      model/cache.go
  52. 71 10
      model/channel.go
  53. 46 9
      model/log.go
  54. 116 54
      model/main.go
  55. 37 0
      model/midjourney.go
  56. 27 11
      model/option.go
  57. 4 4
      model/pricing.go
  58. 11 1
      model/redemption.go
  59. 61 0
      model/task.go
  60. 9 2
      model/token.go
  61. 2 2
      model/token_cache.go
  62. 5 3
      model/user.go
  63. 4 3
      model/user_cache.go
  64. 19 2
      model/utils.go
  65. 3 8
      relay/audio_handler.go
  66. 2 0
      relay/channel/adapter.go
  67. 11 1
      relay/channel/ali/adaptor.go
  68. 1 0
      relay/channel/ali/constants.go
  69. 27 0
      relay/channel/ali/dto.go
  70. 83 0
      relay/channel/ali/rerank.go
  71. 10 9
      relay/channel/ali/text.go
  72. 115 48
      relay/channel/api_request.go
  73. 12 0
      relay/channel/aws/constants.go
  74. 1 2
      relay/channel/baidu/relay-baidu.go
  75. 13 0
      relay/channel/baidu_v2/adaptor.go
  76. 3 3
      relay/channel/claude/adaptor.go
  77. 4 0
      relay/channel/claude/constants.go
  78. 41 48
      relay/channel/claude/relay-claude.go
  79. 3 3
      relay/channel/cloudflare/relay_cloudflare.go
  80. 3 5
      relay/channel/cohere/relay-cohere.go
  81. 1 1
      relay/channel/coze/dto.go
  82. 5 7
      relay/channel/coze/relay-coze.go
  83. 2 10
      relay/channel/dify/relay-dify.go
  84. 15 3
      relay/channel/gemini/adaptor.go
  85. 70 14
      relay/channel/gemini/dto.go
  86. 146 0
      relay/channel/gemini/relay-gemini-native.go
  87. 270 109
      relay/channel/gemini/relay-gemini.go
  88. 42 0
      relay/channel/mistral/text.go
  89. 7 0
      relay/channel/openai/adaptor.go
  90. 21 30
      relay/channel/openai/relay-openai.go
  91. 1 1
      relay/channel/openai/relay_responses.go
  92. 9 0
      relay/channel/openrouter/dto.go
  93. 1 1
      relay/channel/palm/adaptor.go
  94. 3 5
      relay/channel/palm/relay-palm.go
  95. 312 0
      relay/channel/task/kling/adaptor.go
  96. 4 0
      relay/channel/task/suno/adaptor.go
  97. 1 1
      relay/channel/tencent/adaptor.go
  98. 1 2
      relay/channel/tencent/relay-tencent.go
  99. 53 21
      relay/channel/vertex/adaptor.go
  100. 148 2
      relay/channel/volcengine/adaptor.go

+ 14 - 6
.github/workflows/docker-image-amd64.yml → .github/workflows/docker-image-alpha.yml

@@ -1,14 +1,15 @@
-name: Publish Docker image (amd64)
+name: Publish Docker image (alpha)
 
 
 on:
 on:
   push:
   push:
-    tags:
-      - '*'
+    branches:
+      - alpha
   workflow_dispatch:
   workflow_dispatch:
     inputs:
     inputs:
       name:
       name:
-        description: 'reason'
+        description: "reason"
         required: false
         required: false
+
 jobs:
 jobs:
   push_to_registries:
   push_to_registries:
     name: Push Docker image to multiple registries
     name: Push Docker image to multiple registries
@@ -22,7 +23,7 @@ jobs:
 
 
       - name: Save version info
       - name: Save version info
         run: |
         run: |
-          git describe --tags > VERSION 
+          echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
 
 
       - name: Log in to Docker Hub
       - name: Log in to Docker Hub
         uses: docker/login-action@v3
         uses: docker/login-action@v3
@@ -37,6 +38,9 @@ jobs:
           username: ${{ github.actor }}
           username: ${{ github.actor }}
           password: ${{ secrets.GITHUB_TOKEN }}
           password: ${{ secrets.GITHUB_TOKEN }}
 
 
+      - name: Set up Docker Buildx
+        uses: docker/setup-buildx-action@v3
+
       - name: Extract metadata (tags, labels) for Docker
       - name: Extract metadata (tags, labels) for Docker
         id: meta
         id: meta
         uses: docker/metadata-action@v5
         uses: docker/metadata-action@v5
@@ -44,11 +48,15 @@ jobs:
           images: |
           images: |
             calciumion/new-api
             calciumion/new-api
             ghcr.io/${{ github.repository }}
             ghcr.io/${{ github.repository }}
+          tags: |
+            type=raw,value=alpha
+            type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
 
 
       - name: Build and push Docker images
       - name: Build and push Docker images
         uses: docker/build-push-action@v5
         uses: docker/build-push-action@v5
         with:
         with:
           context: .
           context: .
+          platforms: linux/amd64,linux/arm64
           push: true
           push: true
           tags: ${{ steps.meta.outputs.tags }}
           tags: ${{ steps.meta.outputs.tags }}
-          labels: ${{ steps.meta.outputs.labels }}
+          labels: ${{ steps.meta.outputs.labels }}

+ 1 - 6
.github/workflows/docker-image-arm64.yml

@@ -1,14 +1,9 @@
-name: Publish Docker image (arm64)
+name: Publish Docker image (Multi Registries)
 
 
 on:
 on:
   push:
   push:
     tags:
     tags:
       - '*'
       - '*'
-  workflow_dispatch:
-    inputs:
-      name:
-        description: 'reason'
-        required: false
 jobs:
 jobs:
   push_to_registries:
   push_to_registries:
     name: Push Docker image to multiple registries
     name: Push Docker image to multiple registries

+ 9 - 4
.github/workflows/linux-release.yml

@@ -3,6 +3,11 @@ permissions:
   contents: write
   contents: write
 
 
 on:
 on:
+  workflow_dispatch:
+    inputs:
+      name:
+        description: 'reason'
+        required: false
   push:
   push:
     tags:
     tags:
       - '*'
       - '*'
@@ -15,16 +20,16 @@ jobs:
         uses: actions/checkout@v3
         uses: actions/checkout@v3
         with:
         with:
           fetch-depth: 0
           fetch-depth: 0
-      - uses: actions/setup-node@v3
+      - uses: oven-sh/setup-bun@v2
         with:
         with:
-          node-version: 18
+          bun-version: latest
       - name: Build Frontend
       - name: Build Frontend
         env:
         env:
           CI: ""
           CI: ""
         run: |
         run: |
           cd web
           cd web
-          npm install
-          REACT_APP_VERSION=$(git describe --tags) npm run build
+          bun install
+          DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
           cd ..
           cd ..
       - name: Set up Go
       - name: Set up Go
         uses: actions/setup-go@v3
         uses: actions/setup-go@v3

+ 9 - 4
.github/workflows/macos-release.yml

@@ -3,6 +3,11 @@ permissions:
   contents: write
   contents: write
 
 
 on:
 on:
+  workflow_dispatch:
+    inputs:
+      name:
+        description: 'reason'
+        required: false
   push:
   push:
     tags:
     tags:
       - '*'
       - '*'
@@ -15,16 +20,16 @@ jobs:
         uses: actions/checkout@v3
         uses: actions/checkout@v3
         with:
         with:
           fetch-depth: 0
           fetch-depth: 0
-      - uses: actions/setup-node@v3
+      - uses: oven-sh/setup-bun@v2
         with:
         with:
-          node-version: 18
+          bun-version: latest
       - name: Build Frontend
       - name: Build Frontend
         env:
         env:
           CI: ""
           CI: ""
         run: |
         run: |
           cd web
           cd web
-          npm install
-          REACT_APP_VERSION=$(git describe --tags) npm run build
+          bun install
+          DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
           cd ..
           cd ..
       - name: Set up Go
       - name: Set up Go
         uses: actions/setup-go@v3
         uses: actions/setup-go@v3

+ 9 - 4
.github/workflows/windows-release.yml

@@ -3,6 +3,11 @@ permissions:
   contents: write
   contents: write
 
 
 on:
 on:
+  workflow_dispatch:
+    inputs:
+      name:
+        description: 'reason'
+        required: false
   push:
   push:
     tags:
     tags:
       - '*'
       - '*'
@@ -18,16 +23,16 @@ jobs:
         uses: actions/checkout@v3
         uses: actions/checkout@v3
         with:
         with:
           fetch-depth: 0
           fetch-depth: 0
-      - uses: actions/setup-node@v3
+      - uses: oven-sh/setup-bun@v2
         with:
         with:
-          node-version: 18
+          bun-version: latest
       - name: Build Frontend
       - name: Build Frontend
         env:
         env:
           CI: ""
           CI: ""
         run: |
         run: |
           cd web
           cd web
-          npm install
-          REACT_APP_VERSION=$(git describe --tags) npm run build
+          bun install
+          DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
           cd ..
           cd ..
       - name: Set up Go
       - name: Set up Go
         uses: actions/setup-go@v3
         uses: actions/setup-go@v3

+ 1 - 2
Dockerfile

@@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
 
 
 FROM alpine
 FROM alpine
 
 
-RUN apk update \
-    && apk upgrade \
+RUN apk upgrade --no-cache \
     && apk add --no-cache ca-certificates tzdata ffmpeg \
     && apk add --no-cache ca-certificates tzdata ffmpeg \
     && update-ca-certificates
     && update-ca-certificates
 
 

+ 4 - 0
README.en.md

@@ -44,6 +44,9 @@
 
 
 For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
 For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
 
 
+You can also access the AI-generated DeepWiki:
+[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api)
+
 ## ✨ Key Features
 ## ✨ Key Features
 
 
 New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details:
 New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details:
@@ -110,6 +113,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env
 - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
 - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
 - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
 - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
 - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`
 - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`
+- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false`
 
 
 ## Deployment
 ## Deployment
 
 

+ 7 - 1
README.md

@@ -27,6 +27,9 @@
   <a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
   <a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
     <img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
     <img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
   </a>
   </a>
+  <a href="https://coderabbit.ai">
+    <img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
+  </a>
 </p>
 </p>
 </div>
 </div>
 
 
@@ -44,6 +47,9 @@
 
 
 详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
 详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
 
 
+也可访问AI生成的DeepWiki:
+[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/QuantumNous/new-api)
+
 ## ✨ 主要特性
 ## ✨ 主要特性
 
 
 New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction):
 New API提供了丰富的功能,详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction):
@@ -110,6 +116,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
 - `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview`
 - `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview`
 - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
 - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
+- `ERROR_LOG_ENABLED=true`: 是否记录并显示错误日志,默认`false`
 
 
 ## 部署
 ## 部署
 
 
@@ -176,7 +183,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
 
 
 其他基于New API的项目:
 其他基于New API的项目:
 - [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
 - [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
-- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
 
 
 ## 帮助支持
 ## 帮助支持
 
 

+ 2 - 0
common/constants.go

@@ -241,6 +241,7 @@ const (
 	ChannelTypeXinference     = 47
 	ChannelTypeXinference     = 47
 	ChannelTypeXai            = 48
 	ChannelTypeXai            = 48
 	ChannelTypeCoze           = 49
 	ChannelTypeCoze           = 49
+	ChannelTypeKling          = 50
 	ChannelTypeDummy          // this one is only for count, do not add any channel after this
 	ChannelTypeDummy          // this one is only for count, do not add any channel after this
 
 
 )
 )
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
 	"",                                          //47
 	"",                                          //47
 	"https://api.x.ai",                          //48
 	"https://api.x.ai",                          //48
 	"https://api.coze.cn",                       //49
 	"https://api.coze.cn",                       //49
+	"https://api.klingai.com",                   //50
 }
 }

+ 7 - 0
common/database.go

@@ -1,7 +1,14 @@
 package common
 package common
 
 
+const (
+	DatabaseTypeMySQL      = "mysql"
+	DatabaseTypeSQLite     = "sqlite"
+	DatabaseTypePostgreSQL = "postgres"
+)
+
 var UsingSQLite = false
 var UsingSQLite = false
 var UsingPostgreSQL = false
 var UsingPostgreSQL = false
+var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
 var UsingMySQL = false
 var UsingMySQL = false
 var UsingClickHouse = false
 var UsingClickHouse = false
 
 

+ 8 - 4
common/redis.go

@@ -92,12 +92,12 @@ func RedisDel(key string) error {
 	return RDB.Del(ctx, key).Err()
 	return RDB.Del(ctx, key).Err()
 }
 }
 
 
-func RedisHDelObj(key string) error {
+func RedisDelKey(key string) error {
 	if DebugEnabled {
 	if DebugEnabled {
-		SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
+		SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
 	}
 	}
 	ctx := context.Background()
 	ctx := context.Background()
-	return RDB.HDel(ctx, key).Err()
+	return RDB.Del(ctx, key).Err()
 }
 }
 
 
 func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
 func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
@@ -141,7 +141,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
 
 
 	txn := RDB.TxPipeline()
 	txn := RDB.TxPipeline()
 	txn.HSet(ctx, key, data)
 	txn.HSet(ctx, key, data)
-	txn.Expire(ctx, key, expiration)
+
+	// 只有在 expiration 大于 0 时才设置过期时间
+	if expiration > 0 {
+		txn.Expire(ctx, key, expiration)
+	}
 
 
 	_, err := txn.Exec(ctx)
 	_, err := txn.Exec(ctx)
 	if err != nil {
 	if err != nil {

+ 45 - 2
common/utils.go

@@ -13,6 +13,7 @@ import (
 	"math/big"
 	"math/big"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
+	"net/url"
 	"os"
 	"os"
 	"os/exec"
 	"os/exec"
 	"runtime"
 	"runtime"
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
 }
 }
 
 
 // GetAudioDuration returns the duration of an audio file in seconds.
 // GetAudioDuration returns the duration of an audio file in seconds.
-func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
+func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
 	// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
 	// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
 	c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
 	c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
 	output, err := c.Output()
 	output, err := c.Output()
 	if err != nil {
 	if err != nil {
 		return 0, errors.Wrap(err, "failed to get audio duration")
 		return 0, errors.Wrap(err, "failed to get audio duration")
 	}
 	}
+  durationStr := string(bytes.TrimSpace(output))
+  if durationStr == "N/A" {
+    // Create a temporary output file name
+    tmpFp, err := os.CreateTemp("", "audio-*"+ext)
+    if err != nil {
+      return 0, errors.Wrap(err, "failed to create temporary file")
+    }
+    tmpName := tmpFp.Name()
+    // Close immediately so ffmpeg can open the file on Windows.
+    _ = tmpFp.Close()
+    defer os.Remove(tmpName)
+
+    // ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
+    ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
+    if err := ffmpegCmd.Run(); err != nil {
+      return 0, errors.Wrap(err, "failed to run ffmpeg")
+    }
+
+    // Recalculate the duration of the new file
+    c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
+    output, err := c.Output()
+    if err != nil {
+      return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
+    }
+    durationStr = string(bytes.TrimSpace(output))
+  }
+	return strconv.ParseFloat(durationStr, 64)
+}
 
 
-	return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
+// BuildURL concatenates base and endpoint, returns the complete url string
+func BuildURL(base string, endpoint string) string {
+	u, err := url.Parse(base)
+	if err != nil {
+		return base + endpoint
+	}
+	end := endpoint
+	if end == "" {
+		end = "/"
+	}
+	ref, err := url.Parse(end)
+	if err != nil {
+		return base + endpoint
+	}
+	return u.ResolveReference(ref).String()
 }
 }

+ 4 - 6
constant/cache_key.go

@@ -2,12 +2,10 @@ package constant
 
 
 import "one-api/common"
 import "one-api/common"
 
 
-var (
-	TokenCacheSeconds         = common.SyncFrequency
-	UserId2GroupCacheSeconds  = common.SyncFrequency
-	UserId2QuotaCacheSeconds  = common.SyncFrequency
-	UserId2StatusCacheSeconds = common.SyncFrequency
-)
+// 使用函数来避免初始化顺序带来的赋值问题
+func RedisKeyCacheSeconds() int {
+	return common.SyncFrequency
+}
 
 
 // Cache keys
 // Cache keys
 const (
 const (

+ 1 - 0
constant/task.go

@@ -5,6 +5,7 @@ type TaskPlatform string
 const (
 const (
 	TaskPlatformSuno       TaskPlatform = "suno"
 	TaskPlatformSuno       TaskPlatform = "suno"
 	TaskPlatformMidjourney              = "mj"
 	TaskPlatformMidjourney              = "mj"
+	TaskPlatformKling      TaskPlatform = "kling"
 )
 )
 
 
 const (
 const (

+ 1 - 0
constant/user_setting.go

@@ -7,6 +7,7 @@ var (
 	UserSettingWebhookSecret         = "webhook_secret"                 // WebhookSecret webhook密钥
 	UserSettingWebhookSecret         = "webhook_secret"                 // WebhookSecret webhook密钥
 	UserSettingNotificationEmail     = "notification_email"             // NotificationEmail 通知邮箱地址
 	UserSettingNotificationEmail     = "notification_email"             // NotificationEmail 通知邮箱地址
 	UserAcceptUnsetRatioModel        = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
 	UserAcceptUnsetRatioModel        = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
+	UserSettingRecordIpLog          = "record_ip_log"                   // 是否记录请求和错误日志IP
 )
 )
 
 
 var (
 var (

+ 38 - 0
controller/channel-billing.go

@@ -4,11 +4,13 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"github.com/shopspring/decimal"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/model"
 	"one-api/model"
 	"one-api/service"
 	"one-api/service"
+	"one-api/setting"
 	"strconv"
 	"strconv"
 	"time"
 	"time"
 
 
@@ -304,6 +306,40 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
 	return balance, nil
 	return balance, nil
 }
 }
 
 
+func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
+	url := "https://api.moonshot.cn/v1/users/me/balance"
+	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+	if err != nil {
+		return 0, err
+	}
+
+	type MoonshotBalanceData struct {
+		AvailableBalance float64 `json:"available_balance"`
+		VoucherBalance   float64 `json:"voucher_balance"`
+		CashBalance      float64 `json:"cash_balance"`
+	}
+
+	type MoonshotBalanceResponse struct {
+		Code   int                 `json:"code"`
+		Data   MoonshotBalanceData `json:"data"`
+		Scode  string              `json:"scode"`
+		Status bool                `json:"status"`
+	}
+
+	response := MoonshotBalanceResponse{}
+	err = json.Unmarshal(body, &response)
+	if err != nil {
+		return 0, err
+	}
+	if !response.Status || response.Code != 0 {
+		return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
+	}
+	availableBalanceCny := response.Data.AvailableBalance
+	availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
+	channel.UpdateBalance(availableBalanceUsd)
+	return availableBalanceUsd, nil
+}
+
 func updateChannelBalance(channel *model.Channel) (float64, error) {
 func updateChannelBalance(channel *model.Channel) (float64, error) {
 	baseURL := common.ChannelBaseURLs[channel.Type]
 	baseURL := common.ChannelBaseURLs[channel.Type]
 	if channel.GetBaseURL() == "" {
 	if channel.GetBaseURL() == "" {
@@ -332,6 +368,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
 		return updateChannelDeepSeekBalance(channel)
 		return updateChannelDeepSeekBalance(channel)
 	case common.ChannelTypeOpenRouter:
 	case common.ChannelTypeOpenRouter:
 		return updateChannelOpenRouterBalance(channel)
 		return updateChannelOpenRouterBalance(channel)
+	case common.ChannelTypeMoonshot:
+		return updateChannelMoonshotBalance(channel)
 	default:
 	default:
 		return 0, errors.New("尚未实现")
 		return 0, errors.New("尚未实现")
 	}
 	}

+ 16 - 8
controller/channel-test.go

@@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	if channel.Type == common.ChannelTypeSunoAPI {
 	if channel.Type == common.ChannelTypeSunoAPI {
 		return errors.New("suno channel test is not supported"), nil
 		return errors.New("suno channel test is not supported"), nil
 	}
 	}
+	if channel.Type == common.ChannelTypeKling {
+		return errors.New("kling channel test is not supported"), nil
+	}
 	w := httptest.NewRecorder()
 	w := httptest.NewRecorder()
 	c, _ := gin.CreateTestContext(w)
 	c, _ := gin.CreateTestContext(w)
 
 
@@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 
 
 	info := relaycommon.GenRelayInfo(c)
 	info := relaycommon.GenRelayInfo(c)
 
 
-	err = helper.ModelMappedHelper(c, info)
+	err = helper.ModelMappedHelper(c, info, nil)
 	if err != nil {
 	if err != nil {
 		return err, nil
 		return err, nil
 	}
 	}
@@ -165,8 +168,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	tok := time.Now()
 	tok := time.Now()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	milliseconds := tok.Sub(tik).Milliseconds()
 	consumedTime := float64(milliseconds) / 1000.0
 	consumedTime := float64(milliseconds) / 1000.0
-	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
-		usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
+	other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+		usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
 	model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
 	model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
 		quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
 		quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
 	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
 	common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
@@ -200,10 +203,10 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
 	} else {
 	} else {
 		testRequest.MaxTokens = 10
 		testRequest.MaxTokens = 10
 	}
 	}
-	content, _ := json.Marshal("hi")
+
 	testMessage := dto.Message{
 	testMessage := dto.Message{
 		Role:    "user",
 		Role:    "user",
-		Content: content,
+		Content: "hi",
 	}
 	}
 	testRequest.Model = model
 	testRequest.Model = model
 	testRequest.Messages = append(testRequest.Messages, testMessage)
 	testRequest.Messages = append(testRequest.Messages, testMessage)
@@ -271,6 +274,13 @@ func testAllChannels(notify bool) error {
 		disableThreshold = 10000000 // a impossible value
 		disableThreshold = 10000000 // a impossible value
 	}
 	}
 	gopool.Go(func() {
 	gopool.Go(func() {
+		// 使用 defer 确保无论如何都会重置运行状态,防止死锁
+		defer func() {
+			testAllChannelsLock.Lock()
+			testAllChannelsRunning = false
+			testAllChannelsLock.Unlock()
+		}()
+
 		for _, channel := range channels {
 		for _, channel := range channels {
 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 			tik := time.Now()
 			tik := time.Now()
@@ -305,9 +315,7 @@ func testAllChannels(notify bool) error {
 			channel.UpdateResponseTime(milliseconds)
 			channel.UpdateResponseTime(milliseconds)
 			time.Sleep(common.RequestInterval)
 			time.Sleep(common.RequestInterval)
 		}
 		}
-		testAllChannelsLock.Lock()
-		testAllChannelsRunning = false
-		testAllChannelsLock.Unlock()
+
 		if notify {
 		if notify {
 			service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
 			service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
 		}
 		}

+ 97 - 16
controller/channel.go

@@ -43,22 +43,31 @@ type OpenAIModelsResponse struct {
 func GetAllChannels(c *gin.Context) {
 func GetAllChannels(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	p, _ := strconv.Atoi(c.Query("p"))
 	pageSize, _ := strconv.Atoi(c.Query("page_size"))
 	pageSize, _ := strconv.Atoi(c.Query("page_size"))
-	if p < 0 {
-		p = 0
+	if p < 1 {
+		p = 1
 	}
 	}
-	if pageSize < 0 {
+	if pageSize < 1 {
 		pageSize = common.ItemsPerPage
 		pageSize = common.ItemsPerPage
 	}
 	}
 	channelData := make([]*model.Channel, 0)
 	channelData := make([]*model.Channel, 0)
 	idSort, _ := strconv.ParseBool(c.Query("id_sort"))
 	idSort, _ := strconv.ParseBool(c.Query("id_sort"))
 	enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
 	enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
+	// type filter
+	typeStr := c.Query("type")
+	typeFilter := -1
+	if typeStr != "" {
+		if t, err := strconv.Atoi(typeStr); err == nil {
+			typeFilter = t
+		}
+	}
+
+	var total int64
+
 	if enableTagMode {
 	if enableTagMode {
-		tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
+		// tag 分页:先分页 tag,再取各 tag 下 channels
+		tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
 		if err != nil {
 		if err != nil {
-			c.JSON(http.StatusOK, gin.H{
-				"success": false,
-				"message": err.Error(),
-			})
+			c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
 			return
 			return
 		}
 		}
 		for _, tag := range tags {
 		for _, tag := range tags {
@@ -69,21 +78,39 @@ func GetAllChannels(c *gin.Context) {
 				}
 				}
 			}
 			}
 		}
 		}
+		// 计算 tag 总数用于分页
+		total, _ = model.CountAllTags()
+	} else if typeFilter >= 0 {
+		channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+			return
+		}
+		channelData = channels
+		total, _ = model.CountChannelsByType(typeFilter)
 	} else {
 	} else {
-		channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
+		channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
 		if err != nil {
 		if err != nil {
-			c.JSON(http.StatusOK, gin.H{
-				"success": false,
-				"message": err.Error(),
-			})
+			c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
 			return
 			return
 		}
 		}
 		channelData = channels
 		channelData = channels
+		total, _ = model.CountAllChannels()
 	}
 	}
+
+	// calculate type counts
+	typeCounts, _ := model.CountChannelsGroupByType()
+
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
-		"data":    channelData,
+		"data": gin.H{
+			"items":       channelData,
+			"total":       total,
+			"page":        p,
+			"page_size":   pageSize,
+			"type_counts": typeCounts,
+		},
 	})
 	})
 	return
 	return
 }
 }
@@ -119,8 +146,11 @@ func FetchUpstreamModels(c *gin.Context) {
 		baseURL = channel.GetBaseURL()
 		baseURL = channel.GetBaseURL()
 	}
 	}
 	url := fmt.Sprintf("%s/v1/models", baseURL)
 	url := fmt.Sprintf("%s/v1/models", baseURL)
-	if channel.Type == common.ChannelTypeGemini {
+	switch channel.Type {
+	case common.ChannelTypeGemini:
 		url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
 		url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
+	case common.ChannelTypeAli:
+		url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
 	}
 	}
 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
 	if err != nil {
 	if err != nil {
@@ -207,10 +237,20 @@ func SearchChannels(c *gin.Context) {
 		}
 		}
 		channelData = channels
 		channelData = channels
 	}
 	}
+
+	// calculate type counts for search results
+	typeCounts := make(map[int64]int64)
+	for _, channel := range channelData {
+		typeCounts[int64(channel.Type)]++
+	}
+
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
-		"data":    channelData,
+		"data": gin.H{
+			"items":       channelData,
+			"type_counts": typeCounts,
+		},
 	})
 	})
 	return
 	return
 }
 }
@@ -620,3 +660,44 @@ func BatchSetChannelTag(c *gin.Context) {
 	})
 	})
 	return
 	return
 }
 }
+
+func GetTagModels(c *gin.Context) {
+	tag := c.Query("tag")
+	if tag == "" {
+		c.JSON(http.StatusBadRequest, gin.H{
+			"success": false,
+			"message": "tag不能为空",
+		})
+		return
+	}
+
+	channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+
+	var longestModels string
+	maxLength := 0
+
+	// Find the longest models string among all channels with the given tag
+	for _, channel := range channels {
+		if channel.Models != "" {
+			currentModels := strings.Split(channel.Models, ",")
+			if len(currentModels) > maxLength {
+				maxLength = len(currentModels)
+				longestModels = channel.Models
+			}
+		}
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    longestModels,
+	})
+	return
+}

+ 103 - 0
controller/console_migrate.go

@@ -0,0 +1,103 @@
+// 用于迁移检测的旧键,该文件下个版本会删除
+
+package controller
+
+import (
+    "encoding/json"
+    "net/http"
+    "one-api/common"
+    "one-api/model"
+    "github.com/gin-gonic/gin"
+)
+
+// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
+func MigrateConsoleSetting(c *gin.Context) {
+    // 读取全部 option
+    opts, err := model.AllOption()
+    if err != nil {
+        c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+        return
+    }
+    // 建立 map
+    valMap := map[string]string{}
+    for _, o := range opts {
+        valMap[o.Key] = o.Value
+    }
+
+    // 处理 APIInfo
+    if v := valMap["ApiInfo"]; v != "" {
+        var arr []map[string]interface{}
+        if err := json.Unmarshal([]byte(v), &arr); err == nil {
+            if len(arr) > 50 {
+                arr = arr[:50]
+            }
+            bytes, _ := json.Marshal(arr)
+            model.UpdateOption("console_setting.api_info", string(bytes))
+        }
+        model.UpdateOption("ApiInfo", "")
+    }
+    // Announcements 直接搬
+    if v := valMap["Announcements"]; v != "" {
+        model.UpdateOption("console_setting.announcements", v)
+        model.UpdateOption("Announcements", "")
+    }
+    // FAQ 转换
+    if v := valMap["FAQ"]; v != "" {
+        var arr []map[string]interface{}
+        if err := json.Unmarshal([]byte(v), &arr); err == nil {
+            out := []map[string]interface{}{}
+            for _, item := range arr {
+                q, _ := item["question"].(string)
+                if q == "" {
+                    q, _ = item["title"].(string)
+                }
+                a, _ := item["answer"].(string)
+                if a == "" {
+                    a, _ = item["content"].(string)
+                }
+                if q != "" && a != "" {
+                    out = append(out, map[string]interface{}{"question": q, "answer": a})
+                }
+            }
+            if len(out) > 50 {
+                out = out[:50]
+            }
+            bytes, _ := json.Marshal(out)
+            model.UpdateOption("console_setting.faq", string(bytes))
+        }
+        model.UpdateOption("FAQ", "")
+    }
+    // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
+    url := valMap["UptimeKumaUrl"]
+    slug := valMap["UptimeKumaSlug"]
+    if url != "" && slug != "" {
+        // 仅当同时存在 URL 与 Slug 时才进行迁移
+        groups := []map[string]interface{}{
+            {
+                "id":           1,
+                "categoryName": "old",
+                "url":          url,
+                "slug":         slug,
+                "description":  "",
+            },
+        }
+        bytes, _ := json.Marshal(groups)
+        model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
+    }
+    // 清空旧键内容
+    if url != "" {
+        model.UpdateOption("UptimeKumaUrl", "")
+    }
+    if slug != "" {
+        model.UpdateOption("UptimeKumaSlug", "")
+    }
+
+    // 删除旧键记录
+    oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
+    model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
+
+    // 重新加载 OptionMap
+    model.InitOptionMap()
+    common.SysLog("console setting migrated")
+    c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
+} 

+ 11 - 3
controller/group.go

@@ -1,15 +1,17 @@
 package controller
 package controller
 
 
 import (
 import (
-	"github.com/gin-gonic/gin"
 	"net/http"
 	"net/http"
 	"one-api/model"
 	"one-api/model"
 	"one-api/setting"
 	"one-api/setting"
+	"one-api/setting/ratio_setting"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 func GetGroups(c *gin.Context) {
 func GetGroups(c *gin.Context) {
 	groupNames := make([]string, 0)
 	groupNames := make([]string, 0)
-	for groupName, _ := range setting.GetGroupRatioCopy() {
+	for groupName := range ratio_setting.GetGroupRatioCopy() {
 		groupNames = append(groupNames, groupName)
 		groupNames = append(groupNames, groupName)
 	}
 	}
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
 	userGroup := ""
 	userGroup := ""
 	userId := c.GetInt("id")
 	userId := c.GetInt("id")
 	userGroup, _ = model.GetUserGroup(userId, false)
 	userGroup, _ = model.GetUserGroup(userId, false)
-	for groupName, ratio := range setting.GetGroupRatioCopy() {
+	for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
 		// UserUsableGroups contains the groups that the user can use
 		// UserUsableGroups contains the groups that the user can use
 		userUsableGroups := setting.GetUserUsableGroups(userGroup)
 		userUsableGroups := setting.GetUserUsableGroups(userGroup)
 		if desc, ok := userUsableGroups[groupName]; ok {
 		if desc, ok := userUsableGroups[groupName]; ok {
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
 			}
 			}
 		}
 		}
 	}
 	}
+	if setting.GroupInUserUsableGroups("auto") {
+		usableGroups["auto"] = map[string]interface{}{
+			"ratio": "自动",
+			"desc":  setting.GetUsableGroupDescription("auto"),
+		}
+	}
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",

+ 34 - 20
controller/midjourney.go

@@ -7,7 +7,6 @@ import (
 	"fmt"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"io"
 	"io"
-	"log"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
@@ -215,8 +214,12 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
 
 
 func GetAllMidjourney(c *gin.Context) {
 func GetAllMidjourney(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	p, _ := strconv.Atoi(c.Query("p"))
-	if p < 0 {
-		p = 0
+	if p < 1 {
+		p = 1
+	}
+	pageSize, _ := strconv.Atoi(c.Query("page_size"))
+	if pageSize <= 0 {
+		pageSize = common.ItemsPerPage
 	}
 	}
 
 
 	// 解析其他查询参数
 	// 解析其他查询参数
@@ -227,31 +230,38 @@ func GetAllMidjourney(c *gin.Context) {
 		EndTimestamp:   c.Query("end_timestamp"),
 		EndTimestamp:   c.Query("end_timestamp"),
 	}
 	}
 
 
-	logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
-	if logs == nil {
-		logs = make([]*model.Midjourney, 0)
-	}
+	items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
+	total := model.CountAllTasks(queryParams)
+
 	if setting.MjForwardUrlEnabled {
 	if setting.MjForwardUrlEnabled {
-		for i, midjourney := range logs {
+		for i, midjourney := range items {
 			midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
 			midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
-			logs[i] = midjourney
+			items[i] = midjourney
 		}
 		}
 	}
 	}
 	c.JSON(200, gin.H{
 	c.JSON(200, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
-		"data":    logs,
+		"data": gin.H{
+			"items":     items,
+			"total":     total,
+			"page":      p,
+			"page_size": pageSize,
+		},
 	})
 	})
 }
 }
 
 
 func GetUserMidjourney(c *gin.Context) {
 func GetUserMidjourney(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	p, _ := strconv.Atoi(c.Query("p"))
-	if p < 0 {
-		p = 0
+	if p < 1 {
+		p = 1
+	}
+	pageSize, _ := strconv.Atoi(c.Query("page_size"))
+	if pageSize <= 0 {
+		pageSize = common.ItemsPerPage
 	}
 	}
 
 
 	userId := c.GetInt("id")
 	userId := c.GetInt("id")
-	log.Printf("userId = %d \n", userId)
 
 
 	queryParams := model.TaskQueryParams{
 	queryParams := model.TaskQueryParams{
 		MjID:           c.Query("mj_id"),
 		MjID:           c.Query("mj_id"),
@@ -259,19 +269,23 @@ func GetUserMidjourney(c *gin.Context) {
 		EndTimestamp:   c.Query("end_timestamp"),
 		EndTimestamp:   c.Query("end_timestamp"),
 	}
 	}
 
 
-	logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
-	if logs == nil {
-		logs = make([]*model.Midjourney, 0)
-	}
+	items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
+	total := model.CountAllUserTask(userId, queryParams)
+
 	if setting.MjForwardUrlEnabled {
 	if setting.MjForwardUrlEnabled {
-		for i, midjourney := range logs {
+		for i, midjourney := range items {
 			midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
 			midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
-			logs[i] = midjourney
+			items[i] = midjourney
 		}
 		}
 	}
 	}
 	c.JSON(200, gin.H{
 	c.JSON(200, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
-		"data":    logs,
+		"data": gin.H{
+			"items":     items,
+			"total":     total,
+			"page":      p,
+			"page_size": pageSize,
+		},
 	})
 	})
 }
 }

+ 72 - 42
controller/misc.go

@@ -6,8 +6,10 @@ import (
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
+	"one-api/middleware"
 	"one-api/model"
 	"one-api/model"
 	"one-api/setting"
 	"one-api/setting"
+	"one-api/setting/console_setting"
 	"one-api/setting/operation_setting"
 	"one-api/setting/operation_setting"
 	"one-api/setting/system_setting"
 	"one-api/setting/system_setting"
 	"strings"
 	"strings"
@@ -24,57 +26,85 @@ func TestStatus(c *gin.Context) {
 		})
 		})
 		return
 		return
 	}
 	}
+	// 获取HTTP统计信息
+	httpStats := middleware.GetStats()
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
-		"success": true,
-		"message": "Server is running",
+		"success":    true,
+		"message":    "Server is running",
+		"http_stats": httpStats,
 	})
 	})
 	return
 	return
 }
 }
 
 
 func GetStatus(c *gin.Context) {
 func GetStatus(c *gin.Context) {
+
+	cs := console_setting.GetConsoleSetting()
+
+	data := gin.H{
+		"version":                  common.Version,
+		"start_time":               common.StartTime,
+		"email_verification":       common.EmailVerificationEnabled,
+		"github_oauth":             common.GitHubOAuthEnabled,
+		"github_client_id":         common.GitHubClientId,
+		"linuxdo_oauth":            common.LinuxDOOAuthEnabled,
+		"linuxdo_client_id":        common.LinuxDOClientId,
+		"telegram_oauth":           common.TelegramOAuthEnabled,
+		"telegram_bot_name":        common.TelegramBotName,
+		"system_name":              common.SystemName,
+		"logo":                     common.Logo,
+		"footer_html":              common.Footer,
+		"wechat_qrcode":            common.WeChatAccountQRCodeImageURL,
+		"wechat_login":             common.WeChatAuthEnabled,
+		"server_address":           setting.ServerAddress,
+		"price":                    setting.Price,
+		"min_topup":                setting.MinTopUp,
+		"turnstile_check":          common.TurnstileCheckEnabled,
+		"turnstile_site_key":       common.TurnstileSiteKey,
+		"top_up_link":              common.TopUpLink,
+		"docs_link":                operation_setting.GetGeneralSetting().DocsLink,
+		"quota_per_unit":           common.QuotaPerUnit,
+		"display_in_currency":      common.DisplayInCurrencyEnabled,
+		"enable_batch_update":      common.BatchUpdateEnabled,
+		"enable_drawing":           common.DrawingEnabled,
+		"enable_task":              common.TaskEnabled,
+		"enable_data_export":       common.DataExportEnabled,
+		"data_export_default_time": common.DataExportDefaultTime,
+		"default_collapse_sidebar": common.DefaultCollapseSidebar,
+		"enable_online_topup":      setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+		"mj_notify_enabled":        setting.MjNotifyEnabled,
+		"chats":                    setting.Chats,
+		"demo_site_enabled":        operation_setting.DemoSiteEnabled,
+		"self_use_mode_enabled":    operation_setting.SelfUseModeEnabled,
+		"default_use_auto_group":   setting.DefaultUseAutoGroup,
+		"pay_methods":              setting.PayMethods,
+
+		// 面板启用开关
+		"api_info_enabled":      cs.ApiInfoEnabled,
+		"uptime_kuma_enabled":   cs.UptimeKumaEnabled,
+		"announcements_enabled": cs.AnnouncementsEnabled,
+		"faq_enabled":           cs.FAQEnabled,
+
+		"oidc_enabled":                system_setting.GetOIDCSettings().Enabled,
+		"oidc_client_id":              system_setting.GetOIDCSettings().ClientId,
+		"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
+		"setup":                       constant.Setup,
+	}
+
+	// 根据启用状态注入可选内容
+	if cs.ApiInfoEnabled {
+		data["api_info"] = console_setting.GetApiInfo()
+	}
+	if cs.AnnouncementsEnabled {
+		data["announcements"] = console_setting.GetAnnouncements()
+	}
+	if cs.FAQEnabled {
+		data["faq"] = console_setting.GetFAQ()
+	}
+
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
-		"data": gin.H{
-			"version":                     common.Version,
-			"start_time":                  common.StartTime,
-			"email_verification":          common.EmailVerificationEnabled,
-			"github_oauth":                common.GitHubOAuthEnabled,
-			"github_client_id":            common.GitHubClientId,
-			"linuxdo_oauth":               common.LinuxDOOAuthEnabled,
-			"linuxdo_client_id":           common.LinuxDOClientId,
-			"telegram_oauth":              common.TelegramOAuthEnabled,
-			"telegram_bot_name":           common.TelegramBotName,
-			"system_name":                 common.SystemName,
-			"logo":                        common.Logo,
-			"footer_html":                 common.Footer,
-			"wechat_qrcode":               common.WeChatAccountQRCodeImageURL,
-			"wechat_login":                common.WeChatAuthEnabled,
-			"server_address":              setting.ServerAddress,
-			"price":                       setting.Price,
-			"min_topup":                   setting.MinTopUp,
-			"turnstile_check":             common.TurnstileCheckEnabled,
-			"turnstile_site_key":          common.TurnstileSiteKey,
-			"top_up_link":                 common.TopUpLink,
-			"docs_link":                   operation_setting.GetGeneralSetting().DocsLink,
-			"quota_per_unit":              common.QuotaPerUnit,
-			"display_in_currency":         common.DisplayInCurrencyEnabled,
-			"enable_batch_update":         common.BatchUpdateEnabled,
-			"enable_drawing":              common.DrawingEnabled,
-			"enable_task":                 common.TaskEnabled,
-			"enable_data_export":          common.DataExportEnabled,
-			"data_export_default_time":    common.DataExportDefaultTime,
-			"default_collapse_sidebar":    common.DefaultCollapseSidebar,
-			"enable_online_topup":         setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
-			"mj_notify_enabled":           setting.MjNotifyEnabled,
-			"chats":                       setting.Chats,
-			"demo_site_enabled":           operation_setting.DemoSiteEnabled,
-			"self_use_mode_enabled":       operation_setting.SelfUseModeEnabled,
-			"oidc_enabled":                system_setting.GetOIDCSettings().Enabled,
-			"oidc_client_id":              system_setting.GetOIDCSettings().ClientId,
-			"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
-			"setup":                       constant.Setup,
-		},
+		"data":    data,
 	})
 	})
 	return
 	return
 }
 }

+ 20 - 2
controller/model.go

@@ -2,7 +2,7 @@ package controller
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"github.com/gin-gonic/gin"
+	"github.com/samber/lo"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
@@ -15,6 +15,9 @@ import (
 	"one-api/relay/channel/moonshot"
 	"one-api/relay/channel/moonshot"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
+	"one-api/setting"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 // https://platform.openai.com/docs/api-reference/models/list
 // https://platform.openai.com/docs/api-reference/models/list
@@ -134,6 +137,9 @@ func init() {
 		adaptor.Init(meta)
 		adaptor.Init(meta)
 		channelId2Models[i] = adaptor.GetModelList()
 		channelId2Models[i] = adaptor.GetModelList()
 	}
 	}
+	openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
+		return m.Id
+	})
 }
 }
 
 
 func ListModels(c *gin.Context) {
 func ListModels(c *gin.Context) {
@@ -179,7 +185,19 @@ func ListModels(c *gin.Context) {
 		if tokenGroup != "" {
 		if tokenGroup != "" {
 			group = tokenGroup
 			group = tokenGroup
 		}
 		}
-		models := model.GetGroupModels(group)
+		var models []string
+		if tokenGroup == "auto" {
+			for _, autoGroup := range setting.AutoGroups {
+				groupModels := model.GetGroupModels(autoGroup)
+				for _, g := range groupModels {
+					if !common.StringsContains(models, g) {
+						models = append(models, g)
+					}
+				}
+			}
+		} else {
+			models = model.GetGroupModels(group)
+		}
 		for _, s := range models {
 		for _, s := range models {
 			if _, ok := openAIModelsMap[s]; ok {
 			if _, ok := openAIModelsMap[s]; ok {
 				userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
 				userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])

+ 39 - 2
controller/option.go

@@ -6,6 +6,8 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/model"
 	"one-api/model"
 	"one-api/setting"
 	"one-api/setting"
+	"one-api/setting/console_setting"
+	"one-api/setting/ratio_setting"
 	"one-api/setting/system_setting"
 	"one-api/setting/system_setting"
 	"strings"
 	"strings"
 
 
@@ -102,7 +104,7 @@ func UpdateOption(c *gin.Context) {
 			return
 			return
 		}
 		}
 	case "GroupRatio":
 	case "GroupRatio":
-		err = setting.CheckGroupRatio(option.Value)
+		err = ratio_setting.CheckGroupRatio(option.Value)
 		if err != nil {
 		if err != nil {
 			c.JSON(http.StatusOK, gin.H{
 			c.JSON(http.StatusOK, gin.H{
 				"success": false,
 				"success": false,
@@ -119,7 +121,42 @@ func UpdateOption(c *gin.Context) {
 			})
 			})
 			return
 			return
 		}
 		}
-
+	case "console_setting.api_info":
+		err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	case "console_setting.announcements":
+		err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	case "console_setting.faq":
+		err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
+	case "console_setting.uptime_kuma_groups":
+		err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
 	}
 	}
 	err = model.UpdateOption(option.Key, option.Value)
 	err = model.UpdateOption(option.Key, option.Value)
 	if err != nil {
 	if err != nil {

+ 13 - 3
controller/playground.go

@@ -3,7 +3,6 @@ package controller
 import (
 import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
@@ -13,6 +12,8 @@ import (
 	"one-api/service"
 	"one-api/service"
 	"one-api/setting"
 	"one-api/setting"
 	"time"
 	"time"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 func Playground(c *gin.Context) {
 func Playground(c *gin.Context) {
@@ -57,13 +58,22 @@ func Playground(c *gin.Context) {
 		c.Set("group", group)
 		c.Set("group", group)
 	}
 	}
 	c.Set("token_name", "playground-"+group)
 	c.Set("token_name", "playground-"+group)
-	channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
+	channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
 	if err != nil {
 	if err != nil {
-		message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
+		message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
 		openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
 		openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
 		return
 		return
 	}
 	}
 	middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
 	middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
 	c.Set(constant.ContextKeyRequestStartTime, time.Now())
 	c.Set(constant.ContextKeyRequestStartTime, time.Now())
+
+	// Write user context to ensure acceptUnsetRatio is available
+	userId := c.GetInt("id")
+	userCache, err := model.GetUserCache(userId)
+	if err != nil {
+		openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
+		return
+	}
+	userCache.WriteContext(c)
 	Relay(c)
 	Relay(c)
 }
 }

+ 13 - 6
controller/pricing.go

@@ -1,10 +1,11 @@
 package controller
 package controller
 
 
 import (
 import (
-	"github.com/gin-gonic/gin"
 	"one-api/model"
 	"one-api/model"
 	"one-api/setting"
 	"one-api/setting"
-	"one-api/setting/operation_setting"
+	"one-api/setting/ratio_setting"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 func GetPricing(c *gin.Context) {
 func GetPricing(c *gin.Context) {
@@ -12,7 +13,7 @@ func GetPricing(c *gin.Context) {
 	userId, exists := c.Get("id")
 	userId, exists := c.Get("id")
 	usableGroup := map[string]string{}
 	usableGroup := map[string]string{}
 	groupRatio := map[string]float64{}
 	groupRatio := map[string]float64{}
-	for s, f := range setting.GetGroupRatioCopy() {
+	for s, f := range ratio_setting.GetGroupRatioCopy() {
 		groupRatio[s] = f
 		groupRatio[s] = f
 	}
 	}
 	var group string
 	var group string
@@ -20,12 +21,18 @@ func GetPricing(c *gin.Context) {
 		user, err := model.GetUserCache(userId.(int))
 		user, err := model.GetUserCache(userId.(int))
 		if err == nil {
 		if err == nil {
 			group = user.Group
 			group = user.Group
+			for g := range groupRatio {
+				ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
+				if ok {
+					groupRatio[g] = ratio
+				}
+			}
 		}
 		}
 	}
 	}
 
 
 	usableGroup = setting.GetUserUsableGroups(group)
 	usableGroup = setting.GetUserUsableGroups(group)
 	// check groupRatio contains usableGroup
 	// check groupRatio contains usableGroup
-	for group := range setting.GetGroupRatioCopy() {
+	for group := range ratio_setting.GetGroupRatioCopy() {
 		if _, ok := usableGroup[group]; !ok {
 		if _, ok := usableGroup[group]; !ok {
 			delete(groupRatio, group)
 			delete(groupRatio, group)
 		}
 		}
@@ -40,7 +47,7 @@ func GetPricing(c *gin.Context) {
 }
 }
 
 
 func ResetModelRatio(c *gin.Context) {
 func ResetModelRatio(c *gin.Context) {
-	defaultStr := operation_setting.DefaultModelRatio2JSONString()
+	defaultStr := ratio_setting.DefaultModelRatio2JSONString()
 	err := model.UpdateOption("ModelRatio", defaultStr)
 	err := model.UpdateOption("ModelRatio", defaultStr)
 	if err != nil {
 	if err != nil {
 		c.JSON(200, gin.H{
 		c.JSON(200, gin.H{
@@ -49,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
 		})
 		})
 		return
 		return
 	}
 	}
-	err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
+	err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
 	if err != nil {
 	if err != nil {
 		c.JSON(200, gin.H{
 		c.JSON(200, gin.H{
 			"success": false,
 			"success": false,

+ 24 - 0
controller/ratio_config.go

@@ -0,0 +1,24 @@
+package controller
+
+import (
+    "net/http"
+    "one-api/setting/ratio_setting"
+
+    "github.com/gin-gonic/gin"
+)
+
+func GetRatioConfig(c *gin.Context) {
+    if !ratio_setting.IsExposeRatioEnabled() {
+        c.JSON(http.StatusForbidden, gin.H{
+            "success": false,
+            "message": "倍率配置接口未启用",
+        })
+        return
+    }
+
+    c.JSON(http.StatusOK, gin.H{
+        "success": true,
+        "message": "",
+        "data":    ratio_setting.GetExposedData(),
+    })
+} 

+ 322 - 0
controller/ratio_sync.go

@@ -0,0 +1,322 @@
+package controller
+
+import (
+    "context"
+    "encoding/json"
+    "net/http"
+    "strings"
+    "sync"
+    "time"
+
+    "one-api/common"
+    "one-api/dto"
+    "one-api/model"
+    "one-api/setting/ratio_setting"
+
+    "github.com/gin-gonic/gin"
+)
+
+const (
+    defaultTimeoutSeconds  = 10
+    defaultEndpoint        = "/api/ratio_config"
+    maxConcurrentFetches   = 8
+)
+
+var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
+
+type upstreamResult struct {
+    Name string                 `json:"name"`
+    Data map[string]any         `json:"data,omitempty"`
+    Err  string                 `json:"err,omitempty"`
+}
+
+func FetchUpstreamRatios(c *gin.Context) {
+    var req dto.UpstreamRequest
+    if err := c.ShouldBindJSON(&req); err != nil {
+        c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+        return
+    }
+
+    if req.Timeout <= 0 {
+        req.Timeout = defaultTimeoutSeconds
+    }
+
+    var upstreams []dto.UpstreamDTO
+
+    if len(req.ChannelIDs) > 0 {
+        intIds := make([]int, 0, len(req.ChannelIDs))
+        for _, id64 := range req.ChannelIDs {
+            intIds = append(intIds, int(id64))
+        }
+        dbChannels, err := model.GetChannelsByIds(intIds)
+        if err != nil {
+            common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
+            c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
+            return
+        }
+        for _, ch := range dbChannels {
+            if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
+                upstreams = append(upstreams, dto.UpstreamDTO{
+                    Name:     ch.Name,
+                    BaseURL:  strings.TrimRight(base, "/"),
+                    Endpoint: "",
+                })
+            }
+        }
+    }
+
+    if len(upstreams) == 0 {
+        c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
+        return
+    }
+
+    var wg sync.WaitGroup
+    ch := make(chan upstreamResult, len(upstreams))
+
+    sem := make(chan struct{}, maxConcurrentFetches)
+
+    client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
+
+    for _, chn := range upstreams {
+        wg.Add(1)
+        go func(chItem dto.UpstreamDTO) {
+            defer wg.Done()
+
+            sem <- struct{}{}
+            defer func() { <-sem }()
+
+            endpoint := chItem.Endpoint
+            if endpoint == "" {
+                endpoint = defaultEndpoint
+            } else if !strings.HasPrefix(endpoint, "/") {
+                endpoint = "/" + endpoint
+            }
+            fullURL := chItem.BaseURL + endpoint
+
+            ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
+            defer cancel()
+
+            httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
+            if err != nil {
+                common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
+                ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
+                return
+            }
+
+            resp, err := client.Do(httpReq)
+            if err != nil {
+                common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
+                ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
+                return
+            }
+            defer resp.Body.Close()
+            if resp.StatusCode != http.StatusOK {
+                common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
+                ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
+                return
+            }
+            var body struct {
+                Success bool                   `json:"success"`
+                Data    map[string]any         `json:"data"`
+                Message string                 `json:"message"`
+            }
+            if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+                common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
+                ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
+                return
+            }
+            if !body.Success {
+                ch <- upstreamResult{Name: chItem.Name, Err: body.Message}
+                return
+            }
+            ch <- upstreamResult{Name: chItem.Name, Data: body.Data}
+        }(chn)
+    }
+
+    wg.Wait()
+    close(ch)
+
+    localData := ratio_setting.GetExposedData()
+
+    var testResults []dto.TestResult
+    var successfulChannels []struct {
+        name string
+        data map[string]any
+    }
+
+    for r := range ch {
+        if r.Err != "" {
+            testResults = append(testResults, dto.TestResult{
+                Name:   r.Name,
+                Status: "error",
+                Error:  r.Err,
+            })
+        } else {
+            testResults = append(testResults, dto.TestResult{
+                Name:   r.Name,
+                Status: "success",
+            })
+            successfulChannels = append(successfulChannels, struct {
+                name string
+                data map[string]any
+            }{name: r.Name, data: r.Data})
+        }
+    }
+
+    differences := buildDifferences(localData, successfulChannels)
+
+    c.JSON(http.StatusOK, gin.H{
+        "success": true,
+        "data": gin.H{
+            "differences":  differences,
+            "test_results": testResults,
+        },
+    })
+}
+
+func buildDifferences(localData map[string]any, successfulChannels []struct {
+    name string
+    data map[string]any
+}) map[string]map[string]dto.DifferenceItem {
+    differences := make(map[string]map[string]dto.DifferenceItem)
+
+    allModels := make(map[string]struct{})
+    
+    for _, ratioType := range ratioTypes {
+        if localRatioAny, ok := localData[ratioType]; ok {
+            if localRatio, ok := localRatioAny.(map[string]float64); ok {
+                for modelName := range localRatio {
+                    allModels[modelName] = struct{}{}
+                }
+            }
+        }
+    }
+    
+    for _, channel := range successfulChannels {
+        for _, ratioType := range ratioTypes {
+            if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+                for modelName := range upstreamRatio {
+                    allModels[modelName] = struct{}{}
+                }
+            }
+        }
+    }
+
+    for modelName := range allModels {
+        for _, ratioType := range ratioTypes {
+            var localValue interface{} = nil
+            if localRatioAny, ok := localData[ratioType]; ok {
+                if localRatio, ok := localRatioAny.(map[string]float64); ok {
+                    if val, exists := localRatio[modelName]; exists {
+                        localValue = val
+                    }
+                }
+            }
+
+            upstreamValues := make(map[string]interface{})
+            hasUpstreamValue := false
+            hasDifference := false
+
+            for _, channel := range successfulChannels {
+                var upstreamValue interface{} = nil
+                
+                if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+                    if val, exists := upstreamRatio[modelName]; exists {
+                        upstreamValue = val
+                        hasUpstreamValue = true
+                        
+                        if localValue != nil && localValue != val {
+                            hasDifference = true
+                        } else if localValue == val {
+                            upstreamValue = "same"
+                        }
+                    }
+                }
+                if upstreamValue == nil && localValue == nil {
+                    upstreamValue = "same"
+                }
+                
+                if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
+                    hasDifference = true
+                }
+                
+                upstreamValues[channel.name] = upstreamValue
+            }
+
+            shouldInclude := false
+            
+            if localValue != nil {
+                if hasDifference {
+                    shouldInclude = true
+                }
+            } else {
+                if hasUpstreamValue {
+                    shouldInclude = true
+                }
+            }
+
+            if shouldInclude {
+                if differences[modelName] == nil {
+                    differences[modelName] = make(map[string]dto.DifferenceItem)
+                }
+                differences[modelName][ratioType] = dto.DifferenceItem{
+                    Current:   localValue,
+                    Upstreams: upstreamValues,
+                }
+            }
+        }
+    }
+
+    channelHasDiff := make(map[string]bool)
+    for _, ratioMap := range differences {
+        for _, item := range ratioMap {
+            for chName, val := range item.Upstreams {
+                if val != nil && val != "same" {
+                    channelHasDiff[chName] = true
+                }
+            }
+        }
+    }
+
+    for modelName, ratioMap := range differences {
+        for ratioType, item := range ratioMap {
+            for chName := range item.Upstreams {
+                if !channelHasDiff[chName] {
+                    delete(item.Upstreams, chName)
+                }
+            }
+            differences[modelName][ratioType] = item
+        }
+    }
+
+    return differences
+}
+
+func GetSyncableChannels(c *gin.Context) {
+    channels, err := model.GetAllChannels(0, 0, true, false)
+    if err != nil {
+        c.JSON(http.StatusOK, gin.H{
+            "success": false,
+            "message": err.Error(),
+        })
+        return
+    }
+
+    var syncableChannels []dto.SyncableChannel
+    for _, channel := range channels {
+        if channel.GetBaseURL() != "" {
+            syncableChannels = append(syncableChannels, dto.SyncableChannel{
+                ID:      channel.Id,
+                Name:    channel.Name,
+                BaseURL: channel.GetBaseURL(),
+                Status:  channel.Status,
+            })
+        }
+    }
+
+    c.JSON(http.StatusOK, gin.H{
+        "success": true,
+        "message": "",
+        "data":    syncableChannels,
+    })
+} 

+ 39 - 3
controller/redemption.go

@@ -5,6 +5,7 @@ import (
 	"one-api/common"
 	"one-api/common"
 	"one-api/model"
 	"one-api/model"
 	"strconv"
 	"strconv"
+	"errors"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
@@ -126,6 +127,10 @@ func AddRedemption(c *gin.Context) {
 		})
 		})
 		return
 		return
 	}
 	}
+	if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+		c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+		return
+	}
 	var keys []string
 	var keys []string
 	for i := 0; i < redemption.Count; i++ {
 	for i := 0; i < redemption.Count; i++ {
 		key := common.GetUUID()
 		key := common.GetUUID()
@@ -135,6 +140,7 @@ func AddRedemption(c *gin.Context) {
 			Key:         key,
 			Key:         key,
 			CreatedTime: common.GetTimestamp(),
 			CreatedTime: common.GetTimestamp(),
 			Quota:       redemption.Quota,
 			Quota:       redemption.Quota,
+			ExpiredTime: redemption.ExpiredTime,
 		}
 		}
 		err = cleanRedemption.Insert()
 		err = cleanRedemption.Insert()
 		if err != nil {
 		if err != nil {
@@ -191,12 +197,18 @@ func UpdateRedemption(c *gin.Context) {
 		})
 		})
 		return
 		return
 	}
 	}
-	if statusOnly != "" {
-		cleanRedemption.Status = redemption.Status
-	} else {
+	if statusOnly == "" {
+		if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+			c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+			return
+		}
 		// If you add more fields, please also update redemption.Update()
 		// If you add more fields, please also update redemption.Update()
 		cleanRedemption.Name = redemption.Name
 		cleanRedemption.Name = redemption.Name
 		cleanRedemption.Quota = redemption.Quota
 		cleanRedemption.Quota = redemption.Quota
+		cleanRedemption.ExpiredTime = redemption.ExpiredTime
+	}
+	if statusOnly != "" {
+		cleanRedemption.Status = redemption.Status
 	}
 	}
 	err = cleanRedemption.Update()
 	err = cleanRedemption.Update()
 	if err != nil {
 	if err != nil {
@@ -213,3 +225,27 @@ func UpdateRedemption(c *gin.Context) {
 	})
 	})
 	return
 	return
 }
 }
+
+func DeleteInvalidRedemption(c *gin.Context) {
+	rows, err := model.DeleteInvalidRedemptions()
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data": rows,
+	})
+	return
+}
+
+func validateExpiredTime(expired int64) error {
+	if expired != 0 && expired < common.GetTimestamp() {
+		return errors.New("过期时间不能早于当前时间")
+	}
+	return nil
+}

+ 5 - 3
controller/relay.go

@@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 		err = relay.EmbeddingHelper(c)
 		err = relay.EmbeddingHelper(c)
 	case relayconstant.RelayModeResponses:
 	case relayconstant.RelayModeResponses:
 		err = relay.ResponsesHelper(c)
 		err = relay.ResponsesHelper(c)
+	case relayconstant.RelayModeGemini:
+		err = relay.GeminiHelper(c)
 	default:
 	default:
 		err = relay.TextHelper(c)
 		err = relay.TextHelper(c)
 	}
 	}
@@ -257,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
 			AutoBan: &autoBanInt,
 			AutoBan: &autoBanInt,
 		}, nil
 		}, nil
 	}
 	}
-	channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
+	channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
 	if err != nil {
 	if err != nil {
 		return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
 		return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
 	}
 	}
@@ -386,7 +388,7 @@ func RelayTask(c *gin.Context) {
 		retryTimes = 0
 		retryTimes = 0
 	}
 	}
 	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
 	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
-		channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+		channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
 		if err != nil {
 		if err != nil {
 			common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
 			common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
 			break
 			break
@@ -418,7 +420,7 @@ func RelayTask(c *gin.Context) {
 func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
 func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
 	var err *dto.TaskError
 	var err *dto.TaskError
 	switch relayMode {
 	switch relayMode {
-	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
+	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
 		err = relay.RelayTaskFetch(c, relayMode)
 		err = relay.RelayTaskFetch(c, relayMode)
 	default:
 	default:
 		err = relay.RelayTaskSubmit(c, relayMode)
 		err = relay.RelayTaskSubmit(c, relayMode)

+ 8 - 0
controller/setup.go

@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
 
 
 	// If root doesn't exist, validate and create admin account
 	// If root doesn't exist, validate and create admin account
 	if !rootExists {
 	if !rootExists {
+		// Validate username length: max 12 characters to align with model.User validation
+		if len(req.Username) > 12 {
+			c.JSON(400, gin.H{
+				"success": false,
+				"message": "用户名长度不能超过12个字符",
+			})
+			return
+		}
 		// Validate password
 		// Validate password
 		if req.Password != req.ConfirmPassword {
 		if req.Password != req.ConfirmPassword {
 			c.JSON(400, gin.H{
 			c.JSON(400, gin.H{

+ 32 - 14
controller/task.go

@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
 		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
 		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
 	case constant.TaskPlatformSuno:
 	case constant.TaskPlatformSuno:
 		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
 		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
+	case constant.TaskPlatformKling:
+		_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
 	default:
 	default:
 		common.SysLog("未知平台")
 		common.SysLog("未知平台")
 	}
 	}
@@ -224,9 +226,14 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
 
 
 func GetAllTask(c *gin.Context) {
 func GetAllTask(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	p, _ := strconv.Atoi(c.Query("p"))
-	if p < 0 {
-		p = 0
+	if p < 1 {
+		p = 1
 	}
 	}
+	pageSize, _ := strconv.Atoi(c.Query("page_size"))
+	if pageSize <= 0 {
+		pageSize = common.ItemsPerPage
+	}
+
 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 	// 解析其他查询参数
 	// 解析其他查询参数
@@ -237,24 +244,32 @@ func GetAllTask(c *gin.Context) {
 		Action:         c.Query("action"),
 		Action:         c.Query("action"),
 		StartTimestamp: startTimestamp,
 		StartTimestamp: startTimestamp,
 		EndTimestamp:   endTimestamp,
 		EndTimestamp:   endTimestamp,
+		ChannelID:      c.Query("channel_id"),
 	}
 	}
 
 
-	logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
-	if logs == nil {
-		logs = make([]*model.Task, 0)
-	}
+	items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
+	total := model.TaskCountAllTasks(queryParams)
 
 
 	c.JSON(200, gin.H{
 	c.JSON(200, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
-		"data":    logs,
+		"data": gin.H{
+			"items":     items,
+			"total":     total,
+			"page":      p,
+			"page_size": pageSize,
+		},
 	})
 	})
 }
 }
 
 
 func GetUserTask(c *gin.Context) {
 func GetUserTask(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	p, _ := strconv.Atoi(c.Query("p"))
-	if p < 0 {
-		p = 0
+	if p < 1 {
+		p = 1
+	}
+	pageSize, _ := strconv.Atoi(c.Query("page_size"))
+	if pageSize <= 0 {
+		pageSize = common.ItemsPerPage
 	}
 	}
 
 
 	userId := c.GetInt("id")
 	userId := c.GetInt("id")
@@ -271,14 +286,17 @@ func GetUserTask(c *gin.Context) {
 		EndTimestamp:   endTimestamp,
 		EndTimestamp:   endTimestamp,
 	}
 	}
 
 
-	logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
-	if logs == nil {
-		logs = make([]*model.Task, 0)
-	}
+	items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
+	total := model.TaskCountAllUserTask(userId, queryParams)
 
 
 	c.JSON(200, gin.H{
 	c.JSON(200, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
-		"data":    logs,
+		"data": gin.H{
+			"items":     items,
+			"total":     total,
+			"page":      p,
+			"page_size": pageSize,
+		},
 	})
 	})
 }
 }

+ 140 - 0
controller/task_video.go

@@ -0,0 +1,140 @@
+package controller
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/model"
+	"one-api/relay"
+	"one-api/relay/channel"
+)
+
+func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
+			common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	cacheGetChannel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
+			"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if errUpdate != nil {
+			common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+		}
+		return fmt.Errorf("CacheGetChannel failed: %w", err)
+	}
+	adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
+	if adaptor == nil {
+		return fmt.Errorf("video adaptor not found")
+	}
+	for _, taskId := range taskIds {
+		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+			common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
+	baseURL := common.ChannelBaseURLs[channel.Type]
+	if channel.GetBaseURL() != "" {
+		baseURL = channel.GetBaseURL()
+	}
+	resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
+		"task_id": taskId,
+	})
+	if err != nil {
+		return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
+	}
+	if resp.StatusCode != http.StatusOK {
+		return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
+	}
+
+	var responseItem map[string]interface{}
+	err = json.Unmarshal(responseBody, &responseItem)
+	if err != nil {
+		common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
+		return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
+	}
+
+	code, _ := responseItem["code"].(float64)
+	if code != 0 {
+		return fmt.Errorf("video task fetch failed for task %s", taskId)
+	}
+
+	data, ok := responseItem["data"].(map[string]interface{})
+	if !ok {
+		common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
+		return fmt.Errorf("video task data format error for task %s", taskId)
+	}
+
+	task := taskM[taskId]
+	if task == nil {
+		common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+		return fmt.Errorf("task %s not found", taskId)
+	}
+
+	if status, ok := data["task_status"].(string); ok {
+		switch status {
+		case "submitted", "queued":
+			task.Status = model.TaskStatusSubmitted
+		case "processing":
+			task.Status = model.TaskStatusInProgress
+		case "succeed":
+			task.Status = model.TaskStatusSuccess
+			task.Progress = "100%"
+			if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
+				task.FailReason = url
+			} else {
+				common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
+			}
+		case "failed":
+			task.Status = model.TaskStatusFailure
+			task.Progress = "100%"
+			if reason, ok := data["fail_reason"].(string); ok {
+				task.FailReason = reason
+			}
+		}
+	}
+
+	// If task failed, refund quota
+	if task.Status == model.TaskStatusFailure {
+		common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+		quota := task.Quota
+		if quota != 0 {
+			if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
+				common.LogError(ctx, "Failed to increase user quota: "+err.Error())
+			}
+			logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
+			model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+		}
+	}
+
+	task.Data = responseBody
+	if err := task.Update(); err != nil {
+		common.SysError("UpdateVideoTask task error: " + err.Error())
+	}
+
+	return nil
+}

+ 12 - 4
controller/token.go

@@ -12,15 +12,15 @@ func GetAllTokens(c *gin.Context) {
 	userId := c.GetInt("id")
 	userId := c.GetInt("id")
 	p, _ := strconv.Atoi(c.Query("p"))
 	p, _ := strconv.Atoi(c.Query("p"))
 	size, _ := strconv.Atoi(c.Query("size"))
 	size, _ := strconv.Atoi(c.Query("size"))
-	if p < 0 {
-		p = 0
+	if p < 1 {
+		p = 1
 	}
 	}
 	if size <= 0 {
 	if size <= 0 {
 		size = common.ItemsPerPage
 		size = common.ItemsPerPage
 	} else if size > 100 {
 	} else if size > 100 {
 		size = 100
 		size = 100
 	}
 	}
-	tokens, err := model.GetAllUserTokens(userId, p*size, size)
+	tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
 	if err != nil {
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"success": false,
@@ -28,10 +28,18 @@ func GetAllTokens(c *gin.Context) {
 		})
 		})
 		return
 		return
 	}
 	}
+	// Get total count for pagination
+	total, _ := model.CountUserTokens(userId)
+
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
-		"data":    tokens,
+		"data": gin.H{
+			"items":     tokens,
+			"total":     total,
+			"page":      p,
+			"page_size": size,
+		},
 	})
 	})
 	return
 	return
 }
 }

+ 7 - 9
controller/topup.go

@@ -97,16 +97,14 @@ func RequestEpay(c *gin.Context) {
 		c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
 		c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
 		return
 		return
 	}
 	}
-	payType := "wxpay"
-	if req.PaymentMethod == "zfb" {
-		payType = "alipay"
-	}
-	if req.PaymentMethod == "wx" {
-		req.PaymentMethod = "wxpay"
-		payType = "wxpay"
+
+	if !setting.ContainsPayMethod(req.PaymentMethod) {
+		c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
+		return
 	}
 	}
+
 	callBackAddress := service.GetCallbackAddress()
 	callBackAddress := service.GetCallbackAddress()
-	returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
+	returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
 	notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
 	notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
 	tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
 	tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
 	tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
 	tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
 		return
 		return
 	}
 	}
 	uri, params, err := client.Purchase(&epay.PurchaseArgs{
 	uri, params, err := client.Purchase(&epay.PurchaseArgs{
-		Type:           payType,
+		Type:           req.PaymentMethod,
 		ServiceTradeNo: tradeNo,
 		ServiceTradeNo: tradeNo,
 		Name:           fmt.Sprintf("TUC%d", req.Amount),
 		Name:           fmt.Sprintf("TUC%d", req.Amount),
 		Money:          strconv.FormatFloat(payMoney, 'f', 2, 64),
 		Money:          strconv.FormatFloat(payMoney, 'f', 2, 64),

+ 154 - 0
controller/uptime_kuma.go

@@ -0,0 +1,154 @@
+package controller
+
+import (
+	"context"
+	"encoding/json"
+	"errors"
+	"net/http"
+	"one-api/setting/console_setting"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"golang.org/x/sync/errgroup"
+)
+
+const (
+	requestTimeout   = 30 * time.Second
+	httpTimeout      = 10 * time.Second
+	uptimeKeySuffix  = "_24"
+	apiStatusPath    = "/api/status-page/"
+	apiHeartbeatPath = "/api/status-page/heartbeat/"
+)
+
+type Monitor struct {
+	Name   string  `json:"name"`
+	Uptime float64 `json:"uptime"`
+	Status int     `json:"status"`
+	Group  string  `json:"group,omitempty"`
+}
+
+type UptimeGroupResult struct {
+	CategoryName string    `json:"categoryName"`
+	Monitors  []Monitor `json:"monitors"`
+}
+
+func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+	if err != nil {
+		return err
+	}
+
+	resp, err := client.Do(req)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		return errors.New("non-200 status")
+	}
+
+	return json.NewDecoder(resp.Body).Decode(dest)
+}
+
+func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
+	url, _ := groupConfig["url"].(string)
+	slug, _ := groupConfig["slug"].(string)
+	categoryName, _ := groupConfig["categoryName"].(string)
+	
+	result := UptimeGroupResult{
+		CategoryName: categoryName,
+		Monitors:  []Monitor{},
+	}
+	
+	if url == "" || slug == "" {
+		return result
+	}
+
+	baseURL := strings.TrimSuffix(url, "/")
+	
+	var statusData struct {
+		PublicGroupList []struct {
+			ID   int    `json:"id"`
+			Name string `json:"name"`
+			MonitorList []struct {
+				ID   int    `json:"id"`
+				Name string `json:"name"`
+			} `json:"monitorList"`
+		} `json:"publicGroupList"`
+	}
+	
+	var heartbeatData struct {
+		HeartbeatList map[string][]struct {
+			Status int `json:"status"`
+		} `json:"heartbeatList"`
+		UptimeList map[string]float64 `json:"uptimeList"`
+	}
+
+	g, gCtx := errgroup.WithContext(ctx)
+	g.Go(func() error { 
+		return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData) 
+	})
+	g.Go(func() error { 
+		return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData) 
+	})
+
+	if g.Wait() != nil {
+		return result
+	}
+
+	for _, pg := range statusData.PublicGroupList {
+		if len(pg.MonitorList) == 0 {
+			continue
+		}
+
+		for _, m := range pg.MonitorList {
+			monitor := Monitor{
+				Name:  m.Name,
+				Group: pg.Name,
+			}
+
+			monitorID := strconv.Itoa(m.ID)
+
+			if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
+				monitor.Uptime = uptime
+			}
+
+			if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
+				monitor.Status = heartbeats[0].Status
+			}
+
+			result.Monitors = append(result.Monitors, monitor)
+		}
+	}
+
+	return result
+}
+
+func GetUptimeKumaStatus(c *gin.Context) {
+	groups := console_setting.GetUptimeKumaGroups()
+	if len(groups) == 0 {
+		c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
+		return
+	}
+
+	ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
+	defer cancel()
+
+	client := &http.Client{Timeout: httpTimeout}
+	results := make([]UptimeGroupResult, len(groups))
+	
+	g, gCtx := errgroup.WithContext(ctx)
+	for i, group := range groups {
+		i, group := i, group
+		g.Go(func() error {
+			results[i] = fetchGroupData(gCtx, client, group)
+			return nil
+		})
+	}
+	
+	g.Wait()
+	c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
+} 

+ 8 - 0
controller/user.go

@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
 			UnlimitedQuota:     true,
 			UnlimitedQuota:     true,
 			ModelLimitsEnabled: false,
 			ModelLimitsEnabled: false,
 		}
 		}
+		if setting.DefaultUseAutoGroup {
+			token.Group = "auto"
+		}
 		if err := token.Insert(); err != nil {
 		if err := token.Insert(); err != nil {
 			c.JSON(http.StatusOK, gin.H{
 			c.JSON(http.StatusOK, gin.H{
 				"success": false,
 				"success": false,
@@ -459,6 +462,9 @@ func GetSelf(c *gin.Context) {
 		})
 		})
 		return
 		return
 	}
 	}
+	// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
+	user.Remark = ""
+
 	c.JSON(http.StatusOK, gin.H{
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"success": true,
 		"message": "",
 		"message": "",
@@ -943,6 +949,7 @@ type UpdateUserSettingRequest struct {
 	WebhookSecret              string  `json:"webhook_secret,omitempty"`
 	WebhookSecret              string  `json:"webhook_secret,omitempty"`
 	NotificationEmail          string  `json:"notification_email,omitempty"`
 	NotificationEmail          string  `json:"notification_email,omitempty"`
 	AcceptUnsetModelRatioModel bool    `json:"accept_unset_model_ratio_model"`
 	AcceptUnsetModelRatioModel bool    `json:"accept_unset_model_ratio_model"`
+	RecordIpLog                bool    `json:"record_ip_log"`
 }
 }
 
 
 func UpdateUserSetting(c *gin.Context) {
 func UpdateUserSetting(c *gin.Context) {
@@ -1019,6 +1026,7 @@ func UpdateUserSetting(c *gin.Context) {
 		constant.UserSettingNotifyType:            req.QuotaWarningType,
 		constant.UserSettingNotifyType:            req.QuotaWarningType,
 		constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
 		constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
 		"accept_unset_model_ratio_model":          req.AcceptUnsetModelRatioModel,
 		"accept_unset_model_ratio_model":          req.AcceptUnsetModelRatioModel,
+		constant.UserSettingRecordIpLog:           req.RecordIpLog,
 	}
 	}
 
 
 	// 如果是webhook类型,添加webhook相关设置
 	// 如果是webhook类型,添加webhook相关设置

+ 87 - 48
dto/claude.go

@@ -1,29 +1,33 @@
 package dto
 package dto
 
 
-import "encoding/json"
+import (
+	"encoding/json"
+	"one-api/common"
+)
 
 
 type ClaudeMetadata struct {
 type ClaudeMetadata struct {
 	UserId string `json:"user_id"`
 	UserId string `json:"user_id"`
 }
 }
 
 
 type ClaudeMediaMessage struct {
 type ClaudeMediaMessage struct {
-	Type        string               `json:"type,omitempty"`
-	Text        *string              `json:"text,omitempty"`
-	Model       string               `json:"model,omitempty"`
-	Source      *ClaudeMessageSource `json:"source,omitempty"`
-	Usage       *ClaudeUsage         `json:"usage,omitempty"`
-	StopReason  *string              `json:"stop_reason,omitempty"`
-	PartialJson *string              `json:"partial_json,omitempty"`
-	Role        string               `json:"role,omitempty"`
-	Thinking    string               `json:"thinking,omitempty"`
-	Signature   string               `json:"signature,omitempty"`
-	Delta       string               `json:"delta,omitempty"`
+	Type         string               `json:"type,omitempty"`
+	Text         *string              `json:"text,omitempty"`
+	Model        string               `json:"model,omitempty"`
+	Source       *ClaudeMessageSource `json:"source,omitempty"`
+	Usage        *ClaudeUsage         `json:"usage,omitempty"`
+	StopReason   *string              `json:"stop_reason,omitempty"`
+	PartialJson  *string              `json:"partial_json,omitempty"`
+	Role         string               `json:"role,omitempty"`
+	Thinking     string               `json:"thinking,omitempty"`
+	Signature    string               `json:"signature,omitempty"`
+	Delta        string               `json:"delta,omitempty"`
+	CacheControl json.RawMessage      `json:"cache_control,omitempty"`
 	// tool_calls
 	// tool_calls
-	Id        string          `json:"id,omitempty"`
-	Name      string          `json:"name,omitempty"`
-	Input     any             `json:"input,omitempty"`
-	Content   json.RawMessage `json:"content,omitempty"`
-	ToolUseId string          `json:"tool_use_id,omitempty"`
+	Id        string `json:"id,omitempty"`
+	Name      string `json:"name,omitempty"`
+	Input     any    `json:"input,omitempty"`
+	Content   any    `json:"content,omitempty"`
+	ToolUseId string `json:"tool_use_id,omitempty"`
 }
 }
 
 
 func (c *ClaudeMediaMessage) SetText(s string) {
 func (c *ClaudeMediaMessage) SetText(s string) {
@@ -38,15 +42,39 @@ func (c *ClaudeMediaMessage) GetText() string {
 }
 }
 
 
 func (c *ClaudeMediaMessage) IsStringContent() bool {
 func (c *ClaudeMediaMessage) IsStringContent() bool {
-	var content string
-	return json.Unmarshal(c.Content, &content) == nil
+	if c.Content == nil {
+		return false
+	}
+	_, ok := c.Content.(string)
+	if ok {
+		return true
+	}
+	return false
 }
 }
 
 
 func (c *ClaudeMediaMessage) GetStringContent() string {
 func (c *ClaudeMediaMessage) GetStringContent() string {
-	var content string
-	if err := json.Unmarshal(c.Content, &content); err == nil {
-		return content
+	if c.Content == nil {
+		return ""
+	}
+	switch c.Content.(type) {
+	case string:
+		return c.Content.(string)
+	case []any:
+		var contentStr string
+		for _, contentItem := range c.Content.([]any) {
+			contentMap, ok := contentItem.(map[string]any)
+			if !ok {
+				continue
+			}
+			if contentMap["type"] == ContentTypeText {
+				if subStr, ok := contentMap["text"].(string); ok {
+					contentStr += subStr
+				}
+			}
+		}
+		return contentStr
 	}
 	}
+
 	return ""
 	return ""
 }
 }
 
 
@@ -56,16 +84,12 @@ func (c *ClaudeMediaMessage) GetJsonRowString() string {
 }
 }
 
 
 func (c *ClaudeMediaMessage) SetContent(content any) {
 func (c *ClaudeMediaMessage) SetContent(content any) {
-	jsonContent, _ := json.Marshal(content)
-	c.Content = jsonContent
+	c.Content = content
 }
 }
 
 
 func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
 func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
-	var mediaContent []ClaudeMediaMessage
-	if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
-		return mediaContent
-	}
-	return make([]ClaudeMediaMessage, 0)
+	mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
+	return mediaContent
 }
 }
 
 
 type ClaudeMessageSource struct {
 type ClaudeMessageSource struct {
@@ -81,14 +105,36 @@ type ClaudeMessage struct {
 }
 }
 
 
 func (c *ClaudeMessage) IsStringContent() bool {
 func (c *ClaudeMessage) IsStringContent() bool {
+	if c.Content == nil {
+		return false
+	}
 	_, ok := c.Content.(string)
 	_, ok := c.Content.(string)
 	return ok
 	return ok
 }
 }
 
 
 func (c *ClaudeMessage) GetStringContent() string {
 func (c *ClaudeMessage) GetStringContent() string {
-	if c.IsStringContent() {
+	if c.Content == nil {
+		return ""
+	}
+	switch c.Content.(type) {
+	case string:
 		return c.Content.(string)
 		return c.Content.(string)
+	case []any:
+		var contentStr string
+		for _, contentItem := range c.Content.([]any) {
+			contentMap, ok := contentItem.(map[string]any)
+			if !ok {
+				continue
+			}
+			if contentMap["type"] == ContentTypeText {
+				if subStr, ok := contentMap["text"].(string); ok {
+					contentStr += subStr
+				}
+			}
+		}
+		return contentStr
 	}
 	}
+
 	return ""
 	return ""
 }
 }
 
 
@@ -97,15 +143,7 @@ func (c *ClaudeMessage) SetStringContent(content string) {
 }
 }
 
 
 func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
 func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
-	// map content to []ClaudeMediaMessage
-	// parse to json
-	jsonContent, _ := json.Marshal(c.Content)
-	var contentList []ClaudeMediaMessage
-	err := json.Unmarshal(jsonContent, &contentList)
-	if err != nil {
-		return make([]ClaudeMediaMessage, 0), err
-	}
-	return contentList, nil
+	return common.Any2Type[[]ClaudeMediaMessage](c.Content)
 }
 }
 
 
 type Tool struct {
 type Tool struct {
@@ -140,7 +178,14 @@ type ClaudeRequest struct {
 
 
 type Thinking struct {
 type Thinking struct {
 	Type         string `json:"type"`
 	Type         string `json:"type"`
-	BudgetTokens int    `json:"budget_tokens"`
+	BudgetTokens *int   `json:"budget_tokens,omitempty"`
+}
+
+func (c *Thinking) GetBudgetTokens() int {
+	if c.BudgetTokens == nil {
+		return 0
+	}
+	return *c.BudgetTokens
 }
 }
 
 
 func (c *ClaudeRequest) IsStringSystem() bool {
 func (c *ClaudeRequest) IsStringSystem() bool {
@@ -160,14 +205,8 @@ func (c *ClaudeRequest) SetStringSystem(system string) {
 }
 }
 
 
 func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
 func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
-	// map content to []ClaudeMediaMessage
-	// parse to json
-	jsonContent, _ := json.Marshal(c.System)
-	var contentList []ClaudeMediaMessage
-	if err := json.Unmarshal(jsonContent, &contentList); err == nil {
-		return contentList
-	}
-	return make([]ClaudeMediaMessage, 0)
+	mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
+	return mediaContent
 }
 }
 
 
 type ClaudeError struct {
 type ClaudeError struct {

+ 2 - 0
dto/dalle.go

@@ -14,6 +14,8 @@ type ImageRequest struct {
 	ExtraFields    json.RawMessage `json:"extra_fields,omitempty"`
 	ExtraFields    json.RawMessage `json:"extra_fields,omitempty"`
 	Background     string          `json:"background,omitempty"`
 	Background     string          `json:"background,omitempty"`
 	Moderation     string          `json:"moderation,omitempty"`
 	Moderation     string          `json:"moderation,omitempty"`
+	OutputFormat   string          `json:"output_format,omitempty"`
+	Watermark      *bool           `json:"watermark,omitempty"`
 }
 }
 
 
 type ImageResponse struct {
 type ImageResponse struct {

+ 280 - 50
dto/openai_request.go

@@ -2,6 +2,7 @@ package dto
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"one-api/common"
 	"strings"
 	"strings"
 )
 )
 
 
@@ -18,41 +19,54 @@ type FormatJsonSchema struct {
 }
 }
 
 
 type GeneralOpenAIRequest struct {
 type GeneralOpenAIRequest struct {
-	Model               string         `json:"model,omitempty"`
-	Messages            []Message      `json:"messages,omitempty"`
-	Prompt              any            `json:"prompt,omitempty"`
-	Prefix              any            `json:"prefix,omitempty"`
-	Suffix              any            `json:"suffix,omitempty"`
-	Stream              bool           `json:"stream,omitempty"`
-	StreamOptions       *StreamOptions `json:"stream_options,omitempty"`
-	MaxTokens           uint           `json:"max_tokens,omitempty"`
-	MaxCompletionTokens uint           `json:"max_completion_tokens,omitempty"`
-	ReasoningEffort     string         `json:"reasoning_effort,omitempty"`
-	//Reasoning           json.RawMessage   `json:"reasoning,omitempty"`
-	Temperature      *float64          `json:"temperature,omitempty"`
-	TopP             float64           `json:"top_p,omitempty"`
-	TopK             int               `json:"top_k,omitempty"`
-	Stop             any               `json:"stop,omitempty"`
-	N                int               `json:"n,omitempty"`
-	Input            any               `json:"input,omitempty"`
-	Instruction      string            `json:"instruction,omitempty"`
-	Size             string            `json:"size,omitempty"`
-	Functions        any               `json:"functions,omitempty"`
-	FrequencyPenalty float64           `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float64           `json:"presence_penalty,omitempty"`
-	ResponseFormat   *ResponseFormat   `json:"response_format,omitempty"`
-	EncodingFormat   any               `json:"encoding_format,omitempty"`
-	Seed             float64           `json:"seed,omitempty"`
-	Tools            []ToolCallRequest `json:"tools,omitempty"`
-	ToolChoice       any               `json:"tool_choice,omitempty"`
-	User             string            `json:"user,omitempty"`
-	LogProbs         bool              `json:"logprobs,omitempty"`
-	TopLogProbs      int               `json:"top_logprobs,omitempty"`
-	Dimensions       int               `json:"dimensions,omitempty"`
-	Modalities       any               `json:"modalities,omitempty"`
-	Audio            any               `json:"audio,omitempty"`
-	EnableThinking   any               `json:"enable_thinking,omitempty"` // ali
-	ExtraBody        any               `json:"extra_body,omitempty"`
+	Model               string            `json:"model,omitempty"`
+	Messages            []Message         `json:"messages,omitempty"`
+	Prompt              any               `json:"prompt,omitempty"`
+	Prefix              any               `json:"prefix,omitempty"`
+	Suffix              any               `json:"suffix,omitempty"`
+	Stream              bool              `json:"stream,omitempty"`
+	StreamOptions       *StreamOptions    `json:"stream_options,omitempty"`
+	MaxTokens           uint              `json:"max_tokens,omitempty"`
+	MaxCompletionTokens uint              `json:"max_completion_tokens,omitempty"`
+	ReasoningEffort     string            `json:"reasoning_effort,omitempty"`
+	Temperature         *float64          `json:"temperature,omitempty"`
+	TopP                float64           `json:"top_p,omitempty"`
+	TopK                int               `json:"top_k,omitempty"`
+	Stop                any               `json:"stop,omitempty"`
+	N                   int               `json:"n,omitempty"`
+	Input               any               `json:"input,omitempty"`
+	Instruction         string            `json:"instruction,omitempty"`
+	Size                string            `json:"size,omitempty"`
+	Functions           json.RawMessage   `json:"functions,omitempty"`
+	FrequencyPenalty    float64           `json:"frequency_penalty,omitempty"`
+	PresencePenalty     float64           `json:"presence_penalty,omitempty"`
+	ResponseFormat      *ResponseFormat   `json:"response_format,omitempty"`
+	EncodingFormat      json.RawMessage   `json:"encoding_format,omitempty"`
+	Seed                float64           `json:"seed,omitempty"`
+	ParallelTooCalls    *bool             `json:"parallel_tool_calls,omitempty"`
+	Tools               []ToolCallRequest `json:"tools,omitempty"`
+	ToolChoice          any               `json:"tool_choice,omitempty"`
+	User                string            `json:"user,omitempty"`
+	LogProbs            bool              `json:"logprobs,omitempty"`
+	TopLogProbs         int               `json:"top_logprobs,omitempty"`
+	Dimensions          int               `json:"dimensions,omitempty"`
+	Modalities          json.RawMessage   `json:"modalities,omitempty"`
+	Audio               json.RawMessage   `json:"audio,omitempty"`
+	EnableThinking      any               `json:"enable_thinking,omitempty"` // ali
+	THINKING            json.RawMessage   `json:"thinking,omitempty"`        // doubao
+	ExtraBody           json.RawMessage   `json:"extra_body,omitempty"`
+	WebSearchOptions    *WebSearchOptions `json:"web_search_options,omitempty"`
+	// OpenRouter Params
+	Reasoning json.RawMessage `json:"reasoning,omitempty"`
+	// Ali Qwen Params
+	VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
+}
+
+func (r *GeneralOpenAIRequest) ToMap() map[string]any {
+	result := make(map[string]any)
+	data, _ := common.EncodeJson(r)
+	_ = common.DecodeJson(data, &result)
+	return result
 }
 }
 
 
 type ToolCallRequest struct {
 type ToolCallRequest struct {
@@ -72,11 +86,11 @@ type StreamOptions struct {
 	IncludeUsage bool `json:"include_usage,omitempty"`
 	IncludeUsage bool `json:"include_usage,omitempty"`
 }
 }
 
 
-func (r GeneralOpenAIRequest) GetMaxTokens() int {
+func (r *GeneralOpenAIRequest) GetMaxTokens() int {
 	return int(r.MaxTokens)
 	return int(r.MaxTokens)
 }
 }
 
 
-func (r GeneralOpenAIRequest) ParseInput() []string {
+func (r *GeneralOpenAIRequest) ParseInput() []string {
 	if r.Input == nil {
 	if r.Input == nil {
 		return nil
 		return nil
 	}
 	}
@@ -96,16 +110,16 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
 }
 }
 
 
 type Message struct {
 type Message struct {
-	Role                string          `json:"role"`
-	Content             json.RawMessage `json:"content"`
-	Name                *string         `json:"name,omitempty"`
-	Prefix              *bool           `json:"prefix,omitempty"`
-	ReasoningContent    string          `json:"reasoning_content,omitempty"`
-	Reasoning           string          `json:"reasoning,omitempty"`
-	ToolCalls           json.RawMessage `json:"tool_calls,omitempty"`
-	ToolCallId          string          `json:"tool_call_id,omitempty"`
-	parsedContent       []MediaContent
-	parsedStringContent *string
+	Role             string          `json:"role"`
+	Content          any             `json:"content"`
+	Name             *string         `json:"name,omitempty"`
+	Prefix           *bool           `json:"prefix,omitempty"`
+	ReasoningContent string          `json:"reasoning_content,omitempty"`
+	Reasoning        string          `json:"reasoning,omitempty"`
+	ToolCalls        json.RawMessage `json:"tool_calls,omitempty"`
+	ToolCallId       string          `json:"tool_call_id,omitempty"`
+	parsedContent    []MediaContent
+	//parsedStringContent *string
 }
 }
 
 
 type MediaContent struct {
 type MediaContent struct {
@@ -115,25 +129,56 @@ type MediaContent struct {
 	InputAudio any    `json:"input_audio,omitempty"`
 	InputAudio any    `json:"input_audio,omitempty"`
 	File       any    `json:"file,omitempty"`
 	File       any    `json:"file,omitempty"`
 	VideoUrl   any    `json:"video_url,omitempty"`
 	VideoUrl   any    `json:"video_url,omitempty"`
+	// OpenRouter Params
+	CacheControl json.RawMessage `json:"cache_control,omitempty"`
 }
 }
 
 
 func (m *MediaContent) GetImageMedia() *MessageImageUrl {
 func (m *MediaContent) GetImageMedia() *MessageImageUrl {
 	if m.ImageUrl != nil {
 	if m.ImageUrl != nil {
-		return m.ImageUrl.(*MessageImageUrl)
+		if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
+			return m.ImageUrl.(*MessageImageUrl)
+		}
+		if itemMap, ok := m.ImageUrl.(map[string]any); ok {
+			out := &MessageImageUrl{
+				Url:      common.Interface2String(itemMap["url"]),
+				Detail:   common.Interface2String(itemMap["detail"]),
+				MimeType: common.Interface2String(itemMap["mime_type"]),
+			}
+			return out
+		}
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
 func (m *MediaContent) GetInputAudio() *MessageInputAudio {
 func (m *MediaContent) GetInputAudio() *MessageInputAudio {
 	if m.InputAudio != nil {
 	if m.InputAudio != nil {
-		return m.InputAudio.(*MessageInputAudio)
+		if _, ok := m.InputAudio.(*MessageInputAudio); ok {
+			return m.InputAudio.(*MessageInputAudio)
+		}
+		if itemMap, ok := m.InputAudio.(map[string]any); ok {
+			out := &MessageInputAudio{
+				Data:   common.Interface2String(itemMap["data"]),
+				Format: common.Interface2String(itemMap["format"]),
+			}
+			return out
+		}
 	}
 	}
 	return nil
 	return nil
 }
 }
 
 
 func (m *MediaContent) GetFile() *MessageFile {
 func (m *MediaContent) GetFile() *MessageFile {
 	if m.File != nil {
 	if m.File != nil {
-		return m.File.(*MessageFile)
+		if _, ok := m.File.(*MessageFile); ok {
+			return m.File.(*MessageFile)
+		}
+		if itemMap, ok := m.File.(map[string]any); ok {
+			out := &MessageFile{
+				FileName: common.Interface2String(itemMap["file_name"]),
+				FileData: common.Interface2String(itemMap["file_data"]),
+				FileId:   common.Interface2String(itemMap["file_id"]),
+			}
+			return out
+		}
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -199,6 +244,186 @@ func (m *Message) SetToolCalls(toolCalls any) {
 }
 }
 
 
 func (m *Message) StringContent() string {
 func (m *Message) StringContent() string {
+	switch m.Content.(type) {
+	case string:
+		return m.Content.(string)
+	case []any:
+		var contentStr string
+		for _, contentItem := range m.Content.([]any) {
+			contentMap, ok := contentItem.(map[string]any)
+			if !ok {
+				continue
+			}
+			if contentMap["type"] == ContentTypeText {
+				if subStr, ok := contentMap["text"].(string); ok {
+					contentStr += subStr
+				}
+			}
+		}
+		return contentStr
+	}
+
+	return ""
+}
+
+func (m *Message) SetNullContent() {
+	m.Content = nil
+	m.parsedContent = nil
+}
+
+func (m *Message) SetStringContent(content string) {
+	m.Content = content
+	m.parsedContent = nil
+}
+
+func (m *Message) SetMediaContent(content []MediaContent) {
+	m.Content = content
+	m.parsedContent = content
+}
+
+func (m *Message) IsStringContent() bool {
+	_, ok := m.Content.(string)
+	if ok {
+		return true
+	}
+	return false
+}
+
+func (m *Message) ParseContent() []MediaContent {
+	if m.Content == nil {
+		return nil
+	}
+	if len(m.parsedContent) > 0 {
+		return m.parsedContent
+	}
+
+	var contentList []MediaContent
+	// 先尝试解析为字符串
+	content, ok := m.Content.(string)
+	if ok {
+		contentList = []MediaContent{{
+			Type: ContentTypeText,
+			Text: content,
+		}}
+		m.parsedContent = contentList
+		return contentList
+	}
+
+	// 尝试解析为数组
+	//var arrayContent []map[string]interface{}
+
+	arrayContent, ok := m.Content.([]any)
+	if !ok {
+		return contentList
+	}
+
+	for _, contentItemAny := range arrayContent {
+		mediaItem, ok := contentItemAny.(MediaContent)
+		if ok {
+			contentList = append(contentList, mediaItem)
+			continue
+		}
+
+		contentItem, ok := contentItemAny.(map[string]any)
+		if !ok {
+			continue
+		}
+		contentType, ok := contentItem["type"].(string)
+		if !ok {
+			continue
+		}
+
+		switch contentType {
+		case ContentTypeText:
+			if text, ok := contentItem["text"].(string); ok {
+				contentList = append(contentList, MediaContent{
+					Type: ContentTypeText,
+					Text: text,
+				})
+			}
+
+		case ContentTypeImageURL:
+			imageUrl := contentItem["image_url"]
+			temp := &MessageImageUrl{
+				Detail: "high",
+			}
+			switch v := imageUrl.(type) {
+			case string:
+				temp.Url = v
+			case map[string]interface{}:
+				url, ok1 := v["url"].(string)
+				detail, ok2 := v["detail"].(string)
+				if ok2 {
+					temp.Detail = detail
+				}
+				if ok1 {
+					temp.Url = url
+				}
+			}
+			contentList = append(contentList, MediaContent{
+				Type:     ContentTypeImageURL,
+				ImageUrl: temp,
+			})
+
+		case ContentTypeInputAudio:
+			if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
+				data, ok1 := audioData["data"].(string)
+				format, ok2 := audioData["format"].(string)
+				if ok1 && ok2 {
+					temp := &MessageInputAudio{
+						Data:   data,
+						Format: format,
+					}
+					contentList = append(contentList, MediaContent{
+						Type:       ContentTypeInputAudio,
+						InputAudio: temp,
+					})
+				}
+			}
+		case ContentTypeFile:
+			if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
+				fileId, ok3 := fileData["file_id"].(string)
+				if ok3 {
+					contentList = append(contentList, MediaContent{
+						Type: ContentTypeFile,
+						File: &MessageFile{
+							FileId: fileId,
+						},
+					})
+				} else {
+					fileName, ok1 := fileData["filename"].(string)
+					fileDataStr, ok2 := fileData["file_data"].(string)
+					if ok1 && ok2 {
+						contentList = append(contentList, MediaContent{
+							Type: ContentTypeFile,
+							File: &MessageFile{
+								FileName: fileName,
+								FileData: fileDataStr,
+							},
+						})
+					}
+				}
+			}
+		case ContentTypeVideoUrl:
+			if videoUrl, ok := contentItem["video_url"].(string); ok {
+				contentList = append(contentList, MediaContent{
+					Type: ContentTypeVideoUrl,
+					VideoUrl: &MessageVideoUrl{
+						Url: videoUrl,
+					},
+				})
+			}
+		}
+	}
+
+	if len(contentList) > 0 {
+		m.parsedContent = contentList
+	}
+	return contentList
+}
+
+// old code
+/*func (m *Message) StringContent() string {
 	if m.parsedStringContent != nil {
 	if m.parsedStringContent != nil {
 		return *m.parsedStringContent
 		return *m.parsedStringContent
 	}
 	}
@@ -369,6 +594,11 @@ func (m *Message) ParseContent() []MediaContent {
 		m.parsedContent = contentList
 		m.parsedContent = contentList
 	}
 	}
 	return contentList
 	return contentList
+}*/
+
+type WebSearchOptions struct {
+	SearchContextSize string          `json:"search_context_size,omitempty"`
+	UserLocation      json.RawMessage `json:"user_location,omitempty"`
 }
 }
 
 
 type OpenAIResponsesRequest struct {
 type OpenAIResponsesRequest struct {

+ 49 - 0
dto/ratio_sync.go

@@ -0,0 +1,49 @@
+package dto
+
+// UpstreamDTO 提交到后端同步倍率的上游渠道信息
+// Endpoint 可以为空,后端会默认使用 /api/ratio_config
+// BaseURL 必须以 http/https 开头,不要以 / 结尾
+// 例如: https://api.example.com
+// Endpoint: /api/ratio_config
+// 提交示例:
+// {
+//   "name": "openai",
+//   "base_url": "https://api.openai.com",
+//   "endpoint": "/ratio_config"
+// }
+
+type UpstreamDTO struct {
+    Name     string `json:"name" binding:"required"`
+    BaseURL  string `json:"base_url" binding:"required"`
+    Endpoint string `json:"endpoint"`
+}
+
+type UpstreamRequest struct {
+    ChannelIDs []int64 `json:"channel_ids"`
+    Timeout    int     `json:"timeout"`
+}
+
+// TestResult 上游测试连通性结果
+type TestResult struct {
+    Name   string `json:"name"`
+    Status string `json:"status"`
+    Error  string `json:"error,omitempty"`
+}
+
+// DifferenceItem 差异项
+// Current 为本地值,可能为 nil
+// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
+
+type DifferenceItem struct {
+    Current   interface{}            `json:"current"`
+    Upstreams map[string]interface{} `json:"upstreams"`
+}
+
+// SyncableChannel 可同步的渠道信息(base_url 不为空)
+
+type SyncableChannel struct {
+    ID      int    `json:"id"`
+    Name    string `json:"name"`
+    BaseURL string `json:"base_url"`
+    Status  int    `json:"status"`
+} 

+ 47 - 0
dto/video.go

@@ -0,0 +1,47 @@
+package dto
+
+type VideoRequest struct {
+	Model          string         `json:"model,omitempty" example:"kling-v1"`                                                                                                                                    // Model/style ID
+	Prompt         string         `json:"prompt,omitempty" example:"宇航员站起身走了"`                                                                                                                                   // Text prompt
+	Image          string         `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
+	Duration       float64        `json:"duration" example:"5.0"`                                                                                                                                                // Video duration (seconds)
+	Width          int            `json:"width" example:"512"`                                                                                                                                                   // Video width
+	Height         int            `json:"height" example:"512"`                                                                                                                                                  // Video height
+	Fps            int            `json:"fps,omitempty" example:"30"`                                                                                                                                            // Video frame rate
+	Seed           int            `json:"seed,omitempty" example:"20231234"`                                                                                                                                     // Random seed
+	N              int            `json:"n,omitempty" example:"1"`                                                                                                                                               // Number of videos to generate
+	ResponseFormat string         `json:"response_format,omitempty" example:"url"`                                                                                                                               // Response format
+	User           string         `json:"user,omitempty" example:"user-1234"`                                                                                                                                    // User identifier
+	Metadata       map[string]any `json:"metadata,omitempty"`                                                                                                                                                    // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
+}
+
+// VideoResponse 视频生成提交任务后的响应
+type VideoResponse struct {
+	TaskId string `json:"task_id"`
+	Status string `json:"status"`
+}
+
+// VideoTaskResponse 查询视频生成任务状态的响应
+type VideoTaskResponse struct {
+	TaskId   string             `json:"task_id" example:"abcd1234efgh"` // 任务ID
+	Status   string             `json:"status" example:"succeeded"`     // 任务状态
+	Url      string             `json:"url,omitempty"`                  // 视频资源URL(成功时)
+	Format   string             `json:"format,omitempty" example:"mp4"` // 视频格式
+	Metadata *VideoTaskMetadata `json:"metadata,omitempty"`             // 结果元数据
+	Error    *VideoTaskError    `json:"error,omitempty"`                // 错误信息(失败时)
+}
+
+// VideoTaskMetadata 视频任务元数据
+type VideoTaskMetadata struct {
+	Duration float64 `json:"duration" example:"5.0"`  // 实际生成的视频时长
+	Fps      int     `json:"fps" example:"30"`        // 实际帧率
+	Width    int     `json:"width" example:"512"`     // 实际宽度
+	Height   int     `json:"height" example:"512"`    // 实际高度
+	Seed     int     `json:"seed" example:"20231234"` // 使用的随机种子
+}
+
+// VideoTaskError 视频任务错误信息
+type VideoTaskError struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+}

+ 3 - 3
go.mod

@@ -11,7 +11,6 @@ require (
 	github.com/aws/aws-sdk-go-v2/credentials v1.17.11
 	github.com/aws/aws-sdk-go-v2/credentials v1.17.11
 	github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
 	github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
 	github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
 	github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
-	github.com/bytedance/sonic v1.11.6
 	github.com/gin-contrib/cors v1.7.2
 	github.com/gin-contrib/cors v1.7.2
 	github.com/gin-contrib/gzip v0.0.6
 	github.com/gin-contrib/gzip v0.0.6
 	github.com/gin-contrib/sessions v0.0.5
 	github.com/gin-contrib/sessions v0.0.5
@@ -25,10 +24,10 @@ require (
 	github.com/gorilla/websocket v1.5.0
 	github.com/gorilla/websocket v1.5.0
 	github.com/joho/godotenv v1.5.1
 	github.com/joho/godotenv v1.5.1
 	github.com/pkg/errors v0.9.1
 	github.com/pkg/errors v0.9.1
-	github.com/pkoukk/tiktoken-go v0.1.7
 	github.com/samber/lo v1.39.0
 	github.com/samber/lo v1.39.0
 	github.com/shirou/gopsutil v3.21.11+incompatible
 	github.com/shirou/gopsutil v3.21.11+incompatible
 	github.com/shopspring/decimal v1.4.0
 	github.com/shopspring/decimal v1.4.0
+	github.com/tiktoken-go/tokenizer v0.6.2
 	golang.org/x/crypto v0.35.0
 	golang.org/x/crypto v0.35.0
 	golang.org/x/image v0.23.0
 	golang.org/x/image v0.23.0
 	golang.org/x/net v0.35.0
 	golang.org/x/net v0.35.0
@@ -43,12 +42,13 @@ require (
 	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
 	github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
 	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
 	github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
 	github.com/aws/smithy-go v1.20.2 // indirect
 	github.com/aws/smithy-go v1.20.2 // indirect
+	github.com/bytedance/sonic v1.11.6 // indirect
 	github.com/bytedance/sonic/loader v0.1.1 // indirect
 	github.com/bytedance/sonic/loader v0.1.1 // indirect
 	github.com/cespare/xxhash/v2 v2.3.0 // indirect
 	github.com/cespare/xxhash/v2 v2.3.0 // indirect
 	github.com/cloudwego/base64x v0.1.4 // indirect
 	github.com/cloudwego/base64x v0.1.4 // indirect
 	github.com/cloudwego/iasm v0.2.0 // indirect
 	github.com/cloudwego/iasm v0.2.0 // indirect
 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
 	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
-	github.com/dlclark/regexp2 v1.11.0 // indirect
+	github.com/dlclark/regexp2 v1.11.5 // indirect
 	github.com/dustin/go-humanize v1.0.1 // indirect
 	github.com/dustin/go-humanize v1.0.1 // indirect
 	github.com/gabriel-vasile/mimetype v1.4.3 // indirect
 	github.com/gabriel-vasile/mimetype v1.4.3 // indirect
 	github.com/gin-contrib/sse v0.1.0 // indirect
 	github.com/gin-contrib/sse v0.1.0 // indirect

+ 4 - 4
go.sum

@@ -38,8 +38,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
 github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
 github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
 github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
-github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
-github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
+github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
+github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
 github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
 github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
 github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
 github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
 github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
 github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
@@ -167,8 +167,6 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
 github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
-github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
 github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
@@ -197,6 +195,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
 github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
 github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
 github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
 github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
+github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
 github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
 github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
 github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
 github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
 github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
 github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=

+ 21 - 6
main.go

@@ -12,7 +12,7 @@ import (
 	"one-api/model"
 	"one-api/model"
 	"one-api/router"
 	"one-api/router"
 	"one-api/service"
 	"one-api/service"
-	"one-api/setting/operation_setting"
+	"one-api/setting/ratio_setting"
 	"os"
 	"os"
 	"strconv"
 	"strconv"
 
 
@@ -74,7 +74,7 @@ func main() {
 	}
 	}
 
 
 	// Initialize model settings
 	// Initialize model settings
-	operation_setting.InitRatioSettings()
+	ratio_setting.InitRatioSettings()
 	// Initialize constants
 	// Initialize constants
 	constant.InitEnv()
 	constant.InitEnv()
 	// Initialize options
 	// Initialize options
@@ -89,13 +89,28 @@ func main() {
 	if common.MemoryCacheEnabled {
 	if common.MemoryCacheEnabled {
 		common.SysLog("memory cache enabled")
 		common.SysLog("memory cache enabled")
 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
 		common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
-		model.InitChannelCache()
-	}
-	if common.MemoryCacheEnabled {
-		go model.SyncOptions(common.SyncFrequency)
+
+		// Add panic recovery and retry for InitChannelCache
+		func() {
+			defer func() {
+				if r := recover(); r != nil {
+					common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
+					// Retry once
+					_, fixErr := model.FixAbility()
+					if fixErr != nil {
+						common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
+					}
+				}
+			}()
+			model.InitChannelCache()
+		}()
+
 		go model.SyncChannelCache(common.SyncFrequency)
 		go model.SyncChannelCache(common.SyncFrequency)
 	}
 	}
 
 
+	// 热更新配置
+	go model.SyncOptions(common.SyncFrequency)
+
 	// 数据看板
 	// 数据看板
 	go model.UpdateQuotaData()
 	go model.UpdateQuotaData()
 
 

+ 1 - 1
makefile

@@ -7,7 +7,7 @@ all: build-frontend start-backend
 
 
 build-frontend:
 build-frontend:
 	@echo "Building frontend..."
 	@echo "Building frontend..."
-	@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
+	@cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
 
 
 start-backend:
 start-backend:
 	@echo "Starting backend dev server..."
 	@echo "Starting backend dev server..."

+ 15 - 2
middleware/auth.go

@@ -1,13 +1,14 @@
 package middleware
 package middleware
 
 
 import (
 import (
-	"github.com/gin-contrib/sessions"
-	"github.com/gin-gonic/gin"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/model"
 	"one-api/model"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
+
+	"github.com/gin-contrib/sessions"
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 func validUserInfo(username string, role int) bool {
 func validUserInfo(username string, role int) bool {
@@ -182,6 +183,18 @@ func TokenAuth() func(c *gin.Context) {
 				c.Request.Header.Set("Authorization", "Bearer "+key)
 				c.Request.Header.Set("Authorization", "Bearer "+key)
 			}
 			}
 		}
 		}
+		// gemini api 从query中获取key
+		if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
+			skKey := c.Query("key")
+			if skKey != "" {
+				c.Request.Header.Set("Authorization", "Bearer "+skKey)
+			}
+			// 从x-goog-api-key header中获取key
+			xGoogKey := c.Request.Header.Get("x-goog-api-key")
+			if xGoogKey != "" {
+				c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
+			}
+		}
 		key := c.Request.Header.Get("Authorization")
 		key := c.Request.Header.Get("Authorization")
 		parts := make([]string, 0)
 		parts := make([]string, 0)
 		key = strings.TrimPrefix(key, "Bearer ")
 		key = strings.TrimPrefix(key, "Bearer ")

+ 58 - 5
middleware/distributor.go

@@ -11,6 +11,7 @@ import (
 	relayconstant "one-api/relay/constant"
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
 	"one-api/service"
 	"one-api/setting"
 	"one-api/setting"
+	"one-api/setting/ratio_setting"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -48,9 +49,11 @@ func Distribute() func(c *gin.Context) {
 				return
 				return
 			}
 			}
 			// check group in common.GroupRatio
 			// check group in common.GroupRatio
-			if !setting.ContainsGroupRatio(tokenGroup) {
-				abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
-				return
+			if !ratio_setting.ContainsGroupRatio(tokenGroup) {
+				if tokenGroup != "auto" {
+					abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
+					return
+				}
 			}
 			}
 			userGroup = tokenGroup
 			userGroup = tokenGroup
 		}
 		}
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
 			}
 			}
 
 
 			if shouldSelectChannel {
 			if shouldSelectChannel {
-				channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
+				var selectGroup string
+				channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
 				if err != nil {
 				if err != nil {
-					message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+					showGroup := userGroup
+					if userGroup == "auto" {
+						showGroup = fmt.Sprintf("auto(%s)", selectGroup)
+					}
+					message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
 					// 如果错误,但是渠道不为空,说明是数据库一致性问题
 					// 如果错误,但是渠道不为空,说明是数据库一致性问题
 					if channel != nil {
 					if channel != nil {
 						common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 						common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
@@ -162,6 +170,23 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 		}
 		c.Set("platform", string(constant.TaskPlatformSuno))
 		c.Set("platform", string(constant.TaskPlatformSuno))
 		c.Set("relay_mode", relayMode)
 		c.Set("relay_mode", relayMode)
+	} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
+		relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
+		if relayMode == relayconstant.RelayModeKlingFetchByID {
+			shouldSelectChannel = false
+		} else {
+			err = common.UnmarshalBodyReusable(c, &modelRequest)
+		}
+		c.Set("platform", string(constant.TaskPlatformKling))
+		c.Set("relay_mode", relayMode)
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
+		// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
+		relayMode := relayconstant.RelayModeGemini
+		modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
+		if modelName != "" {
+			modelRequest.Model = modelName
+		}
+		c.Set("relay_mode", relayMode)
 	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
 	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
 		err = common.UnmarshalBodyReusable(c, &modelRequest)
 		err = common.UnmarshalBodyReusable(c, &modelRequest)
 	}
 	}
@@ -244,3 +269,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 		c.Set("bot_id", channel.Other)
 		c.Set("bot_id", channel.Other)
 	}
 	}
 }
 }
+
+// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
+// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
+// 输出: gemini-2.0-flash
+func extractModelNameFromGeminiPath(path string) string {
+	// 查找 "/models/" 的位置
+	modelsPrefix := "/models/"
+	modelsIndex := strings.Index(path, modelsPrefix)
+	if modelsIndex == -1 {
+		return ""
+	}
+
+	// 从 "/models/" 之后开始提取
+	startIndex := modelsIndex + len(modelsPrefix)
+	if startIndex >= len(path) {
+		return ""
+	}
+
+	// 查找 ":" 的位置,模型名在 ":" 之前
+	colonIndex := strings.Index(path[startIndex:], ":")
+	if colonIndex == -1 {
+		// 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
+		return path[startIndex:]
+	}
+
+	// 返回模型名部分
+	return path[startIndex : startIndex+colonIndex]
+}

+ 41 - 0
middleware/stats.go

@@ -0,0 +1,41 @@
+package middleware
+
+import (
+	"sync/atomic"
+
+	"github.com/gin-gonic/gin"
+)
+
+// HTTPStats 存储HTTP统计信息
+type HTTPStats struct {
+	activeConnections int64
+}
+
+var globalStats = &HTTPStats{}
+
+// StatsMiddleware 统计中间件
+func StatsMiddleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		// 增加活跃连接数
+		atomic.AddInt64(&globalStats.activeConnections, 1)
+		
+		// 确保在请求结束时减少连接数
+		defer func() {
+			atomic.AddInt64(&globalStats.activeConnections, -1)
+		}()
+		
+		c.Next()
+	}
+}
+
+// StatsInfo 统计信息结构
+type StatsInfo struct {
+	ActiveConnections int64 `json:"active_connections"`
+}
+
+// GetStats 获取统计信息
+func GetStats() StatsInfo {
+	return StatsInfo{
+		ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
+	}
+} 

+ 55 - 25
model/ability.go

@@ -8,6 +8,7 @@ import (
 
 
 	"github.com/samber/lo"
 	"github.com/samber/lo"
 	"gorm.io/gorm"
 	"gorm.io/gorm"
+	"gorm.io/gorm/clause"
 )
 )
 
 
 type Ability struct {
 type Ability struct {
@@ -23,7 +24,7 @@ type Ability struct {
 func GetGroupModels(group string) []string {
 func GetGroupModels(group string) []string {
 	var models []string
 	var models []string
 	// Find distinct models
 	// Find distinct models
-	DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
+	DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
 	return models
 	return models
 }
 }
 
 
@@ -41,15 +42,11 @@ func GetAllEnableAbilities() []Ability {
 }
 }
 
 
 func getPriority(group string, model string, retry int) (int, error) {
 func getPriority(group string, model string, retry int) (int, error) {
-	trueVal := "1"
-	if common.UsingPostgreSQL {
-		trueVal = "true"
-	}
 
 
 	var priorities []int
 	var priorities []int
 	err := DB.Model(&Ability{}).
 	err := DB.Model(&Ability{}).
 		Select("DISTINCT(priority)").
 		Select("DISTINCT(priority)").
-		Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
+		Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
 		Order("priority DESC").              // 按优先级降序排序
 		Order("priority DESC").              // 按优先级降序排序
 		Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
 		Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
 
 
@@ -75,18 +72,14 @@ func getPriority(group string, model string, retry int) (int, error) {
 }
 }
 
 
 func getChannelQuery(group string, model string, retry int) *gorm.DB {
 func getChannelQuery(group string, model string, retry int) *gorm.DB {
-	trueVal := "1"
-	if common.UsingPostgreSQL {
-		trueVal = "true"
-	}
-	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
-	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
+	channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
 	if retry != 0 {
 	if retry != 0 {
 		priority, err := getPriority(group, model, retry)
 		priority, err := getPriority(group, model, retry)
 		if err != nil {
 		if err != nil {
 			common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
 			common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
 		} else {
 		} else {
-			channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
+			channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
 		}
 		}
 	}
 	}
 
 
@@ -133,9 +126,15 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
 func (channel *Channel) AddAbilities() error {
 func (channel *Channel) AddAbilities() error {
 	models_ := strings.Split(channel.Models, ",")
 	models_ := strings.Split(channel.Models, ",")
 	groups_ := strings.Split(channel.Group, ",")
 	groups_ := strings.Split(channel.Group, ",")
+	abilitySet := make(map[string]struct{})
 	abilities := make([]Ability, 0, len(models_))
 	abilities := make([]Ability, 0, len(models_))
 	for _, model := range models_ {
 	for _, model := range models_ {
 		for _, group := range groups_ {
 		for _, group := range groups_ {
+			key := group + "|" + model
+			if _, exists := abilitySet[key]; exists {
+				continue
+			}
+			abilitySet[key] = struct{}{}
 			ability := Ability{
 			ability := Ability{
 				Group:     group,
 				Group:     group,
 				Model:     model,
 				Model:     model,
@@ -152,7 +151,7 @@ func (channel *Channel) AddAbilities() error {
 		return nil
 		return nil
 	}
 	}
 	for _, chunk := range lo.Chunk(abilities, 50) {
 	for _, chunk := range lo.Chunk(abilities, 50) {
-		err := DB.Create(&chunk).Error
+		err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -194,9 +193,15 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
 	// Then add new abilities
 	// Then add new abilities
 	models_ := strings.Split(channel.Models, ",")
 	models_ := strings.Split(channel.Models, ",")
 	groups_ := strings.Split(channel.Group, ",")
 	groups_ := strings.Split(channel.Group, ",")
+	abilitySet := make(map[string]struct{})
 	abilities := make([]Ability, 0, len(models_))
 	abilities := make([]Ability, 0, len(models_))
 	for _, model := range models_ {
 	for _, model := range models_ {
 		for _, group := range groups_ {
 		for _, group := range groups_ {
+			key := group + "|" + model
+			if _, exists := abilitySet[key]; exists {
+				continue
+			}
+			abilitySet[key] = struct{}{}
 			ability := Ability{
 			ability := Ability{
 				Group:     group,
 				Group:     group,
 				Model:     model,
 				Model:     model,
@@ -212,7 +217,7 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
 
 
 	if len(abilities) > 0 {
 	if len(abilities) > 0 {
 		for _, chunk := range lo.Chunk(abilities, 50) {
 		for _, chunk := range lo.Chunk(abilities, 50) {
-			err = tx.Create(&chunk).Error
+			err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
 			if err != nil {
 			if err != nil {
 				if isNewTx {
 				if isNewTx {
 					tx.Rollback()
 					tx.Rollback()
@@ -261,12 +266,28 @@ func FixAbility() (int, error) {
 		common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
 		common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
 		return 0, err
 		return 0, err
 	}
 	}
-	// Delete abilities of channels that are not in channel table
-	err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error
-	if err != nil {
-		common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error()))
-		return 0, err
+
+	// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
+	if len(channelIds) > 0 {
+		// Process deletion in chunks to avoid "too many placeholders" error
+		for _, chunk := range lo.Chunk(channelIds, 100) {
+			err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
+			if err != nil {
+				common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
+				return 0, err
+			}
+		}
+	} else {
+		// If no channels exist, delete all abilities
+		err = DB.Delete(&Ability{}).Error
+		if err != nil {
+			common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
+			return 0, err
+		}
+		common.SysLog("Delete all abilities successfully")
+		return 0, nil
 	}
 	}
+
 	common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
 	common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
 	count += len(channelIds)
 	count += len(channelIds)
 
 
@@ -275,17 +296,26 @@ func FixAbility() (int, error) {
 	err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
 	err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
 	if err != nil {
 	if err != nil {
 		common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
 		common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
-		return 0, err
+		return count, err
 	}
 	}
+
 	var channels []Channel
 	var channels []Channel
 	if len(abilityChannelIds) == 0 {
 	if len(abilityChannelIds) == 0 {
 		err = DB.Find(&channels).Error
 		err = DB.Find(&channels).Error
 	} else {
 	} else {
-		err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
-	}
-	if err != nil {
-		return 0, err
+		// Process query in chunks to avoid "too many placeholders" error
+		err = nil
+		for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
+			var channelsChunk []Channel
+			err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
+			if err != nil {
+				common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
+				return count, err
+			}
+			channels = append(channels, channelsChunk...)
+		}
 	}
 	}
+
 	for _, channel := range channels {
 	for _, channel := range channels {
 		err := channel.UpdateAbilities(nil)
 		err := channel.UpdateAbilities(nil)
 		if err != nil {
 		if err != nil {

+ 45 - 3
model/cache.go

@@ -5,10 +5,13 @@ import (
 	"fmt"
 	"fmt"
 	"math/rand"
 	"math/rand"
 	"one-api/common"
 	"one-api/common"
+	"one-api/setting"
 	"sort"
 	"sort"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 var group2model2channels map[string]map[string][]*Channel
 var group2model2channels map[string]map[string][]*Channel
@@ -16,6 +19,9 @@ var channelsIDM map[int]*Channel
 var channelSyncLock sync.RWMutex
 var channelSyncLock sync.RWMutex
 
 
 func InitChannelCache() {
 func InitChannelCache() {
+	if !common.MemoryCacheEnabled {
+		return
+	}
 	newChannelId2channel := make(map[int]*Channel)
 	newChannelId2channel := make(map[int]*Channel)
 	var channels []*Channel
 	var channels []*Channel
 	DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
 	DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
@@ -72,7 +78,43 @@ func SyncChannelCache(frequency int) {
 	}
 	}
 }
 }
 
 
-func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
+	var channel *Channel
+	var err error
+	selectGroup := group
+	if group == "auto" {
+		if len(setting.AutoGroups) == 0 {
+			return nil, selectGroup, errors.New("auto groups is not enabled")
+		}
+		for _, autoGroup := range setting.AutoGroups {
+			if common.DebugEnabled {
+				println("autoGroup:", autoGroup)
+			}
+			channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
+			if channel == nil {
+				continue
+			} else {
+				c.Set("auto_group", autoGroup)
+				selectGroup = autoGroup
+				if common.DebugEnabled {
+					println("selectGroup:", selectGroup)
+				}
+				break
+			}
+		}
+	} else {
+		channel, err = getRandomSatisfiedChannel(group, model, retry)
+		if err != nil {
+			return nil, group, err
+		}
+	}
+	if channel == nil {
+		return nil, group, errors.New("channel not found")
+	}
+	return channel, selectGroup, nil
+}
+
+func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
 	if strings.HasPrefix(model, "gpt-4-gizmo") {
 	if strings.HasPrefix(model, "gpt-4-gizmo") {
 		model = "gpt-4-gizmo-*"
 		model = "gpt-4-gizmo-*"
 	}
 	}
@@ -84,11 +126,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
 	if !common.MemoryCacheEnabled {
 	if !common.MemoryCacheEnabled {
 		return GetRandomSatisfiedChannel(group, model, retry)
 		return GetRandomSatisfiedChannel(group, model, retry)
 	}
 	}
-	
+
 	channelSyncLock.RLock()
 	channelSyncLock.RLock()
 	channels := group2model2channels[group][model]
 	channels := group2model2channels[group][model]
 	channelSyncLock.RUnlock()
 	channelSyncLock.RUnlock()
-	
+
 	if len(channels) == 0 {
 	if len(channels) == 0 {
 		return nil, errors.New("channel not found")
 		return nil, errors.New("channel not found")
 	}
 	}

+ 71 - 10
model/channel.go

@@ -46,6 +46,17 @@ func (channel *Channel) GetModels() []string {
 	return strings.Split(strings.Trim(channel.Models, ","), ",")
 	return strings.Split(strings.Trim(channel.Models, ","), ",")
 }
 }
 
 
+func (channel *Channel) GetGroups() []string {
+	if channel.Group == "" {
+		return []string{}
+	}
+	groups := strings.Split(strings.Trim(channel.Group, ","), ",")
+	for i, group := range groups {
+		groups[i] = strings.TrimSpace(group)
+	}
+	return groups
+}
+
 func (channel *Channel) GetOtherInfo() map[string]interface{} {
 func (channel *Channel) GetOtherInfo() map[string]interface{} {
 	otherInfo := make(map[string]interface{})
 	otherInfo := make(map[string]interface{})
 	if channel.OtherInfo != "" {
 	if channel.OtherInfo != "" {
@@ -134,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
 	}
 	}
 
 
 	// 构造基础查询
 	// 构造基础查询
-	baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+	baseQuery := DB.Model(&Channel{}).Omit("key")
 
 
 	// 构造WHERE子句
 	// 构造WHERE子句
 	var whereClause string
 	var whereClause string
@@ -142,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
 	if group != "" && group != "null" {
 	if group != "" && group != "null" {
 		var groupCondition string
 		var groupCondition string
 		if common.UsingMySQL {
 		if common.UsingMySQL {
-			groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+			groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
 		} else {
 		} else {
 			// sqlite, PostgreSQL
 			// sqlite, PostgreSQL
-			groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+			groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
 		}
 		}
-		whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+		whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
 	} else {
 	} else {
-		whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+		whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
 	}
 	}
 
 
@@ -467,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
 	}
 	}
 
 
 	// 构造基础查询
 	// 构造基础查询
-	baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+	baseQuery := DB.Model(&Channel{}).Omit("key")
 
 
 	// 构造WHERE子句
 	// 构造WHERE子句
 	var whereClause string
 	var whereClause string
@@ -475,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
 	if group != "" && group != "null" {
 	if group != "" && group != "null" {
 		var groupCondition string
 		var groupCondition string
 		if common.UsingMySQL {
 		if common.UsingMySQL {
-			groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+			groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
 		} else {
 		} else {
 			// sqlite, PostgreSQL
 			// sqlite, PostgreSQL
-			groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+			groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
 		}
 		}
-		whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+		whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
 	} else {
 	} else {
-		whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+		whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
 	}
 	}
 
 
@@ -572,3 +583,53 @@ func BatchSetChannelTag(ids []int, tag *string) error {
 	// 提交事务
 	// 提交事务
 	return tx.Commit().Error
 	return tx.Commit().Error
 }
 }
+
+// CountAllChannels returns total channels in DB
+func CountAllChannels() (int64, error) {
+	var total int64
+	err := DB.Model(&Channel{}).Count(&total).Error
+	return total, err
+}
+
+// CountAllTags returns number of non-empty distinct tags
+func CountAllTags() (int64, error) {
+	var total int64
+	err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
+	return total, err
+}
+
+// Get channels of specified type with pagination
+func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
+	var channels []*Channel
+	order := "priority desc"
+	if idSort {
+		order = "id desc"
+	}
+	err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+	return channels, err
+}
+
+// Count channels of specific type
+func CountChannelsByType(channelType int) (int64, error) {
+	var count int64
+	err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
+	return count, err
+}
+
+// Return map[type]count for all channels
+func CountChannelsGroupByType() (map[int64]int64, error) {
+	type result struct {
+		Type  int64 `gorm:"column:type"`
+		Count int64 `gorm:"column:count"`
+	}
+	var results []result
+	err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
+	if err != nil {
+		return nil, err
+	}
+	counts := make(map[int64]int64)
+	for _, r := range results {
+		counts[r.Type] = r.Count
+	}
+	return counts, nil
+}

+ 46 - 9
model/log.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
 	"one-api/common"
 	"one-api/common"
+	"one-api/constant"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -32,6 +33,7 @@ type Log struct {
 	ChannelName      string `json:"channel_name" gorm:"->"`
 	ChannelName      string `json:"channel_name" gorm:"->"`
 	TokenId          int    `json:"token_id" gorm:"default:0;index"`
 	TokenId          int    `json:"token_id" gorm:"default:0;index"`
 	Group            string `json:"group" gorm:"index"`
 	Group            string `json:"group" gorm:"index"`
+	Ip               string `json:"ip" gorm:"index;default:''"`
 	Other            string `json:"other"`
 	Other            string `json:"other"`
 }
 }
 
 
@@ -61,7 +63,7 @@ func formatUserLogs(logs []*Log) {
 func GetLogByKey(key string) (logs []*Log, err error) {
 func GetLogByKey(key string) (logs []*Log, err error) {
 	if os.Getenv("LOG_SQL_DSN") != "" {
 	if os.Getenv("LOG_SQL_DSN") != "" {
 		var tk Token
 		var tk Token
-		if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
+		if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 		err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
 		err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
@@ -95,6 +97,15 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
 	common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
 	common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
 	username := c.GetString("username")
 	username := c.GetString("username")
 	otherStr := common.MapToJsonStr(other)
 	otherStr := common.MapToJsonStr(other)
+	// 判断是否需要记录 IP
+	needRecordIp := false
+	if settingMap, err := GetUserSetting(userId, false); err == nil {
+		if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
+			if vb, ok := v.(bool); ok && vb {
+				needRecordIp = true
+			}
+		}
+	}
 	log := &Log{
 	log := &Log{
 		UserId:           userId,
 		UserId:           userId,
 		Username:         username,
 		Username:         username,
@@ -111,7 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
 		UseTime:          useTimeSeconds,
 		UseTime:          useTimeSeconds,
 		IsStream:         isStream,
 		IsStream:         isStream,
 		Group:            group,
 		Group:            group,
-		Other:            otherStr,
+		Ip: func() string {
+			if needRecordIp {
+				return c.ClientIP()
+			}
+			return ""
+		}(),
+		Other: otherStr,
 	}
 	}
 	err := LOG_DB.Create(log).Error
 	err := LOG_DB.Create(log).Error
 	if err != nil {
 	if err != nil {
@@ -128,6 +145,15 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
 	}
 	}
 	username := c.GetString("username")
 	username := c.GetString("username")
 	otherStr := common.MapToJsonStr(other)
 	otherStr := common.MapToJsonStr(other)
+	// 判断是否需要记录 IP
+	needRecordIp := false
+	if settingMap, err := GetUserSetting(userId, false); err == nil {
+		if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
+			if vb, ok := v.(bool); ok && vb {
+				needRecordIp = true
+			}
+		}
+	}
 	log := &Log{
 	log := &Log{
 		UserId:           userId,
 		UserId:           userId,
 		Username:         username,
 		Username:         username,
@@ -144,7 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
 		UseTime:          useTimeSeconds,
 		UseTime:          useTimeSeconds,
 		IsStream:         isStream,
 		IsStream:         isStream,
 		Group:            group,
 		Group:            group,
-		Other:            otherStr,
+		Ip: func() string {
+			if needRecordIp {
+				return c.ClientIP()
+			}
+			return ""
+		}(),
+		Other: otherStr,
 	}
 	}
 	err := LOG_DB.Create(log).Error
 	err := LOG_DB.Create(log).Error
 	if err != nil {
 	if err != nil {
@@ -184,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 		tx = tx.Where("logs.channel_id = ?", channel)
 		tx = tx.Where("logs.channel_id = ?", channel)
 	}
 	}
 	if group != "" {
 	if group != "" {
-		tx = tx.Where("logs."+groupCol+" = ?", group)
+		tx = tx.Where("logs."+logGroupCol+" = ?", group)
 	}
 	}
 	err = tx.Model(&Log{}).Count(&total).Error
 	err = tx.Model(&Log{}).Count(&total).Error
 	if err != nil {
 	if err != nil {
@@ -195,13 +227,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 		return nil, 0, err
 		return nil, 0, err
 	}
 	}
 
 
-	channelIds := make([]int, 0)
+	channelIdsMap := make(map[int]struct{})
 	channelMap := make(map[int]string)
 	channelMap := make(map[int]string)
 	for _, log := range logs {
 	for _, log := range logs {
 		if log.ChannelId != 0 {
 		if log.ChannelId != 0 {
-			channelIds = append(channelIds, log.ChannelId)
+			channelIdsMap[log.ChannelId] = struct{}{}
 		}
 		}
 	}
 	}
+
+	channelIds := make([]int, 0, len(channelIdsMap))
+	for channelId := range channelIdsMap {
+		channelIds = append(channelIds, channelId)
+	}
 	if len(channelIds) > 0 {
 	if len(channelIds) > 0 {
 		var channels []struct {
 		var channels []struct {
 			Id   int    `gorm:"column:id"`
 			Id   int    `gorm:"column:id"`
@@ -242,7 +279,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
 		tx = tx.Where("logs.created_at <= ?", endTimestamp)
 		tx = tx.Where("logs.created_at <= ?", endTimestamp)
 	}
 	}
 	if group != "" {
 	if group != "" {
-		tx = tx.Where("logs."+groupCol+" = ?", group)
+		tx = tx.Where("logs."+logGroupCol+" = ?", group)
 	}
 	}
 	err = tx.Model(&Log{}).Count(&total).Error
 	err = tx.Model(&Log{}).Count(&total).Error
 	if err != nil {
 	if err != nil {
@@ -303,8 +340,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 		rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
 		rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
 	}
 	}
 	if group != "" {
 	if group != "" {
-		tx = tx.Where(groupCol+" = ?", group)
-		rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
+		tx = tx.Where(logGroupCol+" = ?", group)
+		rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
 	}
 	}
 
 
 	tx = tx.Where("type = ?", LogTypeConsume)
 	tx = tx.Where("type = ?", LogTypeConsume)

+ 116 - 54
model/main.go

@@ -1,6 +1,7 @@
 package model
 package model
 
 
 import (
 import (
+	"fmt"
 	"log"
 	"log"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
@@ -15,18 +16,48 @@ import (
 	"gorm.io/gorm"
 	"gorm.io/gorm"
 )
 )
 
 
-var groupCol string
-var keyCol string
+var commonGroupCol string
+var commonKeyCol string
+var commonTrueVal string
+var commonFalseVal string
+
+var logKeyCol string
+var logGroupCol string
 
 
 func initCol() {
 func initCol() {
+	// init common column names
 	if common.UsingPostgreSQL {
 	if common.UsingPostgreSQL {
-		groupCol = `"group"`
-		keyCol = `"key"`
-
+		commonGroupCol = `"group"`
+		commonKeyCol = `"key"`
+		commonTrueVal = "true"
+		commonFalseVal = "false"
 	} else {
 	} else {
-		groupCol = "`group`"
-		keyCol = "`key`"
+		commonGroupCol = "`group`"
+		commonKeyCol = "`key`"
+		commonTrueVal = "1"
+		commonFalseVal = "0"
+	}
+	if os.Getenv("LOG_SQL_DSN") != "" {
+		switch common.LogSqlType {
+		case common.DatabaseTypePostgreSQL:
+			logGroupCol = `"group"`
+			logKeyCol = `"key"`
+		default:
+			logGroupCol = commonGroupCol
+			logKeyCol = commonKeyCol
+		}
+	} else {
+		// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
+		if common.UsingPostgreSQL {
+			logGroupCol = `"group"`
+			logKeyCol = `"key"`
+		} else {
+			logGroupCol = commonGroupCol
+			logKeyCol = commonKeyCol
+		}
 	}
 	}
+	// log sql type and database type
+	common.SysLog("Using Log SQL Type: " + common.LogSqlType)
 }
 }
 
 
 var DB *gorm.DB
 var DB *gorm.DB
@@ -83,7 +114,7 @@ func CheckSetup() {
 	}
 	}
 }
 }
 
 
-func chooseDB(envName string) (*gorm.DB, error) {
+func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
 	defer func() {
 	defer func() {
 		initCol()
 		initCol()
 	}()
 	}()
@@ -92,7 +123,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
 		if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
 		if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
 			// Use PostgreSQL
 			// Use PostgreSQL
 			common.SysLog("using PostgreSQL as database")
 			common.SysLog("using PostgreSQL as database")
-			common.UsingPostgreSQL = true
+			if !isLog {
+				common.UsingPostgreSQL = true
+			} else {
+				common.LogSqlType = common.DatabaseTypePostgreSQL
+			}
 			return gorm.Open(postgres.New(postgres.Config{
 			return gorm.Open(postgres.New(postgres.Config{
 				DSN:                  dsn,
 				DSN:                  dsn,
 				PreferSimpleProtocol: true, // disables implicit prepared statement usage
 				PreferSimpleProtocol: true, // disables implicit prepared statement usage
@@ -102,7 +137,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
 		}
 		}
 		if strings.HasPrefix(dsn, "local") {
 		if strings.HasPrefix(dsn, "local") {
 			common.SysLog("SQL_DSN not set, using SQLite as database")
 			common.SysLog("SQL_DSN not set, using SQLite as database")
-			common.UsingSQLite = true
+			if !isLog {
+				common.UsingSQLite = true
+			} else {
+				common.LogSqlType = common.DatabaseTypeSQLite
+			}
 			return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
 			return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
 				PrepareStmt: true, // precompile SQL
 				PrepareStmt: true, // precompile SQL
 			})
 			})
@@ -117,7 +156,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
 				dsn += "?parseTime=true"
 				dsn += "?parseTime=true"
 			}
 			}
 		}
 		}
-		common.UsingMySQL = true
+		if !isLog {
+			common.UsingMySQL = true
+		} else {
+			common.LogSqlType = common.DatabaseTypeMySQL
+		}
 		return gorm.Open(mysql.Open(dsn), &gorm.Config{
 		return gorm.Open(mysql.Open(dsn), &gorm.Config{
 			PrepareStmt: true, // precompile SQL
 			PrepareStmt: true, // precompile SQL
 		})
 		})
@@ -131,7 +174,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
 }
 }
 
 
 func InitDB() (err error) {
 func InitDB() (err error) {
-	db, err := chooseDB("SQL_DSN")
+	db, err := chooseDB("SQL_DSN", false)
 	if err == nil {
 	if err == nil {
 		if common.DebugEnabled {
 		if common.DebugEnabled {
 			db = db.Debug()
 			db = db.Debug()
@@ -149,7 +192,7 @@ func InitDB() (err error) {
 			return nil
 			return nil
 		}
 		}
 		if common.UsingMySQL {
 		if common.UsingMySQL {
-			_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
+			//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
 		}
 		}
 		common.SysLog("database migration started")
 		common.SysLog("database migration started")
 		err = migrateDB()
 		err = migrateDB()
@@ -165,7 +208,7 @@ func InitLogDB() (err error) {
 		LOG_DB = DB
 		LOG_DB = DB
 		return
 		return
 	}
 	}
-	db, err := chooseDB("LOG_SQL_DSN")
+	db, err := chooseDB("LOG_SQL_DSN", true)
 	if err == nil {
 	if err == nil {
 		if common.DebugEnabled {
 		if common.DebugEnabled {
 			db = db.Debug()
 			db = db.Debug()
@@ -198,54 +241,73 @@ func InitLogDB() (err error) {
 }
 }
 
 
 func migrateDB() error {
 func migrateDB() error {
-	err := DB.AutoMigrate(&Channel{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Token{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&User{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Option{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Redemption{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Ability{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Log{})
-	if err != nil {
-		return err
+	if !common.UsingPostgreSQL {
+		return migrateDBFast()
 	}
 	}
-	err = DB.AutoMigrate(&Midjourney{})
+	err := DB.AutoMigrate(
+		&Channel{},
+		&Token{},
+		&User{},
+		&Option{},
+		&Redemption{},
+		&Ability{},
+		&Log{},
+		&Midjourney{},
+		&TopUp{},
+		&QuotaData{},
+		&Task{},
+		&Setup{},
+	)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	err = DB.AutoMigrate(&TopUp{})
-	if err != nil {
-		return err
+	return nil
+}
+
+func migrateDBFast() error {
+	var wg sync.WaitGroup
+	errChan := make(chan error, 12) // Buffer size matches number of migrations
+
+	migrations := []struct {
+		model interface{}
+		name  string
+	}{
+		{&Channel{}, "Channel"},
+		{&Token{}, "Token"},
+		{&User{}, "User"},
+		{&Option{}, "Option"},
+		{&Redemption{}, "Redemption"},
+		{&Ability{}, "Ability"},
+		{&Log{}, "Log"},
+		{&Midjourney{}, "Midjourney"},
+		{&TopUp{}, "TopUp"},
+		{&QuotaData{}, "QuotaData"},
+		{&Task{}, "Task"},
+		{&Setup{}, "Setup"},
 	}
 	}
-	err = DB.AutoMigrate(&QuotaData{})
-	if err != nil {
-		return err
+
+	for _, m := range migrations {
+		wg.Add(1)
+		go func(model interface{}, name string) {
+			defer wg.Done()
+			if err := DB.AutoMigrate(model); err != nil {
+				errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
+			}
+		}(m.model, m.name)
 	}
 	}
-	err = DB.AutoMigrate(&Task{})
-	if err != nil {
-		return err
+
+	// Wait for all migrations to complete
+	wg.Wait()
+	close(errChan)
+
+	// Check for any errors
+	for err := range errChan {
+		if err != nil {
+			return err
+		}
 	}
 	}
-	err = DB.AutoMigrate(&Setup{})
 	common.SysLog("database migrated")
 	common.SysLog("database migrated")
-	//err = createRootAccountIfNeed()
-	return err
+	return nil
 }
 }
 
 
 func migrateLOGDB() error {
 func migrateLOGDB() error {

+ 37 - 0
model/midjourney.go

@@ -166,3 +166,40 @@ func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
 		Where("id in (?)", taskIDs).
 		Where("id in (?)", taskIDs).
 		Updates(params).Error
 		Updates(params).Error
 }
 }
+
+// CountAllTasks returns total midjourney tasks for admin query
+func CountAllTasks(queryParams TaskQueryParams) int64 {
+	var total int64
+	query := DB.Model(&Midjourney{})
+	if queryParams.ChannelID != "" {
+		query = query.Where("channel_id = ?", queryParams.ChannelID)
+	}
+	if queryParams.MjID != "" {
+		query = query.Where("mj_id = ?", queryParams.MjID)
+	}
+	if queryParams.StartTimestamp != "" {
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != "" {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+	_ = query.Count(&total).Error
+	return total
+}
+
+// CountAllUserTask returns total midjourney tasks for user
+func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
+	var total int64
+	query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
+	if queryParams.MjID != "" {
+		query = query.Where("mj_id = ?", queryParams.MjID)
+	}
+	if queryParams.StartTimestamp != "" {
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != "" {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+	_ = query.Count(&total).Error
+	return total
+}

+ 27 - 11
model/option.go

@@ -5,6 +5,7 @@ import (
 	"one-api/setting"
 	"one-api/setting"
 	"one-api/setting/config"
 	"one-api/setting/config"
 	"one-api/setting/operation_setting"
 	"one-api/setting/operation_setting"
+	"one-api/setting/ratio_setting"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -76,6 +77,9 @@ func InitOptionMap() {
 	common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
 	common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
 	common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
 	common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
 	common.OptionMap["Chats"] = setting.Chats2JsonString()
 	common.OptionMap["Chats"] = setting.Chats2JsonString()
+	common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
+	common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
+	common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
 	common.OptionMap["GitHubClientId"] = ""
 	common.OptionMap["GitHubClientId"] = ""
 	common.OptionMap["GitHubClientSecret"] = ""
 	common.OptionMap["GitHubClientSecret"] = ""
 	common.OptionMap["TelegramBotToken"] = ""
 	common.OptionMap["TelegramBotToken"] = ""
@@ -94,12 +98,13 @@ func InitOptionMap() {
 	common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
 	common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
 	common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
 	common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
 	common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
 	common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
-	common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
-	common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
-	common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
-	common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
+	common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
+	common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
+	common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
+	common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
+	common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
 	common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
 	common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
-	common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
+	common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
 	common.OptionMap["TopUpLink"] = common.TopUpLink
 	common.OptionMap["TopUpLink"] = common.TopUpLink
 	//common.OptionMap["ChatLink"] = common.ChatLink
 	//common.OptionMap["ChatLink"] = common.ChatLink
 	//common.OptionMap["ChatLink2"] = common.ChatLink2
 	//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -122,6 +127,7 @@ func InitOptionMap() {
 	common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
 	common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
 	common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
 	common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
 	common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
 	common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
+	common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
 
 
 	// 自动添加所有注册的模型配置
 	// 自动添加所有注册的模型配置
 	modelConfigs := config.GlobalConfig.ExportAllConfigs()
 	modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -191,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
 			common.ImageDownloadPermission = intValue
 			common.ImageDownloadPermission = intValue
 		}
 		}
 	}
 	}
-	if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
+	if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
 		boolValue := value == "true"
 		boolValue := value == "true"
 		switch key {
 		switch key {
 		case "PasswordRegisterEnabled":
 		case "PasswordRegisterEnabled":
@@ -260,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
 			common.SMTPSSLEnabled = boolValue
 			common.SMTPSSLEnabled = boolValue
 		case "WorkerAllowHttpImageRequestEnabled":
 		case "WorkerAllowHttpImageRequestEnabled":
 			setting.WorkerAllowHttpImageRequestEnabled = boolValue
 			setting.WorkerAllowHttpImageRequestEnabled = boolValue
+		case "DefaultUseAutoGroup":
+			setting.DefaultUseAutoGroup = boolValue
+		case "ExposeRatioEnabled":
+			ratio_setting.SetExposeRatioEnabled(boolValue)
 		}
 		}
 	}
 	}
 	switch key {
 	switch key {
@@ -286,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
 		setting.PayAddress = value
 		setting.PayAddress = value
 	case "Chats":
 	case "Chats":
 		err = setting.UpdateChatsByJsonString(value)
 		err = setting.UpdateChatsByJsonString(value)
+	case "AutoGroups":
+		err = setting.UpdateAutoGroupsByJsonString(value)
 	case "CustomCallbackAddress":
 	case "CustomCallbackAddress":
 		setting.CustomCallbackAddress = value
 		setting.CustomCallbackAddress = value
 	case "EpayId":
 	case "EpayId":
@@ -351,17 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
 	case "DataExportDefaultTime":
 	case "DataExportDefaultTime":
 		common.DataExportDefaultTime = value
 		common.DataExportDefaultTime = value
 	case "ModelRatio":
 	case "ModelRatio":
-		err = operation_setting.UpdateModelRatioByJSONString(value)
+		err = ratio_setting.UpdateModelRatioByJSONString(value)
 	case "GroupRatio":
 	case "GroupRatio":
-		err = setting.UpdateGroupRatioByJSONString(value)
+		err = ratio_setting.UpdateGroupRatioByJSONString(value)
+	case "GroupGroupRatio":
+		err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
 	case "UserUsableGroups":
 	case "UserUsableGroups":
 		err = setting.UpdateUserUsableGroupsByJSONString(value)
 		err = setting.UpdateUserUsableGroupsByJSONString(value)
 	case "CompletionRatio":
 	case "CompletionRatio":
-		err = operation_setting.UpdateCompletionRatioByJSONString(value)
+		err = ratio_setting.UpdateCompletionRatioByJSONString(value)
 	case "ModelPrice":
 	case "ModelPrice":
-		err = operation_setting.UpdateModelPriceByJSONString(value)
+		err = ratio_setting.UpdateModelPriceByJSONString(value)
 	case "CacheRatio":
 	case "CacheRatio":
-		err = operation_setting.UpdateCacheRatioByJSONString(value)
+		err = ratio_setting.UpdateCacheRatioByJSONString(value)
 	case "TopUpLink":
 	case "TopUpLink":
 		common.TopUpLink = value
 		common.TopUpLink = value
 	//case "ChatLink":
 	//case "ChatLink":
@@ -378,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
 		operation_setting.AutomaticDisableKeywordsFromString(value)
 		operation_setting.AutomaticDisableKeywordsFromString(value)
 	case "StreamCacheQueueLength":
 	case "StreamCacheQueueLength":
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
+	case "PayMethods":
+		err = setting.UpdatePayMethodsByJsonString(value)
 	}
 	}
 	return err
 	return err
 }
 }

+ 4 - 4
model/pricing.go

@@ -2,7 +2,7 @@ package model
 
 
 import (
 import (
 	"one-api/common"
 	"one-api/common"
-	"one-api/setting/operation_setting"
+	"one-api/setting/ratio_setting"
 	"sync"
 	"sync"
 	"time"
 	"time"
 )
 )
@@ -65,14 +65,14 @@ func updatePricing() {
 			ModelName:   model,
 			ModelName:   model,
 			EnableGroup: groups,
 			EnableGroup: groups,
 		}
 		}
-		modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
+		modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
 		if findPrice {
 		if findPrice {
 			pricing.ModelPrice = modelPrice
 			pricing.ModelPrice = modelPrice
 			pricing.QuotaType = 1
 			pricing.QuotaType = 1
 		} else {
 		} else {
-			modelRatio, _ := operation_setting.GetModelRatio(model)
+			modelRatio, _ := ratio_setting.GetModelRatio(model)
 			pricing.ModelRatio = modelRatio
 			pricing.ModelRatio = modelRatio
-			pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
+			pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
 			pricing.QuotaType = 0
 			pricing.QuotaType = 0
 		}
 		}
 		pricingMap = append(pricingMap, pricing)
 		pricingMap = append(pricingMap, pricing)

+ 11 - 1
model/redemption.go

@@ -21,6 +21,7 @@ type Redemption struct {
 	Count        int            `json:"count" gorm:"-:all"` // only for api request
 	Count        int            `json:"count" gorm:"-:all"` // only for api request
 	UsedUserId   int            `json:"used_user_id"`
 	UsedUserId   int            `json:"used_user_id"`
 	DeletedAt    gorm.DeletedAt `gorm:"index"`
 	DeletedAt    gorm.DeletedAt `gorm:"index"`
+	ExpiredTime  int64          `json:"expired_time" gorm:"bigint"` // 过期时间,0 表示不过期
 }
 }
 
 
 func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
 func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
@@ -131,6 +132,9 @@ func Redeem(key string, userId int) (quota int, err error) {
 		if redemption.Status != common.RedemptionCodeStatusEnabled {
 		if redemption.Status != common.RedemptionCodeStatusEnabled {
 			return errors.New("该兑换码已被使用")
 			return errors.New("该兑换码已被使用")
 		}
 		}
+		if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
+			return errors.New("该兑换码已过期")
+		}
 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
 		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
 		if err != nil {
 		if err != nil {
 			return err
 			return err
@@ -162,7 +166,7 @@ func (redemption *Redemption) SelectUpdate() error {
 // Update Make sure your token's fields is completed, because this will update non-zero values
 // Update Make sure your token's fields is completed, because this will update non-zero values
 func (redemption *Redemption) Update() error {
 func (redemption *Redemption) Update() error {
 	var err error
 	var err error
-	err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
+	err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
 	return err
 	return err
 }
 }
 
 
@@ -183,3 +187,9 @@ func DeleteRedemptionById(id int) (err error) {
 	}
 	}
 	return redemption.Delete()
 	return redemption.Delete()
 }
 }
+
+func DeleteInvalidRedemptions() (int64, error) {
+	now := common.GetTimestamp()
+	result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
+	return result.RowsAffected, result.Error
+}

+ 61 - 0
model/task.go

@@ -302,3 +302,64 @@ func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, e
 	err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
 	err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
 	return stat, err
 	return stat, err
 }
 }
+
+// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
+func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
+	var total int64
+	query := DB.Model(&Task{})
+	if queryParams.ChannelID != "" {
+		query = query.Where("channel_id = ?", queryParams.ChannelID)
+	}
+	if queryParams.Platform != "" {
+		query = query.Where("platform = ?", queryParams.Platform)
+	}
+	if queryParams.UserID != "" {
+		query = query.Where("user_id = ?", queryParams.UserID)
+	}
+	if len(queryParams.UserIDs) != 0 {
+		query = query.Where("user_id in (?)", queryParams.UserIDs)
+	}
+	if queryParams.TaskID != "" {
+		query = query.Where("task_id = ?", queryParams.TaskID)
+	}
+	if queryParams.Action != "" {
+		query = query.Where("action = ?", queryParams.Action)
+	}
+	if queryParams.Status != "" {
+		query = query.Where("status = ?", queryParams.Status)
+	}
+	if queryParams.StartTimestamp != 0 {
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != 0 {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+	_ = query.Count(&total).Error
+	return total
+}
+
+// TaskCountAllUserTask returns total tasks for given user
+func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
+	var total int64
+	query := DB.Model(&Task{}).Where("user_id = ?", userId)
+	if queryParams.TaskID != "" {
+		query = query.Where("task_id = ?", queryParams.TaskID)
+	}
+	if queryParams.Action != "" {
+		query = query.Where("action = ?", queryParams.Action)
+	}
+	if queryParams.Status != "" {
+		query = query.Where("status = ?", queryParams.Status)
+	}
+	if queryParams.Platform != "" {
+		query = query.Where("platform = ?", queryParams.Platform)
+	}
+	if queryParams.StartTimestamp != 0 {
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != 0 {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+	_ = query.Count(&total).Error
+	return total
+}

+ 9 - 2
model/token.go

@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
 	if token != "" {
 	if token != "" {
 		token = strings.Trim(token, "sk-")
 		token = strings.Trim(token, "sk-")
 	}
 	}
-	err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
+	err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
 	return tokens, err
 	return tokens, err
 }
 }
 
 
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
 		// Don't return error - fall through to DB
 		// Don't return error - fall through to DB
 	}
 	}
 	fromDB = true
 	fromDB = true
-	err = DB.Where(keyCol+" = ?", key).First(&token).Error
+	err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
 	return token, err
 	return token, err
 }
 }
 
 
@@ -320,3 +320,10 @@ func decreaseTokenQuota(id int, quota int) (err error) {
 	).Error
 	).Error
 	return err
 	return err
 }
 }
+
+// CountUserTokens returns total number of tokens for the given user, used for pagination
+func CountUserTokens(userId int) (int64, error) {
+	var total int64
+	err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
+	return total, err
+}

+ 2 - 2
model/token_cache.go

@@ -10,7 +10,7 @@ import (
 func cacheSetToken(token Token) error {
 func cacheSetToken(token Token) error {
 	key := common.GenerateHMAC(token.Key)
 	key := common.GenerateHMAC(token.Key)
 	token.Clean()
 	token.Clean()
-	err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
+	err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -19,7 +19,7 @@ func cacheSetToken(token Token) error {
 
 
 func cacheDeleteToken(key string) error {
 func cacheDeleteToken(key string) error {
 	key = common.GenerateHMAC(key)
 	key = common.GenerateHMAC(key)
-	err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
+	err := common.RedisDelKey(fmt.Sprintf("token:%s", key))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 5 - 3
model/user.go

@@ -41,6 +41,7 @@ type User struct {
 	DeletedAt        gorm.DeletedAt `gorm:"index"`
 	DeletedAt        gorm.DeletedAt `gorm:"index"`
 	LinuxDOId        string         `json:"linux_do_id" gorm:"column:linux_do_id;index"`
 	LinuxDOId        string         `json:"linux_do_id" gorm:"column:linux_do_id;index"`
 	Setting          string         `json:"setting" gorm:"type:text;column:setting"`
 	Setting          string         `json:"setting" gorm:"type:text;column:setting"`
+	Remark           string         `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
 }
 }
 
 
 func (user *User) ToBaseUser() *UserBase {
 func (user *User) ToBaseUser() *UserBase {
@@ -175,7 +176,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
 		// 如果是数字,同时搜索ID和其他字段
 		// 如果是数字,同时搜索ID和其他字段
 		likeCondition = "id = ? OR " + likeCondition
 		likeCondition = "id = ? OR " + likeCondition
 		if group != "" {
 		if group != "" {
-			query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+			query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
 				keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
 				keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
 		} else {
 		} else {
 			query = query.Where(likeCondition,
 			query = query.Where(likeCondition,
@@ -184,7 +185,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
 	} else {
 	} else {
 		// 非数字关键字,只搜索字符串字段
 		// 非数字关键字,只搜索字符串字段
 		if group != "" {
 		if group != "" {
-			query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+			query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
 				"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
 				"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
 		} else {
 		} else {
 			query = query.Where(likeCondition,
 			query = query.Where(likeCondition,
@@ -366,6 +367,7 @@ func (user *User) Edit(updatePassword bool) error {
 		"display_name": newUser.DisplayName,
 		"display_name": newUser.DisplayName,
 		"group":        newUser.Group,
 		"group":        newUser.Group,
 		"quota":        newUser.Quota,
 		"quota":        newUser.Quota,
+		"remark":       newUser.Remark,
 	}
 	}
 	if updatePassword {
 	if updatePassword {
 		updates["password"] = newUser.Password
 		updates["password"] = newUser.Password
@@ -615,7 +617,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
 		// Don't return error - fall through to DB
 		// Don't return error - fall through to DB
 	}
 	}
 	fromDB = true
 	fromDB = true
-	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
+	err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}

+ 4 - 3
model/user_cache.go

@@ -3,11 +3,12 @@ package model
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"one-api/common"
 	"one-api/common"
 	"one-api/constant"
 	"one-api/constant"
 	"time"
 	"time"
 
 
+	"github.com/gin-gonic/gin"
+
 	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/bytedance/gopkg/util/gopool"
 )
 )
 
 
@@ -57,7 +58,7 @@ func invalidateUserCache(userId int) error {
 	if !common.RedisEnabled {
 	if !common.RedisEnabled {
 		return nil
 		return nil
 	}
 	}
-	return common.RedisHDelObj(getUserCacheKey(userId))
+	return common.RedisDelKey(getUserCacheKey(userId))
 }
 }
 
 
 // updateUserCache updates all user cache fields using hash
 // updateUserCache updates all user cache fields using hash
@@ -69,7 +70,7 @@ func updateUserCache(user User) error {
 	return common.RedisHSetObj(
 	return common.RedisHSetObj(
 		getUserCacheKey(user.Id),
 		getUserCacheKey(user.Id),
 		user.ToBaseUser(),
 		user.ToBaseUser(),
-		time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
+		time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
 	)
 	)
 }
 }
 
 

+ 19 - 2
model/utils.go

@@ -2,11 +2,12 @@ package model
 
 
 import (
 import (
 	"errors"
 	"errors"
-	"github.com/bytedance/gopkg/util/gopool"
-	"gorm.io/gorm"
 	"one-api/common"
 	"one-api/common"
 	"sync"
 	"sync"
 	"time"
 	"time"
+
+	"github.com/bytedance/gopkg/util/gopool"
+	"gorm.io/gorm"
 )
 )
 
 
 const (
 const (
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
 }
 }
 
 
 func batchUpdate() {
 func batchUpdate() {
+	// check if there's any data to update
+	hasData := false
+	for i := 0; i < BatchUpdateTypeCount; i++ {
+		batchUpdateLocks[i].Lock()
+		if len(batchUpdateStores[i]) > 0 {
+			hasData = true
+			batchUpdateLocks[i].Unlock()
+			break
+		}
+		batchUpdateLocks[i].Unlock()
+	}
+
+	if !hasData {
+		return
+	}
+
 	common.SysLog("batch update started")
 	common.SysLog("batch update started")
 	for i := 0; i < BatchUpdateTypeCount; i++ {
 	for i := 0; i < BatchUpdateTypeCount; i++ {
 		batchUpdateLocks[i].Lock()
 		batchUpdateLocks[i].Lock()

+ 3 - 8
relay/relay-audio.go → relay/audio_handler.go

@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
 }
 }
 
 
 func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
-	relayInfo := relaycommon.GenRelayInfo(c)
+	relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
 	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
 	audioRequest, err := getAndValidAudioRequest(c, relayInfo)
 
 
 	if err != nil {
 	if err != nil {
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	promptTokens := 0
 	promptTokens := 0
 	preConsumedTokens := common.PreConsumedQuota
 	preConsumedTokens := common.PreConsumedQuota
 	if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
 	if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
-		promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
-		}
+		promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
 		preConsumedTokens = promptTokens
 		preConsumedTokens = promptTokens
 		relayInfo.PromptTokens = promptTokens
 		relayInfo.PromptTokens = promptTokens
 	}
 	}
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		}
 		}
 	}()
 	}()
 
 
-	err = helper.ModelMappedHelper(c, relayInfo)
+	err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
 	}
 	}
 
 
-	audioRequest.Model = relayInfo.UpstreamModelName
-
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	adaptor := GetAdaptor(relayInfo.ApiType)
 	if adaptor == nil {
 	if adaptor == nil {
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
 		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)

+ 2 - 0
relay/channel/adapter.go

@@ -44,4 +44,6 @@ type TaskAdaptor interface {
 
 
 	// FetchTask
 	// FetchTask
 	FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
 	FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
+
+	ParseResultUrl(resp map[string]any) (string, error)
 }
 }

+ 11 - 1
relay/channel/ali/adaptor.go

@@ -31,6 +31,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	switch info.RelayMode {
 	switch info.RelayMode {
 	case constant.RelayModeEmbeddings:
 	case constant.RelayModeEmbeddings:
 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
+	case constant.RelayModeRerank:
+		fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
 	case constant.RelayModeImagesGenerations:
 	case constant.RelayModeImagesGenerations:
 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
 	case constant.RelayModeCompletions:
 	case constant.RelayModeCompletions:
@@ -57,6 +59,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 	if request == nil {
 		return nil, errors.New("request is nil")
 		return nil, errors.New("request is nil")
 	}
 	}
+
+	// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
+	if !info.IsStream {
+		request.EnableThinking = false
+	}
+
 	switch info.RelayMode {
 	switch info.RelayMode {
 	default:
 	default:
 		aliReq := requestOpenAI2Ali(*request)
 		aliReq := requestOpenAI2Ali(*request)
@@ -70,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 }
 
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
-	return nil, errors.New("not implemented")
+	return ConvertRerankRequest(request), nil
 }
 }
 
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
@@ -97,6 +105,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, usage = aliImageHandler(c, resp, info)
 		err, usage = aliImageHandler(c, resp, info)
 	case constant.RelayModeEmbeddings:
 	case constant.RelayModeEmbeddings:
 		err, usage = aliEmbeddingHandler(c, resp)
 		err, usage = aliEmbeddingHandler(c, resp)
+	case constant.RelayModeRerank:
+		err, usage = RerankHandler(c, resp, info)
 	default:
 	default:
 		if info.IsStream {
 		if info.IsStream {
 			err, usage = openai.OaiStreamHandler(c, resp, info)
 			err, usage = openai.OaiStreamHandler(c, resp, info)

+ 1 - 0
relay/channel/ali/constants.go

@@ -8,6 +8,7 @@ var ModelList = []string{
 	"qwq-32b",
 	"qwq-32b",
 	"qwen3-235b-a22b",
 	"qwen3-235b-a22b",
 	"text-embedding-v1",
 	"text-embedding-v1",
+	"gte-rerank-v2",
 }
 }
 
 
 var ChannelName = "ali"
 var ChannelName = "ali"

+ 27 - 0
relay/channel/ali/dto.go

@@ -1,5 +1,7 @@
 package ali
 package ali
 
 
+import "one-api/dto"
+
 type AliMessage struct {
 type AliMessage struct {
 	Content string `json:"content"`
 	Content string `json:"content"`
 	Role    string `json:"role"`
 	Role    string `json:"role"`
@@ -97,3 +99,28 @@ type AliImageRequest struct {
 	} `json:"parameters,omitempty"`
 	} `json:"parameters,omitempty"`
 	ResponseFormat string `json:"response_format,omitempty"`
 	ResponseFormat string `json:"response_format,omitempty"`
 }
 }
+
+type AliRerankParameters struct {
+	TopN            *int  `json:"top_n,omitempty"`
+	ReturnDocuments *bool `json:"return_documents,omitempty"`
+}
+
+type AliRerankInput struct {
+	Query     string `json:"query"`
+	Documents []any  `json:"documents"`
+}
+
+type AliRerankRequest struct {
+	Model      string              `json:"model"`
+	Input      AliRerankInput      `json:"input"`
+	Parameters AliRerankParameters `json:"parameters,omitempty"`
+}
+
+type AliRerankResponse struct {
+	Output struct {
+		Results []dto.RerankResponseResult `json:"results"`
+	} `json:"output"`
+	Usage     AliUsage `json:"usage"`
+	RequestId string   `json:"request_id"`
+	AliError
+}

+ 83 - 0
relay/channel/ali/rerank.go

@@ -0,0 +1,83 @@
+package ali
+
+import (
+	"encoding/json"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+
+	"github.com/gin-gonic/gin"
+)
+
+func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
+	returnDocuments := request.ReturnDocuments
+	if returnDocuments == nil {
+		t := true
+		returnDocuments = &t
+	}
+	return &AliRerankRequest{
+		Model: request.Model,
+		Input: AliRerankInput{
+			Query:     request.Query,
+			Documents: request.Documents,
+		},
+		Parameters: AliRerankParameters{
+			TopN:            &request.TopN,
+			ReturnDocuments: returnDocuments,
+		},
+	}
+}
+
+func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	var aliResponse AliRerankResponse
+	err = json.Unmarshal(responseBody, &aliResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	if aliResponse.Code != "" {
+		return &dto.OpenAIErrorWithStatusCode{
+			Error: dto.OpenAIError{
+				Message: aliResponse.Message,
+				Type:    aliResponse.Code,
+				Param:   aliResponse.RequestId,
+				Code:    aliResponse.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+
+	usage := dto.Usage{
+		PromptTokens:     aliResponse.Usage.TotalTokens,
+		CompletionTokens: 0,
+		TotalTokens:      aliResponse.Usage.TotalTokens,
+	}
+	rerankResponse := dto.RerankResponse{
+		Results: aliResponse.Output.Results,
+		Usage:   usage,
+	}
+
+	jsonResponse, err := json.Marshal(rerankResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	return nil, &usage
+}

+ 10 - 9
relay/channel/ali/text.go

@@ -3,7 +3,6 @@ package ali
 import (
 import (
 	"bufio"
 	"bufio"
 	"encoding/json"
 	"encoding/json"
-	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
@@ -11,6 +10,8 @@ import (
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
 	"strings"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
 // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@@ -27,9 +28,6 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
 }
 }
 
 
 func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
 func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
-	if request.Model == "" {
-		request.Model = "text-embedding-v1"
-	}
 	return &AliEmbeddingRequest{
 	return &AliEmbeddingRequest{
 		Model: request.Model,
 		Model: request.Model,
 		Input: struct {
 		Input: struct {
@@ -64,7 +62,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 		}, nil
 		}, nil
 	}
 	}
 
 
-	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
+	model := c.GetString("model")
+	if model == "" {
+		model = "text-embedding-v4"
+	}
+	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
@@ -75,11 +77,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 	return nil, &fullTextResponse.Usage
 	return nil, &fullTextResponse.Usage
 }
 }
 
 
-func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
+func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
 	openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
 	openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
 		Object: "list",
 		Object: "list",
 		Data:   make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
 		Data:   make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
-		Model:  "text-embedding-v1",
+		Model:  model,
 		Usage:  dto.Usage{TotalTokens: response.Usage.TotalTokens},
 		Usage:  dto.Usage{TotalTokens: response.Usage.TotalTokens},
 	}
 	}
 
 
@@ -94,12 +96,11 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
 }
 }
 
 
 func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
 func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
-	content, _ := json.Marshal(response.Output.Text)
 	choice := dto.OpenAITextResponseChoice{
 	choice := dto.OpenAITextResponseChoice{
 		Index: 0,
 		Index: 0,
 		Message: dto.Message{
 		Message: dto.Message{
 			Role:    "assistant",
 			Role:    "assistant",
-			Content: content,
+			Content: response.Output.Text,
 		},
 		},
 		FinishReason: response.Output.FinishReason,
 		FinishReason: response.Output.FinishReason,
 	}
 	}

+ 115 - 48
relay/channel/api_request.go

@@ -104,6 +104,105 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
 	return targetConn, nil
 	return targetConn, nil
 }
 }
 
 
+func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
+	pingerCtx, stopPinger := context.WithCancel(context.Background())
+
+	gopool.Go(func() {
+		defer func() {
+			// 增加panic恢复处理
+			if r := recover(); r != nil {
+				if common2.DebugEnabled {
+					println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
+				}
+			}
+			if common2.DebugEnabled {
+				println("SSE ping goroutine stopped.")
+			}
+		}()
+
+		if pingInterval <= 0 {
+			pingInterval = helper.DefaultPingInterval
+		}
+
+		ticker := time.NewTicker(pingInterval)
+		// 确保在任何情况下都清理ticker
+		defer func() {
+			ticker.Stop()
+			if common2.DebugEnabled {
+				println("SSE ping ticker stopped")
+			}
+		}()
+
+		var pingMutex sync.Mutex
+		if common2.DebugEnabled {
+			println("SSE ping goroutine started")
+		}
+
+		// 增加超时控制,防止goroutine长时间运行
+		maxPingDuration := 120 * time.Minute // 最大ping持续时间
+		pingTimeout := time.NewTimer(maxPingDuration)
+		defer pingTimeout.Stop()
+
+		for {
+			select {
+			// 发送 ping 数据
+			case <-ticker.C:
+				if err := sendPingData(c, &pingMutex); err != nil {
+					if common2.DebugEnabled {
+						println("SSE ping error, stopping goroutine:", err.Error())
+					}
+					return
+				}
+			// 收到退出信号
+			case <-pingerCtx.Done():
+				return
+			// request 结束
+			case <-c.Request.Context().Done():
+				return
+			// 超时保护,防止goroutine无限运行
+			case <-pingTimeout.C:
+				if common2.DebugEnabled {
+					println("SSE ping goroutine timeout, stopping")
+				}
+				return
+			}
+		}
+	})
+
+	return stopPinger
+}
+
+func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
+	// 增加超时控制,防止锁死等待
+	done := make(chan error, 1)
+	go func() {
+		mutex.Lock()
+		defer mutex.Unlock()
+
+		err := helper.PingData(c)
+		if err != nil {
+			common2.LogError(c, "SSE ping error: "+err.Error())
+			done <- err
+			return
+		}
+
+		if common2.DebugEnabled {
+			println("SSE ping data sent.")
+		}
+		done <- nil
+	}()
+
+	// 设置发送ping数据的超时时间
+	select {
+	case err := <-done:
+		return err
+	case <-time.After(10 * time.Second):
+		return errors.New("SSE ping data send timeout")
+	case <-c.Request.Context().Done():
+		return errors.New("request context cancelled during ping")
+	}
+}
+
 func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
 func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
 	var client *http.Client
 	var client *http.Client
 	var err error
 	var err error
@@ -115,68 +214,36 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
 	} else {
 	} else {
 		client = service.GetHttpClient()
 		client = service.GetHttpClient()
 	}
 	}
-	// 流式请求 ping 保活
-	var stopPinger func()
-	generalSettings := operation_setting.GetGeneralSetting()
-	pingEnabled := generalSettings.PingIntervalEnabled
-	var pingerWg sync.WaitGroup
+
+	var stopPinger context.CancelFunc
 	if info.IsStream {
 	if info.IsStream {
 		helper.SetEventStreamHeaders(c)
 		helper.SetEventStreamHeaders(c)
-		pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
-		var pingerCtx context.Context
-		pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
-
-		if pingEnabled {
-			pingerWg.Add(1)
-			gopool.Go(func() {
-				defer pingerWg.Done()
-				if pingInterval <= 0 {
-					pingInterval = helper.DefaultPingInterval
-				}
-
-				ticker := time.NewTicker(pingInterval)
-				defer ticker.Stop()
-				var pingMutex sync.Mutex
-				if common2.DebugEnabled {
-					println("SSE ping goroutine started")
-				}
-
-				for {
-					select {
-					case <-ticker.C:
-						pingMutex.Lock()
-						err2 := helper.PingData(c)
-						pingMutex.Unlock()
-						if err2 != nil {
-							common2.LogError(c, "SSE ping error: "+err.Error())
-							return
-						}
-						if common2.DebugEnabled {
-							println("SSE ping data sent.")
-						}
-					case <-pingerCtx.Done():
-						if common2.DebugEnabled {
-							println("SSE ping goroutine stopped.")
-						}
-						return
+		// 处理流式请求的 ping 保活
+		generalSettings := operation_setting.GetGeneralSetting()
+		if generalSettings.PingIntervalEnabled {
+			pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+			stopPinger = startPingKeepAlive(c, pingInterval)
+			// 使用defer确保在任何情况下都能停止ping goroutine
+			defer func() {
+				if stopPinger != nil {
+					stopPinger()
+					if common2.DebugEnabled {
+						println("SSE ping goroutine stopped by defer")
 					}
 					}
 				}
 				}
-			})
+			}()
 		}
 		}
 	}
 	}
 
 
 	resp, err := client.Do(req)
 	resp, err := client.Do(req)
-	// request结束后停止ping
-	if info.IsStream && pingEnabled {
-		stopPinger()
-		pingerWg.Wait()
-	}
+
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 	if resp == nil {
 	if resp == nil {
 		return nil, errors.New("resp is nil")
 		return nil, errors.New("resp is nil")
 	}
 	}
+
 	_ = req.Body.Close()
 	_ = req.Body.Close()
 	_ = c.Request.Body.Close()
 	_ = c.Request.Body.Close()
 	return resp, nil
 	return resp, nil

+ 12 - 0
relay/channel/aws/constants.go

@@ -11,6 +11,8 @@ var awsModelIDMap = map[string]string{
 	"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
 	"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
 	"claude-3-5-haiku-20241022":  "anthropic.claude-3-5-haiku-20241022-v1:0",
 	"claude-3-5-haiku-20241022":  "anthropic.claude-3-5-haiku-20241022-v1:0",
 	"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
 	"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
+	"claude-sonnet-4-20250514":   "anthropic.claude-sonnet-4-20250514-v1:0",
+	"claude-opus-4-20250514":     "anthropic.claude-opus-4-20250514-v1:0",
 }
 }
 
 
 var awsModelCanCrossRegionMap = map[string]map[string]bool{
 var awsModelCanCrossRegionMap = map[string]map[string]bool{
@@ -41,6 +43,16 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
 	},
 	},
 	"anthropic.claude-3-7-sonnet-20250219-v1:0": {
 	"anthropic.claude-3-7-sonnet-20250219-v1:0": {
 		"us": true,
 		"us": true,
+		"ap": true,
+		"eu": true,
+	},
+	"anthropic.claude-sonnet-4-20250514-v1:0": {
+		"us": true,
+		"ap": true,
+		"eu": true,
+	},
+	"anthropic.claude-opus-4-20250514-v1:0": {
+		"us": true,
 	},
 	},
 }
 }
 
 

+ 1 - 2
relay/channel/baidu/relay-baidu.go

@@ -53,12 +53,11 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
 }
 }
 
 
 func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
 func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
-	content, _ := json.Marshal(response.Result)
 	choice := dto.OpenAITextResponseChoice{
 	choice := dto.OpenAITextResponseChoice{
 		Index: 0,
 		Index: 0,
 		Message: dto.Message{
 		Message: dto.Message{
 			Role:    "assistant",
 			Role:    "assistant",
-			Content: content,
+			Content: response.Result,
 		},
 		},
 		FinishReason: "stop",
 		FinishReason: "stop",
 	}
 	}

+ 13 - 0
relay/channel/baidu_v2/adaptor.go

@@ -9,6 +9,7 @@ import (
 	"one-api/relay/channel"
 	"one-api/relay/channel"
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"strings"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
@@ -49,6 +50,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 	if request == nil {
 		return nil, errors.New("request is nil")
 		return nil, errors.New("request is nil")
 	}
 	}
+	if strings.HasSuffix(info.UpstreamModelName, "-search") {
+		info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
+		request.Model = info.UpstreamModelName
+		toMap := request.ToMap()
+		toMap["web_search"] = map[string]any{
+			"enable":          true,
+			"enable_citation": true,
+			"enable_trace":    true,
+			"enable_status":   false,
+		}
+		return toMap, nil
+	}
 	return request, nil
 	return request, nil
 }
 }
 
 

+ 3 - 3
relay/channel/claude/adaptor.go

@@ -38,10 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 }
 
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
-		a.RequestMode = RequestModeMessage
-	} else {
+	if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
 		a.RequestMode = RequestModeCompletion
 		a.RequestMode = RequestModeCompletion
+	} else {
+		a.RequestMode = RequestModeMessage
 	}
 	}
 }
 }
 
 

+ 4 - 0
relay/channel/claude/constants.go

@@ -13,6 +13,10 @@ var ModelList = []string{
 	"claude-3-5-sonnet-20241022",
 	"claude-3-5-sonnet-20241022",
 	"claude-3-7-sonnet-20250219",
 	"claude-3-7-sonnet-20250219",
 	"claude-3-7-sonnet-20250219-thinking",
 	"claude-3-7-sonnet-20250219-thinking",
+	"claude-sonnet-4-20250514",
+	"claude-sonnet-4-20250514-thinking",
+	"claude-opus-4-20250514",
+	"claude-opus-4-20250514-thinking",
 }
 }
 
 
 var ChannelName = "claude"
 var ChannelName = "claude"

+ 41 - 48
relay/channel/claude/relay-claude.go

@@ -48,9 +48,9 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
 	prompt := ""
 	prompt := ""
 	for _, message := range textRequest.Messages {
 	for _, message := range textRequest.Messages {
 		if message.Role == "user" {
 		if message.Role == "user" {
-			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
+			prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
 		} else if message.Role == "assistant" {
 		} else if message.Role == "assistant" {
-			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
+			prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
 		} else if message.Role == "system" {
 		} else if message.Role == "system" {
 			if prompt == "" {
 			if prompt == "" {
 				prompt = message.StringContent()
 				prompt = message.StringContent()
@@ -113,7 +113,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 		// BudgetTokens 为 max_tokens 的 80%
 		// BudgetTokens 为 max_tokens 的 80%
 		claudeRequest.Thinking = &dto.Thinking{
 		claudeRequest.Thinking = &dto.Thinking{
 			Type:         "enabled",
 			Type:         "enabled",
-			BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
+			BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
 		}
 		}
 		// TODO: 临时处理
 		// TODO: 临时处理
 		// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
 		// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
@@ -155,15 +155,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 		}
 		}
 		if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
 		if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
 			if lastMessage.IsStringContent() && message.IsStringContent() {
 			if lastMessage.IsStringContent() && message.IsStringContent() {
-				content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
-				fmtMessage.Content = content
+				fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
 				// delete last message
 				// delete last message
 				formatMessages = formatMessages[:len(formatMessages)-1]
 				formatMessages = formatMessages[:len(formatMessages)-1]
 			}
 			}
 		}
 		}
 		if fmtMessage.Content == nil {
 		if fmtMessage.Content == nil {
-			content, _ := json.Marshal("...")
-			fmtMessage.Content = content
+			fmtMessage.SetStringContent("...")
 		}
 		}
 		formatMessages = append(formatMessages, fmtMessage)
 		formatMessages = append(formatMessages, fmtMessage)
 		lastMessage = fmtMessage
 		lastMessage = fmtMessage
@@ -397,12 +395,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
 	thinkingContent := ""
 	thinkingContent := ""
 
 
 	if reqMode == RequestModeCompletion {
 	if reqMode == RequestModeCompletion {
-		content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
 		choice := dto.OpenAITextResponseChoice{
 		choice := dto.OpenAITextResponseChoice{
 			Index: 0,
 			Index: 0,
 			Message: dto.Message{
 			Message: dto.Message{
 				Role:    "assistant",
 				Role:    "assistant",
-				Content: content,
+				Content: strings.TrimPrefix(claudeResponse.Completion, " "),
 				Name:    nil,
 				Name:    nil,
 			},
 			},
 			FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
 			FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
@@ -457,6 +454,7 @@ type ClaudeResponseInfo struct {
 	Model        string
 	Model        string
 	ResponseText strings.Builder
 	ResponseText strings.Builder
 	Usage        *dto.Usage
 	Usage        *dto.Usage
+	Done         bool
 }
 }
 
 
 func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
 func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
@@ -464,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
 		claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
 		claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
 	} else {
 	} else {
 		if claudeResponse.Type == "message_start" {
 		if claudeResponse.Type == "message_start" {
-			// message_start, 获取usage
 			claudeInfo.ResponseId = claudeResponse.Message.Id
 			claudeInfo.ResponseId = claudeResponse.Message.Id
 			claudeInfo.Model = claudeResponse.Message.Model
 			claudeInfo.Model = claudeResponse.Message.Model
+
+			// message_start, 获取usage
 			claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
 			claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
+			claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
+			claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
+			claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
 		} else if claudeResponse.Type == "content_block_delta" {
 		} else if claudeResponse.Type == "content_block_delta" {
 			if claudeResponse.Delta.Text != nil {
 			if claudeResponse.Delta.Text != nil {
 				claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
 				claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
 			}
 			}
+			if claudeResponse.Delta.Thinking != "" {
+				claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
+			}
 		} else if claudeResponse.Type == "message_delta" {
 		} else if claudeResponse.Type == "message_delta" {
-			claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+			// 最终的usage获取
 			if claudeResponse.Usage.InputTokens > 0 {
 			if claudeResponse.Usage.InputTokens > 0 {
+				// 不叠加,只取最新的
 				claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
 				claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
 			}
 			}
-			claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
+			claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+			claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
+
+			// 判断是否完整
+			claudeInfo.Done = true
 		} else if claudeResponse.Type == "content_block_start" {
 		} else if claudeResponse.Type == "content_block_start" {
 		} else {
 		} else {
 			return false
 			return false
@@ -509,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		}
 		}
 	}
 	}
 	if info.RelayFormat == relaycommon.RelayFormatClaude {
 	if info.RelayFormat == relaycommon.RelayFormatClaude {
+		FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
+
 		if requestMode == RequestModeCompletion {
 		if requestMode == RequestModeCompletion {
-			claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
 		} else {
 		} else {
 			if claudeResponse.Type == "message_start" {
 			if claudeResponse.Type == "message_start" {
 				// message_start, 获取usage
 				// message_start, 获取usage
 				info.UpstreamModelName = claudeResponse.Message.Model
 				info.UpstreamModelName = claudeResponse.Message.Model
-				claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
-				claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
-				claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
-				claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
 			} else if claudeResponse.Type == "content_block_delta" {
 			} else if claudeResponse.Type == "content_block_delta" {
-				claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
 			} else if claudeResponse.Type == "message_delta" {
 			} else if claudeResponse.Type == "message_delta" {
-				if claudeResponse.Usage.InputTokens > 0 {
-					// 不叠加,只取最新的
-					claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
-				}
-				claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
-				claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
 			}
 			}
 		}
 		}
 		helper.ClaudeChunkData(c, claudeResponse, data)
 		helper.ClaudeChunkData(c, claudeResponse, data)
@@ -547,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 }
 }
 
 
 func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
 func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
-	if info.RelayFormat == relaycommon.RelayFormatClaude {
-		if requestMode == RequestModeCompletion {
-			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
-		} else {
-			// 说明流模式建立失败,可能为官方出错
-			if claudeInfo.Usage.PromptTokens == 0 {
-				//usage.PromptTokens = info.PromptTokens
-			}
-			if claudeInfo.Usage.CompletionTokens == 0 {
-				claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
-			}
+
+	if requestMode == RequestModeCompletion {
+		claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+	} else {
+		if claudeInfo.Usage.PromptTokens == 0 {
+			//上游出错
 		}
 		}
-	} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
-		if requestMode == RequestModeCompletion {
-			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
-		} else {
-			if claudeInfo.Usage.PromptTokens == 0 {
-				//上游出错
-			}
-			if claudeInfo.Usage.CompletionTokens == 0 {
-				claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+		if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
+			if common.DebugEnabled {
+				common.SysError("claude response usage is not complete, maybe upstream error")
 			}
 			}
+			claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
 		}
 		}
+	}
+
+	if info.RelayFormat == relaycommon.RelayFormatClaude {
+		//
+	} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
+
 		if info.ShouldIncludeUsage {
 		if info.ShouldIncludeUsage {
 			response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
 			response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
 			err := helper.ObjectData(c, response)
 			err := helper.ObjectData(c, response)
@@ -622,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		}
 		}
 	}
 	}
 	if requestMode == RequestModeCompletion {
 	if requestMode == RequestModeCompletion {
-		completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
-		if err != nil {
-			return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
-		}
+		completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
 		claudeInfo.Usage.PromptTokens = info.PromptTokens
 		claudeInfo.Usage.PromptTokens = info.PromptTokens
 		claudeInfo.Usage.CompletionTokens = completionTokens
 		claudeInfo.Usage.CompletionTokens = completionTokens
 		claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
 		claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens

+ 3 - 3
relay/channel/cloudflare/relay_cloudflare.go

@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	if err := scanner.Err(); err != nil {
 	if err := scanner.Err(); err != nil {
 		common.LogError(c, "error_scanning_stream_response: "+err.Error())
 		common.LogError(c, "error_scanning_stream_response: "+err.Error())
 	}
 	}
-	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	if info.ShouldIncludeUsage {
 	if info.ShouldIncludeUsage {
 		response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
 		response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
 		err := helper.ObjectData(c, response)
 		err := helper.ObjectData(c, response)
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
 	for _, choice := range response.Choices {
 	for _, choice := range response.Choices {
 		responseText += choice.Message.StringContent()
 		responseText += choice.Message.StringContent()
 	}
 	}
-	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	response.Usage = *usage
 	response.Usage = *usage
 	response.Id = helper.GetResponseID(c)
 	response.Id = helper.GetResponseID(c)
 	jsonResponse, err := json.Marshal(response)
 	jsonResponse, err := json.Marshal(response)
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
 
 
 	usage := &dto.Usage{}
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
 	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
+	usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 
 
 	return nil, usage
 	return nil, usage

+ 3 - 5
relay/channel/cohere/relay-cohere.go

@@ -3,7 +3,6 @@ package cohere
 import (
 import (
 	"bufio"
 	"bufio"
 	"encoding/json"
 	"encoding/json"
-	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
@@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string {
 }
 }
 
 
 func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+	responseId := helper.GetResponseID(c)
 	createdTime := common.GetTimestamp()
 	createdTime := common.GetTimestamp()
 	usage := &dto.Usage{}
 	usage := &dto.Usage{}
 	responseText := ""
 	responseText := ""
@@ -163,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		}
 		}
 	})
 	})
 	if usage.PromptTokens == 0 {
 	if usage.PromptTokens == 0 {
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	}
 	}
 	return nil, usage
 	return nil, usage
 }
 }
@@ -195,11 +194,10 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
 	openaiResp.Model = modelName
 	openaiResp.Model = modelName
 	openaiResp.Usage = usage
 	openaiResp.Usage = usage
 
 
-	content, _ := json.Marshal(cohereResp.Text)
 	openaiResp.Choices = []dto.OpenAITextResponseChoice{
 	openaiResp.Choices = []dto.OpenAITextResponseChoice{
 		{
 		{
 			Index:        0,
 			Index:        0,
-			Message:      dto.Message{Content: content, Role: "assistant"},
+			Message:      dto.Message{Content: cohereResp.Text, Role: "assistant"},
 			FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
 			FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
 		},
 		},
 	}
 	}

+ 1 - 1
relay/channel/coze/dto.go

@@ -10,7 +10,7 @@ type CozeError struct {
 type CozeEnterMessage struct {
 type CozeEnterMessage struct {
 	Role        string          `json:"role"`
 	Role        string          `json:"role"`
 	Type        string          `json:"type,omitempty"`
 	Type        string          `json:"type,omitempty"`
-	Content     json.RawMessage `json:"content,omitempty"`
+	Content     any             `json:"content,omitempty"`
 	MetaData    json.RawMessage `json:"meta_data,omitempty"`
 	MetaData    json.RawMessage `json:"meta_data,omitempty"`
 	ContentType string          `json:"content_type,omitempty"`
 	ContentType string          `json:"content_type,omitempty"`
 }
 }

+ 5 - 7
relay/channel/coze/relay-coze.go

@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
 
 
 	var currentEvent string
 	var currentEvent string
 	var currentData string
 	var currentData string
-	var usage dto.Usage
+	var usage = &dto.Usage{}
 
 
 	for scanner.Scan() {
 	for scanner.Scan() {
 		line := scanner.Text()
 		line := scanner.Text()
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
 		if line == "" {
 		if line == "" {
 			if currentEvent != "" && currentData != "" {
 			if currentEvent != "" && currentData != "" {
 				// handle last event
 				// handle last event
-				handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+				handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
 				currentEvent = ""
 				currentEvent = ""
 				currentData = ""
 				currentData = ""
 			}
 			}
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
 
 
 	// Last event
 	// Last event
 	if currentEvent != "" && currentData != "" {
 	if currentEvent != "" && currentData != "" {
-		handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+		handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
 	}
 	}
 
 
 	if err := scanner.Err(); err != nil {
 	if err := scanner.Err(); err != nil {
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
 	helper.Done(c)
 	helper.Done(c)
 
 
 	if usage.TotalTokens == 0 {
 	if usage.TotalTokens == 0 {
-		usage.PromptTokens = info.PromptTokens
-		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
-		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
 	}
 	}
 
 
-	return nil, &usage
+	return nil, usage
 }
 }
 
 
 func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
 func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {

+ 2 - 10
relay/channel/dify/relay-dify.go

@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 		return true
 		return true
 	})
 	})
 	helper.Done(c)
 	helper.Done(c)
-	err := resp.Body.Close()
-	if err != nil {
-		// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-		common.SysError("close_response_body_failed: " + err.Error())
-	}
 	if usage.TotalTokens == 0 {
 	if usage.TotalTokens == 0 {
-		usage.PromptTokens = info.PromptTokens
-		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
-		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	}
 	}
 	usage.CompletionTokens += nodeToken
 	usage.CompletionTokens += nodeToken
 	return nil, usage
 	return nil, usage
@@ -278,12 +271,11 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
 		Created: common.GetTimestamp(),
 		Created: common.GetTimestamp(),
 		Usage:   difyResponse.MetaData.Usage,
 		Usage:   difyResponse.MetaData.Usage,
 	}
 	}
-	content, _ := json.Marshal(difyResponse.Answer)
 	choice := dto.OpenAITextResponseChoice{
 	choice := dto.OpenAITextResponseChoice{
 		Index: 0,
 		Index: 0,
 		Message: dto.Message{
 		Message: dto.Message{
 			Role:    "assistant",
 			Role:    "assistant",
-			Content: content,
+			Content: difyResponse.Answer,
 		},
 		},
 		FinishReason: "stop",
 		FinishReason: "stop",
 	}
 	}

+ 15 - 3
relay/channel/gemini/adaptor.go

@@ -10,6 +10,7 @@ import (
 	"one-api/dto"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
 	"one-api/service"
 	"one-api/service"
 	"one-api/setting/model_setting"
 	"one-api/setting/model_setting"
 	"strings"
 	"strings"
@@ -71,10 +72,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 
 
 	if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
 	if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
-		// suffix -thinking and -nothinking
-		if strings.HasSuffix(info.OriginModelName, "-thinking") {
+		// 新增逻辑:处理 -thinking-<budget> 格式
+		if strings.Contains(info.UpstreamModelName, "-thinking-") {
+			parts := strings.Split(info.UpstreamModelName, "-thinking-")
+			info.UpstreamModelName = parts[0]
+		} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
 			info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
 			info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
-		} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
+		} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
 			info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
 			info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
 		}
 		}
 	}
 	}
@@ -165,6 +169,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 }
 }
 
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	if info.RelayMode == constant.RelayModeGemini {
+		if info.IsStream {
+			return GeminiTextGenerationStreamHandler(c, resp, info)
+		} else {
+			return GeminiTextGenerationHandler(c, resp, info)
+		}
+	}
+
 	if strings.HasPrefix(info.UpstreamModelName, "imagen") {
 	if strings.HasPrefix(info.UpstreamModelName, "imagen") {
 		return GeminiImageHandler(c, resp, info)
 		return GeminiImageHandler(c, resp, info)
 	}
 	}

+ 70 - 14
relay/channel/gemini/dto.go

@@ -1,11 +1,13 @@
 package gemini
 package gemini
 
 
+import "encoding/json"
+
 type GeminiChatRequest struct {
 type GeminiChatRequest struct {
 	Contents           []GeminiChatContent        `json:"contents"`
 	Contents           []GeminiChatContent        `json:"contents"`
-	SafetySettings     []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
-	GenerationConfig   GeminiChatGenerationConfig `json:"generation_config,omitempty"`
+	SafetySettings     []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
+	GenerationConfig   GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
 	Tools              []GeminiChatTool           `json:"tools,omitempty"`
 	Tools              []GeminiChatTool           `json:"tools,omitempty"`
-	SystemInstructions *GeminiChatContent         `json:"system_instruction,omitempty"`
+	SystemInstructions *GeminiChatContent         `json:"systemInstruction,omitempty"`
 }
 }
 
 
 type GeminiThinkingConfig struct {
 type GeminiThinkingConfig struct {
@@ -22,19 +24,38 @@ type GeminiInlineData struct {
 	Data     string `json:"data"`
 	Data     string `json:"data"`
 }
 }
 
 
+// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
+func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
+	type Alias GeminiInlineData // Use type alias to avoid recursion
+	var aux struct {
+		Alias
+		MimeTypeSnake string `json:"mime_type"`
+	}
+
+	if err := json.Unmarshal(data, &aux); err != nil {
+		return err
+	}
+
+	*g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
+
+	// Prioritize snake_case if present
+	if aux.MimeTypeSnake != "" {
+		g.MimeType = aux.MimeTypeSnake
+	} else if aux.MimeType != "" { // Fallback to camelCase from Alias
+		g.MimeType = aux.MimeType
+	}
+	// g.Data would be populated by aux.Alias.Data
+	return nil
+}
+
 type FunctionCall struct {
 type FunctionCall struct {
 	FunctionName string `json:"name"`
 	FunctionName string `json:"name"`
 	Arguments    any    `json:"args"`
 	Arguments    any    `json:"args"`
 }
 }
 
 
-type GeminiFunctionResponseContent struct {
-	Name    string `json:"name"`
-	Content any    `json:"content"`
-}
-
 type FunctionResponse struct {
 type FunctionResponse struct {
-	Name     string                        `json:"name"`
-	Response GeminiFunctionResponseContent `json:"response"`
+	Name     string                 `json:"name"`
+	Response map[string]interface{} `json:"response"`
 }
 }
 
 
 type GeminiPartExecutableCode struct {
 type GeminiPartExecutableCode struct {
@@ -54,6 +75,7 @@ type GeminiFileData struct {
 
 
 type GeminiPart struct {
 type GeminiPart struct {
 	Text                string                         `json:"text,omitempty"`
 	Text                string                         `json:"text,omitempty"`
+	Thought             bool                           `json:"thought,omitempty"`
 	InlineData          *GeminiInlineData              `json:"inlineData,omitempty"`
 	InlineData          *GeminiInlineData              `json:"inlineData,omitempty"`
 	FunctionCall        *FunctionCall                  `json:"functionCall,omitempty"`
 	FunctionCall        *FunctionCall                  `json:"functionCall,omitempty"`
 	FunctionResponse    *FunctionResponse              `json:"functionResponse,omitempty"`
 	FunctionResponse    *FunctionResponse              `json:"functionResponse,omitempty"`
@@ -62,6 +84,33 @@ type GeminiPart struct {
 	CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
 	CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
 }
 }
 
 
+// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
+func (p *GeminiPart) UnmarshalJSON(data []byte) error {
+	// Alias to avoid recursion during unmarshalling
+	type Alias GeminiPart
+	var aux struct {
+		Alias
+		InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
+	}
+
+	if err := json.Unmarshal(data, &aux); err != nil {
+		return err
+	}
+
+	// Assign fields from alias
+	*p = GeminiPart(aux.Alias)
+
+	// Prioritize snake_case for InlineData if present
+	if aux.InlineDataSnake != nil {
+		p.InlineData = aux.InlineDataSnake
+	} else if aux.InlineData != nil { // Fallback to camelCase from Alias
+		p.InlineData = aux.InlineData
+	}
+	// Other fields like Text, FunctionCall etc. are already populated via aux.Alias
+
+	return nil
+}
+
 type GeminiChatContent struct {
 type GeminiChatContent struct {
 	Role  string       `json:"role,omitempty"`
 	Role  string       `json:"role,omitempty"`
 	Parts []GeminiPart `json:"parts"`
 	Parts []GeminiPart `json:"parts"`
@@ -91,6 +140,7 @@ type GeminiChatGenerationConfig struct {
 	Seed               int64                 `json:"seed,omitempty"`
 	Seed               int64                 `json:"seed,omitempty"`
 	ResponseModalities []string              `json:"responseModalities,omitempty"`
 	ResponseModalities []string              `json:"responseModalities,omitempty"`
 	ThinkingConfig     *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
 	ThinkingConfig     *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
+	SpeechConfig       json.RawMessage       `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
 }
 }
 
 
 type GeminiChatCandidate struct {
 type GeminiChatCandidate struct {
@@ -116,10 +166,16 @@ type GeminiChatResponse struct {
 }
 }
 
 
 type GeminiUsageMetadata struct {
 type GeminiUsageMetadata struct {
-	PromptTokenCount     int `json:"promptTokenCount"`
-	CandidatesTokenCount int `json:"candidatesTokenCount"`
-	TotalTokenCount      int `json:"totalTokenCount"`
-	ThoughtsTokenCount   int `json:"thoughtsTokenCount"`
+	PromptTokenCount     int                         `json:"promptTokenCount"`
+	CandidatesTokenCount int                         `json:"candidatesTokenCount"`
+	TotalTokenCount      int                         `json:"totalTokenCount"`
+	ThoughtsTokenCount   int                         `json:"thoughtsTokenCount"`
+	PromptTokensDetails  []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+}
+
+type GeminiPromptTokensDetails struct {
+	Modality   string `json:"modality"`
+	TokenCount int    `json:"tokenCount"`
 }
 }
 
 
 // Imagen related structs
 // Imagen related structs

+ 146 - 0
relay/channel/gemini/relay-gemini-native.go

@@ -0,0 +1,146 @@
+package gemini
+
+import (
+	"encoding/json"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
+	"one-api/service"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+	// 读取响应体
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+	}
+
+	if common.DebugEnabled {
+		println(string(responseBody))
+	}
+
+	// 解析为 Gemini 原生响应格式
+	var geminiResponse GeminiChatResponse
+	err = common.DecodeJson(responseBody, &geminiResponse)
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	// 计算使用量(基于 UsageMetadata)
+	usage := dto.Usage{
+		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
+		CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
+		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
+	}
+
+	usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+
+	for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+		if detail.Modality == "AUDIO" {
+			usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+		} else if detail.Modality == "TEXT" {
+			usage.PromptTokensDetails.TextTokens = detail.TokenCount
+		}
+	}
+
+	// 直接返回 Gemini 原生格式的 JSON 响应
+	jsonResponse, err := json.Marshal(geminiResponse)
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	// 设置响应头并写入响应
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
+	}
+
+	return &usage, nil
+}
+
+func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+	var usage = &dto.Usage{}
+	var imageCount int
+
+	helper.SetEventStreamHeaders(c)
+
+	responseText := strings.Builder{}
+
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+		var geminiResponse GeminiChatResponse
+		err := common.DecodeJsonStr(data, &geminiResponse)
+		if err != nil {
+			common.LogError(c, "error unmarshalling stream response: "+err.Error())
+			return false
+		}
+
+		// 统计图片数量
+		for _, candidate := range geminiResponse.Candidates {
+			for _, part := range candidate.Content.Parts {
+				if part.InlineData != nil && part.InlineData.MimeType != "" {
+					imageCount++
+				}
+				if part.Text != "" {
+					responseText.WriteString(part.Text)
+				}
+			}
+		}
+
+		// 更新使用量统计
+		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
+			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
+			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
+			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+			usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+			for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+				if detail.Modality == "AUDIO" {
+					usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+				} else if detail.Modality == "TEXT" {
+					usage.PromptTokensDetails.TextTokens = detail.TokenCount
+				}
+			}
+		}
+
+		// 直接发送 GeminiChatResponse 响应
+		err = helper.StringData(c, data)
+		if err != nil {
+			common.LogError(c, err.Error())
+		}
+
+		return true
+	})
+
+	if imageCount != 0 {
+		if usage.CompletionTokens == 0 {
+			usage.CompletionTokens = imageCount * 258
+		}
+	}
+
+	// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
+	if usage.CompletionTokens == 0 {
+		str := responseText.String()
+		if len(str) > 0 {
+			usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
+		} else {
+			// 空补全,不需要使用量
+			usage = &dto.Usage{}
+		}
+	}
+
+	// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
+	//helper.Done(c)
+
+	return usage, nil
+}

+ 270 - 109
relay/channel/gemini/relay-gemini.go

@@ -12,12 +12,72 @@ import (
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
 	"one-api/setting/model_setting"
 	"one-api/setting/model_setting"
+	"strconv"
 	"strings"
 	"strings"
 	"unicode/utf8"
 	"unicode/utf8"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 )
 )
 
 
+var geminiSupportedMimeTypes = map[string]bool{
+	"application/pdf": true,
+	"audio/mpeg":      true,
+	"audio/mp3":       true,
+	"audio/wav":       true,
+	"image/png":       true,
+	"image/jpeg":      true,
+	"text/plain":      true,
+	"video/mov":       true,
+	"video/mpeg":      true,
+	"video/mp4":       true,
+	"video/mpg":       true,
+	"video/avi":       true,
+	"video/wmv":       true,
+	"video/mpegps":    true,
+	"video/flv":       true,
+}
+
+// Gemini 允许的思考预算范围
+const (
+	pro25MinBudget       = 128
+	pro25MaxBudget       = 32768
+	flash25MaxBudget     = 24576
+	flash25LiteMinBudget = 512
+	flash25LiteMaxBudget = 24576
+)
+
+// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
+func clampThinkingBudget(modelName string, budget int) int {
+	isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+		!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+		!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+	is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
+
+	if is25FlashLite {
+		if budget < flash25LiteMinBudget {
+			return flash25LiteMinBudget
+		}
+		if budget > flash25LiteMaxBudget {
+			return flash25LiteMaxBudget
+		}
+	} else if isNew25Pro {
+		if budget < pro25MinBudget {
+			return pro25MinBudget
+		}
+		if budget > pro25MaxBudget {
+			return pro25MaxBudget
+		}
+	} else { // 其他模型
+		if budget < 0 {
+			return 0
+		}
+		if budget > flash25MaxBudget {
+			return flash25MaxBudget
+		}
+	}
+	return budget
+}
+
 // Setting safety to the lowest possible values since Gemini is already powerless enough
 // Setting safety to the lowest possible values since Gemini is already powerless enough
 func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
 func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
 
 
@@ -39,18 +99,54 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 	}
 	}
 
 
 	if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
 	if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
-		if strings.HasSuffix(info.OriginModelName, "-thinking") {
-			budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
-			if budgetTokens == 0 || budgetTokens > 24576 {
-				budgetTokens = 24576
+		modelName := info.UpstreamModelName
+		isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+			!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+			!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+
+		if strings.Contains(modelName, "-thinking-") {
+			parts := strings.SplitN(modelName, "-thinking-", 2)
+			if len(parts) == 2 && parts[1] != "" {
+				if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
+					clampedBudget := clampThinkingBudget(modelName, budgetTokens)
+					geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+						ThinkingBudget:  common.GetPointer(clampedBudget),
+						IncludeThoughts: true,
+					}
+				}
 			}
 			}
-			geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
-				ThinkingBudget:  common.GetPointer(int(budgetTokens)),
-				IncludeThoughts: true,
+		} else if strings.HasSuffix(modelName, "-thinking") {
+			unsupportedModels := []string{
+				"gemini-2.5-pro-preview-05-06",
+				"gemini-2.5-pro-preview-03-25",
 			}
 			}
-		} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
-			geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
-				ThinkingBudget: common.GetPointer(0),
+			isUnsupported := false
+			for _, unsupportedModel := range unsupportedModels {
+				if strings.HasPrefix(modelName, unsupportedModel) {
+					isUnsupported = true
+					break
+				}
+			}
+
+			if isUnsupported {
+				geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+					IncludeThoughts: true,
+				}
+			} else {
+				geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+					IncludeThoughts: true,
+				}
+				if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
+					budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
+					clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
+					geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
+				}
+			}
+		} else if strings.HasSuffix(modelName, "-nothinking") {
+			if !isNew25Pro {
+				geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+					ThinkingBudget: common.GetPointer(0),
+				}
 			}
 			}
 		}
 		}
 	}
 	}
@@ -112,12 +208,6 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 		// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
 		// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
 		// json_data, _ := json.Marshal(geminiRequest.Tools)
 		// json_data, _ := json.Marshal(geminiRequest.Tools)
 		// common.SysLog("tools_json: " + string(json_data))
 		// common.SysLog("tools_json: " + string(json_data))
-	} else if textRequest.Functions != nil {
-		//geminiRequest.Tools = []GeminiChatTool{
-		//	{
-		//		FunctionDeclarations: textRequest.Functions,
-		//	},
-		//}
 	}
 	}
 
 
 	if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
 	if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
@@ -148,17 +238,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 			} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
 			} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
 				name = val
 				name = val
 			}
 			}
-			content := common.StrToMap(message.StringContent())
-			functionResp := &FunctionResponse{
-				Name: name,
-				Response: GeminiFunctionResponseContent{
-					Name:    name,
-					Content: content,
-				},
+			var contentMap map[string]interface{}
+			contentStr := message.StringContent()
+
+			// 1. 尝试解析为 JSON 对象
+			if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
+				// 2. 如果失败,尝试解析为 JSON 数组
+				var contentSlice []interface{}
+				if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
+					// 如果是数组,包装成对象
+					contentMap = map[string]interface{}{"result": contentSlice}
+				} else {
+					// 3. 如果再次失败,作为纯文本处理
+					contentMap = map[string]interface{}{"content": contentStr}
+				}
 			}
 			}
-			if content == nil {
-				functionResp.Response.Content = message.StringContent()
+
+			functionResp := &FunctionResponse{
+				Name:     name,
+				Response: contentMap,
 			}
 			}
+
 			*parts = append(*parts, GeminiPart{
 			*parts = append(*parts, GeminiPart{
 				FunctionResponse: functionResp,
 				FunctionResponse: functionResp,
 			})
 			})
@@ -208,14 +308,21 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 				}
 				}
 				// 判断是否是url
 				// 判断是否是url
 				if strings.HasPrefix(part.GetImageMedia().Url, "http") {
 				if strings.HasPrefix(part.GetImageMedia().Url, "http") {
-					// 是url,获取图片的类型和base64编码的数据
+					// 是url,获取文件的类型和base64编码的数据
 					fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
 					fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
 					if err != nil {
 					if err != nil {
-						return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
+						return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
+					}
+
+					// 校验 MimeType 是否在 Gemini 支持的白名单中
+					if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
+						url := part.GetImageMedia().Url
+						return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
 					}
 					}
+
 					parts = append(parts, GeminiPart{
 					parts = append(parts, GeminiPart{
 						InlineData: &GeminiInlineData{
 						InlineData: &GeminiInlineData{
-							MimeType: fileData.MimeType,
+							MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
 							Data:     fileData.Base64Data,
 							Data:     fileData.Base64Data,
 						},
 						},
 					})
 					})
@@ -249,13 +356,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 				if part.GetInputAudio().Data == "" {
 				if part.GetInputAudio().Data == "" {
 					return nil, fmt.Errorf("only base64 audio is supported in gemini")
 					return nil, fmt.Errorf("only base64 audio is supported in gemini")
 				}
 				}
-				format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
+				base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
 				if err != nil {
 				if err != nil {
 					return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
 					return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
 				}
 				}
 				parts = append(parts, GeminiPart{
 				parts = append(parts, GeminiPart{
 					InlineData: &GeminiInlineData{
 					InlineData: &GeminiInlineData{
-						MimeType: format,
+						MimeType: "audio/" + part.GetInputAudio().Format,
 						Data:     base64String,
 						Data:     base64String,
 					},
 					},
 				})
 				})
@@ -268,7 +375,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 		if content.Role == "assistant" {
 		if content.Role == "assistant" {
 			content.Role = "model"
 			content.Role = "model"
 		}
 		}
-		geminiRequest.Contents = append(geminiRequest.Contents, content)
+		if len(content.Parts) > 0 {
+			geminiRequest.Contents = append(geminiRequest.Contents, content)
+		}
 	}
 	}
 
 
 	if len(system_content) > 0 {
 	if len(system_content) > 0 {
@@ -284,100 +393,126 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
 	return &geminiRequest, nil
 	return &geminiRequest, nil
 }
 }
 
 
+// Helper function to get a list of supported MIME types for error messages
+func getSupportedMimeTypesList() []string {
+	keys := make([]string, 0, len(geminiSupportedMimeTypes))
+	for k := range geminiSupportedMimeTypes {
+		keys = append(keys, k)
+	}
+	return keys
+}
+
 // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
 // cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
 func cleanFunctionParameters(params interface{}) interface{} {
 func cleanFunctionParameters(params interface{}) interface{} {
 	if params == nil {
 	if params == nil {
 		return nil
 		return nil
 	}
 	}
 
 
-	paramMap, ok := params.(map[string]interface{})
-	if !ok {
-		// Not a map, return as is (e.g., could be an array or primitive)
-		return params
-	}
-
-	// Create a copy to avoid modifying the original
-	cleanedMap := make(map[string]interface{})
-	for k, v := range paramMap {
-		cleanedMap[k] = v
-	}
+	switch v := params.(type) {
+	case map[string]interface{}:
+		// Create a copy to avoid modifying the original
+		cleanedMap := make(map[string]interface{})
+		for k, val := range v {
+			cleanedMap[k] = val
+		}
 
 
-	// Remove unsupported root-level fields
-	delete(cleanedMap, "default")
-	delete(cleanedMap, "exclusiveMaximum")
-	delete(cleanedMap, "exclusiveMinimum")
-	delete(cleanedMap, "$schema")
-	delete(cleanedMap, "additionalProperties")
+		// Remove unsupported root-level fields
+		delete(cleanedMap, "default")
+		delete(cleanedMap, "exclusiveMaximum")
+		delete(cleanedMap, "exclusiveMinimum")
+		delete(cleanedMap, "$schema")
+		delete(cleanedMap, "additionalProperties")
 
 
-	// Clean properties
-	if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
-		cleanedProps := make(map[string]interface{})
-		for propName, propValue := range props {
-			propMap, ok := propValue.(map[string]interface{})
-			if !ok {
-				cleanedProps[propName] = propValue // Keep non-map properties
-				continue
+		// Check and clean 'format' for string types
+		if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
+			if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
+				if formatValue != "enum" && formatValue != "date-time" {
+					delete(cleanedMap, "format")
+				}
 			}
 			}
+		}
 
 
-			// Create a copy of the property map
-			cleanedPropMap := make(map[string]interface{})
-			for k, v := range propMap {
-				cleanedPropMap[k] = v
+		// Clean properties
+		if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
+			cleanedProps := make(map[string]interface{})
+			for propName, propValue := range props {
+				cleanedProps[propName] = cleanFunctionParameters(propValue)
 			}
 			}
+			cleanedMap["properties"] = cleanedProps
+		}
 
 
-			// Remove unsupported fields
-			delete(cleanedPropMap, "default")
-			delete(cleanedPropMap, "exclusiveMaximum")
-			delete(cleanedPropMap, "exclusiveMinimum")
-			delete(cleanedPropMap, "$schema")
-			delete(cleanedPropMap, "additionalProperties")
+		// Recursively clean items in arrays
+		if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
+			cleanedMap["items"] = cleanFunctionParameters(items)
+		}
+		// Also handle items if it's an array of schemas
+		if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
+			cleanedItemsArray := make([]interface{}, len(itemsArray))
+			for i, item := range itemsArray {
+				cleanedItemsArray[i] = cleanFunctionParameters(item)
+			}
+			cleanedMap["items"] = cleanedItemsArray
+		}
 
 
-			// Check and clean 'format' for string types
-			if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
-				if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
-					if formatValue != "enum" && formatValue != "date-time" {
-						delete(cleanedPropMap, "format")
-					}
+		// Recursively clean other schema composition keywords
+		for _, field := range []string{"allOf", "anyOf", "oneOf"} {
+			if nested, ok := cleanedMap[field].([]interface{}); ok {
+				cleanedNested := make([]interface{}, len(nested))
+				for i, item := range nested {
+					cleanedNested[i] = cleanFunctionParameters(item)
 				}
 				}
+				cleanedMap[field] = cleanedNested
 			}
 			}
+		}
 
 
-			// Recursively clean nested properties within this property if it's an object/array
-			// Check the type before recursing
-			if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
-				cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
-			} else {
-				cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
+		// Recursively clean patternProperties
+		if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
+			cleanedPatternProps := make(map[string]interface{})
+			for pattern, schema := range patternProps {
+				cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
 			}
 			}
+			cleanedMap["patternProperties"] = cleanedPatternProps
+		}
 
 
+		// Recursively clean definitions
+		if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
+			cleanedDefinitions := make(map[string]interface{})
+			for defName, defSchema := range definitions {
+				cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
+			}
+			cleanedMap["definitions"] = cleanedDefinitions
 		}
 		}
-		cleanedMap["properties"] = cleanedProps
-	}
 
 
-	// Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
-	if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
-		cleanedMap["items"] = cleanFunctionParameters(items)
-	}
-	// Also handle items if it's an array of schemas
-	if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
-		cleanedItemsArray := make([]interface{}, len(itemsArray))
-		for i, item := range itemsArray {
-			cleanedItemsArray[i] = cleanFunctionParameters(item)
+		// Recursively clean $defs (newer JSON Schema draft)
+		if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
+			cleanedDefs := make(map[string]interface{})
+			for defName, defSchema := range defs {
+				cleanedDefs[defName] = cleanFunctionParameters(defSchema)
+			}
+			cleanedMap["$defs"] = cleanedDefs
 		}
 		}
-		cleanedMap["items"] = cleanedItemsArray
-	}
 
 
-	// Recursively clean other schema composition keywords if necessary
-	for _, field := range []string{"allOf", "anyOf", "oneOf"} {
-		if nested, ok := cleanedMap[field].([]interface{}); ok {
-			cleanedNested := make([]interface{}, len(nested))
-			for i, item := range nested {
-				cleanedNested[i] = cleanFunctionParameters(item)
+		// Clean conditional keywords
+		for _, field := range []string{"if", "then", "else", "not"} {
+			if nested, ok := cleanedMap[field]; ok {
+				cleanedMap[field] = cleanFunctionParameters(nested)
 			}
 			}
-			cleanedMap[field] = cleanedNested
 		}
 		}
-	}
 
 
-	return cleanedMap
+		return cleanedMap
+
+	case []interface{}:
+		// Handle arrays of schemas
+		cleanedArray := make([]interface{}, len(v))
+		for i, item := range v {
+			cleanedArray[i] = cleanFunctionParameters(item)
+		}
+		return cleanedArray
+
+	default:
+		// Not a map or array, return as is (e.g., could be a primitive)
+		return params
+	}
 }
 }
 
 
 func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
 func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
@@ -512,21 +647,20 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
 	}
 	}
 }
 }
 
 
-func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
+func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
 	fullTextResponse := dto.OpenAITextResponse{
 	fullTextResponse := dto.OpenAITextResponse{
-		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+		Id:      helper.GetResponseID(c),
 		Object:  "chat.completion",
 		Object:  "chat.completion",
 		Created: common.GetTimestamp(),
 		Created: common.GetTimestamp(),
 		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
 		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
 	}
 	}
-	content, _ := json.Marshal("")
 	isToolCall := false
 	isToolCall := false
 	for _, candidate := range response.Candidates {
 	for _, candidate := range response.Candidates {
 		choice := dto.OpenAITextResponseChoice{
 		choice := dto.OpenAITextResponseChoice{
 			Index: int(candidate.Index),
 			Index: int(candidate.Index),
 			Message: dto.Message{
 			Message: dto.Message{
 				Role:    "assistant",
 				Role:    "assistant",
-				Content: content,
+				Content: "",
 			},
 			},
 			FinishReason: constant.FinishReasonStop,
 			FinishReason: constant.FinishReasonStop,
 		}
 		}
@@ -539,6 +673,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 					if call := getResponseToolCall(&part); call != nil {
 					if call := getResponseToolCall(&part); call != nil {
 						toolCalls = append(toolCalls, *call)
 						toolCalls = append(toolCalls, *call)
 					}
 					}
+				} else if part.Thought {
+					choice.Message.ReasoningContent = part.Text
 				} else {
 				} else {
 					if part.ExecutableCode != nil {
 					if part.ExecutableCode != nil {
 						texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
 						texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
@@ -556,7 +692,6 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 				choice.Message.SetToolCalls(toolCalls)
 				choice.Message.SetToolCalls(toolCalls)
 				isToolCall = true
 				isToolCall = true
 			}
 			}
-
 			choice.Message.SetStringContent(strings.Join(texts, "\n"))
 			choice.Message.SetStringContent(strings.Join(texts, "\n"))
 
 
 		}
 		}
@@ -596,6 +731,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 		}
 		}
 		var texts []string
 		var texts []string
 		isTools := false
 		isTools := false
+		isThought := false
 		if candidate.FinishReason != nil {
 		if candidate.FinishReason != nil {
 			// p := GeminiConvertFinishReason(*candidate.FinishReason)
 			// p := GeminiConvertFinishReason(*candidate.FinishReason)
 			switch *candidate.FinishReason {
 			switch *candidate.FinishReason {
@@ -620,6 +756,9 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 					call.SetIndex(len(choice.Delta.ToolCalls))
 					call.SetIndex(len(choice.Delta.ToolCalls))
 					choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
 					choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
 				}
 				}
+			} else if part.Thought {
+				isThought = true
+				texts = append(texts, part.Text)
 			} else {
 			} else {
 				if part.ExecutableCode != nil {
 				if part.ExecutableCode != nil {
 					texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
 					texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
@@ -632,7 +771,11 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 				}
 				}
 			}
 			}
 		}
 		}
-		choice.Delta.SetContentString(strings.Join(texts, "\n"))
+		if isThought {
+			choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
+		} else {
+			choice.Delta.SetContentString(strings.Join(texts, "\n"))
+		}
 		if isTools {
 		if isTools {
 			choice.FinishReason = &constant.FinishReasonToolCalls
 			choice.FinishReason = &constant.FinishReasonToolCalls
 		}
 		}
@@ -647,7 +790,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 
 
 func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	// responseText := ""
 	// responseText := ""
-	id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+	id := helper.GetResponseID(c)
 	createAt := common.GetTimestamp()
 	createAt := common.GetTimestamp()
 	var usage = &dto.Usage{}
 	var usage = &dto.Usage{}
 	var imageCount int
 	var imageCount int
@@ -672,6 +815,13 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
 			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
 			usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
 			usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
 			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
 			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+			for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+				if detail.Modality == "AUDIO" {
+					usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+				} else if detail.Modality == "TEXT" {
+					usage.PromptTokensDetails.TextTokens = detail.TokenCount
+				}
+			}
 		}
 		}
 		err = helper.ObjectData(c, response)
 		err = helper.ObjectData(c, response)
 		if err != nil {
 		if err != nil {
@@ -716,8 +866,11 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
+	if common.DebugEnabled {
+		println(string(responseBody))
+	}
 	var geminiResponse GeminiChatResponse
 	var geminiResponse GeminiChatResponse
-	err = json.Unmarshal(responseBody, &geminiResponse)
+	err = common.DecodeJson(responseBody, &geminiResponse)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
@@ -732,7 +885,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 			StatusCode: resp.StatusCode,
 			StatusCode: resp.StatusCode,
 		}, nil
 		}, nil
 	}
 	}
-	fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
+	fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
 	fullTextResponse.Model = info.UpstreamModelName
 	fullTextResponse.Model = info.UpstreamModelName
 	usage := dto.Usage{
 	usage := dto.Usage{
 		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
 		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
@@ -743,6 +896,14 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
 	usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
 	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
 	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
 
 
+	for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+		if detail.Modality == "AUDIO" {
+			usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+		} else if detail.Modality == "TEXT" {
+			usage.PromptTokensDetails.TextTokens = detail.TokenCount
+		}
+	}
+
 	fullTextResponse.Usage = usage
 	fullTextResponse.Usage = usage
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	jsonResponse, err := json.Marshal(fullTextResponse)
 	if err != nil {
 	if err != nil {

+ 42 - 0
relay/channel/mistral/text.go

@@ -1,13 +1,55 @@
 package mistral
 package mistral
 
 
 import (
 import (
+	"one-api/common"
 	"one-api/dto"
 	"one-api/dto"
+	"regexp"
 )
 )
 
 
+var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$")
+
 func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
 func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
 	messages := make([]dto.Message, 0, len(request.Messages))
 	messages := make([]dto.Message, 0, len(request.Messages))
+	idMap := make(map[string]string)
 	for _, message := range request.Messages {
 	for _, message := range request.Messages {
+		// 1. tool_calls.id
+		toolCalls := message.ParseToolCalls()
+		if toolCalls != nil {
+			for i := range toolCalls {
+				if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) {
+					if newId, ok := idMap[toolCalls[i].ID]; ok {
+						toolCalls[i].ID = newId
+					} else {
+						newId, err := common.GenerateRandomCharsKey(9)
+						if err == nil {
+							idMap[toolCalls[i].ID] = newId
+							toolCalls[i].ID = newId
+						}
+					}
+				}
+			}
+			message.SetToolCalls(toolCalls)
+		}
+
+		// 2. tool_call_id
+		if message.ToolCallId != "" {
+			if newId, ok := idMap[message.ToolCallId]; ok {
+				message.ToolCallId = newId
+			} else {
+				if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) {
+					newId, err := common.GenerateRandomCharsKey(9)
+					if err == nil {
+						idMap[message.ToolCallId] = newId
+						message.ToolCallId = newId
+					}
+				}
+			}
+		}
+
 		mediaMessages := message.ParseContent()
 		mediaMessages := message.ParseContent()
+		if message.Role == "assistant" && message.ToolCalls != nil && message.Content == "" {
+			mediaMessages = []dto.MediaContent{}
+		}
 		for j, mediaMessage := range mediaMessages {
 		for j, mediaMessage := range mediaMessages {
 			if mediaMessage.Type == dto.ContentTypeImageURL {
 			if mediaMessage.Type == dto.ContentTypeImageURL {
 				imageUrl := mediaMessage.GetImageMedia()
 				imageUrl := mediaMessage.GetImageMedia()

+ 7 - 0
relay/channel/openai/adaptor.go

@@ -88,6 +88,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		requestURL := strings.Split(info.RequestURLPath, "?")[0]
 		requestURL := strings.Split(info.RequestURLPath, "?")[0]
 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
 		task := strings.TrimPrefix(requestURL, "/v1/")
 		task := strings.TrimPrefix(requestURL, "/v1/")
+
+		// 特殊处理 responses API
+		if info.RelayMode == constant.RelayModeResponses {
+			requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
+			return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
+		}
+
 		model_ := info.UpstreamModelName
 		model_ := info.UpstreamModelName
 		// 2025年5月10日后创建的渠道不移除.
 		// 2025年5月10日后创建的渠道不移除.
 		if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
 		if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {

+ 21 - 30
relay/channel/openai/relay-openai.go

@@ -15,6 +15,7 @@ import (
 	"one-api/relay/helper"
 	"one-api/relay/helper"
 	"one-api/service"
 	"one-api/service"
 	"os"
 	"os"
+	"path/filepath"
 	"strings"
 	"strings"
 
 
 	"github.com/bytedance/gopkg/util/gopool"
 	"github.com/bytedance/gopkg/util/gopool"
@@ -180,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	}
 	}
 
 
 	if !containStreamUsage {
 	if !containStreamUsage {
-		usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
 		usage.CompletionTokens += toolCount * 7
 		usage.CompletionTokens += toolCount * 7
 	} else {
 	} else {
 		if info.ChannelType == common.ChannelTypeDeepSeek {
 		if info.ChannelType == common.ChannelTypeDeepSeek {
@@ -215,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 			StatusCode: resp.StatusCode,
 			StatusCode: resp.StatusCode,
 		}, nil
 		}, nil
 	}
 	}
-	
+
 	forceFormat := false
 	forceFormat := false
 	if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
 	if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
 		forceFormat = forceFmt
 		forceFormat = forceFmt
@@ -224,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
 		completionTokens := 0
 		completionTokens := 0
 		for _, choice := range simpleResponse.Choices {
 		for _, choice := range simpleResponse.Choices {
-			ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
+			ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
 			completionTokens += ctkm
 			completionTokens += ctkm
 		}
 		}
 		simpleResponse.Usage = dto.Usage{
 		simpleResponse.Usage = dto.Usage{
@@ -273,36 +274,25 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 }
 }
 
 
 func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	// Reset response body
-	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
-	// We shouldn't set the header before we parse the response body, because the parse part may fail.
-	// And then we will have to send an error response, but in this case, the header has already been set.
-	// So the httpClient will be confused by the response.
-	// For example, Postman will report error, and we cannot check the response at all.
+	// the status code has been judged before, if there is a body reading failure,
+	// it should be regarded as a non-recoverable error, so it should not return err for external retry.
+	// Analogous to nginx's load balancing, it will only retry if it can't be requested or
+	// if the upstream returns a specific status code, once the upstream has already written the header,
+	// the subsequent failure of the response body should be regarded as a non-recoverable error,
+	// and can be terminated directly.
+	defer resp.Body.Close()
+	usage := &dto.Usage{}
+	usage.PromptTokens = info.PromptTokens
+	usage.TotalTokens = info.PromptTokens
 	for k, v := range resp.Header {
 	for k, v := range resp.Header {
 		c.Writer.Header().Set(k, v[0])
 		c.Writer.Header().Set(k, v[0])
 	}
 	}
 	c.Writer.WriteHeader(resp.StatusCode)
 	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
+	c.Writer.WriteHeaderNow()
+	_, err := io.Copy(c.Writer, resp.Body)
 	if err != nil {
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		common.LogError(c, err.Error())
 	}
 	}
-
-	usage := &dto.Usage{}
-	usage.PromptTokens = info.PromptTokens
-	usage.TotalTokens = info.PromptTokens
 	return nil, usage
 	return nil, usage
 }
 }
 
 
@@ -356,13 +346,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
 	if err = c.ShouldBind(&reqBody); err != nil {
 	if err = c.ShouldBind(&reqBody); err != nil {
 		return 0, errors.WithStack(err)
 		return 0, errors.WithStack(err)
 	}
 	}
-
+	ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
 	reqFp, err := reqBody.File.Open()
 	reqFp, err := reqBody.File.Open()
 	if err != nil {
 	if err != nil {
 		return 0, errors.WithStack(err)
 		return 0, errors.WithStack(err)
 	}
 	}
+	defer reqFp.Close()
 
 
-	tmpFp, err := os.CreateTemp("", "audio-*")
+	tmpFp, err := os.CreateTemp("", "audio-*"+ext)
 	if err != nil {
 	if err != nil {
 		return 0, errors.WithStack(err)
 		return 0, errors.WithStack(err)
 	}
 	}
@@ -376,7 +367,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
 		return 0, errors.WithStack(err)
 		return 0, errors.WithStack(err)
 	}
 	}
 
 
-	duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
+	duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
 	if err != nil {
 	if err != nil {
 		return 0, errors.WithStack(err)
 		return 0, errors.WithStack(err)
 	}
 	}

+ 1 - 1
relay/channel/openai/relay_responses.go

@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
 		tempStr := responseTextBuilder.String()
 		tempStr := responseTextBuilder.String()
 		if len(tempStr) > 0 {
 		if len(tempStr) > 0 {
 			// 非正常结束,使用输出文本的 token 数量
 			// 非正常结束,使用输出文本的 token 数量
-			completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
+			completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
 			usage.CompletionTokens = completionTokens
 			usage.CompletionTokens = completionTokens
 		}
 		}
 	}
 	}

+ 9 - 0
relay/channel/openrouter/dto.go

@@ -0,0 +1,9 @@
+package openrouter
+
+type RequestReasoning struct {
+	// One of the following (not both):
+	Effort    string `json:"effort,omitempty"`     // Can be "high", "medium", or "low" (OpenAI-style)
+	MaxTokens int    `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style)
+	// Optional: Default is false. All models support this.
+	Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
+}

+ 1 - 1
relay/channel/palm/adaptor.go

@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 	if info.IsStream {
 		var responseText string
 		var responseText string
 		err, responseText = palmStreamHandler(c, resp)
 		err, responseText = palmStreamHandler(c, resp)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 	} else {
 		err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 		err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	}

+ 3 - 5
relay/channel/palm/relay-palm.go

@@ -2,7 +2,6 @@ package palm
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
-	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
@@ -45,12 +44,11 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
 		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
 		Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
 	}
 	}
 	for i, candidate := range response.Candidates {
 	for i, candidate := range response.Candidates {
-		content, _ := json.Marshal(candidate.Content)
 		choice := dto.OpenAITextResponseChoice{
 		choice := dto.OpenAITextResponseChoice{
 			Index: i,
 			Index: i,
 			Message: dto.Message{
 			Message: dto.Message{
 				Role:    "assistant",
 				Role:    "assistant",
-				Content: content,
+				Content: candidate.Content,
 			},
 			},
 			FinishReason: "stop",
 			FinishReason: "stop",
 		}
 		}
@@ -74,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
 
 
 func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
 func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
 	responseText := ""
 	responseText := ""
-	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+	responseId := helper.GetResponseID(c)
 	createdTime := common.GetTimestamp()
 	createdTime := common.GetTimestamp()
 	dataChan := make(chan string)
 	dataChan := make(chan string)
 	stopChan := make(chan bool)
 	stopChan := make(chan bool)
@@ -157,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 		}, nil
 		}, nil
 	}
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
+	completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
 	usage := dto.Usage{
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,
 		CompletionTokens: completionTokens,

+ 312 - 0
relay/channel/task/kling/adaptor.go

@@ -0,0 +1,312 @@
+package kling
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"strings"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/golang-jwt/jwt"
+	"github.com/pkg/errors"
+
+	"one-api/common"
+	"one-api/dto"
+	"one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type SubmitReq struct {
+	Prompt   string                 `json:"prompt"`
+	Model    string                 `json:"model,omitempty"`
+	Mode     string                 `json:"mode,omitempty"`
+	Image    string                 `json:"image,omitempty"`
+	Size     string                 `json:"size,omitempty"`
+	Duration int                    `json:"duration,omitempty"`
+	Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type requestPayload struct {
+	Prompt      string  `json:"prompt,omitempty"`
+	Image       string  `json:"image,omitempty"`
+	Mode        string  `json:"mode,omitempty"`
+	Duration    string  `json:"duration,omitempty"`
+	AspectRatio string  `json:"aspect_ratio,omitempty"`
+	Model       string  `json:"model,omitempty"`
+	ModelName   string  `json:"model_name,omitempty"`
+	CfgScale    float64 `json:"cfg_scale,omitempty"`
+}
+
+type responsePayload struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+	Data    struct {
+		TaskID string `json:"task_id"`
+	} `json:"data"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+	ChannelType int
+	accessKey   string
+	secretKey   string
+	baseURL     string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+	a.ChannelType = info.ChannelType
+	a.baseURL = info.BaseUrl
+
+	// apiKey format: "access_key,secret_key"
+	keyParts := strings.Split(info.ApiKey, ",")
+	if len(keyParts) == 2 {
+		a.accessKey = strings.TrimSpace(keyParts[0])
+		a.secretKey = strings.TrimSpace(keyParts[1])
+	}
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+	// Accept only POST /v1/video/generations as "generate" action.
+	action := "generate"
+	info.Action = action
+
+	var req SubmitReq
+	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+		return
+	}
+	if strings.TrimSpace(req.Prompt) == "" {
+		taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
+		return
+	}
+
+	// Store into context for later usage
+	c.Set("kling_request", req)
+	return nil
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+	return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+	token, err := a.createJWTToken()
+	if err != nil {
+		return fmt.Errorf("failed to create JWT token: %w", err)
+	}
+
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Bearer "+token)
+	req.Header.Set("User-Agent", "kling-sdk/1.0")
+	return nil
+}
+
+// BuildRequestBody converts request into Kling specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+	v, exists := c.Get("kling_request")
+	if !exists {
+		return nil, fmt.Errorf("request not found in context")
+	}
+	req := v.(SubmitReq)
+
+	body := a.convertToRequestPayload(&req)
+	data, err := json.Marshal(body)
+	if err != nil {
+		return nil, err
+	}
+	return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+
+	// Attempt Kling response parse first.
+	var kResp responsePayload
+	if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
+		c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
+		return kResp.Data.TaskID, responseBody, nil
+	}
+
+	// Fallback generic task response.
+	var generic dto.TaskResponse[string]
+	if err := json.Unmarshal(responseBody, &generic); err != nil {
+		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+
+	if !generic.IsSuccess() {
+		taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
+		return
+	}
+
+	c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
+	return generic.Data, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+	taskID, ok := body["task_id"].(string)
+	if !ok {
+		return nil, fmt.Errorf("invalid task_id")
+	}
+	url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
+
+	req, err := http.NewRequest(http.MethodGet, url, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	token, err := a.createJWTTokenWithKey(key)
+	if err != nil {
+		token = key
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+	defer cancel()
+
+	req = req.WithContext(ctx)
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Authorization", "Bearer "+token)
+	req.Header.Set("User-Agent", "kling-sdk/1.0")
+
+	return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+	return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+	return "kling"
+}
+
+// ============================
+// helpers
+// ============================
+
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
+	r := &requestPayload{
+		Prompt:      req.Prompt,
+		Image:       req.Image,
+		Mode:        defaultString(req.Mode, "std"),
+		Duration:    fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+		AspectRatio: a.getAspectRatio(req.Size),
+		Model:       req.Model,
+		ModelName:   req.Model,
+		CfgScale:    0.5,
+	}
+	if r.Model == "" {
+		r.Model = "kling-v1"
+		r.ModelName = "kling-v1"
+	}
+	return r
+}
+
+func (a *TaskAdaptor) getAspectRatio(size string) string {
+	switch size {
+	case "1024x1024", "512x512":
+		return "1:1"
+	case "1280x720", "1920x1080":
+		return "16:9"
+	case "720x1280", "1080x1920":
+		return "9:16"
+	default:
+		return "1:1"
+	}
+}
+
+func defaultString(s, def string) string {
+	if strings.TrimSpace(s) == "" {
+		return def
+	}
+	return s
+}
+
+func defaultInt(v int, def int) int {
+	if v == 0 {
+		return def
+	}
+	return v
+}
+
+// ============================
+// JWT helpers
+// ============================
+
+func (a *TaskAdaptor) createJWTToken() (string, error) {
+	return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
+}
+
+func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
+	parts := strings.Split(apiKey, ",")
+	if len(parts) != 2 {
+		return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
+	}
+	return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
+}
+
+func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
+	if accessKey == "" || secretKey == "" {
+		return "", fmt.Errorf("access key and secret key are required")
+	}
+	now := time.Now().Unix()
+	claims := jwt.MapClaims{
+		"iss": accessKey,
+		"exp": now + 1800, // 30 minutes
+		"nbf": now - 5,
+	}
+	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+	token.Header["typ"] = "JWT"
+	return token.SignedString([]byte(secretKey))
+}
+
+// ParseResultUrl 提取视频任务结果的 url
+func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
+	data, ok := resp["data"].(map[string]any)
+	if !ok {
+		return "", fmt.Errorf("data field not found or invalid")
+	}
+	taskResult, ok := data["task_result"].(map[string]any)
+	if !ok {
+		return "", fmt.Errorf("task_result field not found or invalid")
+	}
+	videos, ok := taskResult["videos"].([]interface{})
+	if !ok || len(videos) == 0 {
+		return "", fmt.Errorf("videos field not found or empty")
+	}
+	video, ok := videos[0].(map[string]interface{})
+	if !ok {
+		return "", fmt.Errorf("video item invalid")
+	}
+	url, ok := video["url"].(string)
+	if !ok || url == "" {
+		return "", fmt.Errorf("url field not found or invalid")
+	}
+	return url, nil
+}

+ 4 - 0
relay/channel/task/suno/adaptor.go

@@ -22,6 +22,10 @@ type TaskAdaptor struct {
 	ChannelType int
 	ChannelType int
 }
 }
 
 
+func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
+	return "", nil // todo implement this method if needed
+}
+
 func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
 func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
 	a.ChannelType = info.ChannelType
 	a.ChannelType = info.ChannelType
 }
 }

+ 1 - 1
relay/channel/tencent/adaptor.go

@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 	if info.IsStream {
 	if info.IsStream {
 		var responseText string
 		var responseText string
 		err, responseText = tencentStreamHandler(c, resp)
 		err, responseText = tencentStreamHandler(c, resp)
-		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 	} else {
 		err, usage = tencentHandler(c, resp)
 		err, usage = tencentHandler(c, resp)
 	}
 	}

+ 1 - 2
relay/channel/tencent/relay-tencent.go

@@ -56,12 +56,11 @@ func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextRespon
 		},
 		},
 	}
 	}
 	if len(response.Choices) > 0 {
 	if len(response.Choices) > 0 {
-		content, _ := json.Marshal(response.Choices[0].Messages.Content)
 		choice := dto.OpenAITextResponseChoice{
 		choice := dto.OpenAITextResponseChoice{
 			Index: 0,
 			Index: 0,
 			Message: dto.Message{
 			Message: dto.Message{
 				Role:    "assistant",
 				Role:    "assistant",
-				Content: content,
+				Content: response.Choices[0].Messages.Content,
 			},
 			},
 			FinishReason: response.Choices[0].FinishReason,
 			FinishReason: response.Choices[0].FinishReason,
 		}
 		}

+ 53 - 21
relay/channel/vertex/adaptor.go

@@ -12,6 +12,7 @@ import (
 	"one-api/relay/channel/gemini"
 	"one-api/relay/channel/gemini"
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
 	"one-api/setting/model_setting"
 	"one-api/setting/model_setting"
 	"strings"
 	"strings"
 
 
@@ -31,6 +32,8 @@ var claudeModelMap = map[string]string{
 	"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
 	"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
 	"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
 	"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
 	"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
 	"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
+	"claude-sonnet-4-20250514":   "claude-sonnet-4@20250514",
+	"claude-opus-4-20250514":     "claude-opus-4@20250514",
 }
 }
 
 
 const anthropicVersion = "vertex-2023-10-16"
 const anthropicVersion = "vertex-2023-10-16"
@@ -80,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	suffix := ""
 	suffix := ""
 	if a.RequestMode == RequestModeGemini {
 	if a.RequestMode == RequestModeGemini {
 		if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
 		if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
-			// suffix -thinking and -nothinking
-			if strings.HasSuffix(info.OriginModelName, "-thinking") {
+			// 新增逻辑:处理 -thinking-<budget> 格式
+			if strings.Contains(info.UpstreamModelName, "-thinking-") {
+				parts := strings.Split(info.UpstreamModelName, "-thinking-")
+				info.UpstreamModelName = parts[0]
+			} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
 				info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
 				info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
-			} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
+			} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
 				info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
 				info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
 			}
 			}
 		}
 		}
@@ -93,14 +99,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		} else {
 		} else {
 			suffix = "generateContent"
 			suffix = "generateContent"
 		}
 		}
-		return fmt.Sprintf(
-			"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
-			region,
-			adc.ProjectID,
-			region,
-			info.UpstreamModelName,
-			suffix,
-		), nil
+		if region == "global" {
+			return fmt.Sprintf(
+				"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
+				adc.ProjectID,
+				info.UpstreamModelName,
+				suffix,
+			), nil
+		} else {
+			return fmt.Sprintf(
+				"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
+				region,
+				adc.ProjectID,
+				region,
+				info.UpstreamModelName,
+				suffix,
+			), nil
+		}
 	} else if a.RequestMode == RequestModeClaude {
 	} else if a.RequestMode == RequestModeClaude {
 		if info.IsStream {
 		if info.IsStream {
 			suffix = "streamRawPredict?alt=sse"
 			suffix = "streamRawPredict?alt=sse"
@@ -111,14 +126,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
 		if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
 			model = v
 			model = v
 		}
 		}
-		return fmt.Sprintf(
-			"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
-			region,
-			adc.ProjectID,
-			region,
-			model,
-			suffix,
-		), nil
+		if region == "global" {
+			return fmt.Sprintf(
+				"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
+				adc.ProjectID,
+				model,
+				suffix,
+			), nil
+		} else {
+			return fmt.Sprintf(
+				"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
+				region,
+				adc.ProjectID,
+				region,
+				model,
+				suffix,
+			), nil
+		}
 	} else if a.RequestMode == RequestModeLlama {
 	} else if a.RequestMode == RequestModeLlama {
 		return fmt.Sprintf(
 		return fmt.Sprintf(
 			"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
 			"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
@@ -190,7 +214,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		case RequestModeClaude:
 		case RequestModeClaude:
 			err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
 			err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
 		case RequestModeGemini:
 		case RequestModeGemini:
-			err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
+			if info.RelayMode == constant.RelayModeGemini {
+				usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
+			} else {
+				err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
+			}
 		case RequestModeLlama:
 		case RequestModeLlama:
 			err, usage = openai.OaiStreamHandler(c, resp, info)
 			err, usage = openai.OaiStreamHandler(c, resp, info)
 		}
 		}
@@ -199,7 +227,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		case RequestModeClaude:
 		case RequestModeClaude:
 			err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
 			err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
 		case RequestModeGemini:
 		case RequestModeGemini:
-			err, usage = gemini.GeminiChatHandler(c, resp, info)
+			if info.RelayMode == constant.RelayModeGemini {
+				usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
+			} else {
+				err, usage = gemini.GeminiChatHandler(c, resp, info)
+			}
 		case RequestModeLlama:
 		case RequestModeLlama:
 			err, usage = openai.OpenaiHandler(c, resp, info)
 			err, usage = openai.OpenaiHandler(c, resp, info)
 		}
 		}

+ 148 - 2
relay/channel/volcengine/adaptor.go

@@ -1,15 +1,19 @@
 package volcengine
 package volcengine
 
 
 import (
 import (
+	"bytes"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"mime/multipart"
 	"net/http"
 	"net/http"
+	"net/textproto"
 	"one-api/dto"
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel"
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/constant"
 	"one-api/relay/constant"
+	"path/filepath"
 	"strings"
 	"strings"
 
 
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
@@ -30,8 +34,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 }
 
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	switch info.RelayMode {
+	case constant.RelayModeImagesEdits:
+
+		var requestBody bytes.Buffer
+		writer := multipart.NewWriter(&requestBody)
+
+		writer.WriteField("model", request.Model)
+		// 获取所有表单字段
+		formData := c.Request.PostForm
+		// 遍历表单字段并打印输出
+		for key, values := range formData {
+			if key == "model" {
+				continue
+			}
+			for _, value := range values {
+				writer.WriteField(key, value)
+			}
+		}
+
+		// Parse the multipart form to handle both single image and multiple images
+		if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
+			return nil, errors.New("failed to parse multipart form")
+		}
+
+		if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+			// Check if "image" field exists in any form, including array notation
+			var imageFiles []*multipart.FileHeader
+			var exists bool
+
+			// First check for standard "image" field
+			if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+				// If not found, check for "image[]" field
+				if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+					// If still not found, iterate through all fields to find any that start with "image["
+					foundArrayImages := false
+					for fieldName, files := range c.Request.MultipartForm.File {
+						if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
+							foundArrayImages = true
+							for _, file := range files {
+								imageFiles = append(imageFiles, file)
+							}
+						}
+					}
+
+					// If no image fields found at all
+					if !foundArrayImages && (len(imageFiles) == 0) {
+						return nil, errors.New("image is required")
+					}
+				}
+			}
+
+			// Process all image files
+			for i, fileHeader := range imageFiles {
+				file, err := fileHeader.Open()
+				if err != nil {
+					return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
+				}
+				defer file.Close()
+
+				// If multiple images, use image[] as the field name
+				fieldName := "image"
+				if len(imageFiles) > 1 {
+					fieldName = "image[]"
+				}
+
+				// Determine MIME type based on file extension
+				mimeType := detectImageMimeType(fileHeader.Filename)
+
+				// Create a form file with the appropriate content type
+				h := make(textproto.MIMEHeader)
+				h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
+				h.Set("Content-Type", mimeType)
+
+				part, err := writer.CreatePart(h)
+				if err != nil {
+					return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
+				}
+
+				if _, err := io.Copy(part, file); err != nil {
+					return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
+				}
+			}
+
+			// Handle mask file if present
+			if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+				maskFile, err := maskFiles[0].Open()
+				if err != nil {
+					return nil, errors.New("failed to open mask file")
+				}
+				defer maskFile.Close()
+
+				// Determine MIME type for mask file
+				mimeType := detectImageMimeType(maskFiles[0].Filename)
+
+				// Create a form file with the appropriate content type
+				h := make(textproto.MIMEHeader)
+				h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
+				h.Set("Content-Type", mimeType)
+
+				maskPart, err := writer.CreatePart(h)
+				if err != nil {
+					return nil, errors.New("create form file failed for mask")
+				}
+
+				if _, err := io.Copy(maskPart, maskFile); err != nil {
+					return nil, errors.New("copy mask file failed")
+				}
+			}
+		} else {
+			return nil, errors.New("no multipart form data found")
+		}
+
+		// 关闭 multipart 编写器以设置分界线
+		writer.Close()
+		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+		return bytes.NewReader(requestBody.Bytes()), nil
+
+	default:
+		return request, nil
+	}
+}
+
+// detectImageMimeType determines the MIME type based on the file extension
+func detectImageMimeType(filename string) string {
+	ext := strings.ToLower(filepath.Ext(filename))
+	switch ext {
+	case ".jpg", ".jpeg":
+		return "image/jpeg"
+	case ".png":
+		return "image/png"
+	case ".webp":
+		return "image/webp"
+	default:
+		// Try to detect from extension if possible
+		if strings.HasPrefix(ext, ".jp") {
+			return "image/jpeg"
+		}
+		// Default to png as a fallback
+		return "image/png"
+	}
 }
 }
 
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -46,6 +188,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
 		return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
 	case constant.RelayModeEmbeddings:
 	case constant.RelayModeEmbeddings:
 		return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
 		return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
+	case constant.RelayModeImagesGenerations:
+		return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
 	default:
 	default:
 	}
 	}
 	return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
 	return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
@@ -91,6 +235,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		}
 		}
 	case constant.RelayModeEmbeddings:
 	case constant.RelayModeEmbeddings:
 		err, usage = openai.OpenaiHandler(c, resp, info)
 		err, usage = openai.OpenaiHandler(c, resp, info)
+	case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
+		err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
 	}
 	}
 	return
 	return
 }
 }

Некоторые файлы не были показаны из-за большого количества измененных файлов