Explorar el Código

fix: fix postgresql support (#606)

* fix postgresql support

fixes #517

* fix: fix pg support

* chore: delete useless code

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Bryan hace 2 años
padre
commit
a398f35968
Se han modificado 9 ficheros con 48 adiciones y 22 borrados
  1. 0 3
      common/constants.go
  2. 6 0
      common/database.go
  3. 8 0
      common/utils.go
  4. 10 3
      model/ability.go
  5. 6 2
      model/cache.go
  6. 5 12
      model/channel.go
  7. 1 0
      model/main.go
  8. 6 1
      model/redemption.go
  9. 6 1
      model/user.go

+ 0 - 3
common/constants.go

@@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
 var DisplayInCurrencyEnabled = true
 var DisplayTokenStatEnabled = true
 
-var UsingSQLite = false
-
 // Any options with "Secret", "Token" in its key won't be return by GetOptions
 
 var SessionSecret = uuid.New().String()
-var SQLitePath = "one-api.db"
 
 var OptionMap map[string]string
 var OptionMapRWMutex sync.RWMutex

+ 6 - 0
common/database.go

@@ -0,0 +1,6 @@
+package common
+
+var UsingSQLite = false
+var UsingPostgreSQL = false
+
+var SQLitePath = "one-api.db"

+ 8 - 0
common/utils.go

@@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
 func MessageWithRequestId(message string, id string) string {
 	return fmt.Sprintf("%s (request id: %s)", message, id)
 }
+
+func String2Int(str string) int {
+	num, err := strconv.Atoi(str)
+	if err != nil {
+		return 0
+	}
+	return num
+}

+ 10 - 3
model/ability.go

@@ -15,10 +15,17 @@ type Ability struct {
 
 func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 	ability := Ability{}
+	groupCol := "`group`"
+	trueVal := "1"
+	if common.UsingPostgreSQL {
+		groupCol = `"group"`
+		trueVal = "true"
+	}
+
 	var err error = nil
-	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
-	channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
-	if common.UsingSQLite {
+	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
+	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+	if common.UsingSQLite || common.UsingPostgreSQL {
 		err = channelQuery.Order("RANDOM()").First(&ability).Error
 	} else {
 		err = channelQuery.Order("RAND()").First(&ability).Error

+ 6 - 2
model/cache.go

@@ -21,14 +21,18 @@ var (
 )
 
 func CacheGetTokenByKey(key string) (*Token, error) {
+	keyCol := "`key`"
+	if common.UsingPostgreSQL {
+		keyCol = `"key"`
+	}
 	var token Token
 	if !common.RedisEnabled {
-		err := DB.Where("`key` = ?", key).First(&token).Error
+		err := DB.Where(keyCol+" = ?", key).First(&token).Error
 		return &token, err
 	}
 	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
 	if err != nil {
-		err := DB.Where("`key` = ?", key).First(&token).Error
+		err := DB.Where(keyCol+" = ?", key).First(&token).Error
 		if err != nil {
 			return nil, err
 		}

+ 5 - 12
model/channel.go

@@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 }
 
 func SearchChannels(keyword string) (channels []*Channel, err error) {
-	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
+	keyCol := "`key`"
+	if common.UsingPostgreSQL {
+		keyCol = `"key"`
+	}
+	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
 	return channels, err
 }
 
@@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
 	return &channel, err
 }
 
-func GetRandomChannel() (*Channel, error) {
-	channel := Channel{}
-	var err error = nil
-	if common.UsingSQLite {
-		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
-	} else {
-		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
-	}
-	return &channel, err
-}
-
 func BatchInsertChannels(channels []Channel) error {
 	var err error
 	err = DB.Create(&channels).Error

+ 1 - 0
model/main.go

@@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
 		if strings.HasPrefix(dsn, "postgres://") {
 			// Use PostgreSQL
 			common.SysLog("using PostgreSQL as database")
+			common.UsingPostgreSQL = true
 			return gorm.Open(postgres.New(postgres.Config{
 				DSN:                  dsn,
 				PreferSimpleProtocol: true, // disables implicit prepared statement usage

+ 6 - 1
model/redemption.go

@@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
 	}
 	redemption := &Redemption{}
 
+	keyCol := "`key`"
+	if common.UsingPostgreSQL {
+		keyCol = `"key"`
+	}
+
 	err = DB.Transaction(func(tx *gorm.DB) error {
-		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
+		err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
 		if err != nil {
 			return errors.New("无效的兑换码")
 		}

+ 6 - 1
model/user.go

@@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
 }
 
 func GetUserGroup(id int) (group string, err error) {
-	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
+	groupCol := "`group`"
+	if common.UsingPostgreSQL {
+		groupCol = `"group"`
+	}
+
+	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
 	return group, err
 }