Prechádzať zdrojové kódy

feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks

1. Async task model redirection (aligned with sync tasks):
   - Integrate ModelMappedHelper in RelayTaskSubmit after model name
     determination, populating OriginModelName / UpstreamModelName on RelayInfo.
   - All task adaptors now send UpstreamModelName to upstream providers:
     - Gemini & Vertex: BuildRequestURL uses UpstreamModelName.
     - Doubao & Ali: BuildRequestBody conditionally overwrites body.Model.
     - Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo
       and unconditionally uses info.UpstreamModelName.
     - Sora: BuildRequestBody parses JSON and multipart bodies to replace
       the "model" field with UpstreamModelName.
   - Frontend log visibility: LogTaskConsumption and taskBillingOther now
     emit is_model_mapped / upstream_model_name in the "other" JSON field.
   - Billing safety: RecalculateTaskQuotaByTokens reads model name from
     BillingContext.OriginModelName (via taskModelName) instead of
     task.Data["model"], preventing billing leaks from upstream model names.

2. Per-call billing (TaskPricePatches lifecycle):
   - Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling
     bool field, populated from TaskPricePatches at submission time.
   - settleTaskBillingOnComplete short-circuits when PerCallBilling is true,
     skipping both adaptor adjustments and token-based recalculation.
   - Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName
     consistently in controller/relay.go for billing context and logging.

3. Multipart retry boundary mismatch fix:
   - Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a
     new boundary and overwrites c.Request.Header["Content-Type"], subsequent
     calls to ParseMultipartFormReusable on retry would parse the cached
     original body with the wrong boundary, causing "NextPart: EOF".
   - Fix: ParseMultipartFormReusable now caches the original Content-Type in
     gin context key "_original_multipart_ct" on first call and reuses it for
     all subsequent parses, making multipart parsing retry-safe globally.
   - Sora adaptor reverted to the standard pattern (direct header set/get),
     which is now safe thanks to the root fix.

4. Tests:
   - task_billing_test.go: update makeTask to use OriginModelName; add
     PerCallBilling settlement tests (skip adaptor adjust, skip token recalc);
     add non-per-call adaptor adjustment test with refund verification.
CaIon 1 týždeň pred
rodič
commit
ec5c6b28ea

+ 9 - 1
common/gin.go

@@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
 		return nil, err
 	}
 
-	contentType := c.Request.Header.Get("Content-Type")
+	// Use the original Content-Type saved on first call to avoid boundary
+	// mismatch when callers overwrite c.Request.Header after multipart rebuild.
+	var contentType string
+	if saved, ok := c.Get("_original_multipart_ct"); ok {
+		contentType = saved.(string)
+	} else {
+		contentType = c.Request.Header.Get("Content-Type")
+		c.Set("_original_multipart_ct", contentType)
+	}
 	boundary, err := parseBoundary(contentType)
 	if err != nil {
 		return nil, err

+ 9 - 8
controller/relay.go

@@ -518,7 +518,7 @@ func RelayTask(c *gin.Context) {
 		}
 
 		addUsedChannel(c, channel.Id)
-		requestBody, bodyErr := common.GetRequestBody(c)
+		bodyStorage, bodyErr := common.GetBodyStorage(c)
 		if bodyErr != nil {
 			if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
 				taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
@@ -527,7 +527,7 @@ func RelayTask(c *gin.Context) {
 			}
 			break
 		}
-		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+		c.Request.Body = io.NopCloser(bodyStorage)
 
 		result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
 		if taskErr == nil {
@@ -557,7 +557,7 @@ func RelayTask(c *gin.Context) {
 		if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
 			common.SysError("settle task billing error: " + settleErr.Error())
 		}
-		service.LogTaskConsumption(c, relayInfo, result.ModelName)
+		service.LogTaskConsumption(c, relayInfo)
 
 		task := model.InitTask(result.Platform, relayInfo)
 		task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
@@ -565,11 +565,12 @@ func RelayTask(c *gin.Context) {
 		task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
 		task.PrivateData.TokenId = relayInfo.TokenId
 		task.PrivateData.BillingContext = &model.TaskBillingContext{
-			ModelPrice:  relayInfo.PriceData.ModelPrice,
-			GroupRatio:  relayInfo.PriceData.GroupRatioInfo.GroupRatio,
-			ModelRatio:  relayInfo.PriceData.ModelRatio,
-			OtherRatios: relayInfo.PriceData.OtherRatios,
-			ModelName:   result.ModelName,
+			ModelPrice:      relayInfo.PriceData.ModelPrice,
+			GroupRatio:      relayInfo.PriceData.GroupRatioInfo.GroupRatio,
+			ModelRatio:      relayInfo.PriceData.ModelRatio,
+			OtherRatios:     relayInfo.PriceData.OtherRatios,
+			OriginModelName: relayInfo.OriginModelName,
+			PerCallBilling:  common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
 		}
 		task.Quota = result.Quota
 		task.Data = result.TaskData

+ 23 - 3
controller/task.go

@@ -9,6 +9,7 @@ import (
 	"github.com/QuantumNous/new-api/model"
 	"github.com/QuantumNous/new-api/relay"
 	"github.com/QuantumNous/new-api/service"
+	"github.com/QuantumNous/new-api/types"
 
 	"github.com/gin-gonic/gin"
 )
@@ -37,7 +38,7 @@ func GetAllTask(c *gin.Context) {
 	items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
 	total := model.TaskCountAllTasks(queryParams)
 	pageInfo.SetTotal(int(total))
-	pageInfo.SetItems(tasksToDto(items))
+	pageInfo.SetItems(tasksToDto(items, true))
 	common.ApiSuccess(c, pageInfo)
 }
 
@@ -61,13 +62,32 @@ func GetUserTask(c *gin.Context) {
 	items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
 	total := model.TaskCountAllUserTask(userId, queryParams)
 	pageInfo.SetTotal(int(total))
-	pageInfo.SetItems(tasksToDto(items))
+	pageInfo.SetItems(tasksToDto(items, false))
 	common.ApiSuccess(c, pageInfo)
 }
 
-func tasksToDto(tasks []*model.Task) []*dto.TaskDto {
+func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto {
+	var userIdMap map[int]*model.UserBase
+	if fillUser {
+		userIdMap = make(map[int]*model.UserBase)
+		userIds := types.NewSet[int]()
+		for _, task := range tasks {
+			userIds.Add(task.UserId)
+		}
+		for _, userId := range userIds.Items() {
+			cacheUser, err := model.GetUserCache(userId)
+			if err == nil {
+				userIdMap[userId] = cacheUser
+			}
+		}
+	}
 	result := make([]*dto.TaskDto, len(tasks))
 	for i, task := range tasks {
+		if fillUser {
+			if user, ok := userIdMap[task.UserId]; ok {
+				task.Username = user.Username
+			}
+		}
 		result[i] = relay.TaskModel2Dto(task)
 	}
 	return result

+ 6 - 5
model/task.go

@@ -109,11 +109,12 @@ type TaskPrivateData struct {
 
 // TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
 type TaskBillingContext struct {
-	ModelPrice  float64            `json:"model_price,omitempty"`  // 模型单价
-	GroupRatio  float64            `json:"group_ratio,omitempty"`  // 分组倍率
-	ModelRatio  float64            `json:"model_ratio,omitempty"`  // 模型倍率
-	OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
-	ModelName   string             `json:"model_name,omitempty"`   // 模型名称
+	ModelPrice      float64            `json:"model_price,omitempty"`       // 模型单价
+	GroupRatio      float64            `json:"group_ratio,omitempty"`       // 分组倍率
+	ModelRatio      float64            `json:"model_ratio,omitempty"`       // 模型倍率
+	OtherRatios     map[string]float64 `json:"other_ratios,omitempty"`      // 附加倍率(时长、分辨率等)
+	OriginModelName string             `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName
+	PerCallBilling  bool               `json:"per_call_billing,omitempty"`  // 按次计费:跳过轮询阶段的差额结算
 }
 
 // GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)

+ 6 - 2
relay/channel/task/ali/adaptor.go

@@ -253,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
 }
 
 func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
+	upstreamModel := req.Model
+	if info.IsModelMapped {
+		upstreamModel = info.UpstreamModelName
+	}
 	aliReq := &AliVideoRequest{
-		Model: req.Model,
+		Model: upstreamModel,
 		Input: AliVideoInput{
 			Prompt: req.Prompt,
 			ImgURL: req.InputReference,
@@ -332,7 +336,7 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
 		}
 	}
 
-	if aliReq.Model != req.Model {
+	if aliReq.Model != upstreamModel {
 		return nil, errors.New("can't change model with metadata")
 	}
 

+ 5 - 1
relay/channel/task/doubao/adaptor.go

@@ -131,7 +131,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
-	info.UpstreamModelName = body.Model
+	if info.IsModelMapped {
+		body.Model = info.UpstreamModelName
+	} else {
+		info.UpstreamModelName = body.Model
+	}
 	data, err := common.Marshal(body)
 	if err != nil {
 		return nil, err

+ 1 - 1
relay/channel/task/gemini/adaptor.go

@@ -105,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 
 // BuildRequestURL constructs the upstream URL.
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	modelName := info.OriginModelName
+	modelName := info.UpstreamModelName
 	version := model_setting.GetGeminiVersionSetting(modelName)
 
 	return fmt.Sprintf(

+ 4 - 4
relay/channel/task/hailuo/adaptor.go

@@ -61,7 +61,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		return nil, fmt.Errorf("invalid request type in context")
 	}
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
@@ -142,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
 	return ChannelName
 }
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
-	modelConfig := GetModelConfig(req.Model)
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
+	modelConfig := GetModelConfig(info.UpstreamModelName)
 	duration := DefaultDuration
 	if req.Duration > 0 {
 		duration = req.Duration
@@ -154,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 
 	videoRequest := &VideoRequest{
-		Model:      req.Model,
+		Model:      info.UpstreamModelName,
 		Prompt:     req.Prompt,
 		Duration:   &duration,
 		Resolution: resolution,

+ 3 - 3
relay/channel/task/jimeng/adaptor.go

@@ -165,7 +165,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 		}
 	}
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, errors.Wrap(err, "convert request payload failed")
 	}
@@ -378,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
 	return h.Sum(nil)
 }
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
-		ReqKey: req.Model,
+		ReqKey: info.UpstreamModelName,
 		Prompt: req.Prompt,
 	}
 

+ 5 - 4
relay/channel/task/kling/adaptor.go

@@ -150,7 +150,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 	req := v.(relaycommon.TaskSubmitReq)
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, err
 	}
@@ -248,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
 		Prompt:         req.Prompt,
 		Image:          req.Image,
 		Mode:           taskcommon.DefaultString(req.Mode, "std"),
 		Duration:       fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
 		AspectRatio:    a.getAspectRatio(req.Size),
-		ModelName:      req.Model,
-		Model:          req.Model, // Keep consistent with model_name, double writing improves compatibility
+		ModelName:      info.UpstreamModelName,
+		Model:          info.UpstreamModelName,
 		CfgScale:       0.5,
 		StaticMask:     "",
 		DynamicMasks:   []DynamicMask{},
@@ -266,6 +266,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
 	}
 	if r.ModelName == "" {
 		r.ModelName = "kling-v1"
+		r.Model = "kling-v1"
 	}
 	if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
 		return nil, errors.Wrap(err, "unmarshal metadata failed")

+ 55 - 0
relay/channel/task/sora/adaptor.go

@@ -1,8 +1,10 @@
 package sora
 
 import (
+	"bytes"
 	"fmt"
 	"io"
+	"mime/multipart"
 	"net/http"
 	"strconv"
 	"strings"
@@ -145,6 +147,59 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, errors.Wrap(err, "get_request_body_failed")
 	}
+	cachedBody, err := storage.Bytes()
+	if err != nil {
+		return nil, errors.Wrap(err, "read_body_bytes_failed")
+	}
+	contentType := c.GetHeader("Content-Type")
+
+	if strings.HasPrefix(contentType, "application/json") {
+		var bodyMap map[string]interface{}
+		if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
+			bodyMap["model"] = info.UpstreamModelName
+			if newBody, err := common.Marshal(bodyMap); err == nil {
+				return bytes.NewReader(newBody), nil
+			}
+		}
+		return bytes.NewReader(cachedBody), nil
+	}
+
+	if strings.Contains(contentType, "multipart/form-data") {
+		formData, err := common.ParseMultipartFormReusable(c)
+		if err != nil {
+			return bytes.NewReader(cachedBody), nil
+		}
+		var buf bytes.Buffer
+		writer := multipart.NewWriter(&buf)
+		writer.WriteField("model", info.UpstreamModelName)
+		for key, values := range formData.Value {
+			if key == "model" {
+				continue
+			}
+			for _, v := range values {
+				writer.WriteField(key, v)
+			}
+		}
+		for fieldName, fileHeaders := range formData.File {
+			for _, fh := range fileHeaders {
+				f, err := fh.Open()
+				if err != nil {
+					continue
+				}
+				part, err := writer.CreateFormFile(fieldName, fh.Filename)
+				if err != nil {
+					f.Close()
+					continue
+				}
+				io.Copy(part, f)
+				f.Close()
+			}
+		}
+		writer.Close()
+		c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+		return &buf, nil
+	}
+
 	return common.ReaderOnly(storage), nil
 }
 

+ 1 - 1
relay/channel/task/vertex/adaptor.go

@@ -86,7 +86,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
 	if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
 		return "", fmt.Errorf("failed to decode credentials: %w", err)
 	}
-	modelName := info.OriginModelName
+	modelName := info.UpstreamModelName
 	if modelName == "" {
 		modelName = "veo-3.0-generate-001"
 	}

+ 3 - 3
relay/channel/task/vidu/adaptor.go

@@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	}
 	req := v.(relaycommon.TaskSubmitReq)
 
-	body, err := a.convertToRequestPayload(&req)
+	body, err := a.convertToRequestPayload(&req, info)
 	if err != nil {
 		return nil, err
 	}
@@ -224,9 +224,9 @@ func (a *TaskAdaptor) GetChannelName() string {
 // helpers
 // ============================
 
-func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
 	r := requestPayload{
-		Model:             taskcommon.DefaultString(req.Model, "viduq1"),
+		Model:             taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
 		Images:            req.Images,
 		Prompt:            req.Prompt,
 		Duration:          taskcommon.DefaultInt(req.Duration, 5),

+ 7 - 2
relay/relay_task.go

@@ -26,7 +26,6 @@ type TaskSubmitResult struct {
 	UpstreamTaskID string
 	TaskData       []byte
 	Platform       constant.TaskPlatform
-	ModelName      string
 	Quota          int
 	//PerCallPrice   types.PriceData
 }
@@ -163,6 +162,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
 		modelName = service.CoverTaskActionToModelName(platform, info.Action)
 	}
 
+	// 2.5 应用渠道的模型映射(与同步任务对齐)
+	info.OriginModelName = modelName
+	info.UpstreamModelName = modelName
+	if err := helper.ModelMappedHelper(c, info, nil); err != nil {
+		return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
+	}
+
 	// 3. 预生成公开 task ID(仅首次)
 	if info.PublicTaskID == "" {
 		info.PublicTaskID = model.GenerateTaskID()
@@ -241,7 +247,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
 		UpstreamTaskID: upstreamTaskID,
 		TaskData:       taskData,
 		Platform:       platform,
-		ModelName:      modelName,
 		Quota:          finalQuota,
 	}, nil
 }

+ 15 - 14
service/task_billing.go

@@ -16,11 +16,11 @@ import (
 
 // LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
 // 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
-func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
+func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
 	tokenName := c.GetString("token_name")
 	logContent := fmt.Sprintf("操作 %s", info.Action)
 	// 支持任务仅按次计费
-	if common.StringsContains(constant.TaskPricePatches, modelName) {
+	if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
 		logContent = fmt.Sprintf("%s,按次计费", logContent)
 	} else {
 		if len(info.PriceData.OtherRatios) > 0 {
@@ -42,9 +42,13 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName s
 	if info.PriceData.GroupRatioInfo.HasSpecialRatio {
 		other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
 	}
+	if info.IsModelMapped {
+		other["is_model_mapped"] = true
+		other["upstream_model_name"] = info.UpstreamModelName
+	}
 	model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
 		ChannelId: info.ChannelId,
-		ModelName: modelName,
+		ModelName: info.OriginModelName,
 		TokenName: tokenName,
 		Quota:     info.PriceData.Quota,
 		Content:   logContent,
@@ -120,13 +124,18 @@ func taskBillingOther(task *model.Task) map[string]interface{} {
 			}
 		}
 	}
+	props := task.Properties
+	if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
+		other["is_model_mapped"] = true
+		other["upstream_model_name"] = props.UpstreamModelName
+	}
 	return other
 }
 
 // taskModelName 从 BillingContext 或 Properties 中获取模型名称。
 func taskModelName(task *model.Task) string {
-	if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" {
-		return bc.ModelName
+	if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
+		return bc.OriginModelName
 	}
 	return task.Properties.OriginModelName
 }
@@ -237,15 +246,7 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo
 		return
 	}
 
-	// 获取模型名称
-	var taskData map[string]interface{}
-	if err := common.Unmarshal(task.Data, &taskData); err != nil {
-		return
-	}
-	modelName, ok := taskData["model"].(string)
-	if !ok || modelName == "" {
-		return
-	}
+	modelName := taskModelName(task)
 
 	// 获取模型价格和倍率
 	modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)

+ 107 - 1
service/task_billing_test.go

@@ -3,12 +3,14 @@ package service
 import (
 	"context"
 	"encoding/json"
+	"net/http"
 	"os"
 	"testing"
 	"time"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/model"
+	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	"github.com/glebarez/sqlite"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
@@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc
 			BillingContext: &model.TaskBillingContext{
 				ModelPrice: 0.02,
 				GroupRatio: 1.0,
-				ModelName:  "test-model",
+				OriginModelName: "test-model",
 			},
 		},
 	}
@@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) {
 	require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
 	assert.Equal(t, "50%", reloaded.Progress)
 }
+
+// ===========================================================================
+// Mock adaptor for settleTaskBillingOnComplete tests
+// ===========================================================================
+
+type mockAdaptor struct {
+	adjustReturn int
+}
+
+func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo)                                            {}
+func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error)  { return nil, nil }
+func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error)                     { return nil, nil }
+func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
+	return m.adjustReturn
+}
+
+// ===========================================================================
+// PerCallBilling tests — settleTaskBillingOnComplete
+// ===========================================================================
+
+func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 30, 30, 30
+	const initQuota, preConsumed = 10000, 5000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.PrivateData.BillingContext.PerCallBilling = true
+
+	adaptor := &mockAdaptor{adjustReturn: 2000}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Per-call: no adjustment despite adaptor returning 2000
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, preConsumed, task.Quota)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 31, 31, 31
+	const initQuota, preConsumed = 10000, 4000
+	const tokenRemain = 7000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	task.PrivateData.BillingContext.PerCallBilling = true
+
+	adaptor := &mockAdaptor{adjustReturn: 0}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Per-call: no recalculation by tokens
+	assert.Equal(t, initQuota, getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, preConsumed, task.Quota)
+	assert.Equal(t, int64(0), countLogs(t))
+}
+
+func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
+	truncate(t)
+	ctx := context.Background()
+
+	const userID, tokenID, channelID = 32, 32, 32
+	const initQuota, preConsumed = 10000, 5000
+	const adaptorQuota = 3000
+	const tokenRemain = 8000
+
+	seedUser(t, userID, initQuota)
+	seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
+	seedChannel(t, channelID)
+
+	task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
+	// PerCallBilling defaults to false
+
+	adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
+	taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
+
+	settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
+
+	// Non-per-call: adaptor adjustment applies (refund 2000)
+	assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
+	assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
+	assert.Equal(t, adaptorQuota, task.Quota)
+
+	log := getLastLog(t)
+	require.NotNil(t, log)
+	assert.Equal(t, model.LogTypeRefund, log.Type)
+}

+ 5 - 0
service/task_polling.go

@@ -467,6 +467,11 @@ func truncateBase64(s string) string {
 //  2. taskResult.TotalTokens > 0 → 按 token 重算
 //  3. 都不满足 → 保持预扣额度不变
 func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
+	// 0. 按次计费的任务不做差额结算
+	if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
+		logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
+		return
+	}
 	// 1. 优先让 adaptor 决定最终额度
 	if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
 		RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")

+ 13 - 23
web/src/components/table/task-logs/TaskLogsColumnDefs.jsx

@@ -84,8 +84,8 @@ function renderDuration(submit_time, finishTime) {
 
   // 返回带有样式的颜色标签
   return (
-    <Tag color={color} shape='circle' prefixIcon={<Clock size={14} />}>
-      {durationSec} 
+    <Tag color={color} shape='circle'>
+      {durationSec} s
     </Tag>
   );
 }
@@ -149,7 +149,7 @@ const renderPlatform = (platform, t) => {
   );
   if (option) {
     return (
-      <Tag color={option.color} shape='circle' prefixIcon={<Video size={14} />}>
+      <Tag color={option.color} shape='circle'>
         {option.label}
       </Tag>
     );
@@ -157,13 +157,13 @@ const renderPlatform = (platform, t) => {
   switch (platform) {
     case 'suno':
       return (
-        <Tag color='green' shape='circle' prefixIcon={<Music size={14} />}>
+        <Tag color='green' shape='circle'>
           Suno
         </Tag>
       );
     default:
       return (
-        <Tag color='white' shape='circle' prefixIcon={<HelpCircle size={14} />}>
+        <Tag color='white' shape='circle'>
           {t('未知')}
         </Tag>
       );
@@ -240,7 +240,6 @@ export const getTaskLogsColumns = ({
   openContentModal,
   isAdminUser,
   openVideoModal,
-  showUserInfoFunc,
 }) => {
   return [
     {
@@ -278,7 +277,6 @@ export const getTaskLogsColumns = ({
               color={colors[parseInt(text) % colors.length]}
               size='large'
               shape='circle'
-              prefixIcon={<Hash size={14} />}
               onClick={() => {
                 copyText(text);
               }}
@@ -294,7 +292,7 @@ export const getTaskLogsColumns = ({
     {
       key: COLUMN_KEYS.USERNAME,
       title: t('用户'),
-      dataIndex: 'user_id',
+      dataIndex: 'username',
       render: (userId, record, index) => {
         if (!isAdminUser) {
           return <></>;
@@ -302,22 +300,14 @@ export const getTaskLogsColumns = ({
         const displayText = String(record.username || userId || '?');
         return (
           <Space>
-            <Tooltip content={displayText}>
-              <Avatar
-                size='extra-small'
-                color={stringToColor(displayText)}
-                style={{ cursor: 'pointer' }}
-                onClick={() => showUserInfoFunc && showUserInfoFunc(userId)}
-              >
-                {displayText.slice(0, 1)}
-              </Avatar>
-            </Tooltip>
-            <Typography.Text
-              ellipsis={{ showTooltip: true }}
-              style={{ cursor: 'pointer', color: 'var(--semi-color-primary)' }}
-              onClick={() => showUserInfoFunc && showUserInfoFunc(userId)}
+            <Avatar
+              size='extra-small'
+              color={stringToColor(displayText)}
             >
-              {userId}
+              {displayText.slice(0, 1)}
+            </Avatar>
+            <Typography.Text>
+              {displayText}
             </Typography.Text>
           </Space>
         );

+ 0 - 2
web/src/components/table/task-logs/index.jsx

@@ -25,7 +25,6 @@ import TaskLogsActions from './TaskLogsActions';
 import TaskLogsFilters from './TaskLogsFilters';
 import ColumnSelectorModal from './modals/ColumnSelectorModal';
 import ContentModal from './modals/ContentModal';
-import UserInfoModal from '../usage-logs/modals/UserInfoModal';
 import { useTaskLogsData } from '../../../hooks/task-logs/useTaskLogsData';
 import { useIsMobile } from '../../../hooks/common/useIsMobile';
 import { createCardProPagination } from '../../../helpers/utils';
@@ -46,7 +45,6 @@ const TaskLogsPage = () => {
         modalContent={taskLogsData.videoUrl}
         isVideo={true}
       />
-      <UserInfoModal {...taskLogsData} />
 
       <Layout>
         <CardPro