Quellcode durchsuchen

Merge branch 'upstream-main' into feature/improve-param-override

# Conflicts:
#	relay/channel/api_request_test.go
#	relay/common/override_test.go
#	web/src/components/table/channels/modals/EditChannelModal.jsx
Seefs vor 1 Woche
Ursprung
Commit
db0b452ea2
100 geänderte Dateien mit 6436 neuen und 1765 gelöschten Zeilen
  1. 2 2
      README.fr.md
  2. 2 2
      README.ja.md
  3. 2 2
      README.md
  4. 2 2
      README.zh_CN.md
  5. 2 2
      README.zh_TW.md
  6. 16 2
      common/gin.go
  7. 2 0
      common/init.go
  8. 1 0
      constant/env.go
  9. 3 1
      controller/codex_oauth.go
  10. 2 3
      controller/codex_usage.go
  11. 208 21
      controller/custom_oauth.go
  12. 20 11
      controller/midjourney.go
  13. 4 0
      controller/misc.go
  14. 18 6
      controller/oauth.go
  15. 113 38
      controller/relay.go
  16. 33 215
      controller/task.go
  17. 0 313
      controller/task_video.go
  18. 38 0
      controller/user.go
  19. 30 81
      controller/video_proxy.go
  20. 4 4
      controller/video_proxy_gemini.go
  21. 10 8
      dto/channel_settings.go
  22. 9 5
      dto/claude.go
  23. 48 41
      dto/gemini.go
  24. 41 22
      dto/openai_request.go
  25. 0 32
      dto/suno.go
  26. 47 0
      dto/task.go
  27. 1 2
      logger/logger.go
  28. 10 0
      main.go
  29. 18 0
      middleware/auth.go
  30. 16 2
      middleware/logger.go
  31. 100 16
      model/custom_oauth_provider.go
  32. 61 2
      model/log.go
  33. 13 0
      model/midjourney.go
  34. 124 60
      model/task.go
  35. 217 0
      model/task_cas_test.go
  36. 3 3
      model/token.go
  37. 42 1
      model/user.go
  38. 402 2
      oauth/generic.go
  39. 9 0
      oauth/types.go
  40. 28 2
      relay/channel/adapter.go
  41. 3 2
      relay/channel/api_request.go
  42. 27 0
      relay/channel/api_request_test.go
  43. 1 16
      relay/channel/gemini/relay-gemini-native.go
  44. 44 49
      relay/channel/gemini/relay-gemini.go
  45. 333 0
      relay/channel/gemini/relay_gemini_usage_test.go
  46. 11 3
      relay/channel/minimax/adaptor.go
  47. 12 7
      relay/channel/minimax/relay-minimax.go
  48. 44 24
      relay/channel/task/ali/adaptor.go
  49. 15 16
      relay/channel/task/doubao/adaptor.go
  50. 17 33
      relay/channel/task/gemini/adaptor.go
  51. 13 12
      relay/channel/task/hailuo/adaptor.go
  52. 14 20
      relay/channel/task/jimeng/adaptor.go
  53. 17 36
      relay/channel/task/kling/adaptor.go
  54. 132 15
      relay/channel/task/sora/adaptor.go
  55. 24 35
      relay/channel/task/suno/adaptor.go
  56. 95 0
      relay/channel/task/taskcommon/helpers.go
  57. 47 46
      relay/channel/task/vertex/adaptor.go
  58. 15 35
      relay/channel/task/vidu/adaptor.go
  59. 2 2
      relay/chat_completions_via_responses.go
  60. 1 1
      relay/claude_handler.go
  61. 73 0
      relay/common/override_test.go
  62. 62 5
      relay/common/relay_info.go
  63. 40 0
      relay/common/relay_info_test.go
  64. 2 8
      relay/common/relay_utils.go
  65. 3 3
      relay/compatible_handler.go
  66. 13 2
      relay/helper/price.go
  67. 27 16
      relay/helper/stream_scanner.go
  68. 521 0
      relay/helper/stream_scanner_test.go
  69. 2 2
      relay/mjproxy_handler.go
  70. 324 280
      relay/relay_task.go
  71. 1 1
      relay/responses_handler.go
  72. 6 1
      router/api-router.go
  73. 1 0
      router/dashboard.go
  74. 2 0
      router/main.go
  75. 11 2
      router/relay-router.go
  76. 15 5
      router/video-router.go
  77. 1 0
      router/web-router.go
  78. 5 0
      service/billing_session.go
  79. 54 6
      service/channel_affinity.go
  80. 105 0
      service/channel_affinity_usage_cache_test.go
  81. 1 1
      service/codex_credential_refresh.go
  82. 33 4
      service/codex_oauth.go
  83. 13 0
      service/error.go
  84. 1 1
      service/log_info_generate.go
  85. 2 2
      service/midjourney.go
  86. 3 0
      service/quota.go
  87. 285 0
      service/task_billing.go
  88. 712 0
      service/task_billing_test.go
  89. 539 0
      service/task_polling.go
  90. 5 4
      service/violation_fee.go
  91. 13 0
      setting/operation_setting/status_code_ranges.go
  92. 8 0
      setting/operation_setting/status_code_ranges_test.go
  93. 2 7
      types/price_data.go
  94. 15 15
      web/src/components/auth/LoginForm.jsx
  95. 44 14
      web/src/components/auth/RegisterForm.jsx
  96. 214 0
      web/src/components/common/modals/RiskAcknowledgementModal.jsx
  97. 546 124
      web/src/components/settings/CustomOAuthSetting.jsx
  98. 9 6
      web/src/components/settings/personal/cards/AccountManagement.jsx
  99. 7 0
      web/src/components/settings/personal/cards/NotificationSettings.jsx
  100. 148 1
      web/src/components/table/channels/modals/EditChannelModal.jsx

+ 2 - 2
README.fr.md

@@ -30,8 +30,8 @@
 </p>
 
 <p align="center">
-  <a href="https://trendshift.io/repositories/8227" target="_blank">
-    <img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
+  <a href="https://trendshift.io/repositories/20180" target="_blank">
+    <img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
   </a>
   <br>
   <a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">

+ 2 - 2
README.ja.md

@@ -30,8 +30,8 @@
 </p>
 
 <p align="center">
-  <a href="https://trendshift.io/repositories/8227" target="_blank">
-    <img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
+  <a href="https://trendshift.io/repositories/20180" target="_blank">
+    <img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
   </a>
   <br>
   <a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">

+ 2 - 2
README.md

@@ -30,8 +30,8 @@
 </p>
 
 <p align="center">
-  <a href="https://trendshift.io/repositories/8227" target="_blank">
-    <img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
+  <a href="https://trendshift.io/repositories/20180" target="_blank">
+    <img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
   </a>
   <br>
   <a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">

+ 2 - 2
README.zh_CN.md

@@ -30,8 +30,8 @@
 </p>
 
 <p align="center">
-  <a href="https://trendshift.io/repositories/8227" target="_blank">
-    <img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
+  <a href="https://trendshift.io/repositories/20180" target="_blank">
+    <img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
   </a>
   <br>
   <a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">

+ 2 - 2
README.zh_TW.md

@@ -30,8 +30,8 @@
 </p>
 
 <p align="center">
-  <a href="https://trendshift.io/repositories/8227" target="_blank">
-    <img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
+  <a href="https://trendshift.io/repositories/20180" target="_blank">
+    <img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
   </a>
   <br>
   <a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">

+ 16 - 2
common/gin.go

@@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
 		return nil, err
 	}
 
-	contentType := c.Request.Header.Get("Content-Type")
+	// Use the original Content-Type saved on first call to avoid boundary
+	// mismatch when callers overwrite c.Request.Header after multipart rebuild.
+	var contentType string
+	if saved, ok := c.Get("_original_multipart_ct"); ok {
+		contentType = saved.(string)
+	} else {
+		contentType = c.Request.Header.Get("Content-Type")
+		c.Set("_original_multipart_ct", contentType)
+	}
 	boundary, err := parseBoundary(contentType)
 	if err != nil {
 		return nil, err
@@ -295,7 +303,13 @@ func parseFormData(data []byte, v any) error {
 }
 
 func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
-	contentType := c.Request.Header.Get("Content-Type")
+	var contentType string
+	if saved, ok := c.Get("_original_multipart_ct"); ok {
+		contentType = saved.(string)
+	} else {
+		contentType = c.Request.Header.Get("Content-Type")
+		c.Set("_original_multipart_ct", contentType)
+	}
 	boundary, err := parseBoundary(contentType)
 	if err != nil {
 		if errors.Is(err, errBoundaryNotFound) {

+ 2 - 0
common/init.go

@@ -145,6 +145,8 @@ func initConstantEnv() {
 	constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
 	// 任务轮询时查询的最大数量
 	constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
+	// 异步任务超时时间(分钟),超过此时间未完成的任务将被标记为失败并退款。0 表示禁用。
+	constant.TaskTimeoutMinutes = GetEnvOrDefault("TASK_TIMEOUT_MINUTES", 1440)
 
 	soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
 	if soraPatchStr != "" {

+ 1 - 0
constant/env.go

@@ -16,6 +16,7 @@ var NotificationLimitDurationMinute int
 var GenerateDefaultToken bool
 var ErrorLogEnabled bool
 var TaskQueryLimit int
+var TaskTimeoutMinutes int
 
 // temporary variable for sora patch, will be removed in future
 var TaskPricePatches []string

+ 3 - 1
controller/codex_oauth.go

@@ -145,6 +145,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
 		return
 	}
 
+	channelProxy := ""
 	if channelID > 0 {
 		ch, err := model.GetChannelById(channelID, false)
 		if err != nil {
@@ -159,6 +160,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
 			c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
 			return
 		}
+		channelProxy = ch.GetSetting().Proxy
 	}
 
 	session := sessions.Default(c)
@@ -176,7 +178,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
 	ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
 	defer cancel()
 
-	tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
+	tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy)
 	if err != nil {
 		common.SysError("failed to exchange codex authorization code: " + err.Error())
 		c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})

+ 2 - 3
controller/codex_usage.go

@@ -2,7 +2,6 @@ package controller
 
 import (
 	"context"
-	"encoding/json"
 	"fmt"
 	"net/http"
 	"strconv"
@@ -80,7 +79,7 @@ func GetCodexChannelUsage(c *gin.Context) {
 		refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
 		defer refreshCancel()
 
-		res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
+		res, refreshErr := service.RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
 		if refreshErr == nil {
 			oauthKey.AccessToken = res.AccessToken
 			oauthKey.RefreshToken = res.RefreshToken
@@ -109,7 +108,7 @@ func GetCodexChannelUsage(c *gin.Context) {
 	}
 
 	var payload any
-	if json.Unmarshal(body, &payload) != nil {
+	if common.Unmarshal(body, &payload) != nil {
 		payload = string(body)
 	}
 

+ 208 - 21
controller/custom_oauth.go

@@ -1,8 +1,13 @@
 package controller
 
 import (
+	"context"
+	"io"
 	"net/http"
+	"net/url"
 	"strconv"
+	"strings"
+	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/model"
@@ -16,6 +21,7 @@ type CustomOAuthProviderResponse struct {
 	Id                    int    `json:"id"`
 	Name                  string `json:"name"`
 	Slug                  string `json:"slug"`
+	Icon                  string `json:"icon"`
 	Enabled               bool   `json:"enabled"`
 	ClientId              string `json:"client_id"`
 	AuthorizationEndpoint string `json:"authorization_endpoint"`
@@ -28,6 +34,16 @@ type CustomOAuthProviderResponse struct {
 	EmailField            string `json:"email_field"`
 	WellKnown             string `json:"well_known"`
 	AuthStyle             int    `json:"auth_style"`
+	AccessPolicy          string `json:"access_policy"`
+	AccessDeniedMessage   string `json:"access_denied_message"`
+}
+
+type UserOAuthBindingResponse struct {
+	ProviderId     int    `json:"provider_id"`
+	ProviderName   string `json:"provider_name"`
+	ProviderSlug   string `json:"provider_slug"`
+	ProviderIcon   string `json:"provider_icon"`
+	ProviderUserId string `json:"provider_user_id"`
 }
 
 func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
@@ -35,6 +51,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
 		Id:                    p.Id,
 		Name:                  p.Name,
 		Slug:                  p.Slug,
+		Icon:                  p.Icon,
 		Enabled:               p.Enabled,
 		ClientId:              p.ClientId,
 		AuthorizationEndpoint: p.AuthorizationEndpoint,
@@ -47,6 +64,8 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
 		EmailField:            p.EmailField,
 		WellKnown:             p.WellKnown,
 		AuthStyle:             p.AuthStyle,
+		AccessPolicy:          p.AccessPolicy,
+		AccessDeniedMessage:   p.AccessDeniedMessage,
 	}
 }
 
@@ -96,6 +115,7 @@ func GetCustomOAuthProvider(c *gin.Context) {
 type CreateCustomOAuthProviderRequest struct {
 	Name                  string `json:"name" binding:"required"`
 	Slug                  string `json:"slug" binding:"required"`
+	Icon                  string `json:"icon"`
 	Enabled               bool   `json:"enabled"`
 	ClientId              string `json:"client_id" binding:"required"`
 	ClientSecret          string `json:"client_secret" binding:"required"`
@@ -109,6 +129,85 @@ type CreateCustomOAuthProviderRequest struct {
 	EmailField            string `json:"email_field"`
 	WellKnown             string `json:"well_known"`
 	AuthStyle             int    `json:"auth_style"`
+	AccessPolicy          string `json:"access_policy"`
+	AccessDeniedMessage   string `json:"access_denied_message"`
+}
+
+type FetchCustomOAuthDiscoveryRequest struct {
+	WellKnownURL string `json:"well_known_url"`
+	IssuerURL    string `json:"issuer_url"`
+}
+
+// FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route)
+func FetchCustomOAuthDiscovery(c *gin.Context) {
+	var req FetchCustomOAuthDiscoveryRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
+		return
+	}
+
+	wellKnownURL := strings.TrimSpace(req.WellKnownURL)
+	issuerURL := strings.TrimSpace(req.IssuerURL)
+
+	if wellKnownURL == "" && issuerURL == "" {
+		common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL")
+		return
+	}
+
+	targetURL := wellKnownURL
+	if targetURL == "" {
+		targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
+	}
+	targetURL = strings.TrimSpace(targetURL)
+
+	parsedURL, err := url.Parse(targetURL)
+	if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
+		common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https")
+		return
+	}
+
+	ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second)
+	defer cancel()
+
+	httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
+	if err != nil {
+		common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error())
+		return
+	}
+	httpReq.Header.Set("Accept", "application/json")
+
+	client := &http.Client{Timeout: 20 * time.Second}
+	resp, err := client.Do(httpReq)
+	if err != nil {
+		common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error())
+		return
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
+		message := strings.TrimSpace(string(body))
+		if message == "" {
+			message = resp.Status
+		}
+		common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message)
+		return
+	}
+
+	var discovery map[string]any
+	if err = common.DecodeJson(resp.Body, &discovery); err != nil {
+		common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error())
+		return
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data": gin.H{
+			"well_known_url": targetURL,
+			"discovery":      discovery,
+		},
+	})
 }
 
 // CreateCustomOAuthProvider creates a new custom OAuth provider
@@ -134,6 +233,7 @@ func CreateCustomOAuthProvider(c *gin.Context) {
 	provider := &model.CustomOAuthProvider{
 		Name:                  req.Name,
 		Slug:                  req.Slug,
+		Icon:                  req.Icon,
 		Enabled:               req.Enabled,
 		ClientId:              req.ClientId,
 		ClientSecret:          req.ClientSecret,
@@ -147,6 +247,8 @@ func CreateCustomOAuthProvider(c *gin.Context) {
 		EmailField:            req.EmailField,
 		WellKnown:             req.WellKnown,
 		AuthStyle:             req.AuthStyle,
+		AccessPolicy:          req.AccessPolicy,
+		AccessDeniedMessage:   req.AccessDeniedMessage,
 	}
 
 	if err := model.CreateCustomOAuthProvider(provider); err != nil {
@@ -168,9 +270,10 @@ func CreateCustomOAuthProvider(c *gin.Context) {
 type UpdateCustomOAuthProviderRequest struct {
 	Name                  string  `json:"name"`
 	Slug                  string  `json:"slug"`
-	Enabled               *bool   `json:"enabled"`               // Optional: if nil, keep existing
+	Icon                  *string `json:"icon"`    // Optional: if nil, keep existing
+	Enabled               *bool   `json:"enabled"` // Optional: if nil, keep existing
 	ClientId              string  `json:"client_id"`
-	ClientSecret          string  `json:"client_secret"`         // Optional: if empty, keep existing
+	ClientSecret          string  `json:"client_secret"` // Optional: if empty, keep existing
 	AuthorizationEndpoint string  `json:"authorization_endpoint"`
 	TokenEndpoint         string  `json:"token_endpoint"`
 	UserInfoEndpoint      string  `json:"user_info_endpoint"`
@@ -181,6 +284,8 @@ type UpdateCustomOAuthProviderRequest struct {
 	EmailField            string  `json:"email_field"`
 	WellKnown             *string `json:"well_known"`            // Optional: if nil, keep existing
 	AuthStyle             *int    `json:"auth_style"`            // Optional: if nil, keep existing
+	AccessPolicy          *string `json:"access_policy"`         // Optional: if nil, keep existing
+	AccessDeniedMessage   *string `json:"access_denied_message"` // Optional: if nil, keep existing
 }
 
 // UpdateCustomOAuthProvider updates an existing custom OAuth provider
@@ -227,6 +332,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
 	if req.Slug != "" {
 		provider.Slug = req.Slug
 	}
+	if req.Icon != nil {
+		provider.Icon = *req.Icon
+	}
 	if req.Enabled != nil {
 		provider.Enabled = *req.Enabled
 	}
@@ -266,6 +374,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
 	if req.AuthStyle != nil {
 		provider.AuthStyle = *req.AuthStyle
 	}
+	if req.AccessPolicy != nil {
+		provider.AccessPolicy = *req.AccessPolicy
+	}
+	if req.AccessDeniedMessage != nil {
+		provider.AccessDeniedMessage = *req.AccessDeniedMessage
+	}
 
 	if err := model.UpdateCustomOAuthProvider(provider); err != nil {
 		common.ApiError(c, err)
@@ -327,6 +441,30 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
 	})
 }
 
+func buildUserOAuthBindingsResponse(userId int) ([]UserOAuthBindingResponse, error) {
+	bindings, err := model.GetUserOAuthBindingsByUserId(userId)
+	if err != nil {
+		return nil, err
+	}
+
+	response := make([]UserOAuthBindingResponse, 0, len(bindings))
+	for _, binding := range bindings {
+		provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
+		if err != nil {
+			continue
+		}
+		response = append(response, UserOAuthBindingResponse{
+			ProviderId:     binding.ProviderId,
+			ProviderName:   provider.Name,
+			ProviderSlug:   provider.Slug,
+			ProviderIcon:   provider.Icon,
+			ProviderUserId: binding.ProviderUserId,
+		})
+	}
+
+	return response, nil
+}
+
 // GetUserOAuthBindings returns all OAuth bindings for the current user
 func GetUserOAuthBindings(c *gin.Context) {
 	userId := c.GetInt("id")
@@ -335,32 +473,43 @@ func GetUserOAuthBindings(c *gin.Context) {
 		return
 	}
 
-	bindings, err := model.GetUserOAuthBindingsByUserId(userId)
+	response, err := buildUserOAuthBindingsResponse(userId)
 	if err != nil {
 		common.ApiError(c, err)
 		return
 	}
 
-	// Build response with provider info
-	type BindingResponse struct {
-		ProviderId     int    `json:"provider_id"`
-		ProviderName   string `json:"provider_name"`
-		ProviderSlug   string `json:"provider_slug"`
-		ProviderUserId string `json:"provider_user_id"`
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    response,
+	})
+}
+
+func GetUserOAuthBindingsByAdmin(c *gin.Context) {
+	userIdStr := c.Param("id")
+	userId, err := strconv.Atoi(userIdStr)
+	if err != nil {
+		common.ApiErrorMsg(c, "invalid user id")
+		return
 	}
 
-	response := make([]BindingResponse, 0)
-	for _, binding := range bindings {
-		provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
-		if err != nil {
-			continue // Skip if provider not found
-		}
-		response = append(response, BindingResponse{
-			ProviderId:     binding.ProviderId,
-			ProviderName:   provider.Name,
-			ProviderSlug:   provider.Slug,
-			ProviderUserId: binding.ProviderUserId,
-		})
+	targetUser, err := model.GetUserById(userId, false)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	myRole := c.GetInt("role")
+	if myRole <= targetUser.Role && myRole != common.RoleRootUser {
+		common.ApiErrorMsg(c, "no permission")
+		return
+	}
+
+	response, err := buildUserOAuthBindingsResponse(userId)
+	if err != nil {
+		common.ApiError(c, err)
+		return
 	}
 
 	c.JSON(http.StatusOK, gin.H{
@@ -395,3 +544,41 @@ func UnbindCustomOAuth(c *gin.Context) {
 		"message": "解绑成功",
 	})
 }
+
+func UnbindCustomOAuthByAdmin(c *gin.Context) {
+	userIdStr := c.Param("id")
+	userId, err := strconv.Atoi(userIdStr)
+	if err != nil {
+		common.ApiErrorMsg(c, "invalid user id")
+		return
+	}
+
+	targetUser, err := model.GetUserById(userId, false)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	myRole := c.GetInt("role")
+	if myRole <= targetUser.Role && myRole != common.RoleRootUser {
+		common.ApiErrorMsg(c, "no permission")
+		return
+	}
+
+	providerIdStr := c.Param("provider_id")
+	providerId, err := strconv.Atoi(providerIdStr)
+	if err != nil {
+		common.ApiErrorMsg(c, "invalid provider id")
+		return
+	}
+
+	if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "success",
+	})
+}

+ 20 - 11
controller/midjourney.go

@@ -105,13 +105,13 @@ func UpdateMidjourneyTaskBulk() {
 			}
 			responseBody, err := io.ReadAll(resp.Body)
 			if err != nil {
-				logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
+				logger.LogError(ctx, fmt.Sprintf("Get Mjp Task parse body error: %v", err))
 				continue
 			}
 			var responseItems []dto.MidjourneyDto
 			err = json.Unmarshal(responseBody, &responseItems)
 			if err != nil {
-				logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+				logger.LogError(ctx, fmt.Sprintf("Get Mjp Task parse body error2: %v, body: %s", err, string(responseBody)))
 				continue
 			}
 			resp.Body.Close()
@@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() {
 				if !checkMjTaskNeedUpdate(task, responseItem) {
 					continue
 				}
+				preStatus := task.Status
 				task.Code = 1
 				task.Progress = responseItem.Progress
 				task.PromptEn = responseItem.PromptEn
@@ -172,18 +173,26 @@ func UpdateMidjourneyTaskBulk() {
 						shouldReturnQuota = true
 					}
 				}
-				err = task.Update()
+				won, err := task.UpdateWithStatus(preStatus)
 				if err != nil {
 					logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
-				} else {
-					if shouldReturnQuota {
-						err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
-						if err != nil {
-							logger.LogError(ctx, "fail to increase user quota: "+err.Error())
-						}
-						logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
-						model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+				} else if won && shouldReturnQuota {
+					err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
+					if err != nil {
+						logger.LogError(ctx, "fail to increase user quota: "+err.Error())
 					}
+					model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
+						UserId:    task.UserId,
+						LogType:   model.LogTypeRefund,
+						Content:   "",
+						ChannelId: task.ChannelId,
+						ModelName: service.CovertMjpActionToModelName(task.Action),
+						Quota:     task.Quota,
+						Other: map[string]interface{}{
+							"task_id": task.MjId,
+							"reason":  "构图失败",
+						},
+					})
 				}
 			}
 		}

+ 4 - 0
controller/misc.go

@@ -134,8 +134,10 @@ func GetStatus(c *gin.Context) {
 	customProviders := oauth.GetEnabledCustomProviders()
 	if len(customProviders) > 0 {
 		type CustomOAuthInfo struct {
+			Id                    int    `json:"id"`
 			Name                  string `json:"name"`
 			Slug                  string `json:"slug"`
+			Icon                  string `json:"icon"`
 			ClientId              string `json:"client_id"`
 			AuthorizationEndpoint string `json:"authorization_endpoint"`
 			Scopes                string `json:"scopes"`
@@ -144,8 +146,10 @@ func GetStatus(c *gin.Context) {
 		for _, p := range customProviders {
 			config := p.GetConfig()
 			providersInfo = append(providersInfo, CustomOAuthInfo{
+				Id:                    config.Id,
 				Name:                  config.Name,
 				Slug:                  config.Slug,
+				Icon:                  config.Icon,
 				ClientId:              config.ClientId,
 				AuthorizationEndpoint: config.AuthorizationEndpoint,
 				Scopes:                config.Scopes,

+ 18 - 6
controller/oauth.go

@@ -237,6 +237,16 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
 
 	// Set up new user
 	user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
+
+	if oauthUser.Username != "" {
+		if exists, err := model.CheckUserExistOrDeleted(oauthUser.Username, ""); err == nil && !exists {
+			// 防止索引退化
+			if len(oauthUser.Username) <= model.UserNameMaxLength {
+				user.Username = oauthUser.Username
+			}
+		}
+	}
+
 	if oauthUser.DisplayName != "" {
 		user.DisplayName = oauthUser.DisplayName
 	} else if oauthUser.Username != "" {
@@ -295,12 +305,12 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
 			// Set the provider user ID on the user model and update
 			provider.SetProviderUserID(user, oauthUser.ProviderUserID)
 			if err := tx.Model(user).Updates(map[string]interface{}{
-				"github_id":    user.GitHubId,
-				"discord_id":   user.DiscordId,
-				"oidc_id":      user.OidcId,
-				"linux_do_id":  user.LinuxDOId,
-				"wechat_id":    user.WeChatId,
-				"telegram_id":  user.TelegramId,
+				"github_id":   user.GitHubId,
+				"discord_id":  user.DiscordId,
+				"oidc_id":     user.OidcId,
+				"linux_do_id": user.LinuxDOId,
+				"wechat_id":   user.WeChatId,
+				"telegram_id": user.TelegramId,
 			}).Error; err != nil {
 				return err
 			}
@@ -340,6 +350,8 @@ func handleOAuthError(c *gin.Context, err error) {
 		} else {
 			common.ApiErrorI18n(c, e.MsgKey)
 		}
+	case *oauth.AccessDeniedError:
+		common.ApiErrorMsg(c, e.Message)
 	case *oauth.TrustLevelError:
 		common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
 	default:

+ 113 - 38
controller/relay.go

@@ -455,72 +455,147 @@ func RelayNotFound(c *gin.Context) {
 	})
 }
 
+func RelayTaskFetch(c *gin.Context) {
+	relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, &dto.TaskError{
+			Code:       "gen_relay_info_failed",
+			Message:    err.Error(),
+			StatusCode: http.StatusInternalServerError,
+		})
+		return
+	}
+	if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil {
+		respondTaskError(c, taskErr)
+	}
+}
+
 func RelayTask(c *gin.Context) {
-	retryTimes := common.RetryTimes
-	channelId := c.GetInt("channel_id")
-	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
 	relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
 	if err != nil {
+		c.JSON(http.StatusInternalServerError, &dto.TaskError{
+			Code:       "gen_relay_info_failed",
+			Message:    err.Error(),
+			StatusCode: http.StatusInternalServerError,
+		})
 		return
 	}
-	taskErr := taskRelayHandler(c, relayInfo)
-	if taskErr == nil {
-		retryTimes = 0
+
+	if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil {
+		respondTaskError(c, taskErr)
+		return
 	}
+
+	var result *relay.TaskSubmitResult
+	var taskErr *dto.TaskError
+	defer func() {
+		if taskErr != nil && relayInfo.Billing != nil {
+			relayInfo.Billing.Refund(c)
+		}
+	}()
+
 	retryParam := &service.RetryParam{
 		Ctx:        c,
 		TokenGroup: relayInfo.TokenGroup,
 		ModelName:  relayInfo.OriginModelName,
 		Retry:      common.GetPointer(0),
 	}
-	for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
-		channel, newAPIError := getChannel(c, relayInfo, retryParam)
-		if newAPIError != nil {
-			logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
-			taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
-			break
+
+	for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
+		var channel *model.Channel
+
+		if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil {
+			channel = lockedCh
+			if retryParam.GetRetry() > 0 {
+				if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil {
+					taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError)
+					break
+				}
+			}
+		} else {
+			var channelErr *types.NewAPIError
+			channel, channelErr = getChannel(c, relayInfo, retryParam)
+			if channelErr != nil {
+				logger.LogError(c, channelErr.Error())
+				taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError)
+				break
+			}
 		}
-		channelId = channel.Id
-		useChannel := c.GetStringSlice("use_channel")
-		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
-		c.Set("use_channel", useChannel)
-		logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
-		//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
-
-		bodyStorage, err := common.GetBodyStorage(c)
-		if err != nil {
-			if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
-				taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
+
+		addUsedChannel(c, channel.Id)
+		bodyStorage, bodyErr := common.GetBodyStorage(c)
+		if bodyErr != nil {
+			if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
+				taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
 			} else {
-				taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest)
+				taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest)
 			}
 			break
 		}
 		c.Request.Body = io.NopCloser(bodyStorage)
-		taskErr = taskRelayHandler(c, relayInfo)
+
+		result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
+		if taskErr == nil {
+			break
+		}
+
+		if !taskErr.LocalError {
+			processChannelError(c,
+				*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
+					common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
+				types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
+		}
+
+		if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) {
+			break
+		}
 	}
+
 	useChannel := c.GetStringSlice("use_channel")
 	if len(useChannel) > 1 {
 		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
 		logger.LogInfo(c, retryLogStr)
 	}
-	if taskErr != nil {
-		if taskErr.StatusCode == http.StatusTooManyRequests {
-			taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
+
+	// ── 成功:结算 + 日志 + 插入任务 ──
+	if taskErr == nil {
+		if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
+			common.SysError("settle task billing error: " + settleErr.Error())
 		}
-		c.JSON(taskErr.StatusCode, taskErr)
+		service.LogTaskConsumption(c, relayInfo)
+
+		task := model.InitTask(result.Platform, relayInfo)
+		task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
+		task.PrivateData.BillingSource = relayInfo.BillingSource
+		task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
+		task.PrivateData.TokenId = relayInfo.TokenId
+		task.PrivateData.BillingContext = &model.TaskBillingContext{
+			ModelPrice:      relayInfo.PriceData.ModelPrice,
+			GroupRatio:      relayInfo.PriceData.GroupRatioInfo.GroupRatio,
+			ModelRatio:      relayInfo.PriceData.ModelRatio,
+			OtherRatios:     relayInfo.PriceData.OtherRatios,
+			OriginModelName: relayInfo.OriginModelName,
+			PerCallBilling:  common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
+		}
+		task.Quota = result.Quota
+		task.Data = result.TaskData
+		task.Action = relayInfo.Action
+		if insertErr := task.Insert(); insertErr != nil {
+			common.SysError("insert task error: " + insertErr.Error())
+		}
+	}
+
+	if taskErr != nil {
+		respondTaskError(c, taskErr)
 	}
 }
 
-func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
-	var err *dto.TaskError
-	switch relayInfo.RelayMode {
-	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
-		err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
-	default:
-		err = relay.RelayTaskSubmit(c, relayInfo)
+// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写)
+func respondTaskError(c *gin.Context, taskErr *dto.TaskError) {
+	if taskErr.StatusCode == http.StatusTooManyRequests {
+		taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
 	}
-	return err
+	c.JSON(taskErr.StatusCode, taskErr)
 }
 
 func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
@@ -544,7 +619,7 @@ func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError,
 	}
 	if taskErr.StatusCode/100 == 5 {
 		// 超时不重试
-		if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
+		if operation_setting.IsAlwaysSkipRetryStatusCode(taskErr.StatusCode) {
 			return false
 		}
 		return true

+ 33 - 215
controller/task.go

@@ -1,231 +1,22 @@
 package controller
 
 import (
-	"context"
-	"encoding/json"
-	"errors"
-	"fmt"
-	"io"
-	"net/http"
-	"sort"
 	"strconv"
-	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
-	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay"
+	"github.com/QuantumNous/new-api/service"
+	"github.com/QuantumNous/new-api/types"
 
 	"github.com/gin-gonic/gin"
-	"github.com/samber/lo"
 )
 
+// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层
 func UpdateTaskBulk() {
-	//revocer
-	//imageModel := "midjourney"
-	for {
-		time.Sleep(time.Duration(15) * time.Second)
-		common.SysLog("任务进度轮询开始")
-		ctx := context.TODO()
-		allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
-		platformTask := make(map[constant.TaskPlatform][]*model.Task)
-		for _, t := range allTasks {
-			platformTask[t.Platform] = append(platformTask[t.Platform], t)
-		}
-		for platform, tasks := range platformTask {
-			if len(tasks) == 0 {
-				continue
-			}
-			taskChannelM := make(map[int][]string)
-			taskM := make(map[string]*model.Task)
-			nullTaskIds := make([]int64, 0)
-			for _, task := range tasks {
-				if task.TaskID == "" {
-					// 统计失败的未完成任务
-					nullTaskIds = append(nullTaskIds, task.ID)
-					continue
-				}
-				taskM[task.TaskID] = task
-				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
-			}
-			if len(nullTaskIds) > 0 {
-				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
-					"status":   "FAILURE",
-					"progress": "100%",
-				})
-				if err != nil {
-					logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
-				} else {
-					logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
-				}
-			}
-			if len(taskChannelM) == 0 {
-				continue
-			}
-
-			UpdateTaskByPlatform(platform, taskChannelM, taskM)
-		}
-		common.SysLog("任务进度轮询完成")
-	}
-}
-
-func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
-	switch platform {
-	case constant.TaskPlatformMidjourney:
-		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
-	case constant.TaskPlatformSuno:
-		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
-	default:
-		if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
-			common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
-		}
-	}
-}
-
-func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
-	for channelId, taskIds := range taskChannelM {
-		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
-		if err != nil {
-			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
-	logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
-	if len(taskIds) == 0 {
-		return nil
-	}
-	channel, err := model.CacheGetChannel(channelId)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
-		err = model.TaskBulkUpdate(taskIds, map[string]any{
-			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
-			"status":      "FAILURE",
-			"progress":    "100%",
-		})
-		if err != nil {
-			common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
-		}
-		return err
-	}
-	adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
-	if adaptor == nil {
-		return errors.New("adaptor not found")
-	}
-	proxy := channel.GetSetting().Proxy
-	resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
-		"ids": taskIds,
-	}, proxy)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
-		return err
-	}
-	if resp.StatusCode != http.StatusOK {
-		logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
-		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
-	}
-	defer resp.Body.Close()
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
-		return err
-	}
-	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
-	err = json.Unmarshal(responseBody, &responseItems)
-	if err != nil {
-		logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
-		return err
-	}
-	if !responseItems.IsSuccess() {
-		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
-		return err
-	}
-
-	for _, responseItem := range responseItems.Data {
-		task := taskM[responseItem.TaskID]
-		if !checkTaskNeedUpdate(task, responseItem) {
-			continue
-		}
-
-		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
-		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
-		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
-		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
-		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
-		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
-			logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
-			task.Progress = "100%"
-			//err = model.CacheUpdateUserQuota(task.UserId) ?
-			if err != nil {
-				logger.LogError(ctx, "error update user quota cache: "+err.Error())
-			} else {
-				quota := task.Quota
-				if quota != 0 {
-					err = model.IncreaseUserQuota(task.UserId, quota, false)
-					if err != nil {
-						logger.LogError(ctx, "fail to increase user quota: "+err.Error())
-					}
-					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
-					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-				}
-			}
-		}
-		if responseItem.Status == model.TaskStatusSuccess {
-			task.Progress = "100%"
-		}
-		task.Data = responseItem.Data
-
-		err = task.Update()
-		if err != nil {
-			common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
-		}
-	}
-	return nil
-}
-
-func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
-
-	if oldTask.SubmitTime != newTask.SubmitTime {
-		return true
-	}
-	if oldTask.StartTime != newTask.StartTime {
-		return true
-	}
-	if oldTask.FinishTime != newTask.FinishTime {
-		return true
-	}
-	if string(oldTask.Status) != newTask.Status {
-		return true
-	}
-	if oldTask.FailReason != newTask.FailReason {
-		return true
-	}
-	if oldTask.FinishTime != newTask.FinishTime {
-		return true
-	}
-
-	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
-		return true
-	}
-
-	oldData, _ := json.Marshal(oldTask.Data)
-	newData, _ := json.Marshal(newTask.Data)
-
-	sort.Slice(oldData, func(i, j int) bool {
-		return oldData[i] < oldData[j]
-	})
-	sort.Slice(newData, func(i, j int) bool {
-		return newData[i] < newData[j]
-	})
-
-	if string(oldData) != string(newData) {
-		return true
-	}
-	return false
+	service.TaskPollingLoop()
 }
 
 func GetAllTask(c *gin.Context) {
@@ -247,7 +38,7 @@ func GetAllTask(c *gin.Context) {
 	items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
 	total := model.TaskCountAllTasks(queryParams)
 	pageInfo.SetTotal(int(total))
-	pageInfo.SetItems(items)
+	pageInfo.SetItems(tasksToDto(items, true))
 	common.ApiSuccess(c, pageInfo)
 }
 
@@ -271,6 +62,33 @@ func GetUserTask(c *gin.Context) {
 	items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
 	total := model.TaskCountAllUserTask(userId, queryParams)
 	pageInfo.SetTotal(int(total))
-	pageInfo.SetItems(items)
+	pageInfo.SetItems(tasksToDto(items, false))
 	common.ApiSuccess(c, pageInfo)
 }
+
+func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto {
+	var userIdMap map[int]*model.UserBase
+	if fillUser {
+		userIdMap = make(map[int]*model.UserBase)
+		userIds := types.NewSet[int]()
+		for _, task := range tasks {
+			userIds.Add(task.UserId)
+		}
+		for _, userId := range userIds.Items() {
+			cacheUser, err := model.GetUserCache(userId)
+			if err == nil {
+				userIdMap[userId] = cacheUser
+			}
+		}
+	}
+	result := make([]*dto.TaskDto, len(tasks))
+	for i, task := range tasks {
+		if fillUser {
+			if user, ok := userIdMap[task.UserId]; ok {
+				task.Username = user.Username
+			}
+		}
+		result[i] = relay.TaskModel2Dto(task)
+	}
+	return result
+}

+ 0 - 313
controller/task_video.go

@@ -1,313 +0,0 @@
-package controller
-
-import (
-	"context"
-	"encoding/json"
-	"fmt"
-	"io"
-	"time"
-
-	"github.com/QuantumNous/new-api/common"
-	"github.com/QuantumNous/new-api/constant"
-	"github.com/QuantumNous/new-api/dto"
-	"github.com/QuantumNous/new-api/logger"
-	"github.com/QuantumNous/new-api/model"
-	"github.com/QuantumNous/new-api/relay"
-	"github.com/QuantumNous/new-api/relay/channel"
-	relaycommon "github.com/QuantumNous/new-api/relay/common"
-	"github.com/QuantumNous/new-api/setting/ratio_setting"
-)
-
-func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
-	for channelId, taskIds := range taskChannelM {
-		if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
-			logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
-	logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
-	if len(taskIds) == 0 {
-		return nil
-	}
-	cacheGetChannel, err := model.CacheGetChannel(channelId)
-	if err != nil {
-		errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
-			"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
-			"status":      "FAILURE",
-			"progress":    "100%",
-		})
-		if errUpdate != nil {
-			common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
-		}
-		return fmt.Errorf("CacheGetChannel failed: %w", err)
-	}
-	adaptor := relay.GetTaskAdaptor(platform)
-	if adaptor == nil {
-		return fmt.Errorf("video adaptor not found")
-	}
-	info := &relaycommon.RelayInfo{}
-	info.ChannelMeta = &relaycommon.ChannelMeta{
-		ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
-	}
-	info.ApiKey = cacheGetChannel.Key
-	adaptor.Init(info)
-	for _, taskId := range taskIds {
-		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
-			logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
-		}
-	}
-	return nil
-}
-
-func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
-	baseURL := constant.ChannelBaseURLs[channel.Type]
-	if channel.GetBaseURL() != "" {
-		baseURL = channel.GetBaseURL()
-	}
-	proxy := channel.GetSetting().Proxy
-
-	task := taskM[taskId]
-	if task == nil {
-		logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
-		return fmt.Errorf("task %s not found", taskId)
-	}
-	key := channel.Key
-
-	privateData := task.PrivateData
-	if privateData.Key != "" {
-		key = privateData.Key
-	}
-	resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
-		"task_id": taskId,
-		"action":  task.Action,
-	}, proxy)
-	if err != nil {
-		return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
-	}
-	//if resp.StatusCode != http.StatusOK {
-	//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
-	//}
-	defer resp.Body.Close()
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
-	}
-
-	logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
-
-	taskResult := &relaycommon.TaskInfo{}
-	// try parse as New API response format
-	var responseItems dto.TaskResponse[model.Task]
-	if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
-		logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
-		t := responseItems.Data
-		taskResult.TaskID = t.TaskID
-		taskResult.Status = string(t.Status)
-		taskResult.Url = t.FailReason
-		taskResult.Progress = t.Progress
-		taskResult.Reason = t.FailReason
-		task.Data = t.Data
-	} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
-		return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
-	} else {
-		task.Data = redactVideoResponseBody(responseBody)
-	}
-
-	logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
-
-	now := time.Now().Unix()
-	if taskResult.Status == "" {
-		//return fmt.Errorf("task %s status is empty", taskId)
-		taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
-	}
-
-	// 记录原本的状态,防止重复退款
-	shouldRefund := false
-	quota := task.Quota
-	preStatus := task.Status
-
-	task.Status = model.TaskStatus(taskResult.Status)
-	switch taskResult.Status {
-	case model.TaskStatusSubmitted:
-		task.Progress = "10%"
-	case model.TaskStatusQueued:
-		task.Progress = "20%"
-	case model.TaskStatusInProgress:
-		task.Progress = "30%"
-		if task.StartTime == 0 {
-			task.StartTime = now
-		}
-	case model.TaskStatusSuccess:
-		task.Progress = "100%"
-		if task.FinishTime == 0 {
-			task.FinishTime = now
-		}
-		if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
-			task.FailReason = taskResult.Url
-		}
-
-		// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
-		if taskResult.TotalTokens > 0 {
-			// 获取模型名称
-			var taskData map[string]interface{}
-			if err := json.Unmarshal(task.Data, &taskData); err == nil {
-				if modelName, ok := taskData["model"].(string); ok && modelName != "" {
-					// 获取模型价格和倍率
-					modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
-					// 只有配置了倍率(非固定价格)时才按 token 重新计费
-					if hasRatioSetting && modelRatio > 0 {
-						// 获取用户和组的倍率信息
-						group := task.Group
-						if group == "" {
-							user, err := model.GetUserById(task.UserId, false)
-							if err == nil {
-								group = user.Group
-							}
-						}
-						if group != "" {
-							groupRatio := ratio_setting.GetGroupRatio(group)
-							userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
-
-							var finalGroupRatio float64
-							if hasUserGroupRatio {
-								finalGroupRatio = userGroupRatio
-							} else {
-								finalGroupRatio = groupRatio
-							}
-
-							// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
-							actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
-
-							// 计算差额
-							preConsumedQuota := task.Quota
-							quotaDelta := actualQuota - preConsumedQuota
-
-							if quotaDelta > 0 {
-								// 需要补扣费
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
-									task.TaskID,
-									logger.LogQuota(quotaDelta),
-									logger.LogQuota(actualQuota),
-									logger.LogQuota(preConsumedQuota),
-									taskResult.TotalTokens,
-								))
-								if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
-									logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
-								} else {
-									model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
-									model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
-									task.Quota = actualQuota // 更新任务记录的实际扣费额度
-
-									// 记录消费日志
-									logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
-										modelRatio, finalGroupRatio, taskResult.TotalTokens,
-										logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
-									model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-								}
-							} else if quotaDelta < 0 {
-								// 需要退还多扣的费用
-								refundQuota := -quotaDelta
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
-									task.TaskID,
-									logger.LogQuota(refundQuota),
-									logger.LogQuota(actualQuota),
-									logger.LogQuota(preConsumedQuota),
-									taskResult.TotalTokens,
-								))
-								if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
-									logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
-								} else {
-									task.Quota = actualQuota // 更新任务记录的实际扣费额度
-
-									// 记录退款日志
-									logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
-										modelRatio, finalGroupRatio, taskResult.TotalTokens,
-										logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
-									model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-								}
-							} else {
-								// quotaDelta == 0, 预扣费刚好准确
-								logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
-									task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
-							}
-						}
-					}
-				}
-			}
-		}
-	case model.TaskStatusFailure:
-		logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
-		task.Status = model.TaskStatusFailure
-		task.Progress = "100%"
-		if task.FinishTime == 0 {
-			task.FinishTime = now
-		}
-		task.FailReason = taskResult.Reason
-		logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
-		taskResult.Progress = "100%"
-		if quota != 0 {
-			if preStatus != model.TaskStatusFailure {
-				shouldRefund = true
-			} else {
-				logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
-			}
-		}
-	default:
-		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
-	}
-	if taskResult.Progress != "" {
-		task.Progress = taskResult.Progress
-	}
-	if err := task.Update(); err != nil {
-		common.SysLog("UpdateVideoTask task error: " + err.Error())
-		shouldRefund = false
-	}
-
-	if shouldRefund {
-		// 任务失败且之前状态不是失败才退还额度,防止重复退还
-		if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
-			logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
-		}
-		logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
-		model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
-	}
-
-	return nil
-}
-
-func redactVideoResponseBody(body []byte) []byte {
-	var m map[string]any
-	if err := json.Unmarshal(body, &m); err != nil {
-		return body
-	}
-	resp, _ := m["response"].(map[string]any)
-	if resp != nil {
-		delete(resp, "bytesBase64Encoded")
-		if v, ok := resp["video"].(string); ok {
-			resp["video"] = truncateBase64(v)
-		}
-		if vs, ok := resp["videos"].([]any); ok {
-			for i := range vs {
-				if vm, ok := vs[i].(map[string]any); ok {
-					delete(vm, "bytesBase64Encoded")
-				}
-			}
-		}
-	}
-	b, err := json.Marshal(m)
-	if err != nil {
-		return body
-	}
-	return b
-}
-
-func truncateBase64(s string) string {
-	const maxKeep = 256
-	if len(s) <= maxKeep {
-		return s
-	}
-	return s[:maxKeep] + "..."
-}

+ 38 - 0
controller/user.go

@@ -582,6 +582,44 @@ func UpdateUser(c *gin.Context) {
 	return
 }
 
+func AdminClearUserBinding(c *gin.Context) {
+	id, err := strconv.Atoi(c.Param("id"))
+	if err != nil {
+		common.ApiErrorI18n(c, i18n.MsgInvalidParams)
+		return
+	}
+
+	bindingType := strings.ToLower(strings.TrimSpace(c.Param("binding_type")))
+	if bindingType == "" {
+		common.ApiErrorI18n(c, i18n.MsgInvalidParams)
+		return
+	}
+
+	user, err := model.GetUserById(id, false)
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	myRole := c.GetInt("role")
+	if myRole <= user.Role && myRole != common.RoleRootUser {
+		common.ApiErrorI18n(c, i18n.MsgUserNoPermissionSameLevel)
+		return
+	}
+
+	if err := user.ClearBinding(bindingType); err != nil {
+		common.ApiError(c, err)
+		return
+	}
+
+	model.RecordLog(user.Id, model.LogTypeManage, fmt.Sprintf("admin cleared %s binding for user %s", bindingType, user.Username))
+
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "success",
+	})
+}
+
 func UpdateSelf(c *gin.Context) {
 	var requestData map[string]interface{}
 	err := json.NewDecoder(c.Request.Body).Decode(&requestData)

+ 30 - 81
controller/video_proxy.go

@@ -16,59 +16,44 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+// videoProxyError returns a standardized OpenAI-style error response.
+func videoProxyError(c *gin.Context, status int, errType, message string) {
+	c.JSON(status, gin.H{
+		"error": gin.H{
+			"message": message,
+			"type":    errType,
+		},
+	})
+}
+
 func VideoProxy(c *gin.Context) {
 	taskID := c.Param("task_id")
 	if taskID == "" {
-		c.JSON(http.StatusBadRequest, gin.H{
-			"error": gin.H{
-				"message": "task_id is required",
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required")
 		return
 	}
 
 	task, exists, err := model.GetByOnlyTaskId(taskID)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to query task",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
 		return
 	}
 	if !exists || task == nil {
-		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
-		c.JSON(http.StatusNotFound, gin.H{
-			"error": gin.H{
-				"message": "Task not found",
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found")
 		return
 	}
 
 	if task.Status != model.TaskStatusSuccess {
-		c.JSON(http.StatusBadRequest, gin.H{
-			"error": gin.H{
-				"message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
-				"type":    "invalid_request_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadRequest, "invalid_request_error",
+			fmt.Sprintf("Task is not completed yet, current status: %s", task.Status))
 		return
 	}
 
 	channel, err := model.CacheGetChannel(task.ChannelId)
 	if err != nil {
-		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to retrieve channel information",
-				"type":    "server_error",
-			},
-		})
+		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error()))
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information")
 		return
 	}
 	baseURL := channel.GetBaseURL()
@@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) {
 	client, err := service.GetHttpClientWithProxy(proxy)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy client",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client")
 		return
 	}
 
@@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) {
 	req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy request",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
 		return
 	}
 
@@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) {
 		apiKey := task.PrivateData.Key
 		if apiKey == "" {
 			logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
-			c.JSON(http.StatusInternalServerError, gin.H{
-				"error": gin.H{
-					"message": "API key not stored for task",
-					"type":    "server_error",
-				},
-			})
+			videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task")
 			return
 		}
-
 		videoURL, err = getGeminiVideoURL(channel, task, apiKey)
 		if err != nil {
 			logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
-			c.JSON(http.StatusBadGateway, gin.H{
-				"error": gin.H{
-					"message": "Failed to resolve Gemini video URL",
-					"type":    "server_error",
-				},
-			})
+			videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL")
 			return
 		}
 		req.Header.Set("x-goog-api-key", apiKey)
 	case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
-		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
+		videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
 		req.Header.Set("Authorization", "Bearer "+channel.Key)
 	default:
-		// Video URL is directly in task.FailReason
-		videoURL = task.FailReason
+		// Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data)
+		videoURL = task.GetResultURL()
 	}
 
 	req.URL, err = url.Parse(videoURL)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
-		c.JSON(http.StatusInternalServerError, gin.H{
-			"error": gin.H{
-				"message": "Failed to create proxy request",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
 		return
 	}
 
 	resp, err := client.Do(req)
 	if err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
-		c.JSON(http.StatusBadGateway, gin.H{
-			"error": gin.H{
-				"message": "Failed to fetch video content",
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
 		return
 	}
 	defer resp.Body.Close()
 
 	if resp.StatusCode != http.StatusOK {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
-		c.JSON(http.StatusBadGateway, gin.H{
-			"error": gin.H{
-				"message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
-				"type":    "server_error",
-			},
-		})
+		videoProxyError(c, http.StatusBadGateway, "server_error",
+			fmt.Sprintf("Upstream service returned status %d", resp.StatusCode))
 		return
 	}
 
@@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) {
 		}
 	}
 
-	c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
+	c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
 	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
+	if _, err = io.Copy(c.Writer, resp.Body); err != nil {
 		logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
 	}
 }

+ 4 - 4
controller/video_proxy_gemini.go

@@ -1,12 +1,12 @@
 package controller
 
 import (
-	"encoding/json"
 	"fmt"
 	"io"
 	"strconv"
 	"strings"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay"
@@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
 
 	proxy := channel.GetSetting().Proxy
 	resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
-		"task_id": task.TaskID,
+		"task_id": task.GetUpstreamTaskID(),
 		"action":  task.Action,
 	}, proxy)
 	if err != nil {
@@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
 		return ""
 	}
 	var payload map[string]any
-	if err := json.Unmarshal(task.Data, &payload); err != nil {
+	if err := common.Unmarshal(task.Data, &payload); err != nil {
 		return ""
 	}
 	return extractGeminiVideoURLFromMap(payload)
@@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
 
 func extractGeminiVideoURLFromPayload(body []byte) string {
 	var payload map[string]any
-	if err := json.Unmarshal(body, &payload); err != nil {
+	if err := common.Unmarshal(body, &payload); err != nil {
 		return ""
 	}
 	return extractGeminiVideoURLFromMap(payload)

+ 10 - 8
dto/channel_settings.go

@@ -24,14 +24,16 @@ const (
 )
 
 type ChannelOtherSettings struct {
-	AzureResponsesVersion string        `json:"azure_responses_version,omitempty"`
-	VertexKeyType         VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
-	OpenRouterEnterprise  *bool         `json:"openrouter_enterprise,omitempty"`
-	ClaudeBetaQuery       bool          `json:"claude_beta_query,omitempty"`      // Claude 渠道是否强制追加 ?beta=true
-	AllowServiceTier      bool          `json:"allow_service_tier,omitempty"`      // 是否允许 service_tier 透传(默认过滤以避免额外计费)
-	DisableStore          bool          `json:"disable_store,omitempty"`           // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
-	AllowSafetyIdentifier bool          `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
-	AwsKeyType            AwsKeyType    `json:"aws_key_type,omitempty"`
+	AzureResponsesVersion   string        `json:"azure_responses_version,omitempty"`
+	VertexKeyType           VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
+	OpenRouterEnterprise    *bool         `json:"openrouter_enterprise,omitempty"`
+	ClaudeBetaQuery         bool          `json:"claude_beta_query,omitempty"`         // Claude 渠道是否强制追加 ?beta=true
+	AllowServiceTier        bool          `json:"allow_service_tier,omitempty"`        // 是否允许 service_tier 透传(默认过滤以避免额外计费)
+	AllowInferenceGeo       bool          `json:"allow_inference_geo,omitempty"`       // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规)
+	DisableStore            bool          `json:"disable_store,omitempty"`             // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
+	AllowSafetyIdentifier   bool          `json:"allow_safety_identifier,omitempty"`   // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
+	AllowIncludeObfuscation bool          `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护)
+	AwsKeyType              AwsKeyType    `json:"aws_key_type,omitempty"`
 }
 
 func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {

+ 9 - 5
dto/claude.go

@@ -190,10 +190,13 @@ type ClaudeToolChoice struct {
 }
 
 type ClaudeRequest struct {
-	Model             string          `json:"model"`
-	Prompt            string          `json:"prompt,omitempty"`
-	System            any             `json:"system,omitempty"`
-	Messages          []ClaudeMessage `json:"messages,omitempty"`
+	Model    string          `json:"model"`
+	Prompt   string          `json:"prompt,omitempty"`
+	System   any             `json:"system,omitempty"`
+	Messages []ClaudeMessage `json:"messages,omitempty"`
+	// InferenceGeo controls Claude data residency region.
+	// This field is filtered by default and can be enabled via channel setting allow_inference_geo.
+	InferenceGeo      string          `json:"inference_geo,omitempty"`
 	MaxTokens         uint            `json:"max_tokens,omitempty"`
 	MaxTokensToSample uint            `json:"max_tokens_to_sample,omitempty"`
 	StopSequences     []string        `json:"stop_sequences,omitempty"`
@@ -210,7 +213,8 @@ type ClaudeRequest struct {
 	Thinking          *Thinking       `json:"thinking,omitempty"`
 	McpServers        json.RawMessage `json:"mcp_servers,omitempty"`
 	Metadata          json.RawMessage `json:"metadata,omitempty"`
-	// 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
+	// ServiceTier specifies upstream service level and may affect billing.
+	// This field is filtered by default and can be enabled via channel setting allow_service_tier.
 	ServiceTier string `json:"service_tier,omitempty"`
 }
 

+ 48 - 41
dto/gemini.go

@@ -324,25 +324,26 @@ type GeminiChatTool struct {
 }
 
 type GeminiChatGenerationConfig struct {
-	Temperature        *float64              `json:"temperature,omitempty"`
-	TopP               float64               `json:"topP,omitempty"`
-	TopK               float64               `json:"topK,omitempty"`
-	MaxOutputTokens    uint                  `json:"maxOutputTokens,omitempty"`
-	CandidateCount     int                   `json:"candidateCount,omitempty"`
-	StopSequences      []string              `json:"stopSequences,omitempty"`
-	ResponseMimeType   string                `json:"responseMimeType,omitempty"`
-	ResponseSchema     any                   `json:"responseSchema,omitempty"`
-	ResponseJsonSchema json.RawMessage       `json:"responseJsonSchema,omitempty"`
-	PresencePenalty    *float32              `json:"presencePenalty,omitempty"`
-	FrequencyPenalty   *float32              `json:"frequencyPenalty,omitempty"`
-	ResponseLogprobs   bool                  `json:"responseLogprobs,omitempty"`
-	Logprobs           *int32                `json:"logprobs,omitempty"`
-	MediaResolution    MediaResolution       `json:"mediaResolution,omitempty"`
-	Seed               int64                 `json:"seed,omitempty"`
-	ResponseModalities []string              `json:"responseModalities,omitempty"`
-	ThinkingConfig     *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
-	SpeechConfig       json.RawMessage       `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
-	ImageConfig        json.RawMessage       `json:"imageConfig,omitempty"`  // RawMessage to allow flexible image config
+	Temperature                *float64              `json:"temperature,omitempty"`
+	TopP                       float64               `json:"topP,omitempty"`
+	TopK                       float64               `json:"topK,omitempty"`
+	MaxOutputTokens            uint                  `json:"maxOutputTokens,omitempty"`
+	CandidateCount             int                   `json:"candidateCount,omitempty"`
+	StopSequences              []string              `json:"stopSequences,omitempty"`
+	ResponseMimeType           string                `json:"responseMimeType,omitempty"`
+	ResponseSchema             any                   `json:"responseSchema,omitempty"`
+	ResponseJsonSchema         json.RawMessage       `json:"responseJsonSchema,omitempty"`
+	PresencePenalty            *float32              `json:"presencePenalty,omitempty"`
+	FrequencyPenalty           *float32              `json:"frequencyPenalty,omitempty"`
+	ResponseLogprobs           bool                  `json:"responseLogprobs,omitempty"`
+	Logprobs                   *int32                `json:"logprobs,omitempty"`
+	EnableEnhancedCivicAnswers *bool                 `json:"enableEnhancedCivicAnswers,omitempty"`
+	MediaResolution            MediaResolution       `json:"mediaResolution,omitempty"`
+	Seed                       int64                 `json:"seed,omitempty"`
+	ResponseModalities         []string              `json:"responseModalities,omitempty"`
+	ThinkingConfig             *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
+	SpeechConfig               json.RawMessage       `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
+	ImageConfig                json.RawMessage       `json:"imageConfig,omitempty"`  // RawMessage to allow flexible image config
 }
 
 // UnmarshalJSON allows GeminiChatGenerationConfig to accept both snake_case and camelCase fields.
@@ -350,22 +351,23 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
 	type Alias GeminiChatGenerationConfig
 	var aux struct {
 		Alias
-		TopPSnake               float64               `json:"top_p,omitempty"`
-		TopKSnake               float64               `json:"top_k,omitempty"`
-		MaxOutputTokensSnake    uint                  `json:"max_output_tokens,omitempty"`
-		CandidateCountSnake     int                   `json:"candidate_count,omitempty"`
-		StopSequencesSnake      []string              `json:"stop_sequences,omitempty"`
-		ResponseMimeTypeSnake   string                `json:"response_mime_type,omitempty"`
-		ResponseSchemaSnake     any                   `json:"response_schema,omitempty"`
-		ResponseJsonSchemaSnake json.RawMessage       `json:"response_json_schema,omitempty"`
-		PresencePenaltySnake    *float32              `json:"presence_penalty,omitempty"`
-		FrequencyPenaltySnake   *float32              `json:"frequency_penalty,omitempty"`
-		ResponseLogprobsSnake   bool                  `json:"response_logprobs,omitempty"`
-		MediaResolutionSnake    MediaResolution       `json:"media_resolution,omitempty"`
-		ResponseModalitiesSnake []string              `json:"response_modalities,omitempty"`
-		ThinkingConfigSnake     *GeminiThinkingConfig `json:"thinking_config,omitempty"`
-		SpeechConfigSnake       json.RawMessage       `json:"speech_config,omitempty"`
-		ImageConfigSnake        json.RawMessage       `json:"image_config,omitempty"`
+		TopPSnake                       float64               `json:"top_p,omitempty"`
+		TopKSnake                       float64               `json:"top_k,omitempty"`
+		MaxOutputTokensSnake            uint                  `json:"max_output_tokens,omitempty"`
+		CandidateCountSnake             int                   `json:"candidate_count,omitempty"`
+		StopSequencesSnake              []string              `json:"stop_sequences,omitempty"`
+		ResponseMimeTypeSnake           string                `json:"response_mime_type,omitempty"`
+		ResponseSchemaSnake             any                   `json:"response_schema,omitempty"`
+		ResponseJsonSchemaSnake         json.RawMessage       `json:"response_json_schema,omitempty"`
+		PresencePenaltySnake            *float32              `json:"presence_penalty,omitempty"`
+		FrequencyPenaltySnake           *float32              `json:"frequency_penalty,omitempty"`
+		ResponseLogprobsSnake           bool                  `json:"response_logprobs,omitempty"`
+		EnableEnhancedCivicAnswersSnake *bool                 `json:"enable_enhanced_civic_answers,omitempty"`
+		MediaResolutionSnake            MediaResolution       `json:"media_resolution,omitempty"`
+		ResponseModalitiesSnake         []string              `json:"response_modalities,omitempty"`
+		ThinkingConfigSnake             *GeminiThinkingConfig `json:"thinking_config,omitempty"`
+		SpeechConfigSnake               json.RawMessage       `json:"speech_config,omitempty"`
+		ImageConfigSnake                json.RawMessage       `json:"image_config,omitempty"`
 	}
 
 	if err := common.Unmarshal(data, &aux); err != nil {
@@ -408,6 +410,9 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
 	if aux.ResponseLogprobsSnake {
 		c.ResponseLogprobs = aux.ResponseLogprobsSnake
 	}
+	if aux.EnableEnhancedCivicAnswersSnake != nil {
+		c.EnableEnhancedCivicAnswers = aux.EnableEnhancedCivicAnswersSnake
+	}
 	if aux.MediaResolutionSnake != "" {
 		c.MediaResolution = aux.MediaResolutionSnake
 	}
@@ -453,12 +458,14 @@ type GeminiChatResponse struct {
 }
 
 type GeminiUsageMetadata struct {
-	PromptTokenCount        int                         `json:"promptTokenCount"`
-	CandidatesTokenCount    int                         `json:"candidatesTokenCount"`
-	TotalTokenCount         int                         `json:"totalTokenCount"`
-	ThoughtsTokenCount      int                         `json:"thoughtsTokenCount"`
-	CachedContentTokenCount int                         `json:"cachedContentTokenCount"`
-	PromptTokensDetails     []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+	PromptTokenCount           int                         `json:"promptTokenCount"`
+	ToolUsePromptTokenCount    int                         `json:"toolUsePromptTokenCount"`
+	CandidatesTokenCount       int                         `json:"candidatesTokenCount"`
+	TotalTokenCount            int                         `json:"totalTokenCount"`
+	ThoughtsTokenCount         int                         `json:"thoughtsTokenCount"`
+	CachedContentTokenCount    int                         `json:"cachedContentTokenCount"`
+	PromptTokensDetails        []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+	ToolUsePromptTokensDetails []GeminiPromptTokensDetails `json:"toolUsePromptTokensDetails"`
 }
 
 type GeminiPromptTokensDetails struct {

+ 41 - 22
dto/openai_request.go

@@ -54,18 +54,22 @@ type GeneralOpenAIRequest struct {
 	ParallelTooCalls    *bool             `json:"parallel_tool_calls,omitempty"`
 	Tools               []ToolCallRequest `json:"tools,omitempty"`
 	ToolChoice          any               `json:"tool_choice,omitempty"`
+	FunctionCall        json.RawMessage   `json:"function_call,omitempty"`
 	User                string            `json:"user,omitempty"`
-	LogProbs            bool              `json:"logprobs,omitempty"`
-	TopLogProbs         int               `json:"top_logprobs,omitempty"`
-	Dimensions          int               `json:"dimensions,omitempty"`
-	Modalities          json.RawMessage   `json:"modalities,omitempty"`
-	Audio               json.RawMessage   `json:"audio,omitempty"`
+	// ServiceTier specifies upstream service level and may affect billing.
+	// This field is filtered by default and can be enabled via channel setting allow_service_tier.
+	ServiceTier string          `json:"service_tier,omitempty"`
+	LogProbs    bool            `json:"logprobs,omitempty"`
+	TopLogProbs int             `json:"top_logprobs,omitempty"`
+	Dimensions  int             `json:"dimensions,omitempty"`
+	Modalities  json.RawMessage `json:"modalities,omitempty"`
+	Audio       json.RawMessage `json:"audio,omitempty"`
 	// 安全标识符,用于帮助 OpenAI 检测可能违反使用政策的应用程序用户
-	// 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤以保护用户隐私
+	// 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤,可通过 allow_safety_identifier 开启
 	SafetyIdentifier string `json:"safety_identifier,omitempty"`
 	// Whether or not to store the output of this chat completion request for use in our model distillation or evals products.
 	// 是否存储此次请求数据供 OpenAI 用于评估和优化产品
-	// 注意:默认过滤此字段以保护用户隐私,但过滤后可能导致 Codex 无法正常使用
+	// 注意:默认允许透传,可通过 disable_store 禁用;禁用后可能导致 Codex 无法正常使用
 	Store json.RawMessage `json:"store,omitempty"`
 	// Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the user field
 	PromptCacheKey       string          `json:"prompt_cache_key,omitempty"`
@@ -261,6 +265,9 @@ type FunctionRequest struct {
 
 type StreamOptions struct {
 	IncludeUsage bool `json:"include_usage,omitempty"`
+	// IncludeObfuscation is only for /v1/responses stream payload.
+	// This field is filtered by default and can be enabled via channel setting allow_include_obfuscation.
+	IncludeObfuscation bool `json:"include_obfuscation,omitempty"`
 }
 
 func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
@@ -799,30 +806,42 @@ type WebSearchOptions struct {
 
 // https://platform.openai.com/docs/api-reference/responses/create
 type OpenAIResponsesRequest struct {
-	Model              string          `json:"model"`
-	Input              json.RawMessage `json:"input,omitempty"`
-	Include            json.RawMessage `json:"include,omitempty"`
+	Model   string          `json:"model"`
+	Input   json.RawMessage `json:"input,omitempty"`
+	Include json.RawMessage `json:"include,omitempty"`
+	// 在后台运行推理,暂时还不支持依赖的接口
+	// Background         json.RawMessage `json:"background,omitempty"`
+	Conversation       json.RawMessage `json:"conversation,omitempty"`
+	ContextManagement  json.RawMessage `json:"context_management,omitempty"`
 	Instructions       json.RawMessage `json:"instructions,omitempty"`
 	MaxOutputTokens    uint            `json:"max_output_tokens,omitempty"`
+	TopLogProbs        *int            `json:"top_logprobs,omitempty"`
 	Metadata           json.RawMessage `json:"metadata,omitempty"`
 	ParallelToolCalls  json.RawMessage `json:"parallel_tool_calls,omitempty"`
 	PreviousResponseID string          `json:"previous_response_id,omitempty"`
 	Reasoning          *Reasoning      `json:"reasoning,omitempty"`
-	// 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
-	ServiceTier          string          `json:"service_tier,omitempty"`
+	// ServiceTier specifies upstream service level and may affect billing.
+	// This field is filtered by default and can be enabled via channel setting allow_service_tier.
+	ServiceTier string `json:"service_tier,omitempty"`
+	// Store controls whether upstream may store request/response data.
+	// This field is allowed by default and can be disabled via channel setting disable_store.
 	Store                json.RawMessage `json:"store,omitempty"`
 	PromptCacheKey       json.RawMessage `json:"prompt_cache_key,omitempty"`
 	PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"`
-	Stream               bool            `json:"stream,omitempty"`
-	Temperature          *float64        `json:"temperature,omitempty"`
-	Text                 json.RawMessage `json:"text,omitempty"`
-	ToolChoice           json.RawMessage `json:"tool_choice,omitempty"`
-	Tools                json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
-	TopP                 *float64        `json:"top_p,omitempty"`
-	Truncation           string          `json:"truncation,omitempty"`
-	User                 string          `json:"user,omitempty"`
-	MaxToolCalls         uint            `json:"max_tool_calls,omitempty"`
-	Prompt               json.RawMessage `json:"prompt,omitempty"`
+	// SafetyIdentifier carries client identity for policy abuse detection.
+	// This field is filtered by default and can be enabled via channel setting allow_safety_identifier.
+	SafetyIdentifier string          `json:"safety_identifier,omitempty"`
+	Stream           bool            `json:"stream,omitempty"`
+	StreamOptions    *StreamOptions  `json:"stream_options,omitempty"`
+	Temperature      *float64        `json:"temperature,omitempty"`
+	Text             json.RawMessage `json:"text,omitempty"`
+	ToolChoice       json.RawMessage `json:"tool_choice,omitempty"`
+	Tools            json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
+	TopP             *float64        `json:"top_p,omitempty"`
+	Truncation       string          `json:"truncation,omitempty"`
+	User             string          `json:"user,omitempty"`
+	MaxToolCalls     uint            `json:"max_tool_calls,omitempty"`
+	Prompt           json.RawMessage `json:"prompt,omitempty"`
 	// qwen
 	EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
 	// perplexity

+ 0 - 32
dto/suno.go

@@ -4,10 +4,6 @@ import (
 	"encoding/json"
 )
 
-type TaskData interface {
-	SunoDataResponse | []SunoDataResponse | string | any
-}
-
 type SunoSubmitReq struct {
 	GptDescriptionPrompt string  `json:"gpt_description_prompt,omitempty"`
 	Prompt               string  `json:"prompt,omitempty"`
@@ -20,10 +16,6 @@ type SunoSubmitReq struct {
 	MakeInstrumental     bool    `json:"make_instrumental"`
 }
 
-type FetchReq struct {
-	IDs []string `json:"ids"`
-}
-
 type SunoDataResponse struct {
 	TaskID     string          `json:"task_id" gorm:"type:varchar(50);index"`
 	Action     string          `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
@@ -66,30 +58,6 @@ type SunoLyrics struct {
 	Text   string `json:"text"`
 }
 
-const TaskSuccessCode = "success"
-
-type TaskResponse[T TaskData] struct {
-	Code    string `json:"code"`
-	Message string `json:"message"`
-	Data    T      `json:"data"`
-}
-
-func (t *TaskResponse[T]) IsSuccess() bool {
-	return t.Code == TaskSuccessCode
-}
-
-type TaskDto struct {
-	TaskID     string          `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
-	Action     string          `json:"action"`  // 任务类型, song, lyrics, description-mode
-	Status     string          `json:"status"`  // 任务状态, submitted, queueing, processing, success, failed
-	FailReason string          `json:"fail_reason"`
-	SubmitTime int64           `json:"submit_time"`
-	StartTime  int64           `json:"start_time"`
-	FinishTime int64           `json:"finish_time"`
-	Progress   string          `json:"progress"`
-	Data       json.RawMessage `json:"data"`
-}
-
 type SunoGoAPISubmitReq struct {
 	CustomMode bool `json:"custom_mode"`
 

+ 47 - 0
dto/task.go

@@ -1,5 +1,9 @@
 package dto
 
+import (
+	"encoding/json"
+)
+
 type TaskError struct {
 	Code       string `json:"code"`
 	Message    string `json:"message"`
@@ -8,3 +12,46 @@ type TaskError struct {
 	LocalError bool   `json:"-"`
 	Error      error  `json:"-"`
 }
+
+type TaskData interface {
+	SunoDataResponse | []SunoDataResponse | string | any
+}
+
+const TaskSuccessCode = "success"
+
+type TaskResponse[T TaskData] struct {
+	Code    string `json:"code"`
+	Message string `json:"message"`
+	Data    T      `json:"data"`
+}
+
+func (t *TaskResponse[T]) IsSuccess() bool {
+	return t.Code == TaskSuccessCode
+}
+
+type TaskDto struct {
+	ID         int64           `json:"id"`
+	CreatedAt  int64           `json:"created_at"`
+	UpdatedAt  int64           `json:"updated_at"`
+	TaskID     string          `json:"task_id"`
+	Platform   string          `json:"platform"`
+	UserId     int             `json:"user_id"`
+	Group      string          `json:"group"`
+	ChannelId  int             `json:"channel_id"`
+	Quota      int             `json:"quota"`
+	Action     string          `json:"action"`
+	Status     string          `json:"status"`
+	FailReason string          `json:"fail_reason"`
+	ResultURL  string          `json:"result_url,omitempty"` // 任务结果 URL(视频地址等)
+	SubmitTime int64           `json:"submit_time"`
+	StartTime  int64           `json:"start_time"`
+	FinishTime int64           `json:"finish_time"`
+	Progress   string          `json:"progress"`
+	Properties any             `json:"properties"`
+	Username   string          `json:"username,omitempty"`
+	Data       json.RawMessage `json:"data"`
+}
+
+type FetchReq struct {
+	IDs []string `json:"ids"`
+}

+ 1 - 2
logger/logger.go

@@ -2,7 +2,6 @@ package logger
 
 import (
 	"context"
-	"encoding/json"
 	"fmt"
 	"io"
 	"log"
@@ -151,7 +150,7 @@ func FormatQuota(quota int) string {
 
 // LogJson 仅供测试使用 only for test
 func LogJson(ctx context.Context, msg string, obj any) {
-	jsonStr, err := json.Marshal(obj)
+	jsonStr, err := common.Marshal(obj)
 	if err != nil {
 		LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
 		return

+ 10 - 0
main.go

@@ -19,6 +19,7 @@ import (
 	"github.com/QuantumNous/new-api/middleware"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/oauth"
+	"github.com/QuantumNous/new-api/relay"
 	"github.com/QuantumNous/new-api/router"
 	"github.com/QuantumNous/new-api/service"
 	_ "github.com/QuantumNous/new-api/setting/performance_setting"
@@ -111,6 +112,15 @@ func main() {
 	// Subscription quota reset task (daily/weekly/monthly/custom)
 	service.StartSubscriptionQuotaResetTask()
 
+	// Wire task polling adaptor factory (breaks service -> relay import cycle)
+	service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor {
+		a := relay.GetTaskAdaptor(platform)
+		if a == nil {
+			return nil
+		}
+		return a
+	}
+
 	if common.IsMasterNode && constant.UpdateTask {
 		gopool.Go(func() {
 			controller.UpdateMidjourneyTaskBulk()

+ 18 - 0
middleware/auth.go

@@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) {
 
 }
 
+// TokenOrUserAuth allows either session-based user auth or API token auth.
+// Used for endpoints that need to be accessible from both the dashboard and API clients.
+func TokenOrUserAuth() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		// Try session auth first (dashboard users)
+		session := sessions.Default(c)
+		if id := session.Get("id"); id != nil {
+			if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled {
+				c.Set("id", id)
+				c.Next()
+				return
+			}
+		}
+		// Fall back to token auth (API clients)
+		TokenAuth()(c)
+	}
+}
+
 // TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
 // 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
 // 即使令牌已过期、已耗尽或已禁用,也允许访问。

+ 16 - 2
middleware/logger.go

@@ -7,14 +7,28 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+const RouteTagKey = "route_tag"
+
+func RouteTag(tag string) gin.HandlerFunc {
+	return func(c *gin.Context) {
+		c.Set(RouteTagKey, tag)
+		c.Next()
+	}
+}
+
 func SetUpLogger(server *gin.Engine) {
 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
 		var requestID string
 		if param.Keys != nil {
-			requestID = param.Keys[common.RequestIdKey].(string)
+			requestID, _ = param.Keys[common.RequestIdKey].(string)
+		}
+		tag, _ := param.Keys[RouteTagKey].(string)
+		if tag == "" {
+			tag = "web"
 		}
-		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
+		return fmt.Sprintf("[GIN] %s | %s | %s | %3d | %13v | %15s | %7s %s\n",
 			param.TimeStamp.Format("2006/01/02 - 15:04:05"),
+			tag,
 			requestID,
 			param.StatusCode,
 			param.Latency,

+ 100 - 16
model/custom_oauth_provider.go

@@ -2,32 +2,65 @@ package model
 
 import (
 	"errors"
+	"fmt"
 	"strings"
 	"time"
+
+	"github.com/QuantumNous/new-api/common"
 )
 
+type accessPolicyPayload struct {
+	Logic      string                `json:"logic"`
+	Conditions []accessConditionItem `json:"conditions"`
+	Groups     []accessPolicyPayload `json:"groups"`
+}
+
+type accessConditionItem struct {
+	Field string `json:"field"`
+	Op    string `json:"op"`
+	Value any    `json:"value"`
+}
+
+var supportedAccessPolicyOps = map[string]struct{}{
+	"eq":           {},
+	"ne":           {},
+	"gt":           {},
+	"gte":          {},
+	"lt":           {},
+	"lte":          {},
+	"in":           {},
+	"not_in":       {},
+	"contains":     {},
+	"not_contains": {},
+	"exists":       {},
+	"not_exists":   {},
+}
+
 // CustomOAuthProvider stores configuration for custom OAuth providers
 type CustomOAuthProvider struct {
-	Id                    int       `json:"id" gorm:"primaryKey"`
-	Name                  string    `json:"name" gorm:"type:varchar(64);not null"`                 // Display name, e.g., "GitHub Enterprise"
-	Slug                  string    `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"`     // URL identifier, e.g., "github-enterprise"
-	Enabled               bool      `json:"enabled" gorm:"default:false"`                          // Whether this provider is enabled
-	ClientId              string    `json:"client_id" gorm:"type:varchar(256)"`                    // OAuth client ID
-	ClientSecret          string    `json:"-" gorm:"type:varchar(512)"`                            // OAuth client secret (not returned to frontend)
-	AuthorizationEndpoint string    `json:"authorization_endpoint" gorm:"type:varchar(512)"`       // Authorization URL
-	TokenEndpoint         string    `json:"token_endpoint" gorm:"type:varchar(512)"`               // Token exchange URL
-	UserInfoEndpoint      string    `json:"user_info_endpoint" gorm:"type:varchar(512)"`           // User info URL
-	Scopes                string    `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
+	Id                    int    `json:"id" gorm:"primaryKey"`
+	Name                  string `json:"name" gorm:"type:varchar(64);not null"`                          // Display name, e.g., "GitHub Enterprise"
+	Slug                  string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"`              // URL identifier, e.g., "github-enterprise"
+	Icon                  string `json:"icon" gorm:"type:varchar(128);default:''"`                       // Icon name from @lobehub/icons
+	Enabled               bool   `json:"enabled" gorm:"default:false"`                                   // Whether this provider is enabled
+	ClientId              string `json:"client_id" gorm:"type:varchar(256)"`                             // OAuth client ID
+	ClientSecret          string `json:"-" gorm:"type:varchar(512)"`                                     // OAuth client secret (not returned to frontend)
+	AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"`                // Authorization URL
+	TokenEndpoint         string `json:"token_endpoint" gorm:"type:varchar(512)"`                        // Token exchange URL
+	UserInfoEndpoint      string `json:"user_info_endpoint" gorm:"type:varchar(512)"`                    // User info URL
+	Scopes                string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
 
 	// Field mapping configuration (supports JSONPath via gjson)
-	UserIdField       string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"`                // User ID field path, e.g., "sub", "id", "data.user.id"
-	UsernameField     string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
-	DisplayNameField  string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"`          // Display name field path
-	EmailField        string `json:"email_field" gorm:"type:varchar(128);default:'email'"`                // Email field path
+	UserIdField      string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"`                 // User ID field path, e.g., "sub", "id", "data.user.id"
+	UsernameField    string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
+	DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"`           // Display name field path
+	EmailField       string `json:"email_field" gorm:"type:varchar(128);default:'email'"`                 // Email field path
 
 	// Advanced options
-	WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
-	AuthStyle int    `json:"auth_style" gorm:"default:0"`         // 0=auto, 1=params, 2=header (Basic Auth)
+	WellKnown           string `json:"well_known" gorm:"type:varchar(512)"`            // OIDC discovery endpoint (optional)
+	AuthStyle           int    `json:"auth_style" gorm:"default:0"`                    // 0=auto, 1=params, 2=header (Basic Auth)
+	AccessPolicy        string `json:"access_policy" gorm:"type:text"`                 // JSON policy for access control based on user info
+	AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied
 
 	CreatedAt time.Time `json:"created_at"`
 	UpdatedAt time.Time `json:"updated_at"`
@@ -158,6 +191,57 @@ func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
 	if provider.Scopes == "" {
 		provider.Scopes = "openid profile email"
 	}
+	if strings.TrimSpace(provider.AccessPolicy) != "" {
+		var policy accessPolicyPayload
+		if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil {
+			return errors.New("access_policy must be valid JSON")
+		}
+		if err := validateAccessPolicyPayload(&policy); err != nil {
+			return fmt.Errorf("access_policy is invalid: %w", err)
+		}
+	}
+
+	return nil
+}
+
+func validateAccessPolicyPayload(policy *accessPolicyPayload) error {
+	if policy == nil {
+		return errors.New("policy is nil")
+	}
+
+	logic := strings.ToLower(strings.TrimSpace(policy.Logic))
+	if logic == "" {
+		logic = "and"
+	}
+	if logic != "and" && logic != "or" {
+		return fmt.Errorf("unsupported logic: %s", logic)
+	}
+
+	if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
+		return errors.New("policy requires at least one condition or group")
+	}
+
+	for index, condition := range policy.Conditions {
+		field := strings.TrimSpace(condition.Field)
+		if field == "" {
+			return fmt.Errorf("condition[%d].field is required", index)
+		}
+		op := strings.ToLower(strings.TrimSpace(condition.Op))
+		if _, ok := supportedAccessPolicyOps[op]; !ok {
+			return fmt.Errorf("condition[%d].op is unsupported: %s", index, op)
+		}
+		if op == "in" || op == "not_in" {
+			if _, ok := condition.Value.([]any); !ok {
+				return fmt.Errorf("condition[%d].value must be an array for op %s", index, op)
+			}
+		}
+	}
+
+	for index := range policy.Groups {
+		if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil {
+			return fmt.Errorf("group[%d]: %w", index, err)
+		}
+	}
 
 	return nil
 }

+ 61 - 2
model/log.go

@@ -199,6 +199,49 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
 	}
 }
 
+type RecordTaskBillingLogParams struct {
+	UserId    int
+	LogType   int
+	Content   string
+	ChannelId int
+	ModelName string
+	Quota     int
+	TokenId   int
+	Group     string
+	Other     map[string]interface{}
+}
+
+func RecordTaskBillingLog(params RecordTaskBillingLogParams) {
+	if params.LogType == LogTypeConsume && !common.LogConsumeEnabled {
+		return
+	}
+	username, _ := GetUsernameById(params.UserId, false)
+	tokenName := ""
+	if params.TokenId > 0 {
+		if token, err := GetTokenById(params.TokenId); err == nil {
+			tokenName = token.Name
+		}
+	}
+	log := &Log{
+		UserId:    params.UserId,
+		Username:  username,
+		CreatedAt: common.GetTimestamp(),
+		Type:      params.LogType,
+		Content:   params.Content,
+		TokenName: tokenName,
+		ModelName: params.ModelName,
+		Quota:     params.Quota,
+		ChannelId: params.ChannelId,
+		TokenId:   params.TokenId,
+		Group:     params.Group,
+		Other:     common.MapToJsonStr(params.Other),
+	}
+	err := LOG_DB.Create(log).Error
+	if err != nil {
+		common.SysLog("failed to record task billing log: " + err.Error())
+	}
+}
+
 func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) {
 	var tx *gorm.DB
 	if logType == LogTypeUnknown {
@@ -252,8 +295,24 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 			Id   int    `gorm:"column:id"`
 			Name string `gorm:"column:name"`
 		}
-		if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
-			return logs, total, err
+		if common.MemoryCacheEnabled {
+			// Cache get channel
+			for _, channelId := range channelIds.Items() {
+				if cacheChannel, err := CacheGetChannel(channelId); err == nil {
+					channels = append(channels, struct {
+						Id   int    `gorm:"column:id"`
+						Name string `gorm:"column:name"`
+					}{
+						Id:   channelId,
+						Name: cacheChannel.Name,
+					})
+				}
+			}
+		} else {
+			// Bulk query channels from DB
+			if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
+				return logs, total, err
+			}
 		}
 		channelMap := make(map[int]string, len(channels))
 		for _, channel := range channels {

+ 13 - 0
model/midjourney.go

@@ -157,6 +157,19 @@ func (midjourney *Midjourney) Update() error {
 	return err
 }
 
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Returns (true, nil) if this caller won the update, (false, nil) if
+// another process already moved the task out of fromStatus.
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback.
+func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) {
+	result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney)
+	if result.Error != nil {
+		return false, result.Error
+	}
+	return result.RowsAffected > 0, nil
+}
+
 func MjBulkUpdate(mjIds []string, params map[string]any) error {
 	return DB.Model(&Midjourney{}).
 		Where("mj_id in (?)", mjIds).

+ 124 - 60
model/task.go

@@ -1,10 +1,12 @@
 package model
 
 import (
+	"bytes"
 	"database/sql/driver"
 	"encoding/json"
 	"time"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	commonRelay "github.com/QuantumNous/new-api/relay/common"
@@ -64,13 +66,12 @@ type Task struct {
 }
 
 func (t *Task) SetData(data any) {
-	b, _ := json.Marshal(data)
+	b, _ := common.Marshal(data)
 	t.Data = json.RawMessage(b)
 }
 
 func (t *Task) GetData(v any) error {
-	err := json.Unmarshal(t.Data, &v)
-	return err
+	return common.Unmarshal(t.Data, &v)
 }
 
 type Properties struct {
@@ -85,18 +86,59 @@ func (m *Properties) Scan(val interface{}) error {
 		*m = Properties{}
 		return nil
 	}
-	return json.Unmarshal(bytesValue, m)
+	return common.Unmarshal(bytesValue, m)
 }
 
 func (m Properties) Value() (driver.Value, error) {
 	if m == (Properties{}) {
 		return nil, nil
 	}
-	return json.Marshal(m)
+	return common.Marshal(m)
 }
 
 type TaskPrivateData struct {
-	Key string `json:"key,omitempty"`
+	Key            string `json:"key,omitempty"`
+	UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID
+	ResultURL      string `json:"result_url,omitempty"`       // 任务成功后的结果 URL(视频地址等)
+	// 计费上下文:用于异步退款/差额结算(轮询阶段读取)
+	BillingSource  string              `json:"billing_source,omitempty"`  // "wallet" 或 "subscription"
+	SubscriptionId int                 `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
+	TokenId        int                 `json:"token_id,omitempty"`        // 令牌 ID,用于令牌额度退款
+	BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算)
+}
+
+// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
+type TaskBillingContext struct {
+	ModelPrice      float64            `json:"model_price,omitempty"`       // 模型单价
+	GroupRatio      float64            `json:"group_ratio,omitempty"`       // 分组倍率
+	ModelRatio      float64            `json:"model_ratio,omitempty"`       // 模型倍率
+	OtherRatios     map[string]float64 `json:"other_ratios,omitempty"`      // 附加倍率(时长、分辨率等)
+	OriginModelName string             `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName
+	PerCallBilling  bool               `json:"per_call_billing,omitempty"`  // 按次计费:跳过轮询阶段的差额结算
+}
+
+// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
+// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID
+func (t *Task) GetUpstreamTaskID() string {
+	if t.PrivateData.UpstreamTaskID != "" {
+		return t.PrivateData.UpstreamTaskID
+	}
+	return t.TaskID
+}
+
+// GetResultURL 获取任务结果 URL(视频地址等)
+// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容)
+func (t *Task) GetResultURL() string {
+	if t.PrivateData.ResultURL != "" {
+		return t.PrivateData.ResultURL
+	}
+	return t.FailReason
+}
+
+// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID
+func GenerateTaskID() string {
+	key, _ := common.GenerateRandomCharsKey(32)
+	return "task_" + key
 }
 
 func (p *TaskPrivateData) Scan(val interface{}) error {
@@ -104,14 +146,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error {
 	if len(bytesValue) == 0 {
 		return nil
 	}
-	return json.Unmarshal(bytesValue, p)
+	return common.Unmarshal(bytesValue, p)
 }
 
 func (p TaskPrivateData) Value() (driver.Value, error) {
 	if (p == TaskPrivateData{}) {
 		return nil, nil
 	}
-	return json.Marshal(p)
+	return common.Marshal(p)
 }
 
 // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
@@ -142,7 +184,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo)
 		}
 	}
 
+	// 使用预生成的公开 ID(如果有),否则新生成
+	taskID := ""
+	if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" {
+		taskID = relayInfo.TaskRelayInfo.PublicTaskID
+	} else {
+		taskID = GenerateTaskID()
+	}
+
 	t := &Task{
+		TaskID:      taskID,
 		UserId:      relayInfo.UserId,
 		Group:       relayInfo.UsingGroup,
 		SubmitTime:  time.Now().Unix(),
@@ -237,6 +288,20 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*
 	return tasks
 }
 
+func GetTimedOutUnfinishedTasks(cutoffUnix int64, limit int) []*Task {
+	var tasks []*Task
+	err := DB.Where("progress != ?", "100%").
+		Where("status NOT IN ?", []string{TaskStatusFailure, TaskStatusSuccess}).
+		Where("submit_time < ?", cutoffUnix).
+		Order("submit_time").
+		Limit(limit).
+		Find(&tasks).Error
+	if err != nil {
+		return nil
+	}
+	return tasks
+}
+
 func GetAllUnFinishSyncTasks(limit int) []*Task {
 	var tasks []*Task
 	var err error
@@ -291,40 +356,70 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
 	return task, nil
 }
 
-func TaskUpdateProgress(id int64, progress string) error {
-	return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
-}
-
 func (Task *Task) Insert() error {
 	var err error
 	err = DB.Create(Task).Error
 	return err
 }
 
+type taskSnapshot struct {
+	Status     TaskStatus
+	Progress   string
+	StartTime  int64
+	FinishTime int64
+	FailReason string
+	ResultURL  string
+	Data       json.RawMessage
+}
+
+func (s taskSnapshot) Equal(other taskSnapshot) bool {
+	return s.Status == other.Status &&
+		s.Progress == other.Progress &&
+		s.StartTime == other.StartTime &&
+		s.FinishTime == other.FinishTime &&
+		s.FailReason == other.FailReason &&
+		s.ResultURL == other.ResultURL &&
+		bytes.Equal(s.Data, other.Data)
+}
+
+func (t *Task) Snapshot() taskSnapshot {
+	return taskSnapshot{
+		Status:     t.Status,
+		Progress:   t.Progress,
+		StartTime:  t.StartTime,
+		FinishTime: t.FinishTime,
+		FailReason: t.FailReason,
+		ResultURL:  t.PrivateData.ResultURL,
+		Data:       t.Data,
+	}
+}
+
 func (Task *Task) Update() error {
 	var err error
 	err = DB.Save(Task).Error
 	return err
 }
 
-func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
-	if len(TaskIds) == 0 {
-		return nil
-	}
-	return DB.Model(&Task{}).
-		Where("task_id in (?)", TaskIds).
-		Updates(params).Error
-}
-
-func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
-	if len(taskIDs) == 0 {
-		return nil
-	}
-	return DB.Model(&Task{}).
-		Where("id in (?)", taskIDs).
-		Updates(params).Error
+// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
+// Returns (true, nil) if this caller won the update, (false, nil) if
+// another process already moved the task out of fromStatus.
+//
+// Uses Model().Select("*").Updates() instead of Save() because GORM's Save
+// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches
+// zero rows, which silently bypasses the CAS guard.
+func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
+	result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t)
+	if result.Error != nil {
+		return false, result.Error
+	}
+	return result.RowsAffected > 0, nil
 }
 
+// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs.
+// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite
+// any concurrent status changes. DO NOT use in billing/quota lifecycle flows
+// (e.g., timeout, success, failure transitions that trigger refunds or settlements).
+// For status transitions that involve billing, use Task.UpdateWithStatus() instead.
 func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
 	if len(ids) == 0 {
 		return nil
@@ -339,37 +434,6 @@ type TaskQuotaUsage struct {
 	Count float64 `json:"count"`
 }
 
-func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
-	query := DB.Model(Task{})
-	// 添加过滤条件
-	if queryParams.ChannelID != "" {
-		query = query.Where("channel_id = ?", queryParams.ChannelID)
-	}
-	if queryParams.UserID != "" {
-		query = query.Where("user_id = ?", queryParams.UserID)
-	}
-	if len(queryParams.UserIDs) != 0 {
-		query = query.Where("user_id in (?)", queryParams.UserIDs)
-	}
-	if queryParams.TaskID != "" {
-		query = query.Where("task_id = ?", queryParams.TaskID)
-	}
-	if queryParams.Action != "" {
-		query = query.Where("action = ?", queryParams.Action)
-	}
-	if queryParams.Status != "" {
-		query = query.Where("status = ?", queryParams.Status)
-	}
-	if queryParams.StartTimestamp != 0 {
-		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
-	}
-	if queryParams.EndTimestamp != 0 {
-		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
-	}
-	err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
-	return stat, err
-}
-
 // TaskCountAllTasks returns total tasks that match the given query params (admin usage)
 func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
 	var total int64
@@ -438,6 +502,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo {
 	openAIVideo.SetProgressStr(t.Progress)
 	openAIVideo.CreatedAt = t.CreatedAt
 	openAIVideo.CompletedAt = t.UpdatedAt
-	openAIVideo.SetMetadata("url", t.FailReason)
+	openAIVideo.SetMetadata("url", t.GetResultURL())
 	return openAIVideo
 }

+ 217 - 0
model/task_cas_test.go

@@ -0,0 +1,217 @@
+package model
+
+import (
+	"encoding/json"
+	"os"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/glebarez/sqlite"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"gorm.io/gorm"
+)
+
+func TestMain(m *testing.M) {
+	db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+	if err != nil {
+		panic("failed to open test db: " + err.Error())
+	}
+	DB = db
+	LOG_DB = db
+
+	common.UsingSQLite = true
+	common.RedisEnabled = false
+	common.BatchUpdateEnabled = false
+	common.LogConsumeEnabled = true
+
+	sqlDB, err := db.DB()
+	if err != nil {
+		panic("failed to get sql.DB: " + err.Error())
+	}
+	sqlDB.SetMaxOpenConns(1)
+
+	if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
+		panic("failed to migrate: " + err.Error())
+	}
+
+	os.Exit(m.Run())
+}
+
+func truncateTables(t *testing.T) {
+	t.Helper()
+	t.Cleanup(func() {
+		DB.Exec("DELETE FROM tasks")
+		DB.Exec("DELETE FROM users")
+		DB.Exec("DELETE FROM tokens")
+		DB.Exec("DELETE FROM logs")
+		DB.Exec("DELETE FROM channels")
+	})
+}
+
+func insertTask(t *testing.T, task *Task) {
+	t.Helper()
+	task.CreatedAt = time.Now().Unix()
+	task.UpdatedAt = time.Now().Unix()
+	require.NoError(t, DB.Create(task).Error)
+}
+
+// ---------------------------------------------------------------------------
+// Snapshot / Equal — pure logic tests (no DB)
+// ---------------------------------------------------------------------------
+
+func TestSnapshotEqual_Same(t *testing.T) {
+	s := taskSnapshot{
+		Status:     TaskStatusInProgress,
+		Progress:   "50%",
+		StartTime:  1000,
+		FinishTime: 0,
+		FailReason: "",
+		ResultURL:  "",
+		Data:       json.RawMessage(`{"key":"value"}`),
+	}
+	assert.True(t, s.Equal(s))
+}
+
+func TestSnapshotEqual_DifferentStatus(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
+	b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_DifferentProgress(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
+	b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_DifferentData(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
+	b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
+	assert.False(t, a.Equal(b))
+}
+
+func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
+	a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
+	b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
+	// bytes.Equal(nil, []byte{}) == true
+	assert.True(t, a.Equal(b))
+}
+
+func TestSnapshot_Roundtrip(t *testing.T) {
+	task := &Task{
+		Status:     TaskStatusInProgress,
+		Progress:   "42%",
+		StartTime:  1234,
+		FinishTime: 5678,
+		FailReason: "timeout",
+		PrivateData: TaskPrivateData{
+			ResultURL: "https://example.com/result.mp4",
+		},
+		Data: json.RawMessage(`{"model":"test-model"}`),
+	}
+	snap := task.Snapshot()
+	assert.Equal(t, task.Status, snap.Status)
+	assert.Equal(t, task.Progress, snap.Progress)
+	assert.Equal(t, task.StartTime, snap.StartTime)
+	assert.Equal(t, task.FinishTime, snap.FinishTime)
+	assert.Equal(t, task.FailReason, snap.FailReason)
+	assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
+	assert.JSONEq(t, string(task.Data), string(snap.Data))
+}
+
+// ---------------------------------------------------------------------------
+// UpdateWithStatus CAS — DB integration tests
+// ---------------------------------------------------------------------------
+
+func TestUpdateWithStatus_Win(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID:   "task_cas_win",
+		Status:   TaskStatusInProgress,
+		Progress: "50%",
+		Data:     json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	task.Status = TaskStatusSuccess
+	task.Progress = "100%"
+	won, err := task.UpdateWithStatus(TaskStatusInProgress)
+	require.NoError(t, err)
+	assert.True(t, won)
+
+	var reloaded Task
+	require.NoError(t, DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
+	assert.Equal(t, "100%", reloaded.Progress)
+}
+
+func TestUpdateWithStatus_Lose(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID: "task_cas_lose",
+		Status: TaskStatusFailure,
+		Data:   json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	task.Status = TaskStatusSuccess
+	won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
+	require.NoError(t, err)
+	assert.False(t, won)
+
+	var reloaded Task
+	require.NoError(t, DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
+}
+
+func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
+	truncateTables(t)
+
+	task := &Task{
+		TaskID: "task_cas_race",
+		Status: TaskStatusInProgress,
+		Quota:  1000,
+		Data:   json.RawMessage(`{}`),
+	}
+	insertTask(t, task)
+
+	const goroutines = 5
+	wins := make([]bool, goroutines)
+	var wg sync.WaitGroup
+	wg.Add(goroutines)
+
+	for i := 0; i < goroutines; i++ {
+		go func(idx int) {
+			defer wg.Done()
+			t := &Task{}
+			*t = Task{
+				ID:       task.ID,
+				TaskID:   task.TaskID,
+				Status:   TaskStatusSuccess,
+				Progress: "100%",
+				Quota:    task.Quota,
+				Data:     json.RawMessage(`{}`),
+			}
+			t.CreatedAt = task.CreatedAt
+			t.UpdatedAt = time.Now().Unix()
+			won, err := t.UpdateWithStatus(TaskStatusInProgress)
+			if err == nil {
+				wins[idx] = won
+			}
+		}(i)
+	}
+	wg.Wait()
+
+	winCount := 0
+	for _, w := range wins {
+		if w {
+			winCount++
+		}
+	}
+	assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
+}

+ 3 - 3
model/token.go

@@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) {
 	return token.Delete()
 }
 
-func IncreaseTokenQuota(id int, key string, quota int) (err error) {
+func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
@@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
 		})
 	}
 	if common.BatchUpdateEnabled {
-		addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
+		addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota)
 		return nil
 	}
-	return increaseTokenQuota(id, quota)
+	return increaseTokenQuota(tokenId, quota)
 }
 
 func increaseTokenQuota(id int, quota int) (err error) {

+ 42 - 1
model/user.go

@@ -1,6 +1,7 @@
 package model
 
 import (
+	"database/sql"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -15,6 +16,8 @@ import (
 	"gorm.io/gorm"
 )
 
+const UserNameMaxLength = 20
+
 // User if you add sensitive fields, don't forget to clean them in setupLogin function.
 // Otherwise, the sensitive information will be saved on local storage in plain text!
 type User struct {
@@ -536,6 +539,37 @@ func (user *User) Edit(updatePassword bool) error {
 	return updateUserCache(*user)
 }
 
+func (user *User) ClearBinding(bindingType string) error {
+	if user.Id == 0 {
+		return errors.New("user id is empty")
+	}
+
+	bindingColumnMap := map[string]string{
+		"email":    "email",
+		"github":   "github_id",
+		"discord":  "discord_id",
+		"oidc":     "oidc_id",
+		"wechat":   "wechat_id",
+		"telegram": "telegram_id",
+		"linuxdo":  "linux_do_id",
+	}
+
+	column, ok := bindingColumnMap[bindingType]
+	if !ok {
+		return errors.New("invalid binding type")
+	}
+
+	if err := DB.Model(&User{}).Where("id = ?", user.Id).Update(column, "").Error; err != nil {
+		return err
+	}
+
+	if err := DB.Where("id = ?", user.Id).First(user).Error; err != nil {
+		return err
+	}
+
+	return updateUserCache(*user)
+}
+
 func (user *User) Delete() error {
 	if user.Id == 0 {
 		return errors.New("id 为空!")
@@ -820,10 +854,17 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
 		// Don't return error - fall through to DB
 	}
 	fromDB = true
-	err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
+	// can be nil setting
+	var safeSetting sql.NullString
+	err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&safeSetting).Error
 	if err != nil {
 		return settingMap, err
 	}
+	if safeSetting.Valid {
+		setting = safeSetting.String
+	} else {
+		setting = ""
+	}
 	userBase := &UserBase{
 		Setting: setting,
 	}

+ 402 - 2
oauth/generic.go

@@ -3,19 +3,24 @@ package oauth
 import (
 	"context"
 	"encoding/base64"
-	"encoding/json"
+	stdjson "encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
 	"net/url"
+	"regexp"
+	"strconv"
 	"strings"
 	"time"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/i18n"
 	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/setting/system_setting"
 	"github.com/gin-gonic/gin"
+	"github.com/samber/lo"
 	"github.com/tidwall/gjson"
 )
 
@@ -31,6 +36,40 @@ type GenericOAuthProvider struct {
 	config *model.CustomOAuthProvider
 }
 
+type accessPolicy struct {
+	Logic      string            `json:"logic"`
+	Conditions []accessCondition `json:"conditions"`
+	Groups     []accessPolicy    `json:"groups"`
+}
+
+type accessCondition struct {
+	Field string `json:"field"`
+	Op    string `json:"op"`
+	Value any    `json:"value"`
+}
+
+type accessPolicyFailure struct {
+	Field    string
+	Op       string
+	Expected any
+	Current  any
+}
+
+var supportedAccessPolicyOps = []string{
+	"eq",
+	"ne",
+	"gt",
+	"gte",
+	"lt",
+	"lte",
+	"in",
+	"not_in",
+	"contains",
+	"not_contains",
+	"exists",
+	"not_exists",
+}
+
 // NewGenericOAuthProvider creates a new generic OAuth provider from config
 func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
 	return &GenericOAuthProvider{config: config}
@@ -125,7 +164,7 @@ func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c
 		ErrorDesc    string `json:"error_description"`
 	}
 
-	if err := json.Unmarshal(body, &tokenResponse); err != nil {
+	if err := common.Unmarshal(body, &tokenResponse); err != nil {
 		// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
 		parsedValues, parseErr := url.ParseQuery(bodyStr)
 		if parseErr != nil {
@@ -227,11 +266,30 @@ func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToke
 	logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
 		p.config.Slug, userId, username, displayName, email)
 
+	policyRaw := strings.TrimSpace(p.config.AccessPolicy)
+	if policyRaw != "" {
+		policy, err := parseAccessPolicy(policyRaw)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error()))
+			return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration")
+		}
+		allowed, failure := evaluateAccessPolicy(bodyStr, policy)
+		if !allowed {
+			message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure)
+			logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v",
+				p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current))
+			return nil, &AccessDeniedError{Message: message}
+		}
+	}
+
 	return &OAuthUser{
 		ProviderUserID: userId,
 		Username:       username,
 		DisplayName:    displayName,
 		Email:          email,
+		Extra: map[string]any{
+			"provider": p.config.Slug,
+		},
 	}, nil
 }
 
@@ -266,3 +324,345 @@ func (p *GenericOAuthProvider) GetProviderId() int {
 func (p *GenericOAuthProvider) IsGenericProvider() bool {
 	return true
 }
+
+func parseAccessPolicy(raw string) (*accessPolicy, error) {
+	var policy accessPolicy
+	if err := common.UnmarshalJsonStr(raw, &policy); err != nil {
+		return nil, err
+	}
+	if err := validateAccessPolicy(&policy); err != nil {
+		return nil, err
+	}
+	return &policy, nil
+}
+
+func validateAccessPolicy(policy *accessPolicy) error {
+	if policy == nil {
+		return errors.New("policy is nil")
+	}
+
+	logic := strings.ToLower(strings.TrimSpace(policy.Logic))
+	if logic == "" {
+		logic = "and"
+	}
+	if !lo.Contains([]string{"and", "or"}, logic) {
+		return fmt.Errorf("unsupported policy logic: %s", logic)
+	}
+	policy.Logic = logic
+
+	if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
+		return errors.New("policy requires at least one condition or group")
+	}
+
+	for index := range policy.Conditions {
+		if err := validateAccessCondition(&policy.Conditions[index], index); err != nil {
+			return err
+		}
+	}
+
+	for index := range policy.Groups {
+		if err := validateAccessPolicy(&policy.Groups[index]); err != nil {
+			return fmt.Errorf("invalid policy group[%d]: %w", index, err)
+		}
+	}
+
+	return nil
+}
+
+func validateAccessCondition(condition *accessCondition, index int) error {
+	if condition == nil {
+		return fmt.Errorf("condition[%d] is nil", index)
+	}
+
+	condition.Field = strings.TrimSpace(condition.Field)
+	if condition.Field == "" {
+		return fmt.Errorf("condition[%d].field is required", index)
+	}
+
+	condition.Op = normalizePolicyOp(condition.Op)
+	if !lo.Contains(supportedAccessPolicyOps, condition.Op) {
+		return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op)
+	}
+
+	if lo.Contains([]string{"in", "not_in"}, condition.Op) {
+		if _, ok := condition.Value.([]any); !ok {
+			return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op)
+		}
+	}
+
+	return nil
+}
+
+func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) {
+	if policy == nil {
+		return true, nil
+	}
+
+	logic := strings.ToLower(strings.TrimSpace(policy.Logic))
+	if logic == "" {
+		logic = "and"
+	}
+
+	hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0
+	if !hasAny {
+		return true, nil
+	}
+
+	if logic == "or" {
+		var firstFailure *accessPolicyFailure
+		for _, cond := range policy.Conditions {
+			ok, failure := evaluateAccessCondition(body, cond)
+			if ok {
+				return true, nil
+			}
+			if firstFailure == nil {
+				firstFailure = failure
+			}
+		}
+		for _, group := range policy.Groups {
+			ok, failure := evaluateAccessPolicy(body, &group)
+			if ok {
+				return true, nil
+			}
+			if firstFailure == nil {
+				firstFailure = failure
+			}
+		}
+		return false, firstFailure
+	}
+
+	for _, cond := range policy.Conditions {
+		ok, failure := evaluateAccessCondition(body, cond)
+		if !ok {
+			return false, failure
+		}
+	}
+	for _, group := range policy.Groups {
+		ok, failure := evaluateAccessPolicy(body, &group)
+		if !ok {
+			return false, failure
+		}
+	}
+	return true, nil
+}
+
+func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) {
+	path := cond.Field
+	op := cond.Op
+	result := gjson.Get(body, path)
+	current := gjsonResultToValue(result)
+	failure := &accessPolicyFailure{
+		Field:    path,
+		Op:       op,
+		Expected: cond.Value,
+		Current:  current,
+	}
+
+	switch op {
+	case "exists":
+		return result.Exists(), failure
+	case "not_exists":
+		return !result.Exists(), failure
+	case "eq":
+		return compareAny(current, cond.Value) == 0, failure
+	case "ne":
+		return compareAny(current, cond.Value) != 0, failure
+	case "gt":
+		return compareAny(current, cond.Value) > 0, failure
+	case "gte":
+		return compareAny(current, cond.Value) >= 0, failure
+	case "lt":
+		return compareAny(current, cond.Value) < 0, failure
+	case "lte":
+		return compareAny(current, cond.Value) <= 0, failure
+	case "in":
+		return valueInSlice(current, cond.Value), failure
+	case "not_in":
+		return !valueInSlice(current, cond.Value), failure
+	case "contains":
+		return containsValue(current, cond.Value), failure
+	case "not_contains":
+		return !containsValue(current, cond.Value), failure
+	default:
+		return false, failure
+	}
+}
+
+func normalizePolicyOp(op string) string {
+	return strings.ToLower(strings.TrimSpace(op))
+}
+
+func gjsonResultToValue(result gjson.Result) any {
+	if !result.Exists() {
+		return nil
+	}
+	if result.IsArray() {
+		arr := result.Array()
+		values := make([]any, 0, len(arr))
+		for _, item := range arr {
+			values = append(values, gjsonResultToValue(item))
+		}
+		return values
+	}
+	switch result.Type {
+	case gjson.Null:
+		return nil
+	case gjson.True:
+		return true
+	case gjson.False:
+		return false
+	case gjson.Number:
+		return result.Num
+	case gjson.String:
+		return result.String()
+	case gjson.JSON:
+		var data any
+		if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil {
+			return data
+		}
+		return result.Raw
+	default:
+		return result.Value()
+	}
+}
+
+func compareAny(left any, right any) int {
+	if lf, ok := toFloat(left); ok {
+		if rf, ok2 := toFloat(right); ok2 {
+			switch {
+			case lf < rf:
+				return -1
+			case lf > rf:
+				return 1
+			default:
+				return 0
+			}
+		}
+	}
+
+	ls := strings.TrimSpace(fmt.Sprint(left))
+	rs := strings.TrimSpace(fmt.Sprint(right))
+	switch {
+	case ls < rs:
+		return -1
+	case ls > rs:
+		return 1
+	default:
+		return 0
+	}
+}
+
+func toFloat(v any) (float64, bool) {
+	switch value := v.(type) {
+	case float64:
+		return value, true
+	case float32:
+		return float64(value), true
+	case int:
+		return float64(value), true
+	case int8:
+		return float64(value), true
+	case int16:
+		return float64(value), true
+	case int32:
+		return float64(value), true
+	case int64:
+		return float64(value), true
+	case uint:
+		return float64(value), true
+	case uint8:
+		return float64(value), true
+	case uint16:
+		return float64(value), true
+	case uint32:
+		return float64(value), true
+	case uint64:
+		return float64(value), true
+	case stdjson.Number:
+		n, err := value.Float64()
+		if err == nil {
+			return n, true
+		}
+	case string:
+		n, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
+		if err == nil {
+			return n, true
+		}
+	}
+	return 0, false
+}
+
+func valueInSlice(current any, expected any) bool {
+	list, ok := expected.([]any)
+	if !ok {
+		return false
+	}
+	return lo.ContainsBy(list, func(item any) bool {
+		return compareAny(current, item) == 0
+	})
+}
+
+func containsValue(current any, expected any) bool {
+	switch value := current.(type) {
+	case string:
+		target := strings.TrimSpace(fmt.Sprint(expected))
+		return strings.Contains(value, target)
+	case []any:
+		return lo.ContainsBy(value, func(item any) bool {
+			return compareAny(item, expected) == 0
+		})
+	}
+	return false
+}
+
+func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string {
+	defaultMessage := "Access denied: your account does not meet this provider's access requirements."
+	message := strings.TrimSpace(template)
+	if message == "" {
+		return defaultMessage
+	}
+
+	if failure == nil {
+		failure = &accessPolicyFailure{}
+	}
+
+	replacements := map[string]string{
+		"{{provider}}": providerName,
+		"{{field}}":    failure.Field,
+		"{{op}}":       failure.Op,
+		"{{required}}": fmt.Sprint(failure.Expected),
+		"{{current}}":  fmt.Sprint(failure.Current),
+	}
+
+	for key, value := range replacements {
+		message = strings.ReplaceAll(message, key, value)
+	}
+
+	currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`)
+	message = currentPattern.ReplaceAllStringFunc(message, func(token string) string {
+		match := currentPattern.FindStringSubmatch(token)
+		if len(match) != 2 {
+			return ""
+		}
+		path := strings.TrimSpace(match[1])
+		if path == "" {
+			return ""
+		}
+		return strings.TrimSpace(gjson.Get(body, path).String())
+	})
+
+	requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`)
+	message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string {
+		match := requiredPattern.FindStringSubmatch(token)
+		if len(match) != 2 {
+			return ""
+		}
+		path := strings.TrimSpace(match[1])
+		if failure.Field == path {
+			return fmt.Sprint(failure.Expected)
+		}
+		return ""
+	})
+
+	return strings.TrimSpace(message)
+}

+ 9 - 0
oauth/types.go

@@ -57,3 +57,12 @@ func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string)
 		RawError: rawError,
 	}
 }
+
+// AccessDeniedError is a direct user-facing access denial message.
+type AccessDeniedError struct {
+	Message string
+}
+
+func (e *AccessDeniedError) Error() string {
+	return e.Message
+}

+ 28 - 2
relay/channel/adapter.go

@@ -36,6 +36,32 @@ type TaskAdaptor interface {
 
 	ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
 
+	// ── Billing ──────────────────────────────────────────────────────
+
+	// EstimateBilling returns OtherRatios for pre-charge based on user request.
+	// Called after ValidateRequestAndSetAction, before price calculation.
+	// Adaptors should extract duration, resolution, etc. from the parsed request
+	// and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}).
+	// Return nil to use the base model price without extra ratios.
+	EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64
+
+	// AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream
+	// submit response. Called after a successful DoResponse.
+	// If the upstream returned actual parameters that differ from the estimate
+	// (e.g. actual seconds), return updated ratios so the caller can recalculate
+	// the quota and settle the delta with the pre-charge.
+	// Return nil if no adjustment is needed.
+	AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64
+
+	// AdjustBillingOnComplete returns the actual quota when a task reaches a
+	// terminal state (success/failure) during polling.
+	// Called by the polling loop after ParseTaskResult.
+	// Return a positive value to trigger delta settlement (supplement / refund).
+	// Return 0 to keep the pre-charged amount unchanged.
+	AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
+
+	// ── Request / Response ───────────────────────────────────────────
+
 	BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
 	BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 	BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
@@ -46,9 +72,9 @@ type TaskAdaptor interface {
 	GetModelList() []string
 	GetChannelName() string
 
-	// FetchTask
-	FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
+	// ── Polling ──────────────────────────────────────────────────────
 
+	FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
 	ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
 }
 

+ 3 - 2
relay/channel/api_request.go

@@ -61,8 +61,9 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{
 	"cookie": {},
 
 	// Additional headers that should not be forwarded by name-matching passthrough rules.
-	"host":           {},
-	"content-length": {},
+	"host":            {},
+	"content-length":  {},
+	"accept-encoding": {},
 
 	// Do not passthrough credentials by wildcard/regex.
 	"authorization":  {},

+ 27 - 0
relay/channel/api_request_test.go

@@ -110,3 +110,30 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
 	_, ok := headers["X-Legacy"]
 	require.False(t, ok)
 }
+
+func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	recorder := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(recorder)
+	ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+	ctx.Request.Header.Set("X-Trace-Id", "trace-123")
+	ctx.Request.Header.Set("Accept-Encoding", "gzip")
+
+	info := &relaycommon.RelayInfo{
+		IsChannelTest: false,
+		ChannelMeta: &relaycommon.ChannelMeta{
+			HeadersOverride: map[string]any{
+				"*": "",
+			},
+		},
+	}
+
+	headers, err := processHeaderOverride(info, ctx)
+	require.NoError(t, err)
+	require.Equal(t, "trace-123", headers["X-Trace-Id"])
+
+	_, hasAcceptEncoding := headers["Accept-Encoding"]
+	require.False(t, hasAcceptEncoding)
+}

+ 1 - 16
relay/channel/gemini/relay-gemini-native.go

@@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
 	}
 
 	// 计算使用量(基于 UsageMetadata)
-	usage := dto.Usage{
-		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
-		CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
-		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
-	}
-
-	usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-	usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-
-	for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-		if detail.Modality == "AUDIO" {
-			usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-		} else if detail.Modality == "TEXT" {
-			usage.PromptTokensDetails.TextTokens = detail.TokenCount
-		}
-	}
+	usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
 
 	service.IOCopyBytesGracefully(c, resp, responseBody)
 

+ 44 - 49
relay/channel/gemini/relay-gemini.go

@@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
 	}
 }
 
+func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
+	promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
+	if promptTokens <= 0 && fallbackPromptTokens > 0 {
+		promptTokens = fallbackPromptTokens
+	}
+
+	usage := dto.Usage{
+		PromptTokens:     promptTokens,
+		CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
+		TotalTokens:      metadata.TotalTokenCount,
+	}
+	usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
+	usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
+
+	for _, detail := range metadata.PromptTokensDetails {
+		if detail.Modality == "AUDIO" {
+			usage.PromptTokensDetails.AudioTokens += detail.TokenCount
+		} else if detail.Modality == "TEXT" {
+			usage.PromptTokensDetails.TextTokens += detail.TokenCount
+		}
+	}
+	for _, detail := range metadata.ToolUsePromptTokensDetails {
+		if detail.Modality == "AUDIO" {
+			usage.PromptTokensDetails.AudioTokens += detail.TokenCount
+		} else if detail.Modality == "TEXT" {
+			usage.PromptTokensDetails.TextTokens += detail.TokenCount
+		}
+	}
+
+	if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
+		usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+	}
+
+	if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
+		usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+	}
+
+	return usage
+}
+
 func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
 	fullTextResponse := dto.OpenAITextResponse{
 		Id:      helper.GetResponseID(c),
@@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 
 		// 更新使用量统计
 		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
-			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
-			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
-			usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
-			usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-			for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-				if detail.Modality == "AUDIO" {
-					usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-				} else if detail.Modality == "TEXT" {
-					usage.PromptTokensDetails.TextTokens = detail.TokenCount
-				}
-			}
+			mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
+			*usage = mappedUsage
 		}
 
 		return callback(data, &geminiResponse)
@@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 		}
 	}
 
-	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
-	if usage.TotalTokens > 0 {
-		usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-	}
-
 	if usage.CompletionTokens <= 0 {
 		if info.ReceivedResponseCount > 0 {
 			usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
@@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 	if len(geminiResponse.Candidates) == 0 {
-		usage := dto.Usage{
-			PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
-		}
-		usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-		usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-		for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-			if detail.Modality == "AUDIO" {
-				usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-			} else if detail.Modality == "TEXT" {
-				usage.PromptTokensDetails.TextTokens = detail.TokenCount
-			}
-		}
-		if usage.PromptTokens <= 0 {
-			usage.PromptTokens = info.GetEstimatePromptTokens()
-		}
+		usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
 
 		var newAPIError *types.NewAPIError
 		if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
@@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 	}
 	fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
 	fullTextResponse.Model = info.UpstreamModelName
-	usage := dto.Usage{
-		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
-		CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
-		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
-	}
-
-	usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
-	usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
-	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
-
-	for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
-		if detail.Modality == "AUDIO" {
-			usage.PromptTokensDetails.AudioTokens = detail.TokenCount
-		} else if detail.Modality == "TEXT" {
-			usage.PromptTokensDetails.TextTokens = detail.TokenCount
-		}
-	}
+	usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
 
 	fullTextResponse.Usage = usage
 

+ 333 - 0
relay/channel/gemini/relay_gemini_usage_test.go

@@ -0,0 +1,333 @@
+package gemini
+
+import (
+	"bytes"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/types"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/require"
+)
+
+func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	info := &relaycommon.RelayInfo{
+		RelayFormat:     types.RelayFormatGemini,
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+
+	payload := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "ok"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        151,
+			ToolUsePromptTokenCount: 18329,
+			CandidatesTokenCount:    1089,
+			ThoughtsTokenCount:      1120,
+			TotalTokenCount:         20689,
+		},
+	}
+
+	body, err := common.Marshal(payload)
+	require.NoError(t, err)
+
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(body)),
+	}
+
+	usage, newAPIError := GeminiChatHandler(c, info, resp)
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 18480, usage.PromptTokens)
+	require.Equal(t, 2209, usage.CompletionTokens)
+	require.Equal(t, 20689, usage.TotalTokens)
+	require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	oldStreamingTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 300
+	t.Cleanup(func() {
+		constant.StreamingTimeout = oldStreamingTimeout
+	})
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+
+	chunk := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "partial"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        151,
+			ToolUsePromptTokenCount: 18329,
+			CandidatesTokenCount:    1089,
+			ThoughtsTokenCount:      1120,
+			TotalTokenCount:         20689,
+		},
+	}
+
+	chunkData, err := common.Marshal(chunk)
+	require.NoError(t, err)
+
+	streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(streamBody)),
+	}
+
+	usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
+		return true
+	})
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 18480, usage.PromptTokens)
+	require.Equal(t, 2209, usage.CompletionTokens)
+	require.Equal(t, 20689, usage.TotalTokens)
+	require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+
+	payload := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "ok"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        151,
+			ToolUsePromptTokenCount: 18329,
+			CandidatesTokenCount:    1089,
+			ThoughtsTokenCount:      1120,
+			TotalTokenCount:         20689,
+		},
+	}
+
+	body, err := common.Marshal(payload)
+	require.NoError(t, err)
+
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(body)),
+	}
+
+	usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 18480, usage.PromptTokens)
+	require.Equal(t, 2209, usage.CompletionTokens)
+	require.Equal(t, 20689, usage.TotalTokens)
+	require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
+}
+
+func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	info := &relaycommon.RelayInfo{
+		RelayFormat:     types.RelayFormatGemini,
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+	info.SetEstimatePromptTokens(20)
+
+	payload := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "ok"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        0,
+			ToolUsePromptTokenCount: 0,
+			CandidatesTokenCount:    90,
+			ThoughtsTokenCount:      10,
+			TotalTokenCount:         110,
+		},
+	}
+
+	body, err := common.Marshal(payload)
+	require.NoError(t, err)
+
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(body)),
+	}
+
+	usage, newAPIError := GeminiChatHandler(c, info, resp)
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 20, usage.PromptTokens)
+	require.Equal(t, 100, usage.CompletionTokens)
+	require.Equal(t, 110, usage.TotalTokens)
+}
+
+func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	oldStreamingTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 300
+	t.Cleanup(func() {
+		constant.StreamingTimeout = oldStreamingTimeout
+	})
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+	info.SetEstimatePromptTokens(20)
+
+	chunk := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "partial"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        0,
+			ToolUsePromptTokenCount: 0,
+			CandidatesTokenCount:    90,
+			ThoughtsTokenCount:      10,
+			TotalTokenCount:         110,
+		},
+	}
+
+	chunkData, err := common.Marshal(chunk)
+	require.NoError(t, err)
+
+	streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(streamBody)),
+	}
+
+	usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
+		return true
+	})
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 20, usage.PromptTokens)
+	require.Equal(t, 100, usage.CompletionTokens)
+	require.Equal(t, 110, usage.TotalTokens)
+}
+
+func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
+	t.Parallel()
+
+	gin.SetMode(gin.TestMode)
+	c, _ := gin.CreateTestContext(httptest.NewRecorder())
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
+
+	info := &relaycommon.RelayInfo{
+		OriginModelName: "gemini-3-flash-preview",
+		ChannelMeta: &relaycommon.ChannelMeta{
+			UpstreamModelName: "gemini-3-flash-preview",
+		},
+	}
+	info.SetEstimatePromptTokens(20)
+
+	payload := dto.GeminiChatResponse{
+		Candidates: []dto.GeminiChatCandidate{
+			{
+				Content: dto.GeminiChatContent{
+					Role: "model",
+					Parts: []dto.GeminiPart{
+						{Text: "ok"},
+					},
+				},
+			},
+		},
+		UsageMetadata: dto.GeminiUsageMetadata{
+			PromptTokenCount:        0,
+			ToolUsePromptTokenCount: 0,
+			CandidatesTokenCount:    90,
+			ThoughtsTokenCount:      10,
+			TotalTokenCount:         110,
+		},
+	}
+
+	body, err := common.Marshal(payload)
+	require.NoError(t, err)
+
+	resp := &http.Response{
+		Body: io.NopCloser(bytes.NewReader(body)),
+	}
+
+	usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
+	require.Nil(t, newAPIError)
+	require.NotNil(t, usage)
+	require.Equal(t, 20, usage.PromptTokens)
+	require.Equal(t, 100, usage.CompletionTokens)
+	require.Equal(t, 110, usage.TotalTokens)
+}

+ 11 - 3
relay/channel/minimax/adaptor.go

@@ -10,6 +10,7 @@ import (
 
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	"github.com/QuantumNous/new-api/relay/channel/claude"
 	"github.com/QuantumNous/new-api/relay/channel/openai"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/relay/constant"
@@ -26,7 +27,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
 }
 
 func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
-	return nil, errors.New("not implemented")
+	adaptor := claude.Adaptor{}
+	return adaptor.ConvertClaudeRequest(c, info, req)
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -119,8 +121,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		return handleTTSResponse(c, resp, info)
 	}
 
-	adaptor := openai.Adaptor{}
-	return adaptor.DoResponse(c, resp, info)
+	switch info.RelayFormat {
+	case types.RelayFormatClaude:
+		adaptor := claude.Adaptor{}
+		return adaptor.DoResponse(c, resp, info)
+	default:
+		adaptor := openai.Adaptor{}
+		return adaptor.DoResponse(c, resp, info)
+	}
 }
 
 func (a *Adaptor) GetModelList() []string {

+ 12 - 7
relay/channel/minimax/relay-minimax.go

@@ -6,6 +6,7 @@ import (
 	channelconstant "github.com/QuantumNous/new-api/constant"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/relay/constant"
+	"github.com/QuantumNous/new-api/types"
 )
 
 func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -13,13 +14,17 @@ func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if baseUrl == "" {
 		baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax]
 	}
-
-	switch info.RelayMode {
-	case constant.RelayModeChatCompletions:
-		return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
-	case constant.RelayModeAudioSpeech:
-		return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
+	switch info.RelayFormat {
+	case types.RelayFormatClaude:
+		return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
 	default:
-		return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
+		switch info.RelayMode {
+		case constant.RelayModeChatCompletions:
+			return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
+		case constant.RelayModeAudioSpeech:
+			return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
+		default:
+			return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
+		}
 	}
 }

+ 44 - 24
relay/channel/task/ali/adaptor.go

@@ -13,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/logger"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/samber/lo"
@@ -108,10 +109,10 @@ type AliMetadata struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
-	aliReq      *AliVideoRequest
 }
 
 func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	// 阿里通义万相支持 JSON 格式,不使用 multipart
-	var taskReq relaycommon.TaskSubmitReq
-	if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
-		return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
-	}
-	aliReq, err := a.convertToAliRequest(info, taskReq)
-	if err != nil {
-		return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
-	}
-	a.aliReq = aliReq
-	logger.LogJson(c, "ali video request body", aliReq)
+	// ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
 	return relaycommon.ValidateMultipartDirect(c, info)
 }
 
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 }
 
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
-	bodyBytes, err := common.Marshal(a.aliReq)
+	taskReq, err := relaycommon.GetTaskRequest(c)
 	if err != nil {
-		return nil, errors.Wrap(err, "marshal_ali_request_failed")
+		return nil, errors.Wrap(err, "get_task_request_failed")
 	}
 
+	aliReq, err := a.convertToAliRequest(info, taskReq)
+	if err != nil {
+		return nil, errors.Wrap(err, "convert_to_ali_request_failed")
+	}
+	logger.LogJson(c, "ali video request body", aliReq)
+
+	bodyBytes, err := common.Marshal(aliReq)
+	if err != nil {
+		return nil, errors.Wrap(err, "marshal_ali_request_failed")
+	}
 	return bytes.NewReader(bodyBytes), nil
 }
 
@@ -252,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
 }
 
 func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
+	upstreamModel := req.Model
+	if info.IsModelMapped {
+		upstreamModel = info.UpstreamModelName
+	}
 	aliReq := &AliVideoRequest{
-		Model: req.Model,
+		Model: upstreamModel,
 		Input: AliVideoInput{
 			Prompt: req.Prompt,
 			ImgURL: req.InputReference,
@@ -331,23 +336,37 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
 		}
 	}
 
-	if aliReq.Model != req.Model {
+	if aliReq.Model != upstreamModel {
 		return nil, errors.New("can't change model with metadata")
 	}
 
-	info.PriceData.OtherRatios = map[string]float64{
-		"seconds": float64(aliReq.Parameters.Duration),
+	return aliReq, nil
+}
+
+// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
+// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
+	taskReq, err := relaycommon.GetTaskRequest(c)
+	if err != nil {
+		return nil
+	}
+
+	aliReq, err := a.convertToAliRequest(info, taskReq)
+	if err != nil {
+		return nil
 	}
 
+	otherRatios := map[string]float64{
+		"seconds": float64(aliReq.Parameters.Duration),
+	}
 	ratios, err := ProcessAliOtherRatios(aliReq)
 	if err != nil {
-		return nil, err
+		return otherRatios
 	}
-	for s, f := range ratios {
-		info.PriceData.OtherRatios[s] = f
+	for k, v := range ratios {
+		otherRatios[k] = v
 	}
-
-	return aliReq, nil
+	return otherRatios
 }
 
 // DoRequest delegates to common helper
@@ -384,7 +403,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// 转换为 OpenAI 格式响应
 	openAIResp := dto.NewOpenAIVideo()
-	openAIResp.ID = aliResp.Output.TaskID
+	openAIResp.ID = info.PublicTaskID
+	openAIResp.TaskID = info.PublicTaskID
 	openAIResp.Model = c.GetString("model")
 	if openAIResp.Model == "" && info != nil {
 		openAIResp.Model = info.OriginModelName

+ 15 - 16
relay/channel/task/doubao/adaptor.go

@@ -2,7 +2,6 @@ package doubao
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -14,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -89,6 +89,7 @@ type responseTask struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -130,8 +131,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
-	info.UpstreamModelName = body.Model
-	data, err := json.Marshal(body)
+	if info.IsModelMapped {
+		body.Model = info.UpstreamModelName
+	} else {
+		info.UpstreamModelName = body.Model
+	}
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -154,7 +159,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// Parse Doubao response
 	var dResp responsePayload
-	if err := json.Unmarshal(responseBody, &dResp); err != nil {
+	if err := common.Unmarshal(responseBody, &dResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -165,8 +170,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = dResp.ID
-	ov.TaskID = dResp.ID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 
@@ -234,12 +239,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 
 	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
@@ -248,7 +248,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := responseTask{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 
@@ -286,7 +286,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var dResp responseTask
-	if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal doubao task data failed")
 	}
 
@@ -307,6 +307,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 17 - 33
relay/channel/task/gemini/adaptor.go

@@ -2,8 +2,6 @@ package gemini
 
 import (
 	"bytes"
-	"encoding/base64"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -16,10 +14,10 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/QuantumNous/new-api/setting/model_setting"
-	"github.com/QuantumNous/new-api/setting/system_setting"
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
 )
@@ -87,6 +85,7 @@ type operationResponse struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -106,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 
 // BuildRequestURL constructs the upstream URL.
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	modelName := info.OriginModelName
+	modelName := info.UpstreamModelName
 	version := model_setting.GetGeminiVersionSetting(modelName)
 
 	return fmt.Sprintf(
@@ -145,16 +144,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 
 	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &body.Parameters)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -175,16 +169,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var s submitResponse
-	if err := json.Unmarshal(responseBody, &s); err != nil {
+	if err := common.Unmarshal(responseBody, &s); err != nil {
 		return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 	}
 	if strings.TrimSpace(s.Name) == "" {
 		return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
 	}
-	taskID = encodeLocalTaskID(s.Name)
+	taskID = taskcommon.EncodeLocalTaskID(s.Name)
 	ov := dto.NewOpenAIVideo()
-	ov.ID = taskID
-	ov.TaskID = taskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -206,7 +200,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		return nil, fmt.Errorf("invalid task_id")
 	}
 
-	upstreamName, err := decodeLocalTaskID(taskID)
+	upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
 	if err != nil {
 		return nil, fmt.Errorf("decode task_id failed: %w", err)
 	}
@@ -232,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	var op operationResponse
-	if err := json.Unmarshal(respBody, &op); err != nil {
+	if err := common.Unmarshal(respBody, &op); err != nil {
 		return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
 	}
 
@@ -254,9 +248,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 	ti.Status = model.TaskStatusSuccess
 	ti.Progress = "100%"
 
-	taskID := encodeLocalTaskID(op.Name)
-	ti.TaskID = taskID
-	ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
+	ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name)
+	// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
 
 	// Extract URL from generateVideoResponse if available
 	if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
@@ -269,7 +262,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 }
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
-	upstreamName, err := decodeLocalTaskID(task.TaskID)
+	// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
+	// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
+	upstreamTaskID := task.GetUpstreamTaskID()
+	upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
 	if err != nil {
 		upstreamName = ""
 	}
@@ -297,18 +293,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 // helpers
 // ============================
 
-func encodeLocalTaskID(name string) string {
-	return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
-	b, err := base64.RawURLEncoding.DecodeString(local)
-	if err != nil {
-		return "", err
-	}
-	return string(b), nil
-}
-
 var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
 
 func extractModelFromOperationName(name string) string {

+ 13 - 12
relay/channel/task/hailuo/adaptor.go

@@ -2,7 +2,6 @@ package hailuo
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -18,12 +17,14 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
 
 // https://platform.minimaxi.com/docs/api-reference/video-generation-intro
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -60,12 +61,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		return nil, fmt.Errorf("invalid request type in context")
 	}
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -86,7 +87,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var hResp VideoResponse
-	if err := json.Unmarshal(responseBody, &hResp); err != nil {
+	if err := common.Unmarshal(responseBody, &hResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -101,8 +102,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = hResp.TaskID
-	ov.TaskID = hResp.TaskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 
@@ -141,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
 	return ChannelName
 }
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
-	modelConfig := GetModelConfig(req.Model)
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
+	modelConfig := GetModelConfig(info.UpstreamModelName)
 	duration := DefaultDuration
 	if req.Duration > 0 {
 		duration = req.Duration
@@ -153,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 
 	videoRequest := &VideoRequest{
-		Model:      req.Model,
+		Model:      info.UpstreamModelName,
 		Prompt:     req.Prompt,
 		Duration:   &duration,
 		Resolution: resolution,
@@ -182,7 +183,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := QueryTaskResponse{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 
@@ -224,7 +225,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var hailuoResp QueryTaskResponse
-	if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
 	}
 
@@ -271,7 +272,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
 	}
 
 	var retrieveResp RetrieveFileResponse
-	if err := json.Unmarshal(responseBody, &retrieveResp); err != nil {
+	if err := common.Unmarshal(responseBody, &retrieveResp); err != nil {
 		return ""
 	}
 

+ 14 - 20
relay/channel/task/jimeng/adaptor.go

@@ -6,7 +6,6 @@ import (
 	"crypto/sha256"
 	"encoding/base64"
 	"encoding/hex"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -25,6 +24,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
@@ -77,6 +77,7 @@ const (
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	accessKey   string
 	secretKey   string
@@ -164,11 +165,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		}
 	}
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -191,7 +192,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 
 	// Parse Jimeng response
 	var jResp responsePayload
-	if err := json.Unmarshal(responseBody, &jResp); err != nil {
+	if err := common.Unmarshal(responseBody, &jResp); err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
 	}
@@ -202,8 +203,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = jResp.Data.TaskID
-	ov.TaskID = jResp.Data.TaskID
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -225,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
 		"task_id": taskID,
 	}
-	payloadBytes, err := json.Marshal(payload)
+	payloadBytes, err := common.Marshal(payload)
 	if err != nil {
 		return nil, errors.Wrap(err, "marshal fetch task payload failed")
 	}
@@ -377,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
 	return h.Sum(nil)
 }
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
-		ReqKey: req.Model,
+		ReqKey: info.UpstreamModelName,
 		Prompt: req.Prompt,
 	}
 
@@ -398,13 +399,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 			r.BinaryDataBase64 = req.Images
 		}
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 
@@ -432,7 +427,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	resTask := responseTask{}
-	if err := json.Unmarshal(respBody, &resTask); err != nil {
+	if err := common.Unmarshal(respBody, &resTask); err != nil {
 		return nil, errors.Wrap(err, "unmarshal task result failed")
 	}
 	taskResult := relaycommon.TaskInfo{}
@@ -458,7 +453,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var jimengResp responseTask
-	if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
 	}
 
@@ -477,8 +472,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }
 
 func isNewAPIRelay(apiKey string) bool {

+ 17 - 36
relay/channel/task/kling/adaptor.go

@@ -2,7 +2,6 @@ package kling
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -21,6 +20,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 )
@@ -97,6 +97,7 @@ type responsePayload struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -149,14 +150,14 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 	req := v.(relaycommon.TaskSubmitReq)
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, err
 	}
 	if body.Image == "" && body.ImageTail == "" {
 		c.Set("action", constant.TaskActionTextGenerate)
 	}
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -180,7 +181,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	var kResp responsePayload
-	err = json.Unmarshal(responseBody, &kResp)
+	err = common.Unmarshal(responseBody, &kResp)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 		return
@@ -190,8 +191,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 	ov := dto.NewOpenAIVideo()
-	ov.ID = kResp.Data.TaskId
-	ov.TaskID = kResp.Data.TaskId
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -247,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
 		Prompt:         req.Prompt,
 		Image:          req.Image,
-		Mode:           defaultString(req.Mode, "std"),
-		Duration:       fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+		Mode:           taskcommon.DefaultString(req.Mode, "std"),
+		Duration:       fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
 		AspectRatio:    a.getAspectRatio(req.Size),
-		ModelName:      req.Model,
-		Model:          req.Model, // Keep consistent with model_name, double writing improves compatibility
+		ModelName:      info.UpstreamModelName,
+		Model:          info.UpstreamModelName,
 		CfgScale:       0.5,
 		StaticMask:     "",
 		DynamicMasks:   []DynamicMask{},
@@ -265,14 +266,9 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 	if r.ModelName == "" {
 		r.ModelName = "kling-v1"
+		r.Model = "kling-v1"
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 	return &r, nil
@@ -291,20 +287,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string {
 	}
 }
 
-func defaultString(s, def string) string {
-	if strings.TrimSpace(s) == "" {
-		return def
-	}
-	return s
-}
-
-func defaultInt(v int, def int) int {
-	if v == 0 {
-		return def
-	}
-	return v
-}
-
 // ============================
 // JWT helpers
 // ============================
@@ -340,7 +322,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	taskInfo := &relaycommon.TaskInfo{}
 	resPayload := responsePayload{}
-	err := json.Unmarshal(respBody, &resPayload)
+	err := common.Unmarshal(respBody, &resPayload)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to unmarshal response body")
 	}
@@ -374,7 +356,7 @@ func isNewAPIRelay(apiKey string) bool {
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var klingResp responsePayload
-	if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal kling task data failed")
 	}
 
@@ -401,6 +383,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 			Code:    fmt.Sprintf("%d", klingResp.Code),
 		}
 	}
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 132 - 15
relay/channel/task/sora/adaptor.go

@@ -1,9 +1,13 @@
 package sora
 
 import (
+	"bytes"
 	"fmt"
 	"io"
+	"mime/multipart"
 	"net/http"
+	"net/textproto"
+	"strconv"
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
@@ -11,12 +15,13 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
-	"github.com/QuantumNous/new-api/setting/system_setting"
 
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
+	"github.com/tidwall/sjson"
 )
 
 // ============================
@@ -57,6 +62,7 @@ type responseTask struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -69,15 +75,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func validateRemixRequest(c *gin.Context) *dto.TaskError {
-	var req struct {
-		Prompt string `json:"prompt"`
-	}
+	var req relaycommon.TaskSubmitReq
 	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
 		return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
 	}
 	if strings.TrimSpace(req.Prompt) == "" {
 		return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
 	}
+	// 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
+	c.Set("task_request", req)
 	return nil
 }
 
@@ -88,6 +94,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 	return relaycommon.ValidateMultipartDirect(c, info)
 }
 
+// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
+	// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
+	if info.Action == constant.TaskActionRemix {
+		return nil
+	}
+
+	req, err := relaycommon.GetTaskRequest(c)
+	if err != nil {
+		return nil
+	}
+
+	seconds, _ := strconv.Atoi(req.Seconds)
+	if seconds == 0 {
+		seconds = req.Duration
+	}
+	if seconds <= 0 {
+		seconds = 4
+	}
+
+	size := req.Size
+	if size == "" {
+		size = "720x1280"
+	}
+
+	ratios := map[string]float64{
+		"seconds": float64(seconds),
+		"size":    1,
+	}
+	if size == "1792x1024" || size == "1024x1792" {
+		ratios["size"] = 1.666667
+	}
+	return ratios
+}
+
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if info.Action == constant.TaskActionRemix {
 		return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
@@ -107,6 +148,74 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, errors.Wrap(err, "get_request_body_failed")
 	}
+	cachedBody, err := storage.Bytes()
+	if err != nil {
+		return nil, errors.Wrap(err, "read_body_bytes_failed")
+	}
+	contentType := c.GetHeader("Content-Type")
+
+	if strings.HasPrefix(contentType, "application/json") {
+		var bodyMap map[string]interface{}
+		if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
+			bodyMap["model"] = info.UpstreamModelName
+			if newBody, err := common.Marshal(bodyMap); err == nil {
+				return bytes.NewReader(newBody), nil
+			}
+		}
+		return bytes.NewReader(cachedBody), nil
+	}
+
+	if strings.Contains(contentType, "multipart/form-data") {
+		formData, err := common.ParseMultipartFormReusable(c)
+		if err != nil {
+			return bytes.NewReader(cachedBody), nil
+		}
+		var buf bytes.Buffer
+		writer := multipart.NewWriter(&buf)
+		writer.WriteField("model", info.UpstreamModelName)
+		for key, values := range formData.Value {
+			if key == "model" {
+				continue
+			}
+			for _, v := range values {
+				writer.WriteField(key, v)
+			}
+		}
+		for fieldName, fileHeaders := range formData.File {
+			for _, fh := range fileHeaders {
+				f, err := fh.Open()
+				if err != nil {
+					continue
+				}
+				ct := fh.Header.Get("Content-Type")
+				if ct == "" || ct == "application/octet-stream" {
+					buf512 := make([]byte, 512)
+					n, _ := io.ReadFull(f, buf512)
+					ct = http.DetectContentType(buf512[:n])
+					// Re-open after sniffing so the full content is copied below
+					f.Close()
+					f, err = fh.Open()
+					if err != nil {
+						continue
+					}
+				}
+				h := make(textproto.MIMEHeader)
+				h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fh.Filename))
+				h.Set("Content-Type", ct)
+				part, err := writer.CreatePart(h)
+				if err != nil {
+					f.Close()
+					continue
+				}
+				io.Copy(part, f)
+				f.Close()
+			}
+		}
+		writer.Close()
+		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+		return &buf, nil
+	}
+
 	return common.ReaderOnly(storage), nil
 }
 
@@ -116,7 +225,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req
 }
 
 // DoResponse handles upstream response, returns taskID etc.
-func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -131,17 +240,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
 		return
 	}
 
-	if dResp.ID == "" {
-		if dResp.TaskID == "" {
-			taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
-			return
-		}
-		dResp.ID = dResp.TaskID
-		dResp.TaskID = ""
+	upstreamID := dResp.ID
+	if upstreamID == "" {
+		upstreamID = dResp.TaskID
+	}
+	if upstreamID == "" {
+		taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
+		return
 	}
 
+	// 使用公开 task_xxxx ID 返回给客户端
+	dResp.ID = info.PublicTaskID
+	dResp.TaskID = info.PublicTaskID
 	c.JSON(http.StatusOK, dResp)
-	return dResp.ID, responseBody, nil
+	return upstreamID, responseBody, nil
 }
 
 // FetchTask fetch task status
@@ -192,7 +304,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 		taskResult.Status = model.TaskStatusInProgress
 	case "completed":
 		taskResult.Status = model.TaskStatusSuccess
-		taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID)
+		// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
 	case "failed", "cancelled":
 		taskResult.Status = model.TaskStatusFailure
 		if resTask.Error != nil {
@@ -210,5 +322,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 }
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
-	return task.Data, nil
+	data := task.Data
+	var err error
+	if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil {
+		return nil, errors.Wrap(err, "set id failed")
+	}
+	return data, nil
 }

+ 24 - 35
relay/channel/task/suno/adaptor.go

@@ -2,18 +2,16 @@ package suno
 
 import (
 	"bytes"
-	"context"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
 	"strings"
-	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -21,11 +19,16 @@ import (
 )
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 }
 
+// ParseTaskResult is not used for Suno tasks.
+// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that
+// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API.
+// This differs from the per-task polling used by video adaptors.
 func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
-	return nil, fmt.Errorf("not implement") // todo implement this method if needed
+	return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable")
 }
 
 func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
@@ -47,13 +50,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 		return
 	}
 
-	if sunoRequest.ContinueClipId != "" {
-		if sunoRequest.TaskID == "" {
-			taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
-			return
-		}
-		info.OriginTaskID = sunoRequest.TaskID
-	}
+	//if sunoRequest.ContinueClipId != "" {
+	//	if sunoRequest.TaskID == "" {
+	//		taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
+	//		return
+	//	}
+	//	info.OriginTaskID = sunoRequest.TaskID
+	//}
 
 	info.Action = action
 	c.Set("task_request", sunoRequest)
@@ -76,12 +79,9 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
 	sunoRequest, ok := c.Get("task_request")
 	if !ok {
-		err := common.UnmarshalBodyReusable(c, &sunoRequest)
-		if err != nil {
-			return nil, err
-		}
+		return nil, fmt.Errorf("task_request not found in context")
 	}
-	data, err := json.Marshal(sunoRequest)
+	data, err := common.Marshal(sunoRequest)
 	if err != nil {
 		return nil, err
 	}
@@ -99,7 +99,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 	var sunoResponse dto.TaskResponse[string]
-	err = json.Unmarshal(responseBody, &sunoResponse)
+	err = common.Unmarshal(responseBody, &sunoResponse)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return
@@ -109,17 +109,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 		return
 	}
 
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-
-	_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
-		return
+	// 使用公开 task_xxxx ID 替换上游 ID 返回给客户端
+	publicResponse := dto.TaskResponse[string]{
+		Code:    sunoResponse.Code,
+		Message: sunoResponse.Message,
+		Data:    info.PublicTaskID,
 	}
+	c.JSON(http.StatusOK, publicResponse)
 
 	return sunoResponse.Data, nil, nil
 }
@@ -134,7 +130,7 @@ func (a *TaskAdaptor) GetChannelName() string {
 
 func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
-	byteBody, err := json.Marshal(body)
+	byteBody, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -144,13 +140,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		common.SysLog(fmt.Sprintf("Get Task error: %v", err))
 		return nil, err
 	}
-	defer req.Body.Close()
-	// 设置超时时间
-	timeout := time.Second * 15
-	ctx, cancel := context.WithTimeout(context.Background(), timeout)
-	defer cancel()
-	// 使用带有超时的 context 创建新的请求
-	req = req.WithContext(ctx)
 	req.Header.Set("Content-Type", "application/json")
 	req.Header.Set("Authorization", "Bearer "+key)
 	client, err := service.GetHttpClientWithProxy(proxy)

+ 95 - 0
relay/channel/task/taskcommon/helpers.go

@@ -0,0 +1,95 @@
+package taskcommon
+
+import (
+	"encoding/base64"
+	"fmt"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/setting/system_setting"
+	"github.com/gin-gonic/gin"
+)
+
+// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
+// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target).
+func UnmarshalMetadata(metadata map[string]any, target any) error {
+	if metadata == nil {
+		return nil
+	}
+	metaBytes, err := common.Marshal(metadata)
+	if err != nil {
+		return fmt.Errorf("marshal metadata failed: %w", err)
+	}
+	if err := common.Unmarshal(metaBytes, target); err != nil {
+		return fmt.Errorf("unmarshal metadata failed: %w", err)
+	}
+	return nil
+}
+
+// DefaultString returns val if non-empty, otherwise fallback.
+func DefaultString(val, fallback string) string {
+	if val == "" {
+		return fallback
+	}
+	return val
+}
+
+// DefaultInt returns val if non-zero, otherwise fallback.
+func DefaultInt(val, fallback int) int {
+	if val == 0 {
+		return fallback
+	}
+	return val
+}
+
+// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string.
+// Used by Gemini/Vertex to store upstream names as task IDs.
+func EncodeLocalTaskID(name string) string {
+	return base64.RawURLEncoding.EncodeToString([]byte(name))
+}
+
+// DecodeLocalTaskID decodes a base64-encoded upstream operation name.
+func DecodeLocalTaskID(id string) (string, error) {
+	b, err := base64.RawURLEncoding.DecodeString(id)
+	if err != nil {
+		return "", err
+	}
+	return string(b), nil
+}
+
+// BuildProxyURL constructs the video proxy URL using the public task ID.
+// e.g., "https://your-server.com/v1/videos/task_xxxx/content"
+func BuildProxyURL(taskID string) string {
+	return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
+}
+
+// Status-to-progress mapping constants for polling updates.
+const (
+	ProgressSubmitted  = "10%"
+	ProgressQueued     = "20%"
+	ProgressInProgress = "30%"
+	ProgressComplete   = "100%"
+)
+
+// ---------------------------------------------------------------------------
+// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
+// Adaptors that do not need custom billing can embed this struct directly.
+// ---------------------------------------------------------------------------
+
+type BaseBilling struct{}
+
+// EstimateBilling returns nil (no extra ratios; use base model price).
+func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
+	return nil
+}
+
+// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
+func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
+	return nil
+}
+
+// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
+func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+	return 0
+}

+ 47 - 46
relay/channel/task/vertex/adaptor.go

@@ -2,13 +2,12 @@ package vertex
 
 import (
 	"bytes"
-	"encoding/base64"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
 	"regexp"
 	"strings"
+	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/model"
@@ -17,6 +16,7 @@ import (
 	"github.com/QuantumNous/new-api/constant"
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
@@ -62,6 +62,7 @@ type operationResponse struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	apiKey      string
 	baseURL     string
@@ -82,10 +83,10 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 // BuildRequestURL constructs the upstream URL.
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+	if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
 		return "", fmt.Errorf("failed to decode credentials: %w", err)
 	}
-	modelName := info.OriginModelName
+	modelName := info.UpstreamModelName
 	if modelName == "" {
 		modelName = "veo-3.0-generate-001"
 	}
@@ -116,7 +117,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 	req.Header.Set("Accept", "application/json")
 
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
+	if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
 		return fmt.Errorf("failed to decode credentials: %w", err)
 	}
 
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
 	return nil
 }
 
+// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
+func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
+	sampleCount := 1
+	v, ok := c.Get("task_request")
+	if ok {
+		req := v.(relaycommon.TaskSubmitReq)
+		if req.Metadata != nil {
+			if sc, exists := req.Metadata["sampleCount"]; exists {
+				if i, ok := sc.(int); ok && i > 0 {
+					sampleCount = i
+				}
+				if f, ok := sc.(float64); ok && int(f) > 0 {
+					sampleCount = int(f)
+				}
+			}
+		}
+	}
+	return map[string]float64{
+		"sampleCount": float64(sampleCount),
+	}
+}
+
 // BuildRequestBody converts request into Vertex specific format.
 func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
 	v, ok := c.Get("task_request")
@@ -166,25 +189,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		return nil, fmt.Errorf("sampleCount must be greater than 0")
 	}
 
-	// if req.Duration > 0 {
-	// 	body.Parameters["durationSeconds"] = req.Duration
-	// } else if req.Seconds != "" {
-	// 	seconds, err := strconv.Atoi(req.Seconds)
-	// 	if err != nil {
-	// 		return nil, errors.Wrap(err, "convert seconds to int failed")
-	// 	}
-	// 	body.Parameters["durationSeconds"] = seconds
-	// }
-
-	info.PriceData.OtherRatios = map[string]float64{
-		"sampleCount": float64(body.Parameters["sampleCount"].(int)),
-	}
-
-	// if v, ok := body.Parameters["durationSeconds"]; ok {
-	// 	info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
-	// }
-
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -205,14 +210,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	_ = resp.Body.Close()
 
 	var s submitResponse
-	if err := json.Unmarshal(responseBody, &s); err != nil {
+	if err := common.Unmarshal(responseBody, &s); err != nil {
 		return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
 	}
 	if strings.TrimSpace(s.Name) == "" {
 		return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
 	}
-	localID := encodeLocalTaskID(s.Name)
-	c.JSON(http.StatusOK, gin.H{"task_id": localID})
+	localID := taskcommon.EncodeLocalTaskID(s.Name)
+	ov := dto.NewOpenAIVideo()
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
+	ov.CreatedAt = time.Now().Unix()
+	ov.Model = info.OriginModelName
+	c.JSON(http.StatusOK, ov)
 	return localID, responseBody, nil
 }
 
@@ -225,7 +235,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 	if !ok {
 		return nil, fmt.Errorf("invalid task_id")
 	}
-	upstreamName, err := decodeLocalTaskID(taskID)
+	upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
 	if err != nil {
 		return nil, fmt.Errorf("decode task_id failed: %w", err)
 	}
@@ -245,12 +255,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 		url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
 	}
 	payload := map[string]string{"operationName": upstreamName}
-	data, err := json.Marshal(payload)
+	data, err := common.Marshal(payload)
 	if err != nil {
 		return nil, err
 	}
 	adc := &vertexcore.Credentials{}
-	if err := json.Unmarshal([]byte(key), adc); err != nil {
+	if err := common.Unmarshal([]byte(key), adc); err != nil {
 		return nil, fmt.Errorf("failed to decode credentials: %w", err)
 	}
 	token, err := vertexcore.AcquireAccessToken(*adc, proxy)
@@ -274,7 +284,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	var op operationResponse
-	if err := json.Unmarshal(respBody, &op); err != nil {
+	if err := common.Unmarshal(respBody, &op); err != nil {
 		return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
 	}
 	ti := &relaycommon.TaskInfo{}
@@ -338,7 +348,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 }
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
-	upstreamName, err := decodeLocalTaskID(task.TaskID)
+	// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
+	// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
+	upstreamTaskID := task.GetUpstreamTaskID()
+	upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
 	if err != nil {
 		upstreamName = ""
 	}
@@ -353,8 +366,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 	v.SetProgressStr(task.Progress)
 	v.CreatedAt = task.CreatedAt
 	v.CompletedAt = task.UpdatedAt
-	if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
-		v.SetMetadata("url", task.FailReason)
+	if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 {
+		v.SetMetadata("url", resultURL)
 	}
 
 	return common.Marshal(v)
@@ -364,18 +377,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
 // helpers
 // ============================
 
-func encodeLocalTaskID(name string) string {
-	return base64.RawURLEncoding.EncodeToString([]byte(name))
-}
-
-func decodeLocalTaskID(local string) (string, error) {
-	b, err := base64.RawURLEncoding.DecodeString(local)
-	if err != nil {
-		return "", err
-	}
-	return string(b), nil
-}
-
 var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
 
 func extractRegionFromOperationName(name string) string {

+ 15 - 35
relay/channel/task/vidu/adaptor.go

@@ -2,7 +2,6 @@ package vidu
 
 import (
 	"bytes"
-	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -16,6 +15,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/QuantumNous/new-api/service"
 
@@ -73,6 +73,7 @@ type creation struct {
 // ============================
 
 type TaskAdaptor struct {
+	taskcommon.BaseBilling
 	ChannelType int
 	baseURL     string
 }
@@ -115,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 	req := v.(relaycommon.TaskSubmitReq)
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, err
 	}
@@ -127,7 +128,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		}
 	}
 
-	data, err := json.Marshal(body)
+	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err
 	}
@@ -168,7 +169,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	var vResp responsePayload
-	err = json.Unmarshal(responseBody, &vResp)
+	err = common.Unmarshal(responseBody, &vResp)
 	if err != nil {
 		taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
 		return
@@ -180,8 +181,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	ov := dto.NewOpenAIVideo()
-	ov.ID = vResp.TaskId
-	ov.TaskID = vResp.TaskId
+	ov.ID = info.PublicTaskID
+	ov.TaskID = info.PublicTaskID
 	ov.CreatedAt = time.Now().Unix()
 	ov.Model = info.OriginModelName
 	c.JSON(http.StatusOK, ov)
@@ -223,47 +224,27 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
-		Model:             defaultString(req.Model, "viduq1"),
+		Model:             taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
 		Images:            req.Images,
 		Prompt:            req.Prompt,
-		Duration:          defaultInt(req.Duration, 5),
-		Resolution:        defaultString(req.Size, "1080p"),
+		Duration:          taskcommon.DefaultInt(req.Duration, 5),
+		Resolution:        taskcommon.DefaultString(req.Size, "1080p"),
 		MovementAmplitude: "auto",
 		Bgm:               false,
 	}
-	metadata := req.Metadata
-	medaBytes, err := json.Marshal(metadata)
-	if err != nil {
-		return nil, errors.Wrap(err, "metadata marshal metadata failed")
-	}
-	err = json.Unmarshal(medaBytes, &r)
-	if err != nil {
+	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")
 	}
 	return &r, nil
 }
 
-func defaultString(value, defaultValue string) string {
-	if value == "" {
-		return defaultValue
-	}
-	return value
-}
-
-func defaultInt(value, defaultValue int) int {
-	if value == 0 {
-		return defaultValue
-	}
-	return value
-}
-
 func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
 	taskInfo := &relaycommon.TaskInfo{}
 
 	var taskResp taskResultResponse
-	err := json.Unmarshal(respBody, &taskResp)
+	err := common.Unmarshal(respBody, &taskResp)
 	if err != nil {
 		return nil, errors.Wrap(err, "failed to unmarshal response body")
 	}
@@ -293,7 +274,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
 
 func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
 	var viduResp taskResultResponse
-	if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
+	if err := common.Unmarshal(originTask.Data, &viduResp); err != nil {
 		return nil, errors.Wrap(err, "unmarshal vidu task data failed")
 	}
 
@@ -315,6 +296,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
 		}
 	}
 
-	jsonData, _ := common.Marshal(openAIVideo)
-	return jsonData, nil
+	return common.Marshal(openAIVideo)
 }

+ 2 - 2
relay/chat_completions_via_responses.go

@@ -75,7 +75,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
 		return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
 
-	chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings)
+	chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
@@ -119,7 +119,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
 		return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
 
-	jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+	jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
 	if err != nil {
 		return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}

+ 1 - 1
relay/claude_handler.go

@@ -146,7 +146,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 		}
 
 		// remove disabled fields for Claude API
-		jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+		jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}

+ 73 - 0
relay/common/override_test.go

@@ -6,6 +6,9 @@ import (
 	"testing"
 
 	"github.com/QuantumNous/new-api/types"
+
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/setting/model_setting"
 )
 
 func TestApplyParamOverrideTrimPrefix(t *testing.T) {
@@ -1311,6 +1314,76 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
 	}
 }
 
+func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
+	input := `{
+		"service_tier":"flex",
+		"safety_identifier":"user-123",
+		"store":true,
+		"stream_options":{"include_obfuscation":false}
+	}`
+	settings := dto.ChannelOtherSettings{}
+
+	out, err := RemoveDisabledFields([]byte(input), settings, true)
+	if err != nil {
+		t.Fatalf("RemoveDisabledFields returned error: %v", err)
+	}
+	assertJSONEqual(t, input, string(out))
+}
+
+func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) {
+	original := model_setting.GetGlobalSettings().PassThroughRequestEnabled
+	model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
+	t.Cleanup(func() {
+		model_setting.GetGlobalSettings().PassThroughRequestEnabled = original
+	})
+
+	input := `{
+		"service_tier":"flex",
+		"safety_identifier":"user-123",
+		"stream_options":{"include_obfuscation":false}
+	}`
+	settings := dto.ChannelOtherSettings{}
+
+	out, err := RemoveDisabledFields([]byte(input), settings, false)
+	if err != nil {
+		t.Fatalf("RemoveDisabledFields returned error: %v", err)
+	}
+	assertJSONEqual(t, input, string(out))
+}
+
+func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
+	input := `{
+		"service_tier":"flex",
+		"inference_geo":"eu",
+		"safety_identifier":"user-123",
+		"store":true,
+		"stream_options":{"include_obfuscation":false}
+	}`
+	settings := dto.ChannelOtherSettings{}
+
+	out, err := RemoveDisabledFields([]byte(input), settings, false)
+	if err != nil {
+		t.Fatalf("RemoveDisabledFields returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"store":true}`, string(out))
+}
+
+func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
+	input := `{
+		"inference_geo":"eu",
+		"store":true
+	}`
+	settings := dto.ChannelOtherSettings{
+		AllowInferenceGeo: true,
+	}
+
+	out, err := RemoveDisabledFields([]byte(input), settings, false)
+	if err != nil {
+		t.Fatalf("RemoveDisabledFields returned error: %v", err)
+	}
+	assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
+}
+
 func assertJSONEqual(t *testing.T, want, got string) {
 	t.Helper()
 

+ 62 - 5
relay/common/relay_info.go

@@ -119,8 +119,12 @@ type RelayInfo struct {
 	SendResponseCount      int
 	ReceivedResponseCount  int
 	FinalPreConsumedQuota  int // 最终预消耗的配额
+	// ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
+	// 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
+	// 必须在提交前锁定全额。
+	ForcePreConsume bool
 	// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
-	// 免费模型和按次计费(MJ/Task)时为 nil。
+	// 免费模型时为 nil。
 	Billing BillingSettler
 	// BillingSource indicates whether this request is billed from wallet quota or subscription.
 	// "" or "wallet" => wallet; "subscription" => subscription
@@ -153,7 +157,8 @@ type RelayInfo struct {
 	// RequestConversionChain records request format conversions in order, e.g.
 	// ["openai", "openai_responses"] or ["openai", "claude"].
 	RequestConversionChain []types.RelayFormat
-	// 最终请求到上游的格式 TODO: 当前仅设置了Claude
+	// 最终请求到上游的格式。可由 adaptor 显式设置;
+	// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
 	FinalRequestRelayFormat types.RelayFormat
 
 	ThinkingContentInfo
@@ -552,8 +557,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
 		return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
 	case types.RelayFormatTask:
 		info = genBaseRelayInfo(c, nil)
+		info.TaskRelayInfo = &TaskRelayInfo{}
 	case types.RelayFormatMjProxy:
 		info = genBaseRelayInfo(c, nil)
+		info.TaskRelayInfo = &TaskRelayInfo{}
 	default:
 		err = errors.New("invalid relay format")
 	}
@@ -600,6 +607,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
 	info.RequestConversionChain = append(info.RequestConversionChain, format)
 }
 
+func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
+	if info == nil {
+		return ""
+	}
+	if info.FinalRequestRelayFormat != "" {
+		return info.FinalRequestRelayFormat
+	}
+	if n := len(info.RequestConversionChain); n > 0 {
+		return info.RequestConversionChain[n-1]
+	}
+	return info.RelayFormat
+}
+
 func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
 	info := genBaseRelayInfo(c, request)
 	if info.RelayMode == relayconstant.RelayModeUnknown {
@@ -635,8 +655,16 @@ func (info *RelayInfo) HasSendResponse() bool {
 type TaskRelayInfo struct {
 	Action       string
 	OriginTaskID string
+	// PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
+	// 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
+	PublicTaskID string
 
 	ConsumeQuota bool
+
+	// LockedChannel holds the full channel object when the request is bound to
+	// a specific channel (e.g., remix on origin task's channel). Stored as any
+	// to avoid an import cycle with model; callers type-assert to *model.Channel.
+	LockedChannel any
 }
 
 type TaskSubmitReq struct {
@@ -694,11 +722,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
 func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
 	metadata := t.Metadata
 	if metadata != nil {
-		metadataBytes, err := json.Marshal(metadata)
+		metadataBytes, err := common.Marshal(metadata)
 		if err != nil {
 			return fmt.Errorf("marshal metadata failed: %w", err)
 		}
-		err = json.Unmarshal(metadataBytes, v)
+		err = common.Unmarshal(metadataBytes, v)
 		if err != nil {
 			return fmt.Errorf("unmarshal metadata to target failed: %w", err)
 		}
@@ -727,9 +755,15 @@ func FailTaskInfo(reason string) *TaskInfo {
 
 // RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
 // service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
+// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
 // store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
 // safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
-func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) {
+// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
+func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
+	if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
+		return jsonData, nil
+	}
+
 	var data map[string]interface{}
 	if err := common.Unmarshal(jsonData, &data); err != nil {
 		common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
@@ -743,6 +777,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
 		}
 	}
 
+	// 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域)
+	if !channelOtherSettings.AllowInferenceGeo {
+		if _, exists := data["inference_geo"]; exists {
+			delete(data, "inference_geo")
+		}
+	}
+
 	// 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
 	if channelOtherSettings.DisableStore {
 		if _, exists := data["store"]; exists {
@@ -757,6 +798,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
 		}
 	}
 
+	// 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护)
+	if !channelOtherSettings.AllowIncludeObfuscation {
+		if streamOptionsAny, exists := data["stream_options"]; exists {
+			if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
+				if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
+					delete(streamOptions, "include_obfuscation")
+				}
+				if len(streamOptions) == 0 {
+					delete(data, "stream_options")
+				} else {
+					data["stream_options"] = streamOptions
+				}
+			}
+		}
+	}
+
 	jsonDataAfter, err := common.Marshal(data)
 	if err != nil {
 		common.SysError("RemoveDisabledFields Marshal error :" + err.Error())

+ 40 - 0
relay/common/relay_info_test.go

@@ -0,0 +1,40 @@
+package common
+
+import (
+	"testing"
+
+	"github.com/QuantumNous/new-api/types"
+	"github.com/stretchr/testify/require"
+)
+
+func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) {
+	info := &RelayInfo{
+		RelayFormat:             types.RelayFormatOpenAI,
+		RequestConversionChain:  []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
+		FinalRequestRelayFormat: types.RelayFormatOpenAIResponses,
+	}
+
+	require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat())
+}
+
+func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) {
+	info := &RelayInfo{
+		RelayFormat:            types.RelayFormatOpenAI,
+		RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
+	}
+
+	require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat())
+}
+
+func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) {
+	info := &RelayInfo{
+		RelayFormat: types.RelayFormatGemini,
+	}
+
+	require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat())
+}
+
+func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) {
+	var info *RelayInfo
+	require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat())
+}

+ 2 - 8
relay/common/relay_utils.go

@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
 		if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
 			return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
 		}
-		info.PriceData.OtherRatios = map[string]float64{
-			"seconds": float64(seconds),
-			"size":    1,
-		}
-		if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
-			info.PriceData.OtherRatios["size"] = 1.666667
-		}
+		// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
 	}
 
-	info.Action = action
+	storeTaskRequest(c, info, action, req)
 
 	return nil
 }

+ 3 - 3
relay/compatible_handler.go

@@ -165,7 +165,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 		}
 
 		// remove disabled fields for OpenAI API
-		jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+		jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}
@@ -232,7 +232,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 	}
 
 	if originUsage != nil {
-		service.ObserveChannelAffinityUsageCacheFromContext(ctx, usage)
+		service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
 	}
 
 	adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
@@ -336,7 +336,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
 
 	var audioInputQuota decimal.Decimal
 	var audioInputPrice float64
-	isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude
+	isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
 	if !relayInfo.PriceData.UsePrice {
 		baseTokens := dPromptTokens
 		// 减去 cached tokens

+ 13 - 2
relay/helper/price.go

@@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
 }
 
 // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
-func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
+func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData {
 	groupRatioInfo := HandleGroupRatio(c, info)
 
 	modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
@@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.
 		}
 	}
 	quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
-	priceData := types.PerCallPriceData{
+
+	// 免费模型检测(与 ModelPriceHelper 对齐)
+	freeModel := false
+	if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
+		if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
+			quota = 0
+			freeModel = true
+		}
+	}
+
+	priceData := types.PriceData{
+		FreeModel:      freeModel,
 		ModelPrice:     modelPrice,
 		Quota:          quota,
 		GroupRatioInfo: groupRatioInfo,

+ 27 - 16
relay/helper/stream_scanner.go

@@ -176,10 +176,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 		})
 	}
 
+	dataChan := make(chan string, 10)
+
+	wg.Add(1)
+	gopool.Go(func() {
+		defer func() {
+			wg.Done()
+			if r := recover(); r != nil {
+				logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
+			}
+			common.SafeSendBool(stopChan, true)
+		}()
+		for data := range dataChan {
+			writeMutex.Lock()
+			success := dataHandler(data)
+			writeMutex.Unlock()
+			if !success {
+				return
+			}
+		}
+	})
+
 	// Scanner goroutine with improved error handling
 	wg.Add(1)
 	common.RelayCtxGo(ctx, func() {
 		defer func() {
+			close(dataChan)
 			wg.Done()
 			if r := recover(); r != nil {
 				logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
@@ -215,27 +237,16 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 				continue
 			}
 			data = data[5:]
-			data = strings.TrimLeft(data, " ")
-			data = strings.TrimSuffix(data, "\r")
+			data = strings.TrimSpace(data)
+			if data == "" {
+				continue
+			}
 			if !strings.HasPrefix(data, "[DONE]") {
 				info.SetFirstResponseTime()
 				info.ReceivedResponseCount++
-				// 使用超时机制防止写操作阻塞
-				done := make(chan bool, 1)
-				gopool.Go(func() {
-					writeMutex.Lock()
-					defer writeMutex.Unlock()
-					done <- dataHandler(data)
-				})
 
 				select {
-				case success := <-done:
-					if !success {
-						return
-					}
-				case <-time.After(10 * time.Second):
-					logger.LogError(c, "data handler timeout")
-					return
+				case dataChan <- data:
 				case <-ctx.Done():
 					return
 				case <-stopChan:

+ 521 - 0
relay/helper/stream_scanner_test.go

@@ -0,0 +1,521 @@
+package helper
+
+import (
+	"fmt"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"sync"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/constant"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/setting/operation_setting"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func init() {
+	gin.SetMode(gin.TestMode)
+}
+
+func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) {
+	t.Helper()
+
+	oldTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 30
+	t.Cleanup(func() {
+		constant.StreamingTimeout = oldTimeout
+	})
+
+	recorder := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(recorder)
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	resp := &http.Response{
+		Body: io.NopCloser(body),
+	}
+
+	info := &relaycommon.RelayInfo{
+		ChannelMeta: &relaycommon.ChannelMeta{},
+	}
+
+	return c, resp, info
+}
+
+func buildSSEBody(n int) string {
+	var b strings.Builder
+	for i := 0; i < n; i++ {
+		fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i)
+	}
+	b.WriteString("data: [DONE]\n")
+	return b.String()
+}
+
+// slowReader wraps a reader and injects a delay before each Read call,
+// simulating a slow upstream that trickles data.
+type slowReader struct {
+	r     io.Reader
+	delay time.Duration
+}
+
+func (s *slowReader) Read(p []byte) (int, error) {
+	time.Sleep(s.delay)
+	return s.r.Read(p)
+}
+
+// ---------- Basic correctness ----------
+
+func TestStreamScannerHandler_NilInputs(t *testing.T) {
+	t.Parallel()
+
+	recorder := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(recorder)
+	c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+	info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
+
+	StreamScannerHandler(c, nil, info, func(data string) bool { return true })
+	StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
+}
+
+func TestStreamScannerHandler_EmptyBody(t *testing.T) {
+	t.Parallel()
+
+	c, resp, info := setupStreamTest(t, strings.NewReader(""))
+
+	var called atomic.Bool
+	StreamScannerHandler(c, resp, info, func(data string) bool {
+		called.Store(true)
+		return true
+	})
+
+	assert.False(t, called.Load(), "handler should not be called for empty body")
+}
+
+func TestStreamScannerHandler_1000Chunks(t *testing.T) {
+	t.Parallel()
+
+	const numChunks = 1000
+	body := buildSSEBody(numChunks)
+	c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+	var count atomic.Int64
+	StreamScannerHandler(c, resp, info, func(data string) bool {
+		count.Add(1)
+		return true
+	})
+
+	assert.Equal(t, int64(numChunks), count.Load())
+	assert.Equal(t, numChunks, info.ReceivedResponseCount)
+}
+
+func TestStreamScannerHandler_10000Chunks(t *testing.T) {
+	t.Parallel()
+
+	const numChunks = 10000
+	body := buildSSEBody(numChunks)
+	c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+	var count atomic.Int64
+	start := time.Now()
+
+	StreamScannerHandler(c, resp, info, func(data string) bool {
+		count.Add(1)
+		return true
+	})
+
+	elapsed := time.Since(start)
+	assert.Equal(t, int64(numChunks), count.Load())
+	assert.Equal(t, numChunks, info.ReceivedResponseCount)
+	t.Logf("10000 chunks processed in %v", elapsed)
+}
+
+func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
+	t.Parallel()
+
+	const numChunks = 500
+	body := buildSSEBody(numChunks)
+	c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+	var mu sync.Mutex
+	received := make([]string, 0, numChunks)
+
+	StreamScannerHandler(c, resp, info, func(data string) bool {
+		mu.Lock()
+		received = append(received, data)
+		mu.Unlock()
+		return true
+	})
+
+	require.Equal(t, numChunks, len(received))
+	for i := 0; i < numChunks; i++ {
+		expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i)
+		assert.Equal(t, expected, received[i], "chunk %d out of order", i)
+	}
+}
+
+func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
+	t.Parallel()
+
+	body := buildSSEBody(50) + "data: should_not_appear\n"
+	c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+	var count atomic.Int64
+	StreamScannerHandler(c, resp, info, func(data string) bool {
+		count.Add(1)
+		return true
+	})
+
+	assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
+}
+
+func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
+	t.Parallel()
+
+	const numChunks = 200
+	body := buildSSEBody(numChunks)
+	c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+	const failAt = 50
+	var count atomic.Int64
+	StreamScannerHandler(c, resp, info, func(data string) bool {
+		n := count.Add(1)
+		return n < failAt
+	})
+
+	// The worker stops at failAt; the scanner may have read ahead,
+	// but the handler should not be called beyond failAt.
+	assert.Equal(t, int64(failAt), count.Load())
+}
+
+func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
+	t.Parallel()
+
+	var b strings.Builder
+	b.WriteString(": comment line\n")
+	b.WriteString("event: message\n")
+	b.WriteString("id: 12345\n")
+	b.WriteString("retry: 5000\n")
+	for i := 0; i < 100; i++ {
+		fmt.Fprintf(&b, "data: payload_%d\n", i)
+		b.WriteString(": interleaved comment\n")
+	}
+	b.WriteString("data: [DONE]\n")
+
+	c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
+
+	var count atomic.Int64
+	StreamScannerHandler(c, resp, info, func(data string) bool {
+		count.Add(1)
+		return true
+	})
+
+	assert.Equal(t, int64(100), count.Load())
+}
+
+func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
+	t.Parallel()
+
+	body := "data:   {\"trimmed\":true}  \ndata: [DONE]\n"
+	c, resp, info := setupStreamTest(t, strings.NewReader(body))
+
+	var got string
+	StreamScannerHandler(c, resp, info, func(data string) bool {
+		got = data
+		return true
+	})
+
+	assert.Equal(t, "{\"trimmed\":true}", got)
+}
+
+// ---------- Decoupling: scanner not blocked by slow handler ----------
+
+func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
+	t.Parallel()
+
+	// Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
+	// If the scanner were synchronously coupled to the handler, total time would be
+	// ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
+	// With decoupling, total time should be closer to
+	// ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
+	// because the scanner reads ahead into the buffer while the handler processes.
+	const numChunks = 50
+	const upstreamDelay = 10 * time.Millisecond
+	const handlerDelay = 20 * time.Millisecond
+
+	pr, pw := io.Pipe()
+	go func() {
+		defer pw.Close()
+		for i := 0; i < numChunks; i++ {
+			fmt.Fprintf(pw, "data: {\"id\":%d}\n", i)
+			time.Sleep(upstreamDelay)
+		}
+		fmt.Fprint(pw, "data: [DONE]\n")
+	}()
+
+	recorder := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(recorder)
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	oldTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 30
+	t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
+
+	resp := &http.Response{Body: pr}
+	info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
+
+	var count atomic.Int64
+	start := time.Now()
+	done := make(chan struct{})
+	go func() {
+		StreamScannerHandler(c, resp, info, func(data string) bool {
+			time.Sleep(handlerDelay)
+			count.Add(1)
+			return true
+		})
+		close(done)
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(15 * time.Second):
+		t.Fatal("StreamScannerHandler did not complete in time")
+	}
+
+	elapsed := time.Since(start)
+	assert.Equal(t, int64(numChunks), count.Load())
+
+	coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
+	t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
+
+	// If decoupled, elapsed should be well under the coupled estimate.
+	assert.Less(t, elapsed, coupledTime*85/100,
+		"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
+}
+
+func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
+	t.Parallel()
+
+	const numChunks = 50
+	body := buildSSEBody(numChunks)
+	reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond}
+	c, resp, info := setupStreamTest(t, reader)
+
+	var count atomic.Int64
+	start := time.Now()
+
+	done := make(chan struct{})
+	go func() {
+		StreamScannerHandler(c, resp, info, func(data string) bool {
+			count.Add(1)
+			return true
+		})
+		close(done)
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(15 * time.Second):
+		t.Fatal("timed out with slow upstream")
+	}
+
+	elapsed := time.Since(start)
+	assert.Equal(t, int64(numChunks), count.Load())
+	t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed)
+}
+
+// ---------- Ping tests ----------
+
+func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
+	t.Parallel()
+
+	setting := operation_setting.GetGeneralSetting()
+	oldEnabled := setting.PingIntervalEnabled
+	oldSeconds := setting.PingIntervalSeconds
+	setting.PingIntervalEnabled = true
+	setting.PingIntervalSeconds = 1
+	t.Cleanup(func() {
+		setting.PingIntervalEnabled = oldEnabled
+		setting.PingIntervalSeconds = oldSeconds
+	})
+
+	// Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
+	// The ping interval is 1s, so we should see at least 2 pings.
+	pr, pw := io.Pipe()
+	go func() {
+		defer pw.Close()
+		for i := 0; i < 7; i++ {
+			fmt.Fprintf(pw, "data: chunk_%d\n", i)
+			time.Sleep(500 * time.Millisecond)
+		}
+		fmt.Fprint(pw, "data: [DONE]\n")
+	}()
+
+	recorder := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(recorder)
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	oldTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 30
+	t.Cleanup(func() {
+		constant.StreamingTimeout = oldTimeout
+	})
+
+	resp := &http.Response{Body: pr}
+	info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
+
+	var count atomic.Int64
+	done := make(chan struct{})
+	go func() {
+		StreamScannerHandler(c, resp, info, func(data string) bool {
+			count.Add(1)
+			return true
+		})
+		close(done)
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(15 * time.Second):
+		t.Fatal("timed out waiting for stream to finish")
+	}
+
+	assert.Equal(t, int64(7), count.Load())
+
+	body := recorder.Body.String()
+	pingCount := strings.Count(body, ": PING")
+	t.Logf("received %d pings in response body", pingCount)
+	assert.GreaterOrEqual(t, pingCount, 2,
+		"expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount)
+}
+
+func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
+	t.Parallel()
+
+	setting := operation_setting.GetGeneralSetting()
+	oldEnabled := setting.PingIntervalEnabled
+	oldSeconds := setting.PingIntervalSeconds
+	setting.PingIntervalEnabled = true
+	setting.PingIntervalSeconds = 1
+	t.Cleanup(func() {
+		setting.PingIntervalEnabled = oldEnabled
+		setting.PingIntervalSeconds = oldSeconds
+	})
+
+	pr, pw := io.Pipe()
+	go func() {
+		defer pw.Close()
+		for i := 0; i < 5; i++ {
+			fmt.Fprintf(pw, "data: chunk_%d\n", i)
+			time.Sleep(500 * time.Millisecond)
+		}
+		fmt.Fprint(pw, "data: [DONE]\n")
+	}()
+
+	recorder := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(recorder)
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	oldTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 30
+	t.Cleanup(func() {
+		constant.StreamingTimeout = oldTimeout
+	})
+
+	resp := &http.Response{Body: pr}
+	info := &relaycommon.RelayInfo{
+		DisablePing: true,
+		ChannelMeta: &relaycommon.ChannelMeta{},
+	}
+
+	var count atomic.Int64
+	done := make(chan struct{})
+	go func() {
+		StreamScannerHandler(c, resp, info, func(data string) bool {
+			count.Add(1)
+			return true
+		})
+		close(done)
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(15 * time.Second):
+		t.Fatal("timed out")
+	}
+
+	assert.Equal(t, int64(5), count.Load())
+
+	body := recorder.Body.String()
+	pingCount := strings.Count(body, ": PING")
+	assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
+}
+
+func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
+	t.Parallel()
+
+	setting := operation_setting.GetGeneralSetting()
+	oldEnabled := setting.PingIntervalEnabled
+	oldSeconds := setting.PingIntervalSeconds
+	setting.PingIntervalEnabled = true
+	setting.PingIntervalSeconds = 1
+	t.Cleanup(func() {
+		setting.PingIntervalEnabled = oldEnabled
+		setting.PingIntervalSeconds = oldSeconds
+	})
+
+	// Slow upstream + slow handler. Total stream takes ~5 seconds.
+	// The ping goroutine stays alive as long as the scanner is reading,
+	// so pings should fire between data writes.
+	pr, pw := io.Pipe()
+	go func() {
+		defer pw.Close()
+		for i := 0; i < 10; i++ {
+			fmt.Fprintf(pw, "data: chunk_%d\n", i)
+			time.Sleep(500 * time.Millisecond)
+		}
+		fmt.Fprint(pw, "data: [DONE]\n")
+	}()
+
+	recorder := httptest.NewRecorder()
+	c, _ := gin.CreateTestContext(recorder)
+	c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+	oldTimeout := constant.StreamingTimeout
+	constant.StreamingTimeout = 30
+	t.Cleanup(func() {
+		constant.StreamingTimeout = oldTimeout
+	})
+
+	resp := &http.Response{Body: pr}
+	info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
+
+	var count atomic.Int64
+	done := make(chan struct{})
+	go func() {
+		StreamScannerHandler(c, resp, info, func(data string) bool {
+			count.Add(1)
+			return true
+		})
+		close(done)
+	}()
+
+	select {
+	case <-done:
+	case <-time.After(15 * time.Second):
+		t.Fatal("timed out")
+	}
+
+	assert.Equal(t, int64(10), count.Load())
+
+	body := recorder.Body.String()
+	pingCount := strings.Count(body, ": PING")
+	t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount)
+	assert.GreaterOrEqual(t, pingCount, 3,
+		"expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount)
+}

+ 2 - 2
relay/mjproxy_handler.go

@@ -184,7 +184,7 @@ func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyR
 	if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
 		return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
 	}
-	modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
+	modelName := service.CovertMjpActionToModelName(constant.MjActionSwapFace)
 
 	priceData := helper.ModelPriceHelperPerCall(c, info)
 
@@ -485,7 +485,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dt
 
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 
-	modelName := service.CoverActionToModelName(midjRequest.Action)
+	modelName := service.CovertMjpActionToModelName(midjRequest.Action)
 
 	priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
 

+ 324 - 280
relay/relay_task.go

@@ -2,7 +2,6 @@ package relay
 
 import (
 	"bytes"
-	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
@@ -15,29 +14,33 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay/channel"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	relayconstant "github.com/QuantumNous/new-api/relay/constant"
+	"github.com/QuantumNous/new-api/relay/helper"
 	"github.com/QuantumNous/new-api/service"
-	"github.com/QuantumNous/new-api/setting/ratio_setting"
-
 	"github.com/gin-gonic/gin"
 )
 
-/*
-Task 任务通过平台、Action 区分任务
-*/
-func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
-	info.InitChannelMeta(c)
-	// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
-	if info.TaskRelayInfo == nil {
-		info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
-	}
+type TaskSubmitResult struct {
+	UpstreamTaskID string
+	TaskData       []byte
+	Platform       constant.TaskPlatform
+	Quota          int
+	//PerCallPrice   types.PriceData
+}
+
+// ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
+// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
+// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key),
+// 以及提取 OtherRatios(时长、分辨率)。
+// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
+func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
+	// 检测 remix action
 	path := c.Request.URL.Path
 	if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
 		info.Action = constant.TaskActionRemix
 	}
-
-	// 提取 remix 任务的 video_id
 	if info.Action == constant.TaskActionRemix {
 		videoID := c.Param("video_id")
 		if strings.TrimSpace(videoID) == "" {
@@ -46,64 +49,71 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 		info.OriginTaskID = videoID
 	}
 
-	platform := constant.TaskPlatform(c.GetString("platform"))
+	if info.OriginTaskID == "" {
+		return nil
+	}
 
-	// 获取原始任务信息
-	if info.OriginTaskID != "" {
-		originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
-		if err != nil {
-			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
-			return
-		}
-		if !exist {
-			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
-			return
-		}
-		if info.OriginModelName == "" {
-			if originTask.Properties.OriginModelName != "" {
-				info.OriginModelName = originTask.Properties.OriginModelName
-			} else if originTask.Properties.UpstreamModelName != "" {
-				info.OriginModelName = originTask.Properties.UpstreamModelName
-			} else {
-				var taskData map[string]interface{}
-				_ = json.Unmarshal(originTask.Data, &taskData)
-				if m, ok := taskData["model"].(string); ok && m != "" {
-					info.OriginModelName = m
-					platform = originTask.Platform
-				}
+	// 查找原始任务
+	originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
+	if err != nil {
+		return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+	}
+	if !exist {
+		return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+	}
+
+	// 从原始任务推导模型名称
+	if info.OriginModelName == "" {
+		if originTask.Properties.OriginModelName != "" {
+			info.OriginModelName = originTask.Properties.OriginModelName
+		} else if originTask.Properties.UpstreamModelName != "" {
+			info.OriginModelName = originTask.Properties.UpstreamModelName
+		} else {
+			var taskData map[string]interface{}
+			_ = common.Unmarshal(originTask.Data, &taskData)
+			if m, ok := taskData["model"].(string); ok && m != "" {
+				info.OriginModelName = m
 			}
 		}
-		if originTask.ChannelId != info.ChannelId {
-			channel, err := model.GetChannelById(originTask.ChannelId, true)
-			if err != nil {
-				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
-				return
-			}
-			if channel.Status != common.ChannelStatusEnabled {
-				taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
-				return
-			}
-			key, _, newAPIError := channel.GetNextEnabledKey()
-			if newAPIError != nil {
-				taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
-				return
-			}
-			common.SetContextKey(c, constant.ContextKeyChannelKey, key)
-			common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
-			common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
-			common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
-
-			info.ChannelBaseUrl = channel.GetBaseURL()
-			info.ChannelId = originTask.ChannelId
-			info.ChannelType = channel.Type
-			info.ApiKey = key
-			platform = originTask.Platform
+	}
+
+	// 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key)
+	ch, err := model.GetChannelById(originTask.ChannelId, true)
+	if err != nil {
+		return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+	}
+	if ch.Status != common.ChannelStatusEnabled {
+		return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
+	}
+	info.LockedChannel = ch
+
+	if originTask.ChannelId != info.ChannelId {
+		key, _, newAPIError := ch.GetNextEnabledKey()
+		if newAPIError != nil {
+			return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
 		}
+		common.SetContextKey(c, constant.ContextKeyChannelKey, key)
+		common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
+		common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
+		common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
+
+		info.ChannelBaseUrl = ch.GetBaseURL()
+		info.ChannelId = originTask.ChannelId
+		info.ChannelType = ch.Type
+		info.ApiKey = key
+	}
 
-		// 使用原始任务的参数
-		if info.Action == constant.TaskActionRemix {
+	// 提取 remix 参数(时长、分辨率 → OtherRatios)
+	if info.Action == constant.TaskActionRemix {
+		if originTask.PrivateData.BillingContext != nil {
+			// 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在)
+			for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
+				info.PriceData.AddOtherRatio(s, f)
+			}
+		} else {
+			// 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在)
 			var taskData map[string]interface{}
-			_ = json.Unmarshal(originTask.Data, &taskData)
+			_ = common.Unmarshal(originTask.Data, &taskData)
 			secondsStr, _ := taskData["seconds"].(string)
 			seconds, _ := strconv.Atoi(secondsStr)
 			if seconds <= 0 {
@@ -120,167 +130,146 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 			}
 		}
 	}
+
+	return nil
+}
+
+// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
+// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
+// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
+// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
+// 控制器负责 defer Refund 和成功后 Settle。
+func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
+	info.InitChannelMeta(c)
+
+	// 1. 确定 platform → 创建适配器 → 验证请求
+	platform := constant.TaskPlatform(c.GetString("platform"))
 	if platform == "" {
 		platform = GetTaskPlatform(c)
 	}
-
-	info.InitChannelMeta(c)
 	adaptor := GetTaskAdaptor(platform)
 	if adaptor == nil {
-		return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
+		return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
 	}
 	adaptor.Init(info)
-	// get & validate taskRequest 获取并验证文本请求
-	taskErr = adaptor.ValidateRequestAndSetAction(c, info)
-	if taskErr != nil {
-		return
+	if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
+		return nil, taskErr
 	}
 
+	// 2. 确定模型名称
 	modelName := info.OriginModelName
 	if modelName == "" {
 		modelName = service.CoverTaskActionToModelName(platform, info.Action)
 	}
-	modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
-	if !success {
-		defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
-		if !ok {
-			modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
-		} else {
-			modelPrice = defaultPrice
-		}
+
+	// 2.5 应用渠道的模型映射(与同步任务对齐)
+	info.OriginModelName = modelName
+	info.UpstreamModelName = modelName
+	if err := helper.ModelMappedHelper(c, info, nil); err != nil {
+		return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
 	}
 
-	// 处理 auto 分组:从 context 获取实际选中的分组
-	// 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
-	if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
-		if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
-			info.UsingGroup = groupStr
-		}
+	// 3. 预生成公开 task ID(仅首次)
+	if info.PublicTaskID == "" {
+		info.PublicTaskID = model.GenerateTaskID()
 	}
 
-	// 预扣
-	groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
-	var ratio float64
-	userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
-	if hasUserGroupRatio {
-		ratio = modelPrice * userGroupRatio
-	} else {
-		ratio = modelPrice * groupRatio
+	// 4. 价格计算:基础模型价格
+	info.OriginModelName = modelName
+	info.PriceData = helper.ModelPriceHelperPerCall(c, info)
+
+	// 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
+	//    必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
+	//    ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
+	if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
+		for k, v := range estimatedRatios {
+			info.PriceData.AddOtherRatio(k, v)
+		}
 	}
-	// FIXME: 临时修补,支持任务仅按次计费
+
+	// 6. 将 OtherRatios 应用到基础额度
 	if !common.StringsContains(constant.TaskPricePatches, modelName) {
-		if len(info.PriceData.OtherRatios) > 0 {
-			for _, ra := range info.PriceData.OtherRatios {
-				if 1.0 != ra {
-					ratio *= ra
-				}
+		for _, ra := range info.PriceData.OtherRatios {
+			if ra != 1.0 {
+				info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
 			}
 		}
 	}
-	println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
-	userQuota, err := model.GetUserQuota(info.UserId, false)
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
-		return
-	}
-	quota := int(ratio * common.QuotaPerUnit)
-	if userQuota-quota < 0 {
-		taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
-		return
+
+	// 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
+	if info.Billing == nil && !info.PriceData.FreeModel {
+		info.ForcePreConsume = true
+		if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
+			return nil, service.TaskErrorFromAPIError(apiErr)
+		}
 	}
 
-	// build body
+	// 8. 构建请求体
 	requestBody, err := adaptor.BuildRequestBody(c, info)
 	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
-		return
+		return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
 	}
-	// do request
+
+	// 9. 发送请求
 	resp, err := adaptor.DoRequest(c, info, requestBody)
 	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
-		return
+		return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}
-	// handle response
 	if resp != nil && resp.StatusCode != http.StatusOK {
 		responseBody, _ := io.ReadAll(resp.Body)
-		taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
-		return
+		return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
 	}
 
-	defer func() {
-		// release quota
-		if info.ConsumeQuota && taskErr == nil {
-
-			err := service.PostConsumeQuota(info, quota, 0, true)
-			if err != nil {
-				common.SysLog("error consuming token remain quota: " + err.Error())
-			}
-			if quota != 0 {
-				tokenName := c.GetString("token_name")
-				//gRatio := groupRatio
-				//if hasUserGroupRatio {
-				//	gRatio = userGroupRatio
-				//}
-				logContent := fmt.Sprintf("操作 %s", info.Action)
-				// FIXME: 临时修补,支持任务仅按次计费
-				if common.StringsContains(constant.TaskPricePatches, modelName) {
-					logContent = fmt.Sprintf("%s,按次计费", logContent)
-				} else {
-					if len(info.PriceData.OtherRatios) > 0 {
-						var contents []string
-						for key, ra := range info.PriceData.OtherRatios {
-							if 1.0 != ra {
-								contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
-							}
-						}
-						if len(contents) > 0 {
-							logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
-						}
-					}
-				}
-				other := make(map[string]interface{})
-				if c != nil && c.Request != nil && c.Request.URL != nil {
-					other["request_path"] = c.Request.URL.Path
-				}
-				other["model_price"] = modelPrice
-				other["group_ratio"] = groupRatio
-				if hasUserGroupRatio {
-					other["user_group_ratio"] = userGroupRatio
-				}
-				model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
-					ChannelId: info.ChannelId,
-					ModelName: modelName,
-					TokenName: tokenName,
-					Quota:     quota,
-					Content:   logContent,
-					TokenId:   info.TokenId,
-					Group:     info.UsingGroup,
-					Other:     other,
-				})
-				model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
-				model.UpdateChannelUsedQuota(info.ChannelId, quota)
-			}
-		}
-	}()
+	// 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
+	otherRatios := info.PriceData.OtherRatios
+	if otherRatios == nil {
+		otherRatios = map[string]float64{}
+	}
+	ratiosJSON, _ := common.Marshal(otherRatios)
+	c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
 
-	taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
+	// 11. 解析响应
+	upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
 	if taskErr != nil {
-		return
+		return nil, taskErr
 	}
-	info.ConsumeQuota = true
-	// insert task
-	task := model.InitTask(platform, info)
-	task.TaskID = taskID
-	task.Quota = quota
-	task.Data = taskData
-	task.Action = info.Action
-	err = task.Insert()
-	if err != nil {
-		taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
-		return
+
+	// 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
+	finalQuota := info.PriceData.Quota
+	if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
+		// 基于调整后的 ratios 重新计算 quota
+		finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
+		info.PriceData.OtherRatios = adjustedRatios
+		info.PriceData.Quota = finalQuota
 	}
-	return nil
+
+	return &TaskSubmitResult{
+		UpstreamTaskID: upstreamTaskID,
+		TaskData:       taskData,
+		Platform:       platform,
+		Quota:          finalQuota,
+	}, nil
+}
+
+// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
+// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
+func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
+	// 从 PriceData 获取不含 OtherRatios 的基础价格
+	baseQuota := info.PriceData.Quota
+	// 先除掉原有的 OtherRatios 恢复基础额度
+	for _, ra := range info.PriceData.OtherRatios {
+		if ra != 1.0 && ra > 0 {
+			baseQuota = int(float64(baseQuota) / ra)
+		}
+	}
+	// 应用新的 ratios
+	result := float64(baseQuota)
+	for _, ra := range ratios {
+		if ra != 1.0 {
+			result *= ra
+		}
+	}
+	return int(result)
 }
 
 var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
@@ -336,7 +325,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta
 	} else {
 		tasks = make([]any, 0)
 	}
-	respBody, err = json.Marshal(dto.TaskResponse[[]any]{
+	respBody, err = common.Marshal(dto.TaskResponse[[]any]{
 		Code: "success",
 		Data: tasks,
 	})
@@ -357,7 +346,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
 		return
 	}
 
-	respBody, err = json.Marshal(dto.TaskResponse[any]{
+	respBody, err = common.Marshal(dto.TaskResponse[any]{
 		Code: "success",
 		Data: TaskModel2Dto(originTask),
 	})
@@ -381,97 +370,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 		return
 	}
 
-	func() {
-		channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
-		if err2 != nil {
-			return
-		}
-		if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
-			return
-		}
-		baseURL := constant.ChannelBaseURLs[channelModel.Type]
-		if channelModel.GetBaseURL() != "" {
-			baseURL = channelModel.GetBaseURL()
-		}
-		proxy := channelModel.GetSetting().Proxy
-		adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
-		if adaptor == nil {
-			return
-		}
-		resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
-			"task_id": originTask.TaskID,
-			"action":  originTask.Action,
-		}, proxy)
-		if err2 != nil || resp == nil {
-			return
-		}
-		defer resp.Body.Close()
-		body, err2 := io.ReadAll(resp.Body)
-		if err2 != nil {
-			return
-		}
-		ti, err2 := adaptor.ParseTaskResult(body)
-		if err2 == nil && ti != nil {
-			if ti.Status != "" {
-				originTask.Status = model.TaskStatus(ti.Status)
-			}
-			if ti.Progress != "" {
-				originTask.Progress = ti.Progress
-			}
-			if ti.Url != "" {
-				if strings.HasPrefix(ti.Url, "data:") {
-				} else {
-					originTask.FailReason = ti.Url
-				}
-			}
-			_ = originTask.Update()
-			var raw map[string]any
-			_ = json.Unmarshal(body, &raw)
-			format := "mp4"
-			if respObj, ok := raw["response"].(map[string]any); ok {
-				if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
-					if v0, ok := vids[0].(map[string]any); ok {
-						if mt, ok := v0["mimeType"].(string); ok && mt != "" {
-							if strings.Contains(mt, "mp4") {
-								format = "mp4"
-							} else {
-								format = mt
-							}
-						}
-					}
-				}
-			}
-			status := "processing"
-			switch originTask.Status {
-			case model.TaskStatusSuccess:
-				status = "succeeded"
-			case model.TaskStatusFailure:
-				status = "failed"
-			case model.TaskStatusQueued, model.TaskStatusSubmitted:
-				status = "queued"
-			}
-			if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
-				out := map[string]any{
-					"error":    nil,
-					"format":   format,
-					"metadata": nil,
-					"status":   status,
-					"task_id":  originTask.TaskID,
-					"url":      originTask.FailReason,
-				}
-				respBody, _ = json.Marshal(dto.TaskResponse[any]{
-					Code: "success",
-					Data: out,
-				})
-			}
-		}
-	}()
+	isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
 
-	if len(respBody) != 0 {
+	// Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
+	if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
+		respBody = realtimeResp
 		return
 	}
 
-	if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
+	// OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
+	if isOpenAIVideoAPI {
 		adaptor := GetTaskAdaptor(originTask.Platform)
 		if adaptor == nil {
 			taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
@@ -486,10 +394,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 			respBody = openAIVideoData
 			return
 		}
-		taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
+		taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
 		return
 	}
-	respBody, err = json.Marshal(dto.TaskResponse[any]{
+
+	// 通用 TaskDto 格式
+	respBody, err = common.Marshal(dto.TaskResponse[any]{
 		Code: "success",
 		Data: TaskModel2Dto(originTask),
 	})
@@ -499,16 +409,150 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
 	return
 }
 
+// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
+// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
+// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
+func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
+	channelModel, err := model.GetChannelById(task.ChannelId, true)
+	if err != nil {
+		return nil
+	}
+	if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
+		return nil
+	}
+
+	baseURL := constant.ChannelBaseURLs[channelModel.Type]
+	if channelModel.GetBaseURL() != "" {
+		baseURL = channelModel.GetBaseURL()
+	}
+	proxy := channelModel.GetSetting().Proxy
+	adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
+	if adaptor == nil {
+		return nil
+	}
+
+	resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
+		"task_id": task.GetUpstreamTaskID(),
+		"action":  task.Action,
+	}, proxy)
+	if err != nil || resp == nil {
+		return nil
+	}
+	defer resp.Body.Close()
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil
+	}
+
+	ti, err := adaptor.ParseTaskResult(body)
+	if err != nil || ti == nil {
+		return nil
+	}
+
+	snap := task.Snapshot()
+
+	// 将上游最新状态更新到 task
+	if ti.Status != "" {
+		task.Status = model.TaskStatus(ti.Status)
+	}
+	if ti.Progress != "" {
+		task.Progress = ti.Progress
+	}
+	if strings.HasPrefix(ti.Url, "data:") {
+		// data: URI — kept in Data, not ResultURL
+	} else if ti.Url != "" {
+		task.PrivateData.ResultURL = ti.Url
+	} else if task.Status == model.TaskStatusSuccess {
+		// No URL from adaptor — construct proxy URL using public task ID
+		task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
+	}
+
+	if !snap.Equal(task.Snapshot()) {
+		_, _ = task.UpdateWithStatus(snap.Status)
+	}
+
+	// OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
+	if isOpenAIVideoAPI {
+		return nil
+	}
+
+	// 非 OpenAI Video API: 构建自定义格式响应
+	format := detectVideoFormat(body)
+	out := map[string]any{
+		"error":    nil,
+		"format":   format,
+		"metadata": nil,
+		"status":   mapTaskStatusToSimple(task.Status),
+		"task_id":  task.TaskID,
+		"url":      task.GetResultURL(),
+	}
+	respBody, _ := common.Marshal(dto.TaskResponse[any]{
+		Code: "success",
+		Data: out,
+	})
+	return respBody
+}
+
+// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
+func detectVideoFormat(rawBody []byte) string {
+	var raw map[string]any
+	if err := common.Unmarshal(rawBody, &raw); err != nil {
+		return "mp4"
+	}
+	respObj, ok := raw["response"].(map[string]any)
+	if !ok {
+		return "mp4"
+	}
+	vids, ok := respObj["videos"].([]any)
+	if !ok || len(vids) == 0 {
+		return "mp4"
+	}
+	v0, ok := vids[0].(map[string]any)
+	if !ok {
+		return "mp4"
+	}
+	mt, ok := v0["mimeType"].(string)
+	if !ok || mt == "" || strings.Contains(mt, "mp4") {
+		return "mp4"
+	}
+	return mt
+}
+
+// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
+func mapTaskStatusToSimple(status model.TaskStatus) string {
+	switch status {
+	case model.TaskStatusSuccess:
+		return "succeeded"
+	case model.TaskStatusFailure:
+		return "failed"
+	case model.TaskStatusQueued, model.TaskStatusSubmitted:
+		return "queued"
+	default:
+		return "processing"
+	}
+}
+
 func TaskModel2Dto(task *model.Task) *dto.TaskDto {
 	return &dto.TaskDto{
+		ID:         task.ID,
+		CreatedAt:  task.CreatedAt,
+		UpdatedAt:  task.UpdatedAt,
 		TaskID:     task.TaskID,
+		Platform:   string(task.Platform),
+		UserId:     task.UserId,
+		Group:      task.Group,
+		ChannelId:  task.ChannelId,
+		Quota:      task.Quota,
 		Action:     task.Action,
 		Status:     string(task.Status),
 		FailReason: task.FailReason,
+		ResultURL:  task.GetResultURL(),
 		SubmitTime: task.SubmitTime,
 		StartTime:  task.StartTime,
 		FinishTime: task.FinishTime,
 		Progress:   task.Progress,
+		Properties: task.Properties,
+		Username:   task.Username,
 		Data:       task.Data,
 	}
 }

+ 1 - 1
relay/responses_handler.go

@@ -89,7 +89,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		}
 
 		// remove disabled fields for OpenAI Responses API
-		jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
+		jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
 		if err != nil {
 			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 		}

+ 6 - 1
router/api-router.go

@@ -13,6 +13,7 @@ import (
 
 func SetApiRouter(router *gin.Engine) {
 	apiRouter := router.Group("/api")
+	apiRouter.Use(middleware.RouteTag("api"))
 	apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
 	apiRouter.Use(middleware.BodyStorageCleanup()) // 清理请求体存储
 	apiRouter.Use(middleware.GlobalAPIRateLimit())
@@ -114,6 +115,9 @@ func SetApiRouter(router *gin.Engine) {
 				adminRoute.GET("/topup", controller.GetAllTopUps)
 				adminRoute.POST("/topup/complete", controller.AdminCompleteTopUp)
 				adminRoute.GET("/search", controller.SearchUsers)
+				adminRoute.GET("/:id/oauth/bindings", controller.GetUserOAuthBindingsByAdmin)
+				adminRoute.DELETE("/:id/oauth/bindings/:provider_id", controller.UnbindCustomOAuthByAdmin)
+				adminRoute.DELETE("/:id/bindings/:binding_type", controller.AdminClearUserBinding)
 				adminRoute.GET("/:id", controller.GetUser)
 				adminRoute.POST("/", controller.CreateUser)
 				adminRoute.POST("/manage", controller.ManageUser)
@@ -170,10 +174,11 @@ func SetApiRouter(router *gin.Engine) {
 			optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
 		}
 
-		// Custom OAuth provider management (admin only)
+		// Custom OAuth provider management (root only)
 		customOAuthRoute := apiRouter.Group("/custom-oauth-provider")
 		customOAuthRoute.Use(middleware.RootAuth())
 		{
+			customOAuthRoute.POST("/discovery", controller.FetchCustomOAuthDiscovery)
 			customOAuthRoute.GET("/", controller.GetCustomOAuthProviders)
 			customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider)
 			customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider)

+ 1 - 0
router/dashboard.go

@@ -9,6 +9,7 @@ import (
 
 func SetDashboardRouter(router *gin.Engine) {
 	apiRouter := router.Group("/")
+	apiRouter.Use(middleware.RouteTag("old_api"))
 	apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
 	apiRouter.Use(middleware.GlobalAPIRateLimit())
 	apiRouter.Use(middleware.CORS())

+ 2 - 0
router/main.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/middleware"
 
 	"github.com/gin-gonic/gin"
 )
@@ -27,6 +28,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 	} else {
 		frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
 		router.NoRoute(func(c *gin.Context) {
+			c.Set(middleware.RouteTagKey, "web")
 			c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
 		})
 	}

+ 11 - 2
router/relay-router.go

@@ -17,6 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
 	router.Use(middleware.StatsMiddleware())
 	// https://platform.openai.com/docs/api-reference/introduction
 	modelsRouter := router.Group("/v1/models")
+	modelsRouter.Use(middleware.RouteTag("relay"))
 	modelsRouter.Use(middleware.TokenAuth())
 	{
 		modelsRouter.GET("", func(c *gin.Context) {
@@ -41,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
 	}
 
 	geminiRouter := router.Group("/v1beta/models")
+	geminiRouter.Use(middleware.RouteTag("relay"))
 	geminiRouter.Use(middleware.TokenAuth())
 	{
 		geminiRouter.GET("", func(c *gin.Context) {
@@ -49,6 +51,7 @@ func SetRelayRouter(router *gin.Engine) {
 	}
 
 	geminiCompatibleRouter := router.Group("/v1beta/openai/models")
+	geminiCompatibleRouter.Use(middleware.RouteTag("relay"))
 	geminiCompatibleRouter.Use(middleware.TokenAuth())
 	{
 		geminiCompatibleRouter.GET("", func(c *gin.Context) {
@@ -57,12 +60,14 @@ func SetRelayRouter(router *gin.Engine) {
 	}
 
 	playgroundRouter := router.Group("/pg")
+	playgroundRouter.Use(middleware.RouteTag("relay"))
 	playgroundRouter.Use(middleware.SystemPerformanceCheck())
 	playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
 	{
 		playgroundRouter.POST("/chat/completions", controller.Playground)
 	}
 	relayV1Router := router.Group("/v1")
+	relayV1Router.Use(middleware.RouteTag("relay"))
 	relayV1Router.Use(middleware.SystemPerformanceCheck())
 	relayV1Router.Use(middleware.TokenAuth())
 	relayV1Router.Use(middleware.ModelRequestRateLimit())
@@ -161,24 +166,28 @@ func SetRelayRouter(router *gin.Engine) {
 	}
 
 	relayMjRouter := router.Group("/mj")
+	relayMjRouter.Use(middleware.RouteTag("relay"))
 	relayMjRouter.Use(middleware.SystemPerformanceCheck())
 	registerMjRouterGroup(relayMjRouter)
 
 	relayMjModeRouter := router.Group("/:mode/mj")
+	relayMjModeRouter.Use(middleware.RouteTag("relay"))
 	relayMjModeRouter.Use(middleware.SystemPerformanceCheck())
 	registerMjRouterGroup(relayMjModeRouter)
 	//relayMjRouter.Use()
 
 	relaySunoRouter := router.Group("/suno")
+	relaySunoRouter.Use(middleware.RouteTag("relay"))
 	relaySunoRouter.Use(middleware.SystemPerformanceCheck())
 	relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
 		relaySunoRouter.POST("/submit/:action", controller.RelayTask)
-		relaySunoRouter.POST("/fetch", controller.RelayTask)
-		relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
+		relaySunoRouter.POST("/fetch", controller.RelayTaskFetch)
+		relaySunoRouter.GET("/fetch/:id", controller.RelayTaskFetch)
 	}
 
 	relayGeminiRouter := router.Group("/v1beta")
+	relayGeminiRouter.Use(middleware.RouteTag("relay"))
 	relayGeminiRouter.Use(middleware.SystemPerformanceCheck())
 	relayGeminiRouter.Use(middleware.TokenAuth())
 	relayGeminiRouter.Use(middleware.ModelRequestRateLimit())

+ 15 - 5
router/video-router.go

@@ -8,32 +8,42 @@ import (
 )
 
 func SetVideoRouter(router *gin.Engine) {
+	// Video proxy: accepts either session auth (dashboard) or token auth (API clients)
+	videoProxyRouter := router.Group("/v1")
+	videoProxyRouter.Use(middleware.RouteTag("relay"))
+	videoProxyRouter.Use(middleware.TokenOrUserAuth())
+	{
+		videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy)
+	}
+
 	videoV1Router := router.Group("/v1")
+	videoV1Router.Use(middleware.RouteTag("relay"))
 	videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
-		videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
 		videoV1Router.POST("/video/generations", controller.RelayTask)
-		videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+		videoV1Router.GET("/video/generations/:task_id", controller.RelayTaskFetch)
 		videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
 	}
 	// openai compatible API video routes
 	// docs: https://platform.openai.com/docs/api-reference/videos/create
 	{
 		videoV1Router.POST("/videos", controller.RelayTask)
-		videoV1Router.GET("/videos/:task_id", controller.RelayTask)
+		videoV1Router.GET("/videos/:task_id", controller.RelayTaskFetch)
 	}
 
 	klingV1Router := router.Group("/kling/v1")
+	klingV1Router.Use(middleware.RouteTag("relay"))
 	klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
 	{
 		klingV1Router.POST("/videos/text2video", controller.RelayTask)
 		klingV1Router.POST("/videos/image2video", controller.RelayTask)
-		klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask)
-		klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask)
+		klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTaskFetch)
+		klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTaskFetch)
 	}
 
 	// Jimeng official API routes - direct mapping to official API format
 	jimengOfficialGroup := router.Group("jimeng")
+	jimengOfficialGroup.Use(middleware.RouteTag("relay"))
 	jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
 	{
 		// Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31

+ 1 - 0
router/web-router.go

@@ -19,6 +19,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
 	router.Use(middleware.Cache())
 	router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/dist")))
 	router.NoRoute(func(c *gin.Context) {
+		c.Set(middleware.RouteTagKey, "web")
 		if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") || strings.HasPrefix(c.Request.RequestURI, "/assets") {
 			controller.RelayNotFound(c)
 			return

+ 5 - 0
service/billing_session.go

@@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
 
 // shouldTrust 统一信任额度检查,适用于钱包和订阅。
 func (s *BillingSession) shouldTrust(c *gin.Context) bool {
+	// 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路
+	if s.relayInfo.ForcePreConsume {
+		return false
+	}
+
 	trustQuota := common.GetTrustQuota()
 	if trustQuota <= 0 {
 		return false

+ 54 - 6
service/channel_affinity.go

@@ -13,6 +13,7 @@ import (
 	"github.com/QuantumNous/new-api/dto"
 	"github.com/QuantumNous/new-api/pkg/cachex"
 	"github.com/QuantumNous/new-api/setting/operation_setting"
+	"github.com/QuantumNous/new-api/types"
 	"github.com/gin-gonic/gin"
 	"github.com/samber/hot"
 	"github.com/tidwall/gjson"
@@ -61,6 +62,12 @@ type ChannelAffinityStatsContext struct {
 	TTLSeconds     int64
 }
 
+const (
+	cacheTokenRateModeCachedOverPrompt           = "cached_over_prompt"
+	cacheTokenRateModeCachedOverPromptPlusCached = "cached_over_prompt_plus_cached"
+	cacheTokenRateModeMixed                      = "mixed"
+)
+
 type ChannelAffinityCacheStats struct {
 	Enabled       bool           `json:"enabled"`
 	Total         int            `json:"total"`
@@ -565,9 +572,10 @@ func RecordChannelAffinity(c *gin.Context, channelID int) {
 }
 
 type ChannelAffinityUsageCacheStats struct {
-	RuleName       string `json:"rule_name"`
-	UsingGroup     string `json:"using_group"`
-	KeyFingerprint string `json:"key_fp"`
+	RuleName            string `json:"rule_name"`
+	UsingGroup          string `json:"using_group"`
+	KeyFingerprint      string `json:"key_fp"`
+	CachedTokenRateMode string `json:"cached_token_rate_mode"`
 
 	Hit           int64 `json:"hit"`
 	Total         int64 `json:"total"`
@@ -582,6 +590,8 @@ type ChannelAffinityUsageCacheStats struct {
 }
 
 type ChannelAffinityUsageCacheCounters struct {
+	CachedTokenRateMode string `json:"cached_token_rate_mode"`
+
 	Hit           int64 `json:"hit"`
 	Total         int64 `json:"total"`
 	WindowSeconds int64 `json:"window_seconds"`
@@ -596,12 +606,17 @@ type ChannelAffinityUsageCacheCounters struct {
 
 var channelAffinityUsageCacheStatsLocks [64]sync.Mutex
 
-func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage) {
+// ObserveChannelAffinityUsageCacheByRelayFormat records usage cache stats with a stable rate mode derived from relay format.
+func ObserveChannelAffinityUsageCacheByRelayFormat(c *gin.Context, usage *dto.Usage, relayFormat types.RelayFormat) {
+	ObserveChannelAffinityUsageCacheFromContext(c, usage, cachedTokenRateModeByRelayFormat(relayFormat))
+}
+
+func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage, cachedTokenRateMode string) {
 	statsCtx, ok := GetChannelAffinityStatsContext(c)
 	if !ok {
 		return
 	}
-	observeChannelAffinityUsageCache(statsCtx, usage)
+	observeChannelAffinityUsageCache(statsCtx, usage, cachedTokenRateMode)
 }
 
 func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) ChannelAffinityUsageCacheStats {
@@ -628,6 +643,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann
 		}
 	}
 	return ChannelAffinityUsageCacheStats{
+		CachedTokenRateMode:  v.CachedTokenRateMode,
 		RuleName:             ruleName,
 		UsingGroup:           usingGroup,
 		KeyFingerprint:       keyFp,
@@ -643,7 +659,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann
 	}
 }
 
-func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage) {
+func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage, cachedTokenRateMode string) {
 	entryKey := channelAffinityUsageCacheEntryKey(statsCtx.RuleName, statsCtx.UsingGroup, statsCtx.KeyFingerprint)
 	if entryKey == "" {
 		return
@@ -669,6 +685,14 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag
 	if !found {
 		next = ChannelAffinityUsageCacheCounters{}
 	}
+	currentMode := normalizeCachedTokenRateMode(cachedTokenRateMode)
+	if currentMode != "" {
+		if next.CachedTokenRateMode == "" {
+			next.CachedTokenRateMode = currentMode
+		} else if next.CachedTokenRateMode != currentMode && next.CachedTokenRateMode != cacheTokenRateModeMixed {
+			next.CachedTokenRateMode = cacheTokenRateModeMixed
+		}
+	}
 	next.Total++
 	hit, cachedTokens, promptCacheHitTokens := usageCacheSignals(usage)
 	if hit {
@@ -684,6 +708,30 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag
 	_ = cache.SetWithTTL(entryKey, next, ttl)
 }
 
+func normalizeCachedTokenRateMode(mode string) string {
+	switch mode {
+	case cacheTokenRateModeCachedOverPrompt:
+		return cacheTokenRateModeCachedOverPrompt
+	case cacheTokenRateModeCachedOverPromptPlusCached:
+		return cacheTokenRateModeCachedOverPromptPlusCached
+	case cacheTokenRateModeMixed:
+		return cacheTokenRateModeMixed
+	default:
+		return ""
+	}
+}
+
+func cachedTokenRateModeByRelayFormat(relayFormat types.RelayFormat) string {
+	switch relayFormat {
+	case types.RelayFormatOpenAI, types.RelayFormatOpenAIResponses, types.RelayFormatOpenAIResponsesCompaction:
+		return cacheTokenRateModeCachedOverPrompt
+	case types.RelayFormatClaude:
+		return cacheTokenRateModeCachedOverPromptPlusCached
+	default:
+		return ""
+	}
+}
+
 func channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp string) string {
 	ruleName = strings.TrimSpace(ruleName)
 	usingGroup = strings.TrimSpace(usingGroup)

+ 105 - 0
service/channel_affinity_usage_cache_test.go

@@ -0,0 +1,105 @@
+package service
+
+import (
+	"fmt"
+	"net/http/httptest"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/types"
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/require"
+)
+
+func buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP string) *gin.Context {
+	rec := httptest.NewRecorder()
+	ctx, _ := gin.CreateTestContext(rec)
+	setChannelAffinityContext(ctx, channelAffinityMeta{
+		CacheKey:       fmt.Sprintf("test:%s:%s:%s", ruleName, usingGroup, keyFP),
+		TTLSeconds:     600,
+		RuleName:       ruleName,
+		UsingGroup:     usingGroup,
+		KeyFingerprint: keyFP,
+	})
+	return ctx
+}
+
+func TestObserveChannelAffinityUsageCacheByRelayFormat_ClaudeMode(t *testing.T) {
+	ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
+	usingGroup := "default"
+	keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
+	ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
+
+	usage := &dto.Usage{
+		PromptTokens:     100,
+		CompletionTokens: 40,
+		TotalTokens:      140,
+		PromptTokensDetails: dto.InputTokenDetails{
+			CachedTokens: 30,
+		},
+	}
+
+	ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatClaude)
+	stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
+
+	require.EqualValues(t, 1, stats.Total)
+	require.EqualValues(t, 1, stats.Hit)
+	require.EqualValues(t, 100, stats.PromptTokens)
+	require.EqualValues(t, 40, stats.CompletionTokens)
+	require.EqualValues(t, 140, stats.TotalTokens)
+	require.EqualValues(t, 30, stats.CachedTokens)
+	require.Equal(t, cacheTokenRateModeCachedOverPromptPlusCached, stats.CachedTokenRateMode)
+}
+
+func TestObserveChannelAffinityUsageCacheByRelayFormat_MixedMode(t *testing.T) {
+	ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
+	usingGroup := "default"
+	keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
+	ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
+
+	openAIUsage := &dto.Usage{
+		PromptTokens: 100,
+		PromptTokensDetails: dto.InputTokenDetails{
+			CachedTokens: 10,
+		},
+	}
+	claudeUsage := &dto.Usage{
+		PromptTokens: 80,
+		PromptTokensDetails: dto.InputTokenDetails{
+			CachedTokens: 20,
+		},
+	}
+
+	ObserveChannelAffinityUsageCacheByRelayFormat(ctx, openAIUsage, types.RelayFormatOpenAI)
+	ObserveChannelAffinityUsageCacheByRelayFormat(ctx, claudeUsage, types.RelayFormatClaude)
+	stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
+
+	require.EqualValues(t, 2, stats.Total)
+	require.EqualValues(t, 2, stats.Hit)
+	require.EqualValues(t, 180, stats.PromptTokens)
+	require.EqualValues(t, 30, stats.CachedTokens)
+	require.Equal(t, cacheTokenRateModeMixed, stats.CachedTokenRateMode)
+}
+
+func TestObserveChannelAffinityUsageCacheByRelayFormat_UnsupportedModeKeepsEmpty(t *testing.T) {
+	ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
+	usingGroup := "default"
+	keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
+	ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
+
+	usage := &dto.Usage{
+		PromptTokens: 100,
+		PromptTokensDetails: dto.InputTokenDetails{
+			CachedTokens: 25,
+		},
+	}
+
+	ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatGemini)
+	stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
+
+	require.EqualValues(t, 1, stats.Total)
+	require.EqualValues(t, 1, stats.Hit)
+	require.EqualValues(t, 25, stats.CachedTokens)
+	require.Equal(t, "", stats.CachedTokenRateMode)
+}

+ 1 - 1
service/codex_credential_refresh.go

@@ -62,7 +62,7 @@ func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts Code
 	refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
 	defer cancel()
 
-	res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
+	res, err := RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
 	if err != nil {
 		return nil, nil, err
 	}

+ 33 - 4
service/codex_oauth.go

@@ -12,6 +12,8 @@ import (
 	"net/url"
 	"strings"
 	"time"
+
+	"github.com/QuantumNous/new-api/common"
 )
 
 const (
@@ -38,12 +40,26 @@ type CodexOAuthAuthorizationFlow struct {
 }
 
 func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) {
-	client := &http.Client{Timeout: defaultHTTPTimeout}
+	return RefreshCodexOAuthTokenWithProxy(ctx, refreshToken, "")
+}
+
+func RefreshCodexOAuthTokenWithProxy(ctx context.Context, refreshToken string, proxyURL string) (*CodexOAuthTokenResult, error) {
+	client, err := getCodexOAuthHTTPClient(proxyURL)
+	if err != nil {
+		return nil, err
+	}
 	return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken)
 }
 
 func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) {
-	client := &http.Client{Timeout: defaultHTTPTimeout}
+	return ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, "")
+}
+
+func ExchangeCodexAuthorizationCodeWithProxy(ctx context.Context, code string, verifier string, proxyURL string) (*CodexOAuthTokenResult, error) {
+	client, err := getCodexOAuthHTTPClient(proxyURL)
+	if err != nil {
+		return nil, err
+	}
 	return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI)
 }
 
@@ -104,7 +120,7 @@ func refreshCodexOAuthToken(
 		ExpiresIn    int    `json:"expires_in"`
 	}
 
-	if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+	if err := common.DecodeJson(resp.Body, &payload); err != nil {
 		return nil, err
 	}
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
@@ -165,7 +181,7 @@ func exchangeCodexAuthorizationCode(
 		RefreshToken string `json:"refresh_token"`
 		ExpiresIn    int    `json:"expires_in"`
 	}
-	if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+	if err := common.DecodeJson(resp.Body, &payload); err != nil {
 		return nil, err
 	}
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
@@ -181,6 +197,19 @@ func exchangeCodexAuthorizationCode(
 	}, nil
 }
 
+func getCodexOAuthHTTPClient(proxyURL string) (*http.Client, error) {
+	baseClient, err := GetHttpClientWithProxy(strings.TrimSpace(proxyURL))
+	if err != nil {
+		return nil, err
+	}
+	if baseClient == nil {
+		return &http.Client{Timeout: defaultHTTPTimeout}, nil
+	}
+	clientCopy := *baseClient
+	clientCopy.Timeout = defaultHTTPTimeout
+	return &clientCopy, nil
+}
+
 func buildCodexAuthorizeURL(state string, challenge string) (string, error) {
 	u, err := url.Parse(codexOAuthAuthorizeURL)
 	if err != nil {

+ 13 - 0
service/error.go

@@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
 
 	return taskError
 }
+
+// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。
+func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError {
+	if apiErr == nil {
+		return nil
+	}
+	return &dto.TaskError{
+		Code:       string(apiErr.GetErrorCode()),
+		Message:    apiErr.Err.Error(),
+		StatusCode: apiErr.StatusCode,
+		Error:      apiErr.Err,
+	}
+}

+ 1 - 1
service/log_info_generate.go

@@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	return info
 }
 
-func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} {
+func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} {
 	other := make(map[string]interface{})
 	other["model_price"] = priceData.ModelPrice
 	other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio

+ 2 - 2
service/midjourney.go

@@ -19,7 +19,7 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func CoverActionToModelName(mjAction string) string {
+func CovertMjpActionToModelName(mjAction string) string {
 	modelName := "mj_" + strings.ToLower(mjAction)
 	if mjAction == constant.MjActionSwapFace {
 		modelName = "swap_face"
@@ -70,7 +70,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
 			return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
 		}
 	}
-	modelName := CoverActionToModelName(action)
+	modelName := CovertMjpActionToModelName(action)
 	return modelName, nil, true
 }
 

+ 3 - 0
service/quota.go

@@ -236,6 +236,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 }
 
 func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
+	if usage != nil {
+		ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
+	}
 
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens

+ 285 - 0
service/task_billing.go

@@ -0,0 +1,285 @@
+package service
+
+import (
+	"context"
+	"fmt"
+	"strings"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/QuantumNous/new-api/setting/ratio_setting"
+	"github.com/gin-gonic/gin"
+)
+
+// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
+// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
+func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
+	tokenName := c.GetString("token_name")
+	logContent := fmt.Sprintf("操作 %s", info.Action)
+	// 支持任务仅按次计费
+	if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
+		logContent = fmt.Sprintf("%s,按次计费", logContent)
+	} else {
+		if len(info.PriceData.OtherRatios) > 0 {
+			var contents []string
+			for key, ra := range info.PriceData.OtherRatios {
+				if 1.0 != ra {
+					contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
+				}
+			}
+			if len(contents) > 0 {
+				logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
+			}
+		}
+	}
+	other := make(map[string]interface{})
+	other["request_path"] = c.Request.URL.Path
+	other["model_price"] = info.PriceData.ModelPrice
+	other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
+	if info.PriceData.GroupRatioInfo.HasSpecialRatio {
+		other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
+	}
+	if info.IsModelMapped {
+		other["is_model_mapped"] = true
+		other["upstream_model_name"] = info.UpstreamModelName
+	}
+	model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
+		ChannelId: info.ChannelId,
+		ModelName: info.OriginModelName,
+		TokenName: tokenName,
+		Quota:     info.PriceData.Quota,
+		Content:   logContent,
+		TokenId:   info.TokenId,
+		Group:     info.UsingGroup,
+		Other:     other,
+	})
+	model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
+	model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
+}
+
+// ---------------------------------------------------------------------------
+// 异步任务计费辅助函数
+// ---------------------------------------------------------------------------
+
+// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
+// 如果令牌已被删除或查询失败,返回空字符串。
+func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
+	token, err := model.GetTokenById(tokenId)
+	if err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
+		return ""
+	}
+	return token.Key
+}
+
+// taskIsSubscription 判断任务是否通过订阅计费。
+func taskIsSubscription(task *model.Task) bool {
+	return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
+}
+
+// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
+func taskAdjustFunding(task *model.Task, delta int) error {
+	if taskIsSubscription(task) {
+		return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
+	}
+	if delta > 0 {
+		return model.DecreaseUserQuota(task.UserId, delta)
+	}
+	return model.IncreaseUserQuota(task.UserId, -delta, false)
+}
+
+// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
+// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
+func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
+	if task.PrivateData.TokenId <= 0 || delta == 0 {
+		return
+	}
+	tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
+	if tokenKey == "" {
+		return
+	}
+	var err error
+	if delta > 0 {
+		err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
+	} else {
+		err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
+	}
+	if err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
+	}
+}
+
+// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。
+func taskBillingOther(task *model.Task) map[string]interface{} {
+	other := make(map[string]interface{})
+	if bc := task.PrivateData.BillingContext; bc != nil {
+		other["model_price"] = bc.ModelPrice
+		other["group_ratio"] = bc.GroupRatio
+		if len(bc.OtherRatios) > 0 {
+			for k, v := range bc.OtherRatios {
+				other[k] = v
+			}
+		}
+	}
+	props := task.Properties
+	if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
+		other["is_model_mapped"] = true
+		other["upstream_model_name"] = props.UpstreamModelName
+	}
+	return other
+}
+
+// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
+func taskModelName(task *model.Task) string {
+	if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
+		return bc.OriginModelName
+	}
+	return task.Properties.OriginModelName
+}
+
+// RefundTaskQuota 统一的任务失败退款逻辑。
+// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
+func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
+	quota := task.Quota
+	if quota == 0 {
+		return
+	}
+
+	// 1. 退还资金来源(钱包或订阅)
+	if err := taskAdjustFunding(task, -quota); err != nil {
+		logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
+		return
+	}
+
+	// 2. 退还令牌额度
+	taskAdjustTokenQuota(ctx, task, -quota)
+
+	// 3. 记录日志
+	other := taskBillingOther(task)
+	other["task_id"] = task.TaskID
+	other["reason"] = reason
+	model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
+		UserId:    task.UserId,
+		LogType:   model.LogTypeRefund,
+		Content:   "",
+		ChannelId: task.ChannelId,
+		ModelName: taskModelName(task),
+		Quota:     quota,
+		TokenId:   task.PrivateData.TokenId,
+		Group:     task.Group,
+		Other:     other,
+	})
+}
+
+// RecalculateTaskQuota 通用的异步差额结算。
+// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
+// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
+func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
+	if actualQuota <= 0 {
+		return
+	}
+	preConsumedQuota := task.Quota
+	quotaDelta := actualQuota - preConsumedQuota
+
+	if quotaDelta == 0 {
+		logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
+			task.TaskID, logger.LogQuota(actualQuota), reason))
+		return
+	}
+
+	logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
+		task.TaskID,
+		logger.LogQuota(quotaDelta),
+		logger.LogQuota(actualQuota),
+		logger.LogQuota(preConsumedQuota),
+		reason,
+	))
+
+	// 调整资金来源
+	if err := taskAdjustFunding(task, quotaDelta); err != nil {
+		logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
+		return
+	}
+
+	// 调整令牌额度
+	taskAdjustTokenQuota(ctx, task, quotaDelta)
+
+	task.Quota = actualQuota
+
+	var logType int
+	var logQuota int
+	if quotaDelta > 0 {
+		logType = model.LogTypeConsume
+		logQuota = quotaDelta
+		model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
+		model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
+	} else {
+		logType = model.LogTypeRefund
+		logQuota = -quotaDelta
+	}
+	other := taskBillingOther(task)
+	other["task_id"] = task.TaskID
+	other["reason"] = reason
+	other["pre_consumed_quota"] = preConsumedQuota
+	other["actual_quota"] = actualQuota
+	model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
+		UserId:    task.UserId,
+		LogType:   logType,
+		Content:   "",
+		ChannelId: task.ChannelId,
+		ModelName: taskModelName(task),
+		Quota:     logQuota,
+		TokenId:   task.PrivateData.TokenId,
+		Group:     task.Group,
+		Other:     other,
+	})
+}
+
+// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
+// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
+// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
+func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
+	if totalTokens <= 0 {
+		return
+	}
+
+	modelName := taskModelName(task)
+
+	// 获取模型价格和倍率
+	modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
+	// 只有配置了倍率(非固定价格)时才按 token 重新计费
+	if !hasRatioSetting || modelRatio <= 0 {
+		return
+	}
+
+	// 获取用户和组的倍率信息
+	group := task.Group
+	if group == "" {
+		user, err := model.GetUserById(task.UserId, false)
+		if err == nil {
+			group = user.Group
+		}
+	}
+	if group == "" {
+		return
+	}
+
+	groupRatio := ratio_setting.GetGroupRatio(group)
+	userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
+
+	var finalGroupRatio float64
+	if hasUserGroupRatio {
+		finalGroupRatio = userGroupRatio
+	} else {
+		finalGroupRatio = groupRatio
+	}
+
+	// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
+	actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
+
+	reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
+	RecalculateTaskQuota(ctx, task, actualQuota, reason)
+}

+ 712 - 0
service/task_billing_test.go

@@ -0,0 +1,712 @@
+package service
+
+import (
+	"context"
+	"encoding/json"
+	"net/http"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+	"github.com/glebarez/sqlite"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"gorm.io/gorm"
+)
+
+func TestMain(m *testing.M) {
+	db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+	if err != nil {
+		panic("failed to open test db: " + err.Error())
+	}
+	sqlDB, err := db.DB()
+	if err != nil {
+		panic("failed to get sql.DB: " + err.Error())
+	}
+	sqlDB.SetMaxOpenConns(1)
+
+	model.DB = db
+	model.LOG_DB = db
+
+	common.UsingSQLite = true
+	common.RedisEnabled = false
+	common.BatchUpdateEnabled = false
+	common.LogConsumeEnabled = true
+
+	if err := db.AutoMigrate(
+		&model.Task{},
+		&model.User{},
+		&model.Token{},
+		&model.Log{},
+		&model.Channel{},
+		&model.UserSubscription{},
+	); err != nil {
+		panic("failed to migrate: " + err.Error())
+	}
+
+	os.Exit(m.Run())
+}
+
+// ---------------------------------------------------------------------------
+// Seed helpers
+// ---------------------------------------------------------------------------
+
+func truncate(t *testing.T) {
+	t.Helper()
+	t.Cleanup(func() {
+		model.DB.Exec("DELETE FROM tasks")
+		model.DB.Exec("DELETE FROM users")
+		model.DB.Exec("DELETE FROM tokens")
+		model.DB.Exec("DELETE FROM logs")
+		model.DB.Exec("DELETE FROM channels")
+		model.DB.Exec("DELETE FROM user_subscriptions")
+	})
+}
+
+func seedUser(t *testing.T, id int, quota int) {
+	t.Helper()
+	user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
+	require.NoError(t, model.DB.Create(user).Error)
+}
+
+func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
+	t.Helper()
+	token := &model.Token{
+		Id:          id,
+		UserId:      userId,
+		Key:         key,
+		Name:        "test_token",
+		Status:      common.TokenStatusEnabled,
+		RemainQuota: remainQuota,
+		UsedQuota:   0,
+	}
+	require.NoError(t, model.DB.Create(token).Error)
+}
+
+func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
+	t.Helper()
+	sub := &model.UserSubscription{
+		Id:          id,
+		UserId:      userId,
+		AmountTotal: amountTotal,
+		AmountUsed:  amountUsed,
+		Status:      "active",
+		StartTime:   time.Now().Unix(),
+		EndTime:     time.Now().Add(30 * 24 * time.Hour).Unix(),
+	}
+	require.NoError(t, model.DB.Create(sub).Error)
+}
+
+func seedChannel(t *testing.T, id int) {
+	t.Helper()
+	ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
+	require.NoError(t, model.DB.Create(ch).Error)
+}
+
+func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
+	return &model.Task{
+		TaskID:    "task_" + time.Now().Format("150405.000"),
+		UserId:    userId,
+		ChannelId: channelId,
+		Quota:     quota,
+		Status:    model.TaskStatus(model.TaskStatusInProgress),
+		Group:     "default",
+		Data:      json.RawMessage(`{}`),
+		CreatedAt: time.Now().Unix(),
+		UpdatedAt: time.Now().Unix(),
+		Properties: model.Properties{
+			OriginModelName: "test-model",
+		},
+		PrivateData: model.TaskPrivateData{
+			BillingSource:  billingSource,
+			SubscriptionId: subscriptionId,
+			TokenId:        tokenId,
+			BillingContext: &model.TaskBillingContext{
+				ModelPrice: 0.02,
+				GroupRatio: 1.0,
+				OriginModelName: "test-model",
+			},
+		},
+	}
+}
+
+// ---------------------------------------------------------------------------
+// Read-back helpers
+// ---------------------------------------------------------------------------
+
+func getUserQuota(t *testing.T, id int) int {
+	t.Helper()
+	var user model.User
+	require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
+	return user.Quota
+}
+
+func getTokenRemainQuota(t *testing.T, id int) int {
+	t.Helper()
+	var token model.Token
+	require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
+	return token.RemainQuota
+}
+
+func getTokenUsedQuota(t *testing.T, id int) int {
+	t.Helper()
+	var token model.Token
+	require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
+	return token.UsedQuota
+}
+
+func getSubscriptionUsed(t *testing.T, id int) int64 {
+	t.Helper()
+	var sub model.UserSubscription
+	require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
+	return sub.AmountUsed
+}
+
+func getLastLog(t *testing.T) *model.Log {
+	t.Helper()
+	var log model.Log
+	err := model.LOG_DB.Order("id desc").First(&log).Error
+	if err != nil {
+		return nil
+	}
+	return &log
+}
+
+func countLogs(t *testing.T) int64 {
+	t.Helper()
+	var count int64
+	model.LOG_DB.Model(&model.Log{}).Count(&count)
+	return count
+}
+
+// ===========================================================================
+// RefundTaskQuota tests
+// ===========================================================================
+
+func TestRefundTaskQuota_Wallet(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 1, 1, 1
+	const initQuota, preConsumed = 10000, 3000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RefundTaskQuota(ctx, task, "task failed: upstream error")
+
+	// User quota should increase by preConsumed
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+
+	// Token remain_quota should increase, used_quota should decrease
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
+
+	// A refund log should be created
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+	assert.Equal(t, preConsumed, log.Quota)
+	assert.Equal(t, "test-model", log.ModelName)
+}
+
+func TestRefundTaskQuota_Subscription(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID, subID = 2, 2, 2, 1
+	const preConsumed = 2000
+	const subTotal, subUsed int64 = 100000, 50000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, 0)
+	seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
+	seedChannel(t, channelID)
+	seedSubscription(t, subID, userID, subTotal, subUsed)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
+
+	RefundTaskQuota(ctx, task, "subscription task failed")
+
+	// Subscription used should decrease by preConsumed
+	assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
+
+	// Token should also be refunded
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 3
+	seedUser(t, userID, 5000)
+
+	task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
+
+	RefundTaskQuota(ctx, task, "zero quota task")
+
+	// No change to user quota
+	assert.Equal(t, 5000, getUserQuota(t, userID))
+
+	// No log created
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRefundTaskQuota_NoToken(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, channelID = 4, 4
+	const initQuota, preConsumed = 10000, 1500
+
+	seedUser(t, userID, initQuota)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
+
+	RefundTaskQuota(ctx, task, "no token task failed")
+
+	// User quota refunded
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+
+	// Log created
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+// ===========================================================================
+// RecalculateTaskQuota tests
+// ===========================================================================
+
+func TestRecalculate_PositiveDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 10, 10, 10
+	const initQuota, preConsumed = 10000, 2000
+	const actualQuota = 3000 // under-charged by 1000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
+
+	// User quota should decrease by the delta (1000 additional charge)
+	assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
+
+	// Token should also be charged the delta
+	assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota should be updated to actualQuota
+	assert.Equal(t, actualQuota, task.Quota)
+
+	// Log type should be Consume (additional charge)
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeConsume, log.Type)
+	assert.Equal(t, actualQuota-preConsumed, log.Quota)
+}
+
+func TestRecalculate_NegativeDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 11, 11, 11
+	const initQuota, preConsumed = 10000, 5000
+	const actualQuota = 3000 // over-charged by 2000
+	const tokenRemain = 5000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
+
+	// User quota should increase by abs(delta) = 2000 (refund overpayment)
+	assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
+
+	// Token should be refunded the difference
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota updated
+	assert.Equal(t, actualQuota, task.Quota)
+
+	// Log type should be Refund
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+	assert.Equal(t, preConsumed-actualQuota, log.Quota)
+}
+
+func TestRecalculate_ZeroDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 12
+	const initQuota, preConsumed = 10000, 3000
+
+	seedUser(t, userID, initQuota)
+
+	task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
+
+	// No change to user quota
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+
+	// No log created (delta is zero)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRecalculate_ActualQuotaZero(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID = 13
+	const initQuota = 10000
+
+	seedUser(t, userID, initQuota)
+
+	task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
+
+	RecalculateTaskQuota(ctx, task, 0, "zero actual")
+
+	// No change (early return)
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID, subID = 14, 14, 14, 2
+	const preConsumed = 5000
+	const actualQuota = 2000 // over-charged by 3000
+	const subTotal, subUsed int64 = 100000, 50000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, 0)
+	seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
+	seedChannel(t, channelID)
+	seedSubscription(t, subID, userID, subTotal, subUsed)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
+
+	RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
+
+	// Subscription used should decrease by delta (refund 3000)
+	assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
+
+	// Token refunded
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	assert.Equal(t, actualQuota, task.Quota)
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+// ===========================================================================
+// CAS + Billing integration tests
+// Simulates the flow in updateVideoSingleTask (service/task_polling.go)
+// ===========================================================================
+
+// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
+// It takes a persisted task (already in DB), applies the new status, and performs
+// the conditional update + billing exactly as the polling loop does.
+func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
+	snap := task.Snapshot()
+
+	shouldRefund := false
+	shouldSettle := false
+	quota := task.Quota
+
+	task.Status = newStatus
+	switch string(newStatus) {
+	case model.TaskStatusSuccess:
+		task.Progress = "100%"
+		task.FinishTime = 9999
+		shouldSettle = true
+	case model.TaskStatusFailure:
+		task.Progress = "100%"
+		task.FinishTime = 9999
+		task.FailReason = "upstream error"
+		if quota != 0 {
+			shouldRefund = true
+		}
+	default:
+		task.Progress = "50%"
+	}
+
+	isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
+	if isDone && snap.Status != task.Status {
+		won, err := task.UpdateWithStatus(snap.Status)
+		if err != nil {
+			shouldRefund = false
+			shouldSettle = false
+		} else if !won {
+			shouldRefund = false
+			shouldSettle = false
+		}
+	} else if !snap.Equal(task.Snapshot()) {
+		_, _ = task.UpdateWithStatus(snap.Status)
+	}
+
+	if shouldSettle && actualQuota > 0 {
+		RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
+	}
+	if shouldRefund {
+		RefundTaskQuota(ctx, task, task.FailReason)
+	}
+}
+
+func TestCASGuardedRefund_Win(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 20, 20, 20
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 6000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
+
+	// CAS wins: task in DB should now be FAILURE
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
+
+	// Refund should have happened
+	assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}
+
+func TestCASGuardedRefund_Lose(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 21, 21, 21
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 6000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
+	seedChannel(t, channelID)
+
+	// Create task with IN_PROGRESS in DB
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	// Simulate another process already transitioning to FAILURE
+	model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
+
+	// Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
+	// task.Status is still IN_PROGRESS in the snapshot
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
+
+	// CAS lost: user quota should NOT change (no double refund)
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+
+	// No billing log should be created
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestCASGuardedSettle_Win(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 22, 22, 22
+	const initQuota, preConsumed = 10000, 5000
+	const actualQuota = 3000 // over-charged, should get partial refund
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	require.NoError(t, model.DB.Create(task).Error)
+
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
+
+	// CAS wins: task should be SUCCESS
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
+
+	// Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
+	assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
+
+	// task.Quota should be updated to actualQuota
+	assert.Equal(t, actualQuota, task.Quota)
+}
+
+func TestNonTerminalUpdate_NoBilling(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, channelID = 23, 23
+	const initQuota, preConsumed = 10000, 3000
+
+	seedUser(t, userID, initQuota)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
+	task.Status = model.TaskStatus(model.TaskStatusInProgress)
+	task.Progress = "20%"
+	require.NoError(t, model.DB.Create(task).Error)
+
+	// Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
+	simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
+
+	// User quota should NOT change
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+
+	// No billing log
+	assert.Equal(t, int64(0), countLogs(t))
+
+	// Task progress should be updated in DB
+	var reloaded model.Task
+	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
+	assert.Equal(t, "50%", reloaded.Progress)
+}
+
+// ===========================================================================
+// Mock adaptor for settleTaskBillingOnComplete tests
+// ===========================================================================
+
+type mockAdaptor struct {
+	adjustReturn int
+}
+
+func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo)                                            {}
+func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error)  { return nil, nil }
+func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error)                     { return nil, nil }
+func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+	return m.adjustReturn
+}
+
+// ===========================================================================
+// PerCallBilling tests — settleTaskBillingOnComplete
+// ===========================================================================
+
+func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 30, 30, 30
+	const initQuota, preConsumed = 10000, 5000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.PrivateData.BillingContext.PerCallBilling = true
+
+	adaptor := &mockAdaptor{adjustReturn: 2000}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Per-call: no adjustment despite adaptor returning 2000
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, preConsumed, task.Quota)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 31, 31, 31
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 7000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.PrivateData.BillingContext.PerCallBilling = true
+
+	adaptor := &mockAdaptor{adjustReturn: 0}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Per-call: no recalculation by tokens
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, preConsumed, task.Quota)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 32, 32, 32
+	const initQuota, preConsumed = 10000, 5000
+	const adaptorQuota = 3000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	// PerCallBilling defaults to false
+
+	adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Non-per-call: adaptor adjustment applies (refund 2000)
+	assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, adaptorQuota, task.Quota)
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}

+ 539 - 0
service/task_polling.go

@@ -0,0 +1,539 @@
+package service
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"sort"
+	"strings"
+	"time"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/dto"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/model"
+	"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
+
+	"github.com/samber/lo"
+)
+
+// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
+type TaskPollingAdaptor interface {
+	Init(info *relaycommon.RelayInfo)
+	FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
+	ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
+	// AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
+	// 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
+	AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
+}
+
+// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
+// 打破 service -> relay -> relay/channel -> service 的循环依赖。
+var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
+
+// sweepTimedOutTasks 在主轮询之前独立清理超时任务。
+// 每次最多处理 100 条,剩余的下个周期继续处理。
+// 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。
+func sweepTimedOutTasks(ctx context.Context) {
+	if constant.TaskTimeoutMinutes <= 0 {
+		return
+	}
+	cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60
+	tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100)
+	if len(tasks) == 0 {
+		return
+	}
+
+	const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC
+	reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes)
+	legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)"
+	now := time.Now().Unix()
+	timedOutCount := 0
+
+	for _, task := range tasks {
+		isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff
+
+		oldStatus := task.Status
+		task.Status = model.TaskStatusFailure
+		task.Progress = "100%"
+		task.FinishTime = now
+		if isLegacy {
+			task.FailReason = legacyReason
+		} else {
+			task.FailReason = reason
+		}
+
+		won, err := task.UpdateWithStatus(oldStatus)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err))
+			continue
+		}
+		if !won {
+			logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID))
+			continue
+		}
+		timedOutCount++
+		if !isLegacy && task.Quota != 0 {
+			RefundTaskQuota(ctx, task, reason)
+		}
+	}
+
+	if timedOutCount > 0 {
+		logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount))
+	}
+}
+
+// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
+func TaskPollingLoop() {
+	for {
+		time.Sleep(time.Duration(15) * time.Second)
+		common.SysLog("任务进度轮询开始")
+		ctx := context.TODO()
+		sweepTimedOutTasks(ctx)
+		allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
+		platformTask := make(map[constant.TaskPlatform][]*model.Task)
+		for _, t := range allTasks {
+			platformTask[t.Platform] = append(platformTask[t.Platform], t)
+		}
+		for platform, tasks := range platformTask {
+			if len(tasks) == 0 {
+				continue
+			}
+			taskChannelM := make(map[int][]string)
+			taskM := make(map[string]*model.Task)
+			nullTaskIds := make([]int64, 0)
+			for _, task := range tasks {
+				upstreamID := task.GetUpstreamTaskID()
+				if upstreamID == "" {
+					// 统计失败的未完成任务
+					nullTaskIds = append(nullTaskIds, task.ID)
+					continue
+				}
+				taskM[upstreamID] = task
+				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
+			}
+			if len(nullTaskIds) > 0 {
+				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
+					"status":   "FAILURE",
+					"progress": "100%",
+				})
+				if err != nil {
+					logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+				} else {
+					logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+				}
+			}
+			if len(taskChannelM) == 0 {
+				continue
+			}
+
+			DispatchPlatformUpdate(platform, taskChannelM, taskM)
+		}
+		common.SysLog("任务进度轮询完成")
+	}
+}
+
+// DispatchPlatformUpdate 按平台分发轮询更新
+func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
+	switch platform {
+	case constant.TaskPlatformMidjourney:
+		// MJ 轮询由其自身处理,这里预留入口
+	case constant.TaskPlatformSuno:
+		_ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
+	default:
+		if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
+			common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
+		}
+	}
+}
+
+// UpdateSunoTasks 按渠道更新所有 Suno 任务
+func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		err := updateSunoTasks(ctx, channelId, taskIds, taskM)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	ch, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+		// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
+		var failedIDs []int64
+		for _, upstreamID := range taskIds {
+			if t, ok := taskM[upstreamID]; ok {
+				failedIDs = append(failedIDs, t.ID)
+			}
+		}
+		err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
+			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if err != nil {
+			common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
+		}
+		return err
+	}
+	adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
+	if adaptor == nil {
+		return errors.New("adaptor not found")
+	}
+	proxy := ch.GetSetting().Proxy
+	resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
+		"ids": taskIds,
+	}, proxy)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
+		return err
+	}
+	if resp.StatusCode != http.StatusOK {
+		logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+		return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("Get Suno Task parse body error: %v", err))
+		return err
+	}
+	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
+	err = common.Unmarshal(responseBody, &responseItems)
+	if err != nil {
+		logger.LogError(ctx, fmt.Sprintf("Get Suno Task parse body error2: %v, body: %s", err, string(responseBody)))
+		return err
+	}
+	if !responseItems.IsSuccess() {
+		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
+		return err
+	}
+
+	for _, responseItem := range responseItems.Data {
+		task := taskM[responseItem.TaskID]
+		if !taskNeedsUpdate(task, responseItem) {
+			continue
+		}
+
+		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
+		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
+		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
+		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
+		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
+		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
+			logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+			task.Progress = "100%"
+			RefundTaskQuota(ctx, task, task.FailReason)
+		}
+		if responseItem.Status == model.TaskStatusSuccess {
+			task.Progress = "100%"
+		}
+		task.Data = responseItem.Data
+
+		err = task.Update()
+		if err != nil {
+			common.SysLog("UpdateSunoTask task error: " + err.Error())
+		}
+	}
+	return nil
+}
+
+// taskNeedsUpdate 检查 Suno 任务是否需要更新
+func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
+	if oldTask.SubmitTime != newTask.SubmitTime {
+		return true
+	}
+	if oldTask.StartTime != newTask.StartTime {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if string(oldTask.Status) != newTask.Status {
+		return true
+	}
+	if oldTask.FailReason != newTask.FailReason {
+		return true
+	}
+
+	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
+		return true
+	}
+
+	oldData, _ := common.Marshal(oldTask.Data)
+	newData, _ := common.Marshal(newTask.Data)
+
+	sort.Slice(oldData, func(i, j int) bool {
+		return oldData[i] < oldData[j]
+	})
+	sort.Slice(newData, func(i, j int) bool {
+		return newData[i] < newData[j]
+	})
+
+	if string(oldData) != string(newData) {
+		return true
+	}
+	return false
+}
+
+// UpdateVideoTasks 按渠道更新所有视频任务
+func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	cacheGetChannel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
+		var failedIDs []int64
+		for _, upstreamID := range taskIds {
+			if t, ok := taskM[upstreamID]; ok {
+				failedIDs = append(failedIDs, t.ID)
+			}
+		}
+		errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
+			"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if errUpdate != nil {
+			common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+		}
+		return fmt.Errorf("CacheGetChannel failed: %w", err)
+	}
+	adaptor := GetTaskAdaptorFunc(platform)
+	if adaptor == nil {
+		return fmt.Errorf("video adaptor not found")
+	}
+	info := &relaycommon.RelayInfo{}
+	info.ChannelMeta = &relaycommon.ChannelMeta{
+		ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
+	}
+	info.ApiKey = cacheGetChannel.Key
+	adaptor.Init(info)
+	for _, taskId := range taskIds {
+		if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
+	baseURL := constant.ChannelBaseURLs[ch.Type]
+	if ch.GetBaseURL() != "" {
+		baseURL = ch.GetBaseURL()
+	}
+	proxy := ch.GetSetting().Proxy
+
+	task := taskM[taskId]
+	if task == nil {
+		logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+		return fmt.Errorf("task %s not found", taskId)
+	}
+	key := ch.Key
+
+	privateData := task.PrivateData
+	if privateData.Key != "" {
+		key = privateData.Key
+	}
+	resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
+		"task_id": task.GetUpstreamTaskID(),
+		"action":  task.Action,
+	}, proxy)
+	if err != nil {
+		return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
+	}
+
+	logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
+
+	snap := task.Snapshot()
+
+	taskResult := &relaycommon.TaskInfo{}
+	// try parse as New API response format
+	var responseItems dto.TaskResponse[model.Task]
+	if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
+		logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
+		t := responseItems.Data
+		taskResult.TaskID = t.TaskID
+		taskResult.Status = string(t.Status)
+		taskResult.Url = t.GetResultURL()
+		taskResult.Progress = t.Progress
+		taskResult.Reason = t.FailReason
+		task.Data = t.Data
+	} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
+		return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
+	} else {
+		task.Data = redactVideoResponseBody(responseBody)
+	}
+
+	logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
+
+	now := time.Now().Unix()
+	if taskResult.Status == "" {
+		taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
+	}
+
+	shouldRefund := false
+	shouldSettle := false
+	quota := task.Quota
+
+	task.Status = model.TaskStatus(taskResult.Status)
+	switch taskResult.Status {
+	case model.TaskStatusSubmitted:
+		task.Progress = taskcommon.ProgressSubmitted
+	case model.TaskStatusQueued:
+		task.Progress = taskcommon.ProgressQueued
+	case model.TaskStatusInProgress:
+		task.Progress = taskcommon.ProgressInProgress
+		if task.StartTime == 0 {
+			task.StartTime = now
+		}
+	case model.TaskStatusSuccess:
+		task.Progress = taskcommon.ProgressComplete
+		if task.FinishTime == 0 {
+			task.FinishTime = now
+		}
+		if strings.HasPrefix(taskResult.Url, "data:") {
+			// data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
+		} else if taskResult.Url != "" {
+			// Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
+			task.PrivateData.ResultURL = taskResult.Url
+		} else {
+			// No URL from adaptor — construct proxy URL using public task ID
+			task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
+		}
+		shouldSettle = true
+	case model.TaskStatusFailure:
+		logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
+		task.Status = model.TaskStatusFailure
+		task.Progress = taskcommon.ProgressComplete
+		if task.FinishTime == 0 {
+			task.FinishTime = now
+		}
+		task.FailReason = taskResult.Reason
+		logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+		taskResult.Progress = taskcommon.ProgressComplete
+		if quota != 0 {
+			shouldRefund = true
+		}
+	default:
+		return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID)
+	}
+	if taskResult.Progress != "" {
+		task.Progress = taskResult.Progress
+	}
+
+	isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure
+	if isDone && snap.Status != task.Status {
+		won, err := task.UpdateWithStatus(snap.Status)
+		if err != nil {
+			logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error()))
+			shouldRefund = false
+			shouldSettle = false
+		} else if !won {
+			logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID))
+			shouldRefund = false
+			shouldSettle = false
+		}
+	} else if !snap.Equal(task.Snapshot()) {
+		if _, err := task.UpdateWithStatus(snap.Status); err != nil {
+			logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error()))
+		}
+	} else {
+		// No changes, skip update
+		logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID))
+	}
+
+	if shouldSettle {
+		settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+	}
+	if shouldRefund {
+		RefundTaskQuota(ctx, task, task.FailReason)
+	}
+
+	return nil
+}
+
+func redactVideoResponseBody(body []byte) []byte {
+	var m map[string]any
+	if err := common.Unmarshal(body, &m); err != nil {
+		return body
+	}
+	resp, _ := m["response"].(map[string]any)
+	if resp != nil {
+		delete(resp, "bytesBase64Encoded")
+		if v, ok := resp["video"].(string); ok {
+			resp["video"] = truncateBase64(v)
+		}
+		if vs, ok := resp["videos"].([]any); ok {
+			for i := range vs {
+				if vm, ok := vs[i].(map[string]any); ok {
+					delete(vm, "bytesBase64Encoded")
+				}
+			}
+		}
+	}
+	b, err := common.Marshal(m)
+	if err != nil {
+		return body
+	}
+	return b
+}
+
+func truncateBase64(s string) string {
+	const maxKeep = 256
+	if len(s) <= maxKeep {
+		return s
+	}
+	return s[:maxKeep] + "..."
+}
+
+// settleTaskBillingOnComplete 任务完成时的统一计费调整。
+// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
+//
+//  2. taskResult.TotalTokens > 0 → 按 token 重算
+//  3. 都不满足 → 保持预扣额度不变
+func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
+	// 0. 按次计费的任务不做差额结算
+	if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
+		logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
+		return
+	}
+	// 1. 优先让 adaptor 决定最终额度
+	if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
+		RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
+		return
+	}
+	// 2. 回退到 token 重算
+	if taskResult.TotalTokens > 0 {
+		RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
+		return
+	}
+	// 3. 无调整,保持预扣额度
+}

+ 5 - 4
service/violation_fee.go

@@ -18,8 +18,9 @@ import (
 )
 
 const (
-	ViolationFeeCodePrefix = "violation_fee."
-	CSAMViolationMarker    = "Failed check: SAFETY_CHECK_TYPE_CSAM"
+	ViolationFeeCodePrefix     = "violation_fee."
+	CSAMViolationMarker        = "Failed check: SAFETY_CHECK_TYPE"
+	ContentViolatesUsageMarker = "Content violates usage guidelines"
 )
 
 func IsViolationFeeCode(code types.ErrorCode) bool {
@@ -30,11 +31,11 @@ func HasCSAMViolationMarker(err *types.NewAPIError) bool {
 	if err == nil {
 		return false
 	}
-	if strings.Contains(err.Error(), CSAMViolationMarker) {
+	if strings.Contains(err.Error(), CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) {
 		return true
 	}
 	msg := err.ToOpenAIError().Message
-	return strings.Contains(msg, CSAMViolationMarker)
+	return strings.Contains(msg, CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker)
 }
 
 func WrapAsViolationFeeGrokCSAM(err *types.NewAPIError) *types.NewAPIError {

+ 13 - 0
setting/operation_setting/status_code_ranges.go

@@ -26,6 +26,11 @@ var AutomaticRetryStatusCodeRanges = []StatusCodeRange{
 	{Start: 525, End: 599},
 }
 
+var alwaysSkipRetryStatusCodes = map[int]struct{}{
+	504: {},
+	524: {},
+}
+
 func AutomaticDisableStatusCodesToString() string {
 	return statusCodeRangesToString(AutomaticDisableStatusCodeRanges)
 }
@@ -56,7 +61,15 @@ func AutomaticRetryStatusCodesFromString(s string) error {
 	return nil
 }
 
+func IsAlwaysSkipRetryStatusCode(code int) bool {
+	_, exists := alwaysSkipRetryStatusCodes[code]
+	return exists
+}
+
 func ShouldRetryByStatusCode(code int) bool {
+	if IsAlwaysSkipRetryStatusCode(code) {
+		return false
+	}
 	return shouldMatchStatusCodeRanges(AutomaticRetryStatusCodeRanges, code)
 }
 

+ 8 - 0
setting/operation_setting/status_code_ranges_test.go

@@ -62,6 +62,8 @@ func TestShouldRetryByStatusCode(t *testing.T) {
 
 	require.True(t, ShouldRetryByStatusCode(429))
 	require.True(t, ShouldRetryByStatusCode(500))
+	require.False(t, ShouldRetryByStatusCode(504))
+	require.False(t, ShouldRetryByStatusCode(524))
 	require.False(t, ShouldRetryByStatusCode(400))
 	require.False(t, ShouldRetryByStatusCode(200))
 }
@@ -77,3 +79,9 @@ func TestShouldRetryByStatusCode_DefaultMatchesLegacyBehavior(t *testing.T) {
 	require.False(t, ShouldRetryByStatusCode(524))
 	require.True(t, ShouldRetryByStatusCode(599))
 }
+
+func TestIsAlwaysSkipRetryStatusCode(t *testing.T) {
+	require.True(t, IsAlwaysSkipRetryStatusCode(504))
+	require.True(t, IsAlwaysSkipRetryStatusCode(524))
+	require.False(t, IsAlwaysSkipRetryStatusCode(500))
+}

+ 2 - 7
types/price_data.go

@@ -22,7 +22,8 @@ type PriceData struct {
 	AudioCompletionRatio float64
 	OtherRatios          map[string]float64
 	UsePrice             bool
-	QuotaToPreConsume    int // 预消耗额度
+	Quota                int // 按次计费的最终额度(MJ / Task)
+	QuotaToPreConsume    int // 按量计费的预消耗额度
 	GroupRatioInfo       GroupRatioInfo
 }
 
@@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) {
 	p.OtherRatios[key] = ratio
 }
 
-type PerCallPriceData struct {
-	ModelPrice     float64
-	Quota          int
-	GroupRatioInfo GroupRatioInfo
-}
-
 func (p *PriceData) ToSetting() string {
 	return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
 }

+ 15 - 15
web/src/components/auth/LoginForm.jsx

@@ -29,6 +29,7 @@ import {
   showSuccess,
   updateAPI,
   getSystemName,
+  getOAuthProviderIcon,
   setUserData,
   onGitHubOAuthClicked,
   onDiscordOAuthClicked,
@@ -130,6 +131,17 @@ const LoginForm = () => {
       return {};
     }
   }, [statusState?.status]);
+  const hasCustomOAuthProviders =
+    (status.custom_oauth_providers || []).length > 0;
+  const hasOAuthLoginOptions = Boolean(
+    status.github_oauth ||
+      status.discord_oauth ||
+      status.oidc_enabled ||
+      status.wechat_login ||
+      status.linuxdo_oauth ||
+      status.telegram_oauth ||
+      hasCustomOAuthProviders,
+  );
 
   useEffect(() => {
     if (status?.turnstile_check) {
@@ -598,7 +610,7 @@ const LoginForm = () => {
                       theme='outline'
                       className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
                       type='tertiary'
-                      icon={<IconLock size='large' />}
+                      icon={getOAuthProviderIcon(provider.icon || '', 20)}
                       onClick={() => handleCustomOAuthClick(provider)}
                       loading={customOAuthLoading[provider.slug]}
                     >
@@ -817,12 +829,7 @@ const LoginForm = () => {
                 </div>
               </Form>
 
-              {(status.github_oauth ||
-                status.discord_oauth ||
-                status.oidc_enabled ||
-                status.wechat_login ||
-                status.linuxdo_oauth ||
-                status.telegram_oauth) && (
+              {hasOAuthLoginOptions && (
                 <>
                   <Divider margin='12px' align='center'>
                     {t('或')}
@@ -952,14 +959,7 @@ const LoginForm = () => {
       />
       <div className='w-full max-w-sm mt-[60px]'>
         {showEmailLogin ||
-        !(
-          status.github_oauth ||
-          status.discord_oauth ||
-          status.oidc_enabled ||
-          status.wechat_login ||
-          status.linuxdo_oauth ||
-          status.telegram_oauth
-        )
+        !hasOAuthLoginOptions
           ? renderEmailLoginForm()
           : renderOAuthOptions()}
         {renderWeChatLoginModal()}

+ 44 - 14
web/src/components/auth/RegisterForm.jsx

@@ -27,8 +27,10 @@ import {
   showSuccess,
   updateAPI,
   getSystemName,
+  getOAuthProviderIcon,
   setUserData,
   onDiscordOAuthClicked,
+  onCustomOAuthClicked,
 } from '../../helpers';
 import Turnstile from 'react-turnstile';
 import {
@@ -98,6 +100,7 @@ const RegisterForm = () => {
   const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] =
     useState(false);
   const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false);
+  const [customOAuthLoading, setCustomOAuthLoading] = useState({});
   const [disableButton, setDisableButton] = useState(false);
   const [countdown, setCountdown] = useState(30);
   const [agreedToTerms, setAgreedToTerms] = useState(false);
@@ -126,6 +129,17 @@ const RegisterForm = () => {
       return {};
     }
   }, [statusState?.status]);
+  const hasCustomOAuthProviders =
+    (status.custom_oauth_providers || []).length > 0;
+  const hasOAuthRegisterOptions = Boolean(
+    status.github_oauth ||
+      status.discord_oauth ||
+      status.oidc_enabled ||
+      status.wechat_login ||
+      status.linuxdo_oauth ||
+      status.telegram_oauth ||
+      hasCustomOAuthProviders,
+  );
 
   const [showEmailVerification, setShowEmailVerification] = useState(false);
 
@@ -319,6 +333,17 @@ const RegisterForm = () => {
     }
   };
 
+  const handleCustomOAuthClick = (provider) => {
+    setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: true }));
+    try {
+      onCustomOAuthClicked(provider, { shouldLogout: true });
+    } finally {
+      setTimeout(() => {
+        setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: false }));
+      }, 3000);
+    }
+  };
+
   const handleEmailRegisterClick = () => {
     setEmailRegisterLoading(true);
     setShowEmailRegister(true);
@@ -469,6 +494,23 @@ const RegisterForm = () => {
                   </Button>
                 )}
 
+                {status.custom_oauth_providers &&
+                  status.custom_oauth_providers.map((provider) => (
+                    <Button
+                      key={provider.slug}
+                      theme='outline'
+                      className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
+                      type='tertiary'
+                      icon={getOAuthProviderIcon(provider.icon || '', 20)}
+                      onClick={() => handleCustomOAuthClick(provider)}
+                      loading={customOAuthLoading[provider.slug]}
+                    >
+                      <span className='ml-3'>
+                        {t('使用 {{name}} 继续', { name: provider.name })}
+                      </span>
+                    </Button>
+                  ))}
+
                 {status.telegram_oauth && (
                   <div className='flex justify-center my-2'>
                     <TelegramLoginButton
@@ -650,12 +692,7 @@ const RegisterForm = () => {
                 </div>
               </Form>
 
-              {(status.github_oauth ||
-                status.discord_oauth ||
-                status.oidc_enabled ||
-                status.wechat_login ||
-                status.linuxdo_oauth ||
-                status.telegram_oauth) && (
+              {hasOAuthRegisterOptions && (
                 <>
                   <Divider margin='12px' align='center'>
                     {t('或')}
@@ -745,14 +782,7 @@ const RegisterForm = () => {
       />
       <div className='w-full max-w-sm mt-[60px]'>
         {showEmailRegister ||
-        !(
-          status.github_oauth ||
-          status.discord_oauth ||
-          status.oidc_enabled ||
-          status.wechat_login ||
-          status.linuxdo_oauth ||
-          status.telegram_oauth
-        )
+        !hasOAuthRegisterOptions
           ? renderEmailRegisterForm()
           : renderOAuthOptions()}
         {renderWeChatLoginModal()}

+ 214 - 0
web/src/components/common/modals/RiskAcknowledgementModal.jsx

@@ -0,0 +1,214 @@
+/*
+Copyright (C) 2025 QuantumNous
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as
+published by the Free Software Foundation, either version 3 of the
+License, or (at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public License
+along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+For commercial licensing, please contact support@quantumnous.com
+*/
+
+import React, { useCallback, useEffect, useMemo, useState } from 'react';
+import {
+  Modal,
+  Button,
+  Typography,
+  Checkbox,
+  Input,
+  Space,
+} from '@douyinfe/semi-ui';
+import { IconAlertTriangle } from '@douyinfe/semi-icons';
+import { useIsMobile } from '../../../hooks/common/useIsMobile';
+import MarkdownRenderer from '../markdown/MarkdownRenderer';
+
+const { Text } = Typography;
+
+const RiskMarkdownBlock = React.memo(function RiskMarkdownBlock({
+  markdownContent,
+}) {
+  if (!markdownContent) {
+    return null;
+  }
+
+  return (
+    <div
+      className='rounded-lg'
+      style={{
+        border: '1px solid var(--semi-color-warning-light-hover)',
+        padding: '12px',
+        contentVisibility: 'auto',
+      }}
+    >
+      <MarkdownRenderer content={markdownContent} />
+    </div>
+  );
+});
+
+const RiskAcknowledgementModal = React.memo(function RiskAcknowledgementModal({
+  visible,
+  title,
+  markdownContent = '',
+  detailTitle = '',
+  detailItems = [],
+  checklist = [],
+  inputPrompt = '',
+  requiredText = '',
+  inputPlaceholder = '',
+  mismatchText = '',
+  cancelText = '',
+  confirmText = '',
+  onCancel,
+  onConfirm,
+}) {
+  const isMobile = useIsMobile();
+  const [checkedItems, setCheckedItems] = useState([]);
+  const [typedText, setTypedText] = useState('');
+
+  useEffect(() => {
+    if (!visible) return;
+    setCheckedItems(Array(checklist.length).fill(false));
+    setTypedText('');
+  }, [visible, checklist.length]);
+
+  const allChecked = useMemo(() => {
+    if (checklist.length === 0) return true;
+    return checkedItems.length === checklist.length && checkedItems.every(Boolean);
+  }, [checkedItems, checklist.length]);
+
+  const typedMatched = useMemo(() => {
+    if (!requiredText) return true;
+    return typedText.trim() === requiredText.trim();
+  }, [typedText, requiredText]);
+
+  const detailText = useMemo(() => detailItems.join(', '), [detailItems]);
+  const canConfirm = allChecked && typedMatched;
+
+  const handleChecklistChange = useCallback((index, checked) => {
+    setCheckedItems((previous) => {
+      const next = [...previous];
+      next[index] = checked;
+      return next;
+    });
+  }, []);
+
+  return (
+    <Modal
+      visible={visible}
+      title={
+        <Space align='center'>
+          <IconAlertTriangle style={{ color: 'var(--semi-color-warning)' }} />
+          <span>{title}</span>
+        </Space>
+      }
+      width={isMobile ? '100%' : 860}
+      centered
+      maskClosable={false}
+      closeOnEsc={false}
+      onCancel={onCancel}
+      bodyStyle={{
+        maxHeight: isMobile ? '70vh' : '72vh',
+        overflowY: 'auto',
+        padding: isMobile ? '12px 16px' : '18px 22px',
+      }}
+      footer={
+        <Space>
+          <Button onClick={onCancel}>{cancelText}</Button>
+          <Button
+            theme='solid'
+            type='danger'
+            disabled={!canConfirm}
+            onClick={onConfirm}
+          >
+            {confirmText}
+          </Button>
+        </Space>
+      }
+    >
+      <div className='flex flex-col gap-4'>
+
+        <RiskMarkdownBlock markdownContent={markdownContent} />
+
+        {detailItems.length > 0 ? (
+          <div
+            className='flex flex-col gap-2 rounded-lg'
+            style={{
+              border: '1px solid var(--semi-color-warning-light-hover)',
+              background: 'var(--semi-color-fill-0)',
+              padding: isMobile ? '10px 12px' : '12px 14px',
+            }}
+          >
+            {detailTitle ? <Text strong>{detailTitle}</Text> : null}
+            <div className='font-mono text-xs break-all bg-orange-50 border border-orange-200 rounded-md p-2'>
+              {detailText}
+            </div>
+          </div>
+        ) : null}
+
+        {checklist.length > 0 ? (
+          <div
+            className='flex flex-col gap-2 rounded-lg'
+            style={{
+              border: '1px solid var(--semi-color-border)',
+              background: 'var(--semi-color-fill-0)',
+              padding: isMobile ? '10px 12px' : '12px 14px',
+            }}
+          >
+            {checklist.map((item, index) => (
+              <Checkbox
+                key={`risk-check-${index}`}
+                checked={!!checkedItems[index]}
+                onChange={(event) => {
+                  handleChecklistChange(index, event.target.checked);
+                }}
+              >
+                {item}
+              </Checkbox>
+            ))}
+          </div>
+        ) : null}
+
+        {requiredText ? (
+          <div
+            className='flex flex-col gap-2 rounded-lg'
+            style={{
+              border: '1px solid var(--semi-color-danger-light-hover)',
+              background: 'var(--semi-color-danger-light-default)',
+              padding: isMobile ? '10px 12px' : '12px 14px',
+            }}
+          >
+            {inputPrompt ? <Text strong>{inputPrompt}</Text> : null}
+            <div className='font-mono text-xs break-all rounded-md p-2 bg-gray-50 border border-gray-200'>
+              {requiredText}
+            </div>
+            <Input
+              value={typedText}
+              onChange={setTypedText}
+              placeholder={inputPlaceholder}
+              autoFocus={visible}
+              onCopy={(event) => event.preventDefault()}
+              onCut={(event) => event.preventDefault()}
+              onPaste={(event) => event.preventDefault()}
+              onDrop={(event) => event.preventDefault()}
+            />
+            {!typedMatched && typedText ? (
+              <Text type='danger' size='small'>
+                {mismatchText}
+              </Text>
+            ) : null}
+          </div>
+        ) : null}
+      </div>
+    </Modal>
+  );
+});
+
+export default RiskAcknowledgementModal;

+ 546 - 124
web/src/components/settings/CustomOAuthSetting.jsx

@@ -27,14 +27,20 @@ import {
   Modal,
   Banner,
   Card,
+  Collapse,
+  Switch,
   Table,
   Tag,
   Popconfirm,
   Space,
-  Select,
 } from '@douyinfe/semi-ui';
-import { IconPlus, IconEdit, IconDelete } from '@douyinfe/semi-icons';
-import { API, showError, showSuccess } from '../../helpers';
+import {
+  IconPlus,
+  IconEdit,
+  IconDelete,
+  IconRefresh,
+} from '@douyinfe/semi-icons';
+import { API, showError, showSuccess, getOAuthProviderIcon } from '../../helpers';
 import { useTranslation } from 'react-i18next';
 
 const { Text } = Typography;
@@ -120,6 +126,69 @@ const OAUTH_PRESETS = {
   },
 };
 
+const OAUTH_PRESET_ICONS = {
+  'github-enterprise': 'github',
+  gitlab: 'gitlab',
+  gitea: 'gitea',
+  nextcloud: 'nextcloud',
+  keycloak: 'keycloak',
+  authentik: 'authentik',
+  ory: 'openid',
+};
+
+const getPresetIcon = (preset) => OAUTH_PRESET_ICONS[preset] || '';
+
+const PRESET_RESET_VALUES = {
+  name: '',
+  slug: '',
+  icon: '',
+  authorization_endpoint: '',
+  token_endpoint: '',
+  user_info_endpoint: '',
+  scopes: '',
+  user_id_field: '',
+  username_field: '',
+  display_name_field: '',
+  email_field: '',
+  well_known: '',
+  auth_style: 0,
+  access_policy: '',
+  access_denied_message: '',
+};
+
+const DISCOVERY_FIELD_LABELS = {
+  authorization_endpoint: 'Authorization Endpoint',
+  token_endpoint: 'Token Endpoint',
+  user_info_endpoint: 'User Info Endpoint',
+  scopes: 'Scopes',
+  user_id_field: 'User ID Field',
+  username_field: 'Username Field',
+  display_name_field: 'Display Name Field',
+  email_field: 'Email Field',
+};
+
+const ACCESS_POLICY_TEMPLATES = {
+  level_active: `{
+  "logic": "and",
+  "conditions": [
+    {"field": "trust_level", "op": "gte", "value": 2},
+    {"field": "active", "op": "eq", "value": true}
+  ]
+}`,
+  org_or_role: `{
+  "logic": "or",
+  "conditions": [
+    {"field": "org", "op": "eq", "value": "core"},
+    {"field": "roles", "op": "contains", "value": "admin"}
+  ]
+}`,
+};
+
+const ACCESS_DENIED_TEMPLATES = {
+  level_hint: '需要等级 {{required}},你当前等级 {{current}}(字段:{{field}})',
+  org_hint: '仅限指定组织或角色访问。组织={{current.org}},角色={{current.roles}}',
+};
+
 const CustomOAuthSetting = ({ serverAddress }) => {
   const { t } = useTranslation();
   const [providers, setProviders] = useState([]);
@@ -129,8 +198,47 @@ const CustomOAuthSetting = ({ serverAddress }) => {
   const [formValues, setFormValues] = useState({});
   const [selectedPreset, setSelectedPreset] = useState('');
   const [baseUrl, setBaseUrl] = useState('');
+  const [discoveryLoading, setDiscoveryLoading] = useState(false);
+  const [discoveryInfo, setDiscoveryInfo] = useState(null);
+  const [advancedActiveKeys, setAdvancedActiveKeys] = useState([]);
   const formApiRef = React.useRef(null);
 
+  const mergeFormValues = (newValues) => {
+    setFormValues((prev) => ({ ...prev, ...newValues }));
+    if (!formApiRef.current) return;
+    Object.entries(newValues).forEach(([key, value]) => {
+      formApiRef.current.setValue(key, value);
+    });
+  };
+
+  const getLatestFormValues = () => {
+    const values = formApiRef.current?.getValues?.();
+    return values && typeof values === 'object' ? values : formValues;
+  };
+
+  const normalizeBaseUrl = (url) => (url || '').trim().replace(/\/+$/, '');
+
+  const inferBaseUrlFromProvider = (provider) => {
+    const endpoint = provider?.authorization_endpoint || provider?.token_endpoint;
+    if (!endpoint) return '';
+    try {
+      const url = new URL(endpoint);
+      return `${url.protocol}//${url.host}`;
+    } catch (error) {
+      return '';
+    }
+  };
+
+  const resetDiscoveryState = () => {
+    setDiscoveryInfo(null);
+  };
+
+  const closeModal = () => {
+    setModalVisible(false);
+    resetDiscoveryState();
+    setAdvancedActiveKeys([]);
+  };
+
   const fetchProviders = async () => {
     setLoading(true);
     try {
@@ -154,23 +262,30 @@ const CustomOAuthSetting = ({ serverAddress }) => {
     setEditingProvider(null);
     setFormValues({
       enabled: false,
+      icon: '',
       scopes: 'openid profile email',
       user_id_field: 'sub',
       username_field: 'preferred_username',
       display_name_field: 'name',
       email_field: 'email',
       auth_style: 0,
+      access_policy: '',
+      access_denied_message: '',
     });
     setSelectedPreset('');
     setBaseUrl('');
+    resetDiscoveryState();
+    setAdvancedActiveKeys([]);
     setModalVisible(true);
   };
 
   const handleEdit = (provider) => {
     setEditingProvider(provider);
     setFormValues({ ...provider });
-    setSelectedPreset('');
-    setBaseUrl('');
+    setSelectedPreset(OAUTH_PRESETS[provider.slug] ? provider.slug : '');
+    setBaseUrl(inferBaseUrlFromProvider(provider));
+    resetDiscoveryState();
+    setAdvancedActiveKeys([]);
     setModalVisible(true);
   };
 
@@ -189,6 +304,8 @@ const CustomOAuthSetting = ({ serverAddress }) => {
   };
 
   const handleSubmit = async () => {
+    const currentValues = getLatestFormValues();
+
     // Validate required fields
     const requiredFields = [
       'name',
@@ -204,7 +321,7 @@ const CustomOAuthSetting = ({ serverAddress }) => {
     }
 
     for (const field of requiredFields) {
-      if (!formValues[field]) {
+      if (!currentValues[field]) {
         showError(t(`请填写 ${field}`));
         return;
       }
@@ -213,11 +330,11 @@ const CustomOAuthSetting = ({ serverAddress }) => {
     // Validate endpoint URLs must be full URLs
     const endpointFields = ['authorization_endpoint', 'token_endpoint', 'user_info_endpoint'];
     for (const field of endpointFields) {
-      const value = formValues[field];
+      const value = currentValues[field];
       if (value && !value.startsWith('http://') && !value.startsWith('https://')) {
-        // Check if user selected a preset but forgot to fill server address
+        // Check if user selected a preset but forgot to fill issuer URL
         if (selectedPreset && !baseUrl) {
-          showError(t('请先填写服务器地址,以自动生成完整的端点 URL'));
+          showError(t('请先填写 Issuer URL,以自动生成完整的端点 URL'));
         } else {
           showError(t('端点 URL 必须是完整地址(以 http:// 或 https:// 开头)'));
         }
@@ -226,80 +343,199 @@ const CustomOAuthSetting = ({ serverAddress }) => {
     }
 
     try {
+      const payload = { ...currentValues, enabled: !!currentValues.enabled };
+      delete payload.preset;
+      delete payload.base_url;
+
       let res;
       if (editingProvider) {
         res = await API.put(
           `/api/custom-oauth-provider/${editingProvider.id}`,
-          formValues
+          payload
         );
       } else {
-        res = await API.post('/api/custom-oauth-provider/', formValues);
+        res = await API.post('/api/custom-oauth-provider/', payload);
       }
 
       if (res.data.success) {
         showSuccess(editingProvider ? t('更新成功') : t('创建成功'));
-        setModalVisible(false);
+        closeModal();
         fetchProviders();
       } else {
         showError(res.data.message);
       }
     } catch (error) {
-      showError(editingProvider ? t('更新失败') : t('创建失败'));
+      showError(
+        error?.response?.data?.message ||
+          (editingProvider ? t('更新失败') : t('创建失败')),
+      );
     }
   };
 
-  const handlePresetChange = (preset) => {
-    setSelectedPreset(preset);
-    if (preset && OAUTH_PRESETS[preset]) {
-      const presetConfig = OAUTH_PRESETS[preset];
-      const cleanUrl = baseUrl ? baseUrl.replace(/\/+$/, '') : '';
-      const newValues = {
-        name: presetConfig.name,
-        slug: preset,
-        scopes: presetConfig.scopes,
-        user_id_field: presetConfig.user_id_field,
-        username_field: presetConfig.username_field,
-        display_name_field: presetConfig.display_name_field,
-        email_field: presetConfig.email_field,
-        auth_style: presetConfig.auth_style ?? 0,
+  const handleFetchFromDiscovery = async () => {
+    const cleanBaseUrl = normalizeBaseUrl(baseUrl);
+    const configuredWellKnown = (formValues.well_known || '').trim();
+    const wellKnownUrl =
+      configuredWellKnown ||
+      (cleanBaseUrl ? `${cleanBaseUrl}/.well-known/openid-configuration` : '');
+
+    if (!wellKnownUrl) {
+      showError(t('请先填写 Discovery URL 或 Issuer URL'));
+      return;
+    }
+
+    setDiscoveryLoading(true);
+    try {
+      const res = await API.post('/api/custom-oauth-provider/discovery', {
+        well_known_url: configuredWellKnown || '',
+        issuer_url: cleanBaseUrl || '',
+      });
+      if (!res.data.success) {
+        throw new Error(res.data.message || t('未知错误'));
+      }
+      const data = res.data.data?.discovery || {};
+      const resolvedWellKnown = res.data.data?.well_known_url || wellKnownUrl;
+
+      const discoveredValues = {
+        well_known: resolvedWellKnown,
       };
-      // Only fill endpoints if server address is provided
-      if (cleanUrl) {
-        newValues.authorization_endpoint = cleanUrl + presetConfig.authorization_endpoint;
-        newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint;
-        newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint;
+      const autoFilledFields = [];
+      if (data.authorization_endpoint) {
+        discoveredValues.authorization_endpoint = data.authorization_endpoint;
+        autoFilledFields.push('authorization_endpoint');
       }
-      setFormValues((prev) => ({ ...prev, ...newValues }));
-      // Update form fields directly via formApi
-      if (formApiRef.current) {
-        Object.entries(newValues).forEach(([key, value]) => {
-          formApiRef.current.setValue(key, value);
-        });
+      if (data.token_endpoint) {
+        discoveredValues.token_endpoint = data.token_endpoint;
+        autoFilledFields.push('token_endpoint');
       }
+      if (data.userinfo_endpoint) {
+        discoveredValues.user_info_endpoint = data.userinfo_endpoint;
+        autoFilledFields.push('user_info_endpoint');
+      }
+
+      const scopesSupported = Array.isArray(data.scopes_supported)
+        ? data.scopes_supported
+        : [];
+      if (scopesSupported.length > 0 && !formValues.scopes) {
+        const preferredScopes = ['openid', 'profile', 'email'].filter((scope) =>
+          scopesSupported.includes(scope),
+        );
+        discoveredValues.scopes =
+          preferredScopes.length > 0
+            ? preferredScopes.join(' ')
+            : scopesSupported.slice(0, 5).join(' ');
+        autoFilledFields.push('scopes');
+      }
+
+      const claimsSupported = Array.isArray(data.claims_supported)
+        ? data.claims_supported
+        : [];
+      const claimMap = {
+        user_id_field: 'sub',
+        username_field: 'preferred_username',
+        display_name_field: 'name',
+        email_field: 'email',
+      };
+      Object.entries(claimMap).forEach(([field, claim]) => {
+        if (!formValues[field] && claimsSupported.includes(claim)) {
+          discoveredValues[field] = claim;
+          autoFilledFields.push(field);
+        }
+      });
+
+      const hasCoreEndpoint =
+        discoveredValues.authorization_endpoint ||
+        discoveredValues.token_endpoint ||
+        discoveredValues.user_info_endpoint;
+      if (!hasCoreEndpoint) {
+        showError(t('未在 Discovery 响应中找到可用的 OAuth 端点'));
+        return;
+      }
+
+      mergeFormValues(discoveredValues);
+      setDiscoveryInfo({
+        wellKnown: wellKnownUrl,
+        autoFilledFields,
+        scopesSupported: scopesSupported.slice(0, 12),
+        claimsSupported: claimsSupported.slice(0, 12),
+      });
+      showSuccess(t('已从 Discovery 自动填充配置'));
+    } catch (error) {
+      showError(
+        t('获取 Discovery 配置失败:') + (error?.message || t('未知错误')),
+      );
+    } finally {
+      setDiscoveryLoading(false);
+    }
+  };
+
+  const handlePresetChange = (preset) => {
+    setSelectedPreset(preset);
+    resetDiscoveryState();
+    const cleanUrl = normalizeBaseUrl(baseUrl);
+    if (!preset || !OAUTH_PRESETS[preset]) {
+      mergeFormValues(PRESET_RESET_VALUES);
+      return;
+    }
+
+    const presetConfig = OAUTH_PRESETS[preset];
+    const newValues = {
+      ...PRESET_RESET_VALUES,
+      name: presetConfig.name,
+      slug: preset,
+      icon: getPresetIcon(preset),
+      scopes: presetConfig.scopes,
+      user_id_field: presetConfig.user_id_field,
+      username_field: presetConfig.username_field,
+      display_name_field: presetConfig.display_name_field,
+      email_field: presetConfig.email_field,
+      auth_style: presetConfig.auth_style ?? 0,
+    };
+    if (cleanUrl) {
+      newValues.authorization_endpoint =
+        cleanUrl + presetConfig.authorization_endpoint;
+      newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint;
+      newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint;
     }
+    mergeFormValues(newValues);
   };
 
   const handleBaseUrlChange = (url) => {
     setBaseUrl(url);
     if (url && selectedPreset && OAUTH_PRESETS[selectedPreset]) {
       const presetConfig = OAUTH_PRESETS[selectedPreset];
-      const cleanUrl = url.replace(/\/+$/, ''); // Remove trailing slashes
+      const cleanUrl = normalizeBaseUrl(url);
       const newValues = {
         authorization_endpoint: cleanUrl + presetConfig.authorization_endpoint,
         token_endpoint: cleanUrl + presetConfig.token_endpoint,
         user_info_endpoint: cleanUrl + presetConfig.user_info_endpoint,
       };
-      setFormValues((prev) => ({ ...prev, ...newValues }));
-      // Update form fields directly via formApi (use merge mode to preserve other fields)
-      if (formApiRef.current) {
-        Object.entries(newValues).forEach(([key, value]) => {
-          formApiRef.current.setValue(key, value);
-        });
-      }
+      mergeFormValues(newValues);
     }
   };
 
+  const applyAccessPolicyTemplate = (templateKey) => {
+    const template = ACCESS_POLICY_TEMPLATES[templateKey];
+    if (!template) return;
+    mergeFormValues({ access_policy: template });
+    showSuccess(t('已填充策略模板'));
+  };
+
+  const applyDeniedTemplate = (templateKey) => {
+    const template = ACCESS_DENIED_TEMPLATES[templateKey];
+    if (!template) return;
+    mergeFormValues({ access_denied_message: template });
+    showSuccess(t('已填充提示模板'));
+  };
+
   const columns = [
+    {
+      title: t('图标'),
+      dataIndex: 'icon',
+      key: 'icon',
+      width: 80,
+      render: (icon) => getOAuthProviderIcon(icon || '', 18),
+    },
     {
       title: t('名称'),
       dataIndex: 'name',
@@ -325,7 +561,10 @@ const CustomOAuthSetting = ({ serverAddress }) => {
       title: t('Client ID'),
       dataIndex: 'client_id',
       key: 'client_id',
-      render: (id) => (id ? id.substring(0, 20) + '...' : '-'),
+      render: (id) => {
+        if (!id) return '-';
+        return id.length > 20 ? `${id.substring(0, 20)}...` : id;
+      },
     },
     {
       title: t('操作'),
@@ -352,6 +591,10 @@ const CustomOAuthSetting = ({ serverAddress }) => {
     },
   ];
 
+  const discoveryAutoFilledLabels = (discoveryInfo?.autoFilledFields || [])
+    .map((field) => DISCOVERY_FIELD_LABELS[field] || field)
+    .join(', ');
+
   return (
     <Card>
       <Form.Section text={t('自定义 OAuth 提供商')}>
@@ -391,56 +634,142 @@ const CustomOAuthSetting = ({ serverAddress }) => {
         <Modal
           title={editingProvider ? t('编辑 OAuth 提供商') : t('添加 OAuth 提供商')}
           visible={modalVisible}
-          onOk={handleSubmit}
-          onCancel={() => setModalVisible(false)}
-          okText={t('保存')}
-          cancelText={t('取消')}
-          width={800}
+          onCancel={closeModal}
+          width={860}
+          centered
+          bodyStyle={{ maxHeight: '72vh', overflowY: 'auto', paddingRight: 6 }}
+          footer={
+            <div
+              style={{
+                display: 'flex',
+                justifyContent: 'flex-end',
+                alignItems: 'center',
+                gap: 12,
+                flexWrap: 'wrap',
+              }}
+            >
+              <Space spacing={8} align='center'>
+                <Text type='secondary'>{t('启用供应商')}</Text>
+                <Switch
+                  checked={!!formValues.enabled}
+                  size='large'
+                  onChange={(checked) => mergeFormValues({ enabled: !!checked })}
+                />
+                <Tag color={formValues.enabled ? 'green' : 'grey'}>
+                  {formValues.enabled ? t('已启用') : t('已禁用')}
+                </Tag>
+              </Space>
+              <Button onClick={closeModal}>{t('取消')}</Button>
+              <Button type='primary' onClick={handleSubmit}>
+                {t('保存')}
+              </Button>
+            </div>
+          }
         >
           <Form
             initValues={formValues}
-            onValueChange={(values) => setFormValues(values)}
+            onValueChange={() => {
+              setFormValues((prev) => ({ ...prev, ...getLatestFormValues() }));
+            }}
             getFormApi={(api) => (formApiRef.current = api)}
           >
-            {!editingProvider && (
-              <Row gutter={16} style={{ marginBottom: 16 }}>
-                <Col span={12}>
-                  <Form.Select
-                    field="preset"
-                    label={t('预设模板')}
-                    placeholder={t('选择预设模板(可选)')}
-                    value={selectedPreset}
-                    onChange={handlePresetChange}
-                    optionList={[
-                      { value: '', label: t('自定义') },
-                      ...Object.entries(OAUTH_PRESETS).map(([key, config]) => ({
-                        value: key,
-                        label: config.name,
-                      })),
-                    ]}
-                  />
-                </Col>
-                <Col span={12}>
-                  <Form.Input
-                    field="base_url"
-                    label={
-                      selectedPreset
-                        ? t('服务器地址') + ' *'
-                        : t('服务器地址')
-                    }
-                    placeholder={t('例如:https://gitea.example.com')}
-                    value={baseUrl}
-                    onChange={handleBaseUrlChange}
-                    extraText={
-                      selectedPreset
-                        ? t('必填:请输入服务器地址以自动生成完整端点 URL')
-                        : t('选择预设模板后填写服务器地址可自动填充端点')
-                    }
-                  />
-                </Col>
-              </Row>
+            <Text strong style={{ display: 'block', marginBottom: 8 }}>
+              {t('Configuration')}
+            </Text>
+            <Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
+              {t('先填写配置,再自动填充 OAuth 端点,能显著减少手工输入')}
+            </Text>
+            {discoveryInfo && (
+              <Banner
+                type='success'
+                closeIcon={null}
+                style={{ marginBottom: 12 }}
+                description={
+                  <div>
+                    <div>
+                      {t('已从 Discovery 获取配置,可继续手动修改所有字段。')}
+                    </div>
+                    {discoveryAutoFilledLabels ? (
+                      <div>
+                        {t('自动填充字段')}:
+                        {' '}
+                        {discoveryAutoFilledLabels}
+                      </div>
+                    ) : null}
+                    {discoveryInfo.scopesSupported?.length ? (
+                      <div>
+                        {t('Discovery scopes')}:
+                        {' '}
+                        {discoveryInfo.scopesSupported.join(', ')}
+                      </div>
+                    ) : null}
+                    {discoveryInfo.claimsSupported?.length ? (
+                      <div>
+                        {t('Discovery claims')}:
+                        {' '}
+                        {discoveryInfo.claimsSupported.join(', ')}
+                      </div>
+                    ) : null}
+                  </div>
+                }
+              />
             )}
 
+            <Row gutter={16}>
+              <Col span={8}>
+                <Form.Select
+                  field="preset"
+                  label={t('预设模板')}
+                  placeholder={t('选择预设模板(可选)')}
+                  value={selectedPreset}
+                  onChange={handlePresetChange}
+                  optionList={[
+                    { value: '', label: t('自定义') },
+                    ...Object.entries(OAUTH_PRESETS).map(([key, config]) => ({
+                      value: key,
+                      label: config.name,
+                    })),
+                  ]}
+                />
+              </Col>
+              <Col span={10}>
+                <Form.Input
+                  field="base_url"
+                  label={t('发行者 URL(Issuer URL)')}
+                  placeholder={t('例如:https://gitea.example.com')}
+                  value={baseUrl}
+                  onChange={handleBaseUrlChange}
+                  extraText={
+                    selectedPreset
+                      ? t('填写后会自动拼接预设端点')
+                      : t('可选:用于自动生成端点或 Discovery URL')
+                  }
+                />
+              </Col>
+              <Col span={6}>
+                <div style={{ display: 'flex', alignItems: 'flex-end', height: '100%' }}>
+                  <Button
+                    icon={<IconRefresh />}
+                    onClick={handleFetchFromDiscovery}
+                    loading={discoveryLoading}
+                    block
+                  >
+                    {t('获取 Discovery 配置')}
+                  </Button>
+                </div>
+              </Col>
+            </Row>
+            <Row gutter={16}>
+              <Col span={24}>
+                <Form.Input
+                  field="well_known"
+                  label={t('发现文档地址(Discovery URL,可选)')}
+                  placeholder={t('例如:https://example.com/.well-known/openid-configuration')}
+                  extraText={t('可留空;留空时会尝试使用 Issuer URL + /.well-known/openid-configuration')}
+                />
+              </Col>
+            </Row>
+
             <Row gutter={16}>
               <Col span={12}>
                 <Form.Input
@@ -461,6 +790,41 @@ const CustomOAuthSetting = ({ serverAddress }) => {
               </Col>
             </Row>
 
+            <Row gutter={16}>
+              <Col span={18}>
+                <Form.Input
+                  field='icon'
+                  label={t('图标')}
+                  placeholder={t('例如:github / si:google / https://example.com/logo.png / 🐱')}
+                  extraText={
+                    <span>
+                      {t(
+                        '图标使用 react-icons(Simple Icons)或 URL/emoji,例如:github、gitlab、si:google',
+                      )}
+                    </span>
+                  }
+                  showClear
+                />
+              </Col>
+              <Col span={6} style={{ display: 'flex', alignItems: 'flex-end' }}>
+                <div
+                  style={{
+                    width: '100%',
+                    minHeight: 74,
+                    border: '1px solid var(--semi-color-border)',
+                    borderRadius: 8,
+                    display: 'flex',
+                    alignItems: 'center',
+                    justifyContent: 'center',
+                    marginBottom: 24,
+                    background: 'var(--semi-color-fill-0)',
+                  }}
+                >
+                  {getOAuthProviderIcon(formValues.icon || '', 24)}
+                </div>
+              </Col>
+            </Row>
+
             <Row gutter={16}>
               <Col span={12}>
                 <Form.Input
@@ -500,7 +864,7 @@ const CustomOAuthSetting = ({ serverAddress }) => {
                   label={t('Authorization Endpoint')}
                   placeholder={
                     selectedPreset && OAUTH_PRESETS[selectedPreset]
-                      ? t('填写服务器地址后自动生成:') +
+                      ? t('填写 Issuer URL 后自动生成:') +
                         OAUTH_PRESETS[selectedPreset].authorization_endpoint
                       : 'https://example.com/oauth/authorize'
                   }
@@ -544,15 +908,14 @@ const CustomOAuthSetting = ({ serverAddress }) => {
               <Col span={12}>
                 <Form.Input
                   field="scopes"
-                  label={t('Scopes')}
+                  label={t('Scopes(可选)')}
                   placeholder="openid profile email"
-                />
-              </Col>
-              <Col span={12}>
-                <Form.Input
-                  field="well_known"
-                  label={t('Well-Known URL')}
-                  placeholder={t('OIDC Discovery 端点(可选)')}
+                  extraText={
+                    discoveryInfo?.scopesSupported?.length
+                      ? t('Discovery 建议 scopes:') +
+                        discoveryInfo.scopesSupported.join(', ')
+                      : t('可手动填写,多个 scope 用空格分隔')
+                  }
                 />
               </Col>
             </Row>
@@ -568,7 +931,7 @@ const CustomOAuthSetting = ({ serverAddress }) => {
               <Col span={12}>
                 <Form.Input
                   field="user_id_field"
-                  label={t('用户 ID 字段')}
+                  label={t('用户 ID 字段(可选)')}
                   placeholder={t('例如:sub、id、data.user.id')}
                   extraText={t('用于唯一标识用户的字段路径')}
                 />
@@ -576,7 +939,7 @@ const CustomOAuthSetting = ({ serverAddress }) => {
               <Col span={12}>
                 <Form.Input
                   field="username_field"
-                  label={t('用户名字段')}
+                  label={t('用户名字段(可选)')}
                   placeholder={t('例如:preferred_username、login')}
                 />
               </Col>
@@ -586,41 +949,100 @@ const CustomOAuthSetting = ({ serverAddress }) => {
               <Col span={12}>
                 <Form.Input
                   field="display_name_field"
-                  label={t('显示名称字段')}
+                  label={t('显示名称字段(可选)')}
                   placeholder={t('例如:name、full_name')}
                 />
               </Col>
               <Col span={12}>
                 <Form.Input
                   field="email_field"
-                  label={t('邮箱字段')}
+                  label={t('邮箱字段(可选)')}
                   placeholder={t('例如:email')}
                 />
               </Col>
             </Row>
 
-            <Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
-              {t('高级选项')}
-            </Text>
+            <Collapse
+              keepDOM
+              activeKey={advancedActiveKeys}
+              style={{ marginTop: 16 }}
+              onChange={(activeKey) => {
+                const keys = Array.isArray(activeKey) ? activeKey : [activeKey];
+                setAdvancedActiveKeys(keys.filter(Boolean));
+              }}
+            >
+              <Collapse.Panel header={t('高级选项')} itemKey='advanced'>
+                <Row gutter={16}>
+                  <Col span={12}>
+                    <Form.Select
+                      field="auth_style"
+                      label={t('认证方式')}
+                      optionList={[
+                        { value: 0, label: t('自动检测') },
+                        { value: 1, label: t('POST 参数') },
+                        { value: 2, label: t('Basic Auth 头') },
+                      ]}
+                    />
+                  </Col>
+                </Row>
 
-            <Row gutter={16}>
-              <Col span={12}>
-                <Form.Select
-                  field="auth_style"
-                  label={t('认证方式')}
-                  optionList={[
-                    { value: 0, label: t('自动检测') },
-                    { value: 1, label: t('POST 参数') },
-                    { value: 2, label: t('Basic Auth 头') },
-                  ]}
-                />
-              </Col>
-              <Col span={12}>
-                <Form.Checkbox field="enabled" noLabel>
-                  {t('启用此 OAuth 提供商')}
-                </Form.Checkbox>
-              </Col>
-            </Row>
+                <Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
+                  {t('准入策略')}
+                </Text>
+                <Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
+                  {t('可选:基于用户信息 JSON 做组合条件准入,条件不满足时返回自定义提示')}
+                </Text>
+                <Row gutter={16}>
+                  <Col span={24}>
+                    <Form.TextArea
+                      field='access_policy'
+                      value={formValues.access_policy || ''}
+                      onChange={(value) => mergeFormValues({ access_policy: value })}
+                      label={t('准入策略 JSON(可选)')}
+                      rows={6}
+                      placeholder={`{
+  "logic": "and",
+  "conditions": [
+    {"field": "trust_level", "op": "gte", "value": 2},
+    {"field": "active", "op": "eq", "value": true}
+  ]
+}`}
+                      extraText={t('支持逻辑 and/or 与嵌套 groups;操作符支持 eq/ne/gt/gte/lt/lte/in/not_in/contains/exists')}
+                      showClear
+                    />
+                    <Space spacing={8} style={{ marginTop: 8 }}>
+                      <Button size='small' theme='light' onClick={() => applyAccessPolicyTemplate('level_active')}>
+                        {t('填充模板:等级+激活')}
+                      </Button>
+                      <Button size='small' theme='light' onClick={() => applyAccessPolicyTemplate('org_or_role')}>
+                        {t('填充模板:组织或角色')}
+                      </Button>
+                    </Space>
+                  </Col>
+                </Row>
+                <Row gutter={16}>
+                  <Col span={24}>
+                    <Form.Input
+                      field='access_denied_message'
+                      value={formValues.access_denied_message || ''}
+                      onChange={(value) => mergeFormValues({ access_denied_message: value })}
+                      label={t('拒绝提示模板(可选)')}
+                      placeholder={t('例如:需要等级 {{required}},你当前等级 {{current}}')}
+                      extraText={t('可用变量:{{provider}} {{field}} {{op}} {{required}} {{current}} 以及 {{current.path}}')}
+                      showClear
+                    />
+                    <Space spacing={8} style={{ marginTop: 8 }}>
+                      <Button size='small' theme='light' onClick={() => applyDeniedTemplate('level_hint')}>
+                        {t('填充模板:等级提示')}
+                      </Button>
+                      <Button size='small' theme='light' onClick={() => applyDeniedTemplate('org_hint')}>
+                        {t('填充模板:组织提示')}
+                      </Button>
+                    </Space>
+                  </Col>
+                </Row>
+              </Collapse.Panel>
+            </Collapse>
           </Form>
         </Modal>
       </Form.Section>

+ 9 - 6
web/src/components/settings/personal/cards/AccountManagement.jsx

@@ -50,6 +50,7 @@ import {
   onLinuxDOOAuthClicked,
   onDiscordOAuthClicked,
   onCustomOAuthClicked,
+  getOAuthProviderIcon,
 } from '../../../../helpers';
 import TwoFASetting from '../components/TwoFASetting';
 
@@ -148,12 +149,14 @@ const AccountManagement = ({
 
   // Check if custom OAuth provider is bound
   const isCustomOAuthBound = (providerId) => {
-    return customOAuthBindings.some((b) => b.provider_id === providerId);
+    const normalizedId = Number(providerId);
+    return customOAuthBindings.some((b) => Number(b.provider_id) === normalizedId);
   };
 
   // Get binding info for a provider
   const getCustomOAuthBinding = (providerId) => {
-    return customOAuthBindings.find((b) => b.provider_id === providerId);
+    const normalizedId = Number(providerId);
+    return customOAuthBindings.find((b) => Number(b.provider_id) === normalizedId);
   };
 
   React.useEffect(() => {
@@ -524,10 +527,10 @@ const AccountManagement = ({
                       <div className='flex items-center justify-between gap-3'>
                         <div className='flex items-center flex-1 min-w-0'>
                           <div className='w-10 h-10 rounded-full bg-slate-100 dark:bg-slate-700 flex items-center justify-center mr-3 flex-shrink-0'>
-                            <IconLock
-                              size='default'
-                              className='text-slate-600 dark:text-slate-300'
-                            />
+                            {getOAuthProviderIcon(
+                              provider.icon || binding?.provider_icon || '',
+                              20,
+                            )}
                           </div>
                           <div className='flex-1 min-w-0'>
                             <div className='font-medium text-gray-900'>

+ 7 - 0
web/src/components/settings/personal/cards/NotificationSettings.jsx

@@ -86,6 +86,7 @@ const NotificationSettings = ({
       channel: true,
       models: true,
       deployment: true,
+      subscription: true,
       redemption: true,
       user: true,
       setting: true,
@@ -169,6 +170,7 @@ const NotificationSettings = ({
         channel: true,
         models: true,
         deployment: true,
+        subscription: true,
         redemption: true,
         user: true,
         setting: true,
@@ -296,6 +298,11 @@ const NotificationSettings = ({
           title: t('模型部署'),
           description: t('模型部署管理'),
         },
+        {
+          key: 'subscription',
+          title: t('订阅管理'),
+          description: t('订阅套餐管理'),
+        },
         {
           key: 'redemption',
           title: t('兑换码管理'),

+ 148 - 1
web/src/components/table/channels/modals/EditChannelModal.jsx

@@ -62,9 +62,14 @@ import CodexOAuthModal from './CodexOAuthModal';
 import ParamOverrideEditorModal from './ParamOverrideEditorModal';
 import JSONEditor from '../../../common/ui/JSONEditor';
 import SecureVerificationModal from '../../../common/modals/SecureVerificationModal';
+import StatusCodeRiskGuardModal from './StatusCodeRiskGuardModal';
 import ChannelKeyDisplay from '../../../common/ui/ChannelKeyDisplay';
 import { useSecureVerification } from '../../../../hooks/common/useSecureVerification';
 import { createApiCalls } from '../../../../services/secureVerification';
+import {
+  collectInvalidStatusCodeEntries,
+  collectNewDisallowedStatusCodeRedirects,
+} from './statusCodeRiskGuard';
 import {
   IconSave,
   IconClose,
@@ -195,6 +200,8 @@ const EditChannelModal = (props) => {
     allow_service_tier: false,
     disable_store: false, // false = 允许透传(默认开启)
     allow_safety_identifier: false,
+    allow_include_obfuscation: false,
+    allow_inference_geo: false,
     claude_beta_query: false,
   };
   const [batch, setBatch] = useState(false);
@@ -209,6 +216,7 @@ const EditChannelModal = (props) => {
   const [fullModels, setFullModels] = useState([]);
   const [modelGroups, setModelGroups] = useState([]);
   const [customModel, setCustomModel] = useState('');
+  const [modelSearchValue, setModelSearchValue] = useState('');
   const [modalImageUrl, setModalImageUrl] = useState('');
   const [isModalOpenurl, setIsModalOpenurl] = useState(false);
   const [modelModalVisible, setModelModalVisible] = useState(false);
@@ -249,6 +257,25 @@ const EditChannelModal = (props) => {
       return [];
     }
   }, [inputs.model_mapping]);
+  const modelSearchMatchedCount = useMemo(() => {
+    const keyword = modelSearchValue.trim();
+    if (!keyword) {
+      return modelOptions.length;
+    }
+    return modelOptions.reduce(
+      (count, option) => count + (selectFilter(keyword, option) ? 1 : 0),
+      0,
+    );
+  }, [modelOptions, modelSearchValue]);
+  const modelSearchHintText = useMemo(() => {
+    const keyword = modelSearchValue.trim();
+    if (!keyword || modelSearchMatchedCount !== 0) {
+      return '';
+    }
+    return t('未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加', {
+      name: keyword,
+    });
+  }, [modelSearchMatchedCount, modelSearchValue, t]);
   const paramOverrideMeta = useMemo(() => {
     const raw =
       typeof inputs.param_override === 'string'
@@ -338,6 +365,12 @@ const EditChannelModal = (props) => {
     window.open(targetUrl, '_blank', 'noopener');
   };
   const [verifyLoading, setVerifyLoading] = useState(false);
+  const statusCodeRiskConfirmResolverRef = useRef(null);
+  const [statusCodeRiskConfirmVisible, setStatusCodeRiskConfirmVisible] =
+    useState(false);
+  const [statusCodeRiskDetailItems, setStatusCodeRiskDetailItems] = useState(
+    [],
+  );
 
   // 表单块导航相关状态
   const formSectionRefs = useRef({
@@ -359,6 +392,7 @@ const EditChannelModal = (props) => {
   const doubaoApiClickCountRef = useRef(0);
   const initialModelsRef = useRef([]);
   const initialModelMappingRef = useRef('');
+  const initialStatusCodeMappingRef = useRef('');
 
   // 2FA状态更新辅助函数
   const updateTwoFAState = (updates) => {
@@ -811,6 +845,10 @@ const EditChannelModal = (props) => {
           data.disable_store = parsedSettings.disable_store || false;
           data.allow_safety_identifier =
             parsedSettings.allow_safety_identifier || false;
+          data.allow_include_obfuscation =
+            parsedSettings.allow_include_obfuscation || false;
+          data.allow_inference_geo =
+            parsedSettings.allow_inference_geo || false;
           data.claude_beta_query = parsedSettings.claude_beta_query || false;
         } catch (error) {
           console.error('解析其他设置失败:', error);
@@ -822,6 +860,8 @@ const EditChannelModal = (props) => {
           data.allow_service_tier = false;
           data.disable_store = false;
           data.allow_safety_identifier = false;
+          data.allow_include_obfuscation = false;
+          data.allow_inference_geo = false;
           data.claude_beta_query = false;
         }
       } else {
@@ -832,6 +872,8 @@ const EditChannelModal = (props) => {
         data.allow_service_tier = false;
         data.disable_store = false;
         data.allow_safety_identifier = false;
+        data.allow_include_obfuscation = false;
+        data.allow_inference_geo = false;
         data.claude_beta_query = false;
       }
 
@@ -868,6 +910,7 @@ const EditChannelModal = (props) => {
         .map((model) => (model || '').trim())
         .filter(Boolean);
       initialModelMappingRef.current = data.model_mapping || '';
+      initialStatusCodeMappingRef.current = data.status_code_mapping || '';
 
       let parsedIonet = null;
       if (data.other_info) {
@@ -1173,6 +1216,7 @@ const EditChannelModal = (props) => {
   }, [inputs]);
 
   useEffect(() => {
+    setModelSearchValue('');
     if (props.visible) {
       if (isEdit) {
         loadChannel();
@@ -1194,11 +1238,22 @@ const EditChannelModal = (props) => {
     if (!isEdit) {
       initialModelsRef.current = [];
       initialModelMappingRef.current = '';
+      initialStatusCodeMappingRef.current = '';
     }
   }, [isEdit, props.visible]);
 
+  useEffect(() => {
+    return () => {
+      if (statusCodeRiskConfirmResolverRef.current) {
+        statusCodeRiskConfirmResolverRef.current(false);
+        statusCodeRiskConfirmResolverRef.current = null;
+      }
+    };
+  }, []);
+
   // 统一的模态框重置函数
   const resetModalState = () => {
+    resolveStatusCodeRiskConfirm(false);
     formApiRef.current?.reset();
     // 重置渠道设置状态
     setChannelSettings({
@@ -1216,6 +1271,7 @@ const EditChannelModal = (props) => {
     // 重置豆包隐藏入口状态
     setDoubaoApiEditUnlocked(false);
     doubaoApiClickCountRef.current = 0;
+    setModelSearchValue('');
     // 清空表单中的key_mode字段
     if (formApiRef.current) {
       formApiRef.current.setValue('key_mode', undefined);
@@ -1328,6 +1384,22 @@ const EditChannelModal = (props) => {
       });
     });
 
+  const resolveStatusCodeRiskConfirm = (confirmed) => {
+    setStatusCodeRiskConfirmVisible(false);
+    setStatusCodeRiskDetailItems([]);
+    if (statusCodeRiskConfirmResolverRef.current) {
+      statusCodeRiskConfirmResolverRef.current(confirmed);
+      statusCodeRiskConfirmResolverRef.current = null;
+    }
+  };
+
+  const confirmStatusCodeRisk = (detailItems) =>
+    new Promise((resolve) => {
+      statusCodeRiskConfirmResolverRef.current = resolve;
+      setStatusCodeRiskDetailItems(detailItems);
+      setStatusCodeRiskConfirmVisible(true);
+    });
+
   const hasModelConfigChanged = (normalizedModels, modelMappingStr) => {
     if (!isEdit) return true;
     const initialModels = initialModelsRef.current;
@@ -1518,6 +1590,27 @@ const EditChannelModal = (props) => {
       }
     }
 
+    const invalidStatusCodeEntries = collectInvalidStatusCodeEntries(
+      localInputs.status_code_mapping,
+    );
+    if (invalidStatusCodeEntries.length > 0) {
+      showError(
+        `${t('状态码复写包含无效的状态码')}: ${invalidStatusCodeEntries.join(', ')}`,
+      );
+      return;
+    }
+
+    const riskyStatusCodeRedirects = collectNewDisallowedStatusCodeRedirects(
+      initialStatusCodeMappingRef.current,
+      localInputs.status_code_mapping,
+    );
+    if (riskyStatusCodeRedirects.length > 0) {
+      const confirmed = await confirmStatusCodeRisk(riskyStatusCodeRedirects);
+      if (!confirmed) {
+        return;
+      }
+    }
+
     if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
       localInputs.base_url = localInputs.base_url.slice(
         0,
@@ -1570,13 +1663,16 @@ const EditChannelModal = (props) => {
     // type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值)
     if (localInputs.type === 1 || localInputs.type === 14) {
       settings.allow_service_tier = localInputs.allow_service_tier === true;
-      // 仅 OpenAI 渠道需要 store 和 safety_identifier
+      // 仅 OpenAI 渠道需要 store / safety_identifier / include_obfuscation
       if (localInputs.type === 1) {
         settings.disable_store = localInputs.disable_store === true;
         settings.allow_safety_identifier =
           localInputs.allow_safety_identifier === true;
+        settings.allow_include_obfuscation =
+          localInputs.allow_include_obfuscation === true;
       }
       if (localInputs.type === 14) {
+        settings.allow_inference_geo = localInputs.allow_inference_geo === true;
         settings.claude_beta_query = localInputs.claude_beta_query === true;
       }
     }
@@ -1599,6 +1695,8 @@ const EditChannelModal = (props) => {
     delete localInputs.allow_service_tier;
     delete localInputs.disable_store;
     delete localInputs.allow_safety_identifier;
+    delete localInputs.allow_include_obfuscation;
+    delete localInputs.allow_inference_geo;
     delete localInputs.claude_beta_query;
 
     let res;
@@ -2917,9 +3015,18 @@ const EditChannelModal = (props) => {
                       rules={[{ required: true, message: t('请选择模型') }]}
                       multiple
                       filter={selectFilter}
+                      allowCreate
                       autoClearSearchValue={false}
                       searchPosition='dropdown'
                       optionList={modelOptions}
+                      onSearch={(value) => setModelSearchValue(value)}
+                      innerBottomSlot={
+                        modelSearchHintText ? (
+                          <Text className='px-3 py-2 block text-xs !text-semi-color-text-2'>
+                            {modelSearchHintText}
+                          </Text>
+                        ) : null
+                      }
                       style={{ width: '100%' }}
                       onChange={(value) => handleInputChange('models', value)}
                       renderSelectedItem={(optionNode) => {
@@ -3444,6 +3551,24 @@ const EditChannelModal = (props) => {
                             'safety_identifier 字段用于帮助 OpenAI 识别可能违反使用政策的应用程序用户。默认关闭以保护用户隐私',
                           )}
                         />
+
+                        <Form.Switch
+                          field='allow_include_obfuscation'
+                          label={t(
+                            '允许 stream_options.include_obfuscation 透传',
+                          )}
+                          checkedText={t('开')}
+                          uncheckedText={t('关')}
+                          onChange={(value) =>
+                            handleChannelOtherSettingsChange(
+                              'allow_include_obfuscation',
+                              value,
+                            )
+                          }
+                          extraText={t(
+                            'include_obfuscation 用于控制 Responses 流混淆字段。默认关闭以避免客户端关闭该安全保护',
+                          )}
+                        />
                       </>
                     )}
 
@@ -3469,6 +3594,22 @@ const EditChannelModal = (props) => {
                             'service_tier 字段用于指定服务层级,允许透传可能导致实际计费高于预期。默认关闭以避免额外费用',
                           )}
                         />
+
+                        <Form.Switch
+                          field='allow_inference_geo'
+                          label={t('允许 inference_geo 透传')}
+                          checkedText={t('开')}
+                          uncheckedText={t('关')}
+                          onChange={(value) =>
+                            handleChannelOtherSettingsChange(
+                              'allow_inference_geo',
+                              value,
+                            )
+                          }
+                          extraText={t(
+                            'inference_geo 字段用于控制 Claude 数据驻留推理区域。默认关闭以避免未经授权透传地域信息',
+                          )}
+                        />
                       </>
                     )}
                   </Card>
@@ -3613,6 +3754,12 @@ const EditChannelModal = (props) => {
           onVisibleChange={(visible) => setIsModalOpenurl(visible)}
         />
       </SideSheet>
+      <StatusCodeRiskGuardModal
+        visible={statusCodeRiskConfirmVisible}
+        detailItems={statusCodeRiskDetailItems}
+        onCancel={() => resolveStatusCodeRiskConfirm(false)}
+        onConfirm={() => resolveStatusCodeRiskConfirm(true)}
+      />
       {/* 使用通用安全验证模态框 */}
       <SecureVerificationModal
         visible={isModalVisible}

Einige Dateien werden nicht angezeigt, da zu viele Dateien in diesem Diff geändert wurden.