task_polling.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "sort"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/common"
  12. "github.com/QuantumNous/new-api/constant"
  13. "github.com/QuantumNous/new-api/dto"
  14. "github.com/QuantumNous/new-api/logger"
  15. "github.com/QuantumNous/new-api/model"
  16. "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
  17. relaycommon "github.com/QuantumNous/new-api/relay/common"
  18. "github.com/samber/lo"
  19. )
  20. // TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
  21. type TaskPollingAdaptor interface {
  22. Init(info *relaycommon.RelayInfo)
  23. FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
  24. ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
  25. // AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
  26. // 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
  27. AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
  28. }
  29. // GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
  30. // 打破 service -> relay -> relay/channel -> service 的循环依赖。
  31. var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
  32. // sweepTimedOutTasks 在主轮询之前独立清理超时任务。
  33. // 每次最多处理 100 条,剩余的下个周期继续处理。
  34. // 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。
  35. func sweepTimedOutTasks(ctx context.Context) {
  36. if constant.TaskTimeoutMinutes <= 0 {
  37. return
  38. }
  39. cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60
  40. tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100)
  41. if len(tasks) == 0 {
  42. return
  43. }
  44. const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC
  45. reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes)
  46. legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)"
  47. now := time.Now().Unix()
  48. timedOutCount := 0
  49. for _, task := range tasks {
  50. isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff
  51. oldStatus := task.Status
  52. task.Status = model.TaskStatusFailure
  53. task.Progress = "100%"
  54. task.FinishTime = now
  55. if isLegacy {
  56. task.FailReason = legacyReason
  57. } else {
  58. task.FailReason = reason
  59. }
  60. won, err := task.UpdateWithStatus(oldStatus)
  61. if err != nil {
  62. logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err))
  63. continue
  64. }
  65. if !won {
  66. logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID))
  67. continue
  68. }
  69. timedOutCount++
  70. if !isLegacy && task.Quota != 0 {
  71. RefundTaskQuota(ctx, task, reason)
  72. }
  73. }
  74. if timedOutCount > 0 {
  75. logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount))
  76. }
  77. }
  78. // TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
  79. func TaskPollingLoop() {
  80. for {
  81. time.Sleep(time.Duration(15) * time.Second)
  82. common.SysLog("任务进度轮询开始")
  83. ctx := context.TODO()
  84. sweepTimedOutTasks(ctx)
  85. allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
  86. platformTask := make(map[constant.TaskPlatform][]*model.Task)
  87. for _, t := range allTasks {
  88. platformTask[t.Platform] = append(platformTask[t.Platform], t)
  89. }
  90. for platform, tasks := range platformTask {
  91. if len(tasks) == 0 {
  92. continue
  93. }
  94. taskChannelM := make(map[int][]string)
  95. taskM := make(map[string]*model.Task)
  96. nullTaskIds := make([]int64, 0)
  97. for _, task := range tasks {
  98. upstreamID := task.GetUpstreamTaskID()
  99. if upstreamID == "" {
  100. // 统计失败的未完成任务
  101. nullTaskIds = append(nullTaskIds, task.ID)
  102. continue
  103. }
  104. taskM[upstreamID] = task
  105. taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
  106. }
  107. if len(nullTaskIds) > 0 {
  108. err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
  109. "status": "FAILURE",
  110. "progress": "100%",
  111. })
  112. if err != nil {
  113. logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
  114. } else {
  115. logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
  116. }
  117. }
  118. if len(taskChannelM) == 0 {
  119. continue
  120. }
  121. DispatchPlatformUpdate(platform, taskChannelM, taskM)
  122. }
  123. common.SysLog("任务进度轮询完成")
  124. }
  125. }
  126. // DispatchPlatformUpdate 按平台分发轮询更新
  127. func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
  128. switch platform {
  129. case constant.TaskPlatformMidjourney:
  130. // MJ 轮询由其自身处理,这里预留入口
  131. case constant.TaskPlatformSuno:
  132. _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
  133. default:
  134. if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
  135. common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
  136. }
  137. }
  138. }
  139. // UpdateSunoTasks 按渠道更新所有 Suno 任务
  140. func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  141. for channelId, taskIds := range taskChannelM {
  142. err := updateSunoTasks(ctx, channelId, taskIds, taskM)
  143. if err != nil {
  144. logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
  145. }
  146. }
  147. return nil
  148. }
  149. func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  150. logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
  151. if len(taskIds) == 0 {
  152. return nil
  153. }
  154. ch, err := model.CacheGetChannel(channelId)
  155. if err != nil {
  156. common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
  157. // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
  158. var failedIDs []int64
  159. for _, upstreamID := range taskIds {
  160. if t, ok := taskM[upstreamID]; ok {
  161. failedIDs = append(failedIDs, t.ID)
  162. }
  163. }
  164. err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
  165. "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
  166. "status": "FAILURE",
  167. "progress": "100%",
  168. })
  169. if err != nil {
  170. common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
  171. }
  172. return err
  173. }
  174. adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
  175. if adaptor == nil {
  176. return errors.New("adaptor not found")
  177. }
  178. proxy := ch.GetSetting().Proxy
  179. resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
  180. "ids": taskIds,
  181. }, proxy)
  182. if err != nil {
  183. common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
  184. return err
  185. }
  186. if resp.StatusCode != http.StatusOK {
  187. logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
  188. return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
  189. }
  190. defer resp.Body.Close()
  191. responseBody, err := io.ReadAll(resp.Body)
  192. if err != nil {
  193. common.SysLog(fmt.Sprintf("Get Suno Task parse body error: %v", err))
  194. return err
  195. }
  196. var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
  197. err = common.Unmarshal(responseBody, &responseItems)
  198. if err != nil {
  199. logger.LogError(ctx, fmt.Sprintf("Get Suno Task parse body error2: %v, body: %s", err, string(responseBody)))
  200. return err
  201. }
  202. if !responseItems.IsSuccess() {
  203. common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
  204. return err
  205. }
  206. for _, responseItem := range responseItems.Data {
  207. task := taskM[responseItem.TaskID]
  208. if !taskNeedsUpdate(task, responseItem) {
  209. continue
  210. }
  211. task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
  212. task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
  213. task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
  214. task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
  215. task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
  216. if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
  217. logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
  218. task.Progress = "100%"
  219. RefundTaskQuota(ctx, task, task.FailReason)
  220. }
  221. if responseItem.Status == model.TaskStatusSuccess {
  222. task.Progress = "100%"
  223. }
  224. task.Data = responseItem.Data
  225. err = task.Update()
  226. if err != nil {
  227. common.SysLog("UpdateSunoTask task error: " + err.Error())
  228. }
  229. }
  230. return nil
  231. }
  232. // taskNeedsUpdate 检查 Suno 任务是否需要更新
  233. func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
  234. if oldTask.SubmitTime != newTask.SubmitTime {
  235. return true
  236. }
  237. if oldTask.StartTime != newTask.StartTime {
  238. return true
  239. }
  240. if oldTask.FinishTime != newTask.FinishTime {
  241. return true
  242. }
  243. if string(oldTask.Status) != newTask.Status {
  244. return true
  245. }
  246. if oldTask.FailReason != newTask.FailReason {
  247. return true
  248. }
  249. if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
  250. return true
  251. }
  252. oldData, _ := common.Marshal(oldTask.Data)
  253. newData, _ := common.Marshal(newTask.Data)
  254. sort.Slice(oldData, func(i, j int) bool {
  255. return oldData[i] < oldData[j]
  256. })
  257. sort.Slice(newData, func(i, j int) bool {
  258. return newData[i] < newData[j]
  259. })
  260. if string(oldData) != string(newData) {
  261. return true
  262. }
  263. return false
  264. }
  265. // UpdateVideoTasks 按渠道更新所有视频任务
  266. func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
  267. for channelId, taskIds := range taskChannelM {
  268. if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
  269. logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
  270. }
  271. }
  272. return nil
  273. }
  274. func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
  275. logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
  276. if len(taskIds) == 0 {
  277. return nil
  278. }
  279. cacheGetChannel, err := model.CacheGetChannel(channelId)
  280. if err != nil {
  281. // Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
  282. var failedIDs []int64
  283. for _, upstreamID := range taskIds {
  284. if t, ok := taskM[upstreamID]; ok {
  285. failedIDs = append(failedIDs, t.ID)
  286. }
  287. }
  288. errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
  289. "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
  290. "status": "FAILURE",
  291. "progress": "100%",
  292. })
  293. if errUpdate != nil {
  294. common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
  295. }
  296. return fmt.Errorf("CacheGetChannel failed: %w", err)
  297. }
  298. adaptor := GetTaskAdaptorFunc(platform)
  299. if adaptor == nil {
  300. return fmt.Errorf("video adaptor not found")
  301. }
  302. info := &relaycommon.RelayInfo{}
  303. info.ChannelMeta = &relaycommon.ChannelMeta{
  304. ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
  305. }
  306. info.ApiKey = cacheGetChannel.Key
  307. adaptor.Init(info)
  308. for _, taskId := range taskIds {
  309. if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
  310. logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
  311. }
  312. // sleep 1 second between each task to avoid hitting rate limits of upstream platforms
  313. time.Sleep(1 * time.Second)
  314. }
  315. return nil
  316. }
  317. func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
  318. baseURL := constant.ChannelBaseURLs[ch.Type]
  319. if ch.GetBaseURL() != "" {
  320. baseURL = ch.GetBaseURL()
  321. }
  322. proxy := ch.GetSetting().Proxy
  323. task := taskM[taskId]
  324. if task == nil {
  325. logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
  326. return fmt.Errorf("task %s not found", taskId)
  327. }
  328. key := ch.Key
  329. privateData := task.PrivateData
  330. if privateData.Key != "" {
  331. key = privateData.Key
  332. }
  333. resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
  334. "task_id": task.GetUpstreamTaskID(),
  335. "action": task.Action,
  336. }, proxy)
  337. if err != nil {
  338. return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
  339. }
  340. defer resp.Body.Close()
  341. responseBody, err := io.ReadAll(resp.Body)
  342. if err != nil {
  343. return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
  344. }
  345. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
  346. snap := task.Snapshot()
  347. taskResult := &relaycommon.TaskInfo{}
  348. // try parse as New API response format
  349. var responseItems dto.TaskResponse[model.Task]
  350. if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
  351. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
  352. t := responseItems.Data
  353. taskResult.TaskID = t.TaskID
  354. taskResult.Status = string(t.Status)
  355. taskResult.Url = t.GetResultURL()
  356. taskResult.Progress = t.Progress
  357. taskResult.Reason = t.FailReason
  358. task.Data = t.Data
  359. } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
  360. return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
  361. }
  362. task.Data = redactVideoResponseBody(responseBody)
  363. logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
  364. now := time.Now().Unix()
  365. if taskResult.Status == "" {
  366. //taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
  367. errorResult := &dto.GeneralErrorResponse{}
  368. if err = common.Unmarshal(responseBody, &errorResult); err == nil {
  369. openaiError := errorResult.TryToOpenAIError()
  370. if openaiError != nil {
  371. // 返回规范的 OpenAI 错误格式,提取错误信息,判断错误是否为任务失败
  372. if openaiError.Code == "429" {
  373. // 429 错误通常表示请求过多或速率限制,暂时不认为是任务失败,保持原状态等待下一轮轮询
  374. return nil
  375. }
  376. // 其他错误认为是任务失败,记录错误信息并更新任务状态
  377. taskResult = relaycommon.FailTaskInfo("upstream returned error")
  378. } else {
  379. // unknown error format, log original response
  380. logger.LogError(ctx, fmt.Sprintf("Task %s returned empty status with unrecognized error format, response: %s", taskId, string(responseBody)))
  381. taskResult = relaycommon.FailTaskInfo("upstream returned unrecognized message")
  382. }
  383. }
  384. }
  385. shouldRefund := false
  386. shouldSettle := false
  387. quota := task.Quota
  388. task.Status = model.TaskStatus(taskResult.Status)
  389. switch taskResult.Status {
  390. case model.TaskStatusSubmitted:
  391. task.Progress = taskcommon.ProgressSubmitted
  392. case model.TaskStatusQueued:
  393. task.Progress = taskcommon.ProgressQueued
  394. case model.TaskStatusInProgress:
  395. task.Progress = taskcommon.ProgressInProgress
  396. if task.StartTime == 0 {
  397. task.StartTime = now
  398. }
  399. case model.TaskStatusSuccess:
  400. task.Progress = taskcommon.ProgressComplete
  401. if task.FinishTime == 0 {
  402. task.FinishTime = now
  403. }
  404. if strings.HasPrefix(taskResult.Url, "data:") {
  405. // data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
  406. task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
  407. } else if taskResult.Url != "" {
  408. // Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
  409. task.PrivateData.ResultURL = taskResult.Url
  410. } else {
  411. // No URL from adaptor — construct proxy URL using public task ID
  412. task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
  413. }
  414. shouldSettle = true
  415. case model.TaskStatusFailure:
  416. logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
  417. task.Status = model.TaskStatusFailure
  418. task.Progress = taskcommon.ProgressComplete
  419. if task.FinishTime == 0 {
  420. task.FinishTime = now
  421. }
  422. task.FailReason = taskResult.Reason
  423. logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
  424. taskResult.Progress = taskcommon.ProgressComplete
  425. if quota != 0 {
  426. shouldRefund = true
  427. }
  428. default:
  429. return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID)
  430. }
  431. if taskResult.Progress != "" {
  432. task.Progress = taskResult.Progress
  433. }
  434. isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure
  435. if isDone && snap.Status != task.Status {
  436. won, err := task.UpdateWithStatus(snap.Status)
  437. if err != nil {
  438. logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error()))
  439. shouldRefund = false
  440. shouldSettle = false
  441. } else if !won {
  442. logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID))
  443. shouldRefund = false
  444. shouldSettle = false
  445. }
  446. } else if !snap.Equal(task.Snapshot()) {
  447. if _, err := task.UpdateWithStatus(snap.Status); err != nil {
  448. logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error()))
  449. }
  450. } else {
  451. // No changes, skip update
  452. logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID))
  453. }
  454. if shouldSettle {
  455. settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
  456. }
  457. if shouldRefund {
  458. RefundTaskQuota(ctx, task, task.FailReason)
  459. }
  460. return nil
  461. }
  462. func redactVideoResponseBody(body []byte) []byte {
  463. var m map[string]any
  464. if err := common.Unmarshal(body, &m); err != nil {
  465. return body
  466. }
  467. resp, _ := m["response"].(map[string]any)
  468. if resp != nil {
  469. delete(resp, "bytesBase64Encoded")
  470. if v, ok := resp["video"].(string); ok {
  471. resp["video"] = truncateBase64(v)
  472. }
  473. if vs, ok := resp["videos"].([]any); ok {
  474. for i := range vs {
  475. if vm, ok := vs[i].(map[string]any); ok {
  476. delete(vm, "bytesBase64Encoded")
  477. }
  478. }
  479. }
  480. }
  481. b, err := common.Marshal(m)
  482. if err != nil {
  483. return body
  484. }
  485. return b
  486. }
  487. func truncateBase64(s string) string {
  488. const maxKeep = 256
  489. if len(s) <= maxKeep {
  490. return s
  491. }
  492. return s[:maxKeep] + "..."
  493. }
  494. // settleTaskBillingOnComplete 任务完成时的统一计费调整。
  495. // 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
  496. //
  497. // 2. taskResult.TotalTokens > 0 → 按 token 重算
  498. // 3. 都不满足 → 保持预扣额度不变
  499. func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
  500. // 0. 按次计费的任务不做差额结算
  501. if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
  502. logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
  503. return
  504. }
  505. // 1. 优先让 adaptor 决定最终额度
  506. if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
  507. RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
  508. return
  509. }
  510. // 2. 回退到 token 重算
  511. if taskResult.TotalTokens > 0 {
  512. RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
  513. return
  514. }
  515. // 3. 无调整,保持预扣额度
  516. }