|
|
@@ -1,6 +1,7 @@
|
|
|
package model
|
|
|
|
|
|
import (
|
|
|
+ "fmt"
|
|
|
"log"
|
|
|
"one-api/common"
|
|
|
"one-api/constant"
|
|
|
@@ -15,18 +16,33 @@ import (
|
|
|
"gorm.io/gorm"
|
|
|
)
|
|
|
|
|
|
-var groupCol string
|
|
|
-var keyCol string
|
|
|
+var commonGroupCol string
|
|
|
+var commonKeyCol string
|
|
|
+
|
|
|
+var logKeyCol string
|
|
|
+var logGroupCol string
|
|
|
|
|
|
func initCol() {
|
|
|
+ // init common column names
|
|
|
if common.UsingPostgreSQL {
|
|
|
- groupCol = `"group"`
|
|
|
- keyCol = `"key"`
|
|
|
-
|
|
|
+ commonGroupCol = `"group"`
|
|
|
+ commonKeyCol = `"key"`
|
|
|
} else {
|
|
|
- groupCol = "`group`"
|
|
|
- keyCol = "`key`"
|
|
|
+ commonGroupCol = "`group`"
|
|
|
+ commonKeyCol = "`key`"
|
|
|
+ }
|
|
|
+ if DB != LOG_DB {
|
|
|
+ switch common.LogSqlType {
|
|
|
+ case common.DatabaseTypePostgreSQL:
|
|
|
+ logGroupCol = `"group"`
|
|
|
+ logKeyCol = `"key"`
|
|
|
+ default:
|
|
|
+ logGroupCol = commonGroupCol
|
|
|
+ logKeyCol = commonKeyCol
|
|
|
+ }
|
|
|
}
|
|
|
+ // log sql type and database type
|
|
|
+ common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
|
|
}
|
|
|
|
|
|
var DB *gorm.DB
|
|
|
@@ -83,7 +99,7 @@ func CheckSetup() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func chooseDB(envName string) (*gorm.DB, error) {
|
|
|
+func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
|
|
defer func() {
|
|
|
initCol()
|
|
|
}()
|
|
|
@@ -92,7 +108,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
|
|
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
|
|
// Use PostgreSQL
|
|
|
common.SysLog("using PostgreSQL as database")
|
|
|
- common.UsingPostgreSQL = true
|
|
|
+ if !isLog {
|
|
|
+ common.UsingPostgreSQL = true
|
|
|
+ } else {
|
|
|
+ common.LogSqlType = common.DatabaseTypePostgreSQL
|
|
|
+ }
|
|
|
return gorm.Open(postgres.New(postgres.Config{
|
|
|
DSN: dsn,
|
|
|
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
|
|
@@ -102,7 +122,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
|
|
}
|
|
|
if strings.HasPrefix(dsn, "local") {
|
|
|
common.SysLog("SQL_DSN not set, using SQLite as database")
|
|
|
- common.UsingSQLite = true
|
|
|
+ if !isLog {
|
|
|
+ common.UsingSQLite = true
|
|
|
+ } else {
|
|
|
+ common.LogSqlType = common.DatabaseTypeSQLite
|
|
|
+ }
|
|
|
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
|
|
PrepareStmt: true, // precompile SQL
|
|
|
})
|
|
|
@@ -117,7 +141,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
|
|
dsn += "?parseTime=true"
|
|
|
}
|
|
|
}
|
|
|
- common.UsingMySQL = true
|
|
|
+ if !isLog {
|
|
|
+ common.UsingMySQL = true
|
|
|
+ } else {
|
|
|
+ common.LogSqlType = common.DatabaseTypeMySQL
|
|
|
+ }
|
|
|
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
|
PrepareStmt: true, // precompile SQL
|
|
|
})
|
|
|
@@ -131,7 +159,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
|
|
}
|
|
|
|
|
|
func InitDB() (err error) {
|
|
|
- db, err := chooseDB("SQL_DSN")
|
|
|
+ db, err := chooseDB("SQL_DSN", false)
|
|
|
if err == nil {
|
|
|
if common.DebugEnabled {
|
|
|
db = db.Debug()
|
|
|
@@ -149,7 +177,7 @@ func InitDB() (err error) {
|
|
|
return nil
|
|
|
}
|
|
|
if common.UsingMySQL {
|
|
|
- _, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
|
|
+ //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
|
|
}
|
|
|
common.SysLog("database migration started")
|
|
|
err = migrateDB()
|
|
|
@@ -165,7 +193,7 @@ func InitLogDB() (err error) {
|
|
|
LOG_DB = DB
|
|
|
return
|
|
|
}
|
|
|
- db, err := chooseDB("LOG_SQL_DSN")
|
|
|
+ db, err := chooseDB("LOG_SQL_DSN", true)
|
|
|
if err == nil {
|
|
|
if common.DebugEnabled {
|
|
|
db = db.Debug()
|
|
|
@@ -198,54 +226,50 @@ func InitLogDB() (err error) {
|
|
|
}
|
|
|
|
|
|
func migrateDB() error {
|
|
|
- err := DB.AutoMigrate(&Channel{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- err = DB.AutoMigrate(&Token{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- err = DB.AutoMigrate(&User{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- err = DB.AutoMigrate(&Option{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- err = DB.AutoMigrate(&Redemption{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- err = DB.AutoMigrate(&Ability{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- err = DB.AutoMigrate(&Log{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- err = DB.AutoMigrate(&Midjourney{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- err = DB.AutoMigrate(&TopUp{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ errChan := make(chan error, 12) // Buffer size matches number of migrations
|
|
|
+
|
|
|
+ migrations := []struct {
|
|
|
+ model interface{}
|
|
|
+ name string
|
|
|
+ }{
|
|
|
+ {&Channel{}, "Channel"},
|
|
|
+ {&Token{}, "Token"},
|
|
|
+ {&User{}, "User"},
|
|
|
+ {&Option{}, "Option"},
|
|
|
+ {&Redemption{}, "Redemption"},
|
|
|
+ {&Ability{}, "Ability"},
|
|
|
+ {&Log{}, "Log"},
|
|
|
+ {&Midjourney{}, "Midjourney"},
|
|
|
+ {&TopUp{}, "TopUp"},
|
|
|
+ {&QuotaData{}, "QuotaData"},
|
|
|
+ {&Task{}, "Task"},
|
|
|
+ {&Setup{}, "Setup"},
|
|
|
}
|
|
|
- err = DB.AutoMigrate(&QuotaData{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+
|
|
|
+ for _, m := range migrations {
|
|
|
+ wg.Add(1)
|
|
|
+ go func(model interface{}, name string) {
|
|
|
+ defer wg.Done()
|
|
|
+ if err := DB.AutoMigrate(model); err != nil {
|
|
|
+ errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
|
|
|
+ }
|
|
|
+ }(m.model, m.name)
|
|
|
}
|
|
|
- err = DB.AutoMigrate(&Task{})
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+
|
|
|
+ // Wait for all migrations to complete
|
|
|
+ wg.Wait()
|
|
|
+ close(errChan)
|
|
|
+
|
|
|
+ // Check for any errors
|
|
|
+ for err := range errChan {
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
}
|
|
|
- err = DB.AutoMigrate(&Setup{})
|
|
|
+
|
|
|
common.SysLog("database migrated")
|
|
|
- //err = createRootAccountIfNeed()
|
|
|
- return err
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
func migrateLOGDB() error {
|