Переглянути джерело

fix: Resolving conflicts caused by mixing multiple databases

CaIon 8 місяців тому
батько
коміт
b9b4b24961
7 змінених файлів з 129 додано та 88 видалено
  1. 7 0
      common/database.go
  2. 6 6
      model/ability.go
  3. 10 10
      model/channel.go
  4. 19 9
      model/log.go
  5. 82 58
      model/main.go
  6. 2 2
      model/token.go
  7. 3 3
      model/user.go

+ 7 - 0
common/database.go

@@ -1,7 +1,14 @@
 package common
 
+const (
+	DatabaseTypeMySQL      = "mysql"
+	DatabaseTypeSQLite     = "sqlite"
+	DatabaseTypePostgreSQL = "postgres"
+)
+
 var UsingSQLite = false
 var UsingPostgreSQL = false
+var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
 var UsingMySQL = false
 var UsingClickHouse = false
 

+ 6 - 6
model/ability.go

@@ -24,7 +24,7 @@ type Ability struct {
 func GetGroupModels(group string) []string {
 	var models []string
 	// Find distinct models
-	DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
+	DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
 	return models
 }
 
@@ -50,8 +50,8 @@ func getPriority(group string, model string, retry int) (int, error) {
 	var priorities []int
 	err := DB.Model(&Ability{}).
 		Select("DISTINCT(priority)").
-		Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
-		Order("priority DESC"). // 按优先级降序排序
+		Where(commonGroupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
+		Order("priority DESC").              // 按优先级降序排序
 		Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
 
 	if err != nil {
@@ -80,14 +80,14 @@ func getChannelQuery(group string, model string, retry int) *gorm.DB {
 	if common.UsingPostgreSQL {
 		trueVal = "true"
 	}
-	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
-	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
+	channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
 	if retry != 0 {
 		priority, err := getPriority(group, model, retry)
 		if err != nil {
 			common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
 		} else {
-			channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
+			channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
 		}
 	}
 

+ 10 - 10
model/channel.go

@@ -145,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
 	}
 
 	// 构造基础查询
-	baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+	baseQuery := DB.Model(&Channel{}).Omit(commonKeyCol)
 
 	// 构造WHERE子句
 	var whereClause string
@@ -153,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
 	if group != "" && group != "null" {
 		var groupCondition string
 		if common.UsingMySQL {
-			groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+			groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
 		} else {
 			// sqlite, PostgreSQL
-			groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+			groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
 		}
-		whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+		whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
 	} else {
-		whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+		whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
 	}
 
@@ -478,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
 	}
 
 	// 构造基础查询
-	baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+	baseQuery := DB.Model(&Channel{}).Omit(commonKeyCol)
 
 	// 构造WHERE子句
 	var whereClause string
@@ -486,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
 	if group != "" && group != "null" {
 		var groupCondition string
 		if common.UsingMySQL {
-			groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+			groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
 		} else {
 			// sqlite, PostgreSQL
-			groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+			groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
 		}
-		whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+		whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
 	} else {
-		whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+		whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
 		args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
 	}
 

+ 19 - 9
model/log.go

@@ -63,7 +63,7 @@ func formatUserLogs(logs []*Log) {
 func GetLogByKey(key string) (logs []*Log, err error) {
 	if os.Getenv("LOG_SQL_DSN") != "" {
 		var tk Token
-		if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
+		if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
 			return nil, err
 		}
 		err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
@@ -122,8 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
 		UseTime:          useTimeSeconds,
 		IsStream:         isStream,
 		Group:            group,
-		Ip:               func() string { if needRecordIp { return c.ClientIP() }; return "" }(),
-		Other:            otherStr,
+		Ip: func() string {
+			if needRecordIp {
+				return c.ClientIP()
+			}
+			return ""
+		}(),
+		Other: otherStr,
 	}
 	err := LOG_DB.Create(log).Error
 	if err != nil {
@@ -165,8 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
 		UseTime:          useTimeSeconds,
 		IsStream:         isStream,
 		Group:            group,
-		Ip:               func() string { if needRecordIp { return c.ClientIP() }; return "" }(),
-		Other:            otherStr,
+		Ip: func() string {
+			if needRecordIp {
+				return c.ClientIP()
+			}
+			return ""
+		}(),
+		Other: otherStr,
 	}
 	err := LOG_DB.Create(log).Error
 	if err != nil {
@@ -206,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 		tx = tx.Where("logs.channel_id = ?", channel)
 	}
 	if group != "" {
-		tx = tx.Where("logs."+groupCol+" = ?", group)
+		tx = tx.Where("logs."+logGroupCol+" = ?", group)
 	}
 	err = tx.Model(&Log{}).Count(&total).Error
 	if err != nil {
@@ -264,7 +274,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
 		tx = tx.Where("logs.created_at <= ?", endTimestamp)
 	}
 	if group != "" {
-		tx = tx.Where("logs."+groupCol+" = ?", group)
+		tx = tx.Where("logs."+logGroupCol+" = ?", group)
 	}
 	err = tx.Model(&Log{}).Count(&total).Error
 	if err != nil {
@@ -325,8 +335,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 		rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
 	}
 	if group != "" {
-		tx = tx.Where(groupCol+" = ?", group)
-		rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
+		tx = tx.Where(logGroupCol+" = ?", group)
+		rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
 	}
 
 	tx = tx.Where("type = ?", LogTypeConsume)

+ 82 - 58
model/main.go

@@ -1,6 +1,7 @@
 package model
 
 import (
+	"fmt"
 	"log"
 	"one-api/common"
 	"one-api/constant"
@@ -15,18 +16,33 @@ import (
 	"gorm.io/gorm"
 )
 
-var groupCol string
-var keyCol string
+var commonGroupCol string
+var commonKeyCol string
+
+var logKeyCol string
+var logGroupCol string
 
 func initCol() {
+	// init common column names
 	if common.UsingPostgreSQL {
-		groupCol = `"group"`
-		keyCol = `"key"`
-
+		commonGroupCol = `"group"`
+		commonKeyCol = `"key"`
 	} else {
-		groupCol = "`group`"
-		keyCol = "`key`"
+		commonGroupCol = "`group`"
+		commonKeyCol = "`key`"
+	}
+	if DB != LOG_DB {
+		switch common.LogSqlType {
+		case common.DatabaseTypePostgreSQL:
+			logGroupCol = `"group"`
+			logKeyCol = `"key"`
+		default:
+			logGroupCol = commonGroupCol
+			logKeyCol = commonKeyCol
+		}
 	}
+	// log sql type and database type
+	common.SysLog("Using Log SQL Type: " + common.LogSqlType)
 }
 
 var DB *gorm.DB
@@ -83,7 +99,7 @@ func CheckSetup() {
 	}
 }
 
-func chooseDB(envName string) (*gorm.DB, error) {
+func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
 	defer func() {
 		initCol()
 	}()
@@ -92,7 +108,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
 		if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
 			// Use PostgreSQL
 			common.SysLog("using PostgreSQL as database")
-			common.UsingPostgreSQL = true
+			if !isLog {
+				common.UsingPostgreSQL = true
+			} else {
+				common.LogSqlType = common.DatabaseTypePostgreSQL
+			}
 			return gorm.Open(postgres.New(postgres.Config{
 				DSN:                  dsn,
 				PreferSimpleProtocol: true, // disables implicit prepared statement usage
@@ -102,7 +122,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
 		}
 		if strings.HasPrefix(dsn, "local") {
 			common.SysLog("SQL_DSN not set, using SQLite as database")
-			common.UsingSQLite = true
+			if !isLog {
+				common.UsingSQLite = true
+			} else {
+				common.LogSqlType = common.DatabaseTypeSQLite
+			}
 			return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
 				PrepareStmt: true, // precompile SQL
 			})
@@ -117,7 +141,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
 				dsn += "?parseTime=true"
 			}
 		}
-		common.UsingMySQL = true
+		if !isLog {
+			common.UsingMySQL = true
+		} else {
+			common.LogSqlType = common.DatabaseTypeMySQL
+		}
 		return gorm.Open(mysql.Open(dsn), &gorm.Config{
 			PrepareStmt: true, // precompile SQL
 		})
@@ -131,7 +159,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
 }
 
 func InitDB() (err error) {
-	db, err := chooseDB("SQL_DSN")
+	db, err := chooseDB("SQL_DSN", false)
 	if err == nil {
 		if common.DebugEnabled {
 			db = db.Debug()
@@ -149,7 +177,7 @@ func InitDB() (err error) {
 			return nil
 		}
 		if common.UsingMySQL {
-			_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
+			//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
 		}
 		common.SysLog("database migration started")
 		err = migrateDB()
@@ -165,7 +193,7 @@ func InitLogDB() (err error) {
 		LOG_DB = DB
 		return
 	}
-	db, err := chooseDB("LOG_SQL_DSN")
+	db, err := chooseDB("LOG_SQL_DSN", true)
 	if err == nil {
 		if common.DebugEnabled {
 			db = db.Debug()
@@ -198,54 +226,50 @@ func InitLogDB() (err error) {
 }
 
 func migrateDB() error {
-	err := DB.AutoMigrate(&Channel{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Token{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&User{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Option{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Redemption{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Ability{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Log{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&Midjourney{})
-	if err != nil {
-		return err
-	}
-	err = DB.AutoMigrate(&TopUp{})
-	if err != nil {
-		return err
+	var wg sync.WaitGroup
+	errChan := make(chan error, 12) // Buffer size matches number of migrations
+
+	migrations := []struct {
+		model interface{}
+		name  string
+	}{
+		{&Channel{}, "Channel"},
+		{&Token{}, "Token"},
+		{&User{}, "User"},
+		{&Option{}, "Option"},
+		{&Redemption{}, "Redemption"},
+		{&Ability{}, "Ability"},
+		{&Log{}, "Log"},
+		{&Midjourney{}, "Midjourney"},
+		{&TopUp{}, "TopUp"},
+		{&QuotaData{}, "QuotaData"},
+		{&Task{}, "Task"},
+		{&Setup{}, "Setup"},
 	}
-	err = DB.AutoMigrate(&QuotaData{})
-	if err != nil {
-		return err
+
+	for _, m := range migrations {
+		wg.Add(1)
+		go func(model interface{}, name string) {
+			defer wg.Done()
+			if err := DB.AutoMigrate(model); err != nil {
+				errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
+			}
+		}(m.model, m.name)
 	}
-	err = DB.AutoMigrate(&Task{})
-	if err != nil {
-		return err
+
+	// Wait for all migrations to complete
+	wg.Wait()
+	close(errChan)
+
+	// Check for any errors
+	for err := range errChan {
+		if err != nil {
+			return err
+		}
 	}
-	err = DB.AutoMigrate(&Setup{})
+
 	common.SysLog("database migrated")
-	//err = createRootAccountIfNeed()
-	return err
+	return nil
 }
 
 func migrateLOGDB() error {

+ 2 - 2
model/token.go

@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
 	if token != "" {
 		token = strings.Trim(token, "sk-")
 	}
-	err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
+	err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
 	return tokens, err
 }
 
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
 		// Don't return error - fall through to DB
 	}
 	fromDB = true
-	err = DB.Where(keyCol+" = ?", key).First(&token).Error
+	err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
 	return token, err
 }
 

+ 3 - 3
model/user.go

@@ -175,7 +175,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
 		// 如果是数字,同时搜索ID和其他字段
 		likeCondition = "id = ? OR " + likeCondition
 		if group != "" {
-			query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+			query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
 				keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
 		} else {
 			query = query.Where(likeCondition,
@@ -184,7 +184,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
 	} else {
 		// 非数字关键字,只搜索字符串字段
 		if group != "" {
-			query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+			query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
 				"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
 		} else {
 			query = query.Where(likeCondition,
@@ -615,7 +615,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
 		// Don't return error - fall through to DB
 	}
 	fromDB = true
-	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
+	err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
 	if err != nil {
 		return "", err
 	}