Przeglądaj źródła

fix: 修复mj错误返还费用问题

CaIon 2 lat temu
rodzic
commit
2a9c3ac6af
3 zmienionych plików z 26 dodań i 7 usunięć
  1. 16 5
      controller/midjourney.go
  2. 1 0
      controller/relay-mj.go
  3. 9 2
      model/midjourney.go

+ 16 - 5
controller/midjourney.go

@@ -154,7 +154,7 @@ func UpdateMidjourneyTaskBulk() {
 			log.Printf("UpdateMidjourneyTask panic: %v", err)
 		}
 	}()
-	imageModel := "midjourney"
+	//imageModel := "midjourney"
 	ctx := context.TODO()
 	for {
 		time.Sleep(time.Duration(15) * time.Second)
@@ -167,13 +167,27 @@ func UpdateMidjourneyTaskBulk() {
 		common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
 		taskChannelM := make(map[int][]string)
 		taskM := make(map[string]*model.Midjourney)
+		nullTaskIds := make([]int, 0)
 		for _, task := range tasks {
 			if task.MjId == "" {
+				// 统计失败的未完成任务
+				nullTaskIds = append(nullTaskIds, task.Id)
 				continue
 			}
 			taskM[task.MjId] = task
 			taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
 		}
+		if len(nullTaskIds) > 0 {
+			err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
+				"status":   "FAILURE",
+				"progress": "100%",
+			})
+			if err != nil {
+				common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
+			} else {
+				common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
+			}
+		}
 		if len(taskChannelM) == 0 {
 			continue
 		}
@@ -256,10 +270,7 @@ func UpdateMidjourneyTaskBulk() {
 					if err != nil {
 						common.LogError(ctx, "error update user quota cache: "+err.Error())
 					} else {
-						modelRatio := common.GetModelRatio(imageModel)
-						groupRatio := common.GetGroupRatio("default")
-						ratio := modelRatio * groupRatio
-						quota := int(ratio * 1 * 1000)
+						quota := task.Quota
 						if quota != 0 {
 							err = model.IncreaseUserQuota(task.UserId, quota)
 							if err != nil {

+ 1 - 0
controller/relay-mj.go

@@ -544,6 +544,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
 		Progress:    "0%",
 		FailReason:  "",
 		ChannelId:   c.GetInt("channel_id"),
+		Quota:       quota,
 	}
 
 	if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {

+ 9 - 2
model/midjourney.go

@@ -18,6 +18,7 @@ type Midjourney struct {
 	Progress    string `json:"progress"`
 	FailReason  string `json:"fail_reason"`
 	ChannelId   int    `json:"channel_id"`
+	Quota       int    `json:"quota"`
 }
 
 // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
@@ -152,8 +153,14 @@ func (midjourney *Midjourney) Update() error {
 	return err
 }
 
-func MjBulkUpdate(taskIDs []string, params map[string]any) error {
+func MjBulkUpdate(mjIds []string, params map[string]any) error {
 	return DB.Model(&Midjourney{}).
-		Where("mj_id in (?)", taskIDs).
+		Where("mj_id in (?)", mjIds).
+		Updates(params).Error
+}
+
+func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
+	return DB.Model(&Midjourney{}).
+		Where("id in (?)", taskIDs).
 		Updates(params).Error
 }