Просмотр исходного кода

Merge branch 'songquanpeng:main' into main

Calcium-Ion 2 лет назад
Родитель
Сommit
7688e9f9dd

+ 4 - 0
README.md

@@ -370,6 +370,10 @@ graph LR
 13. 请求频率限制:
 13. 请求频率限制:
     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
     + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
     + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
+14. 编码器缓存设置:
+    + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+    + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
+15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
 
 
 ### 命令行参数
 ### 命令行参数
 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

+ 2 - 3
common/constants.go

@@ -25,12 +25,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
 var DisplayInCurrencyEnabled = true
 var DisplayInCurrencyEnabled = true
 var DisplayTokenStatEnabled = true
 var DisplayTokenStatEnabled = true
 
 
-var UsingSQLite = false
-
 // Any options with "Secret", "Token" in its key won't be return by GetOptions
 // Any options with "Secret", "Token" in its key won't be return by GetOptions
 
 
 var SessionSecret = uuid.New().String()
 var SessionSecret = uuid.New().String()
-var SQLitePath = "one-api.db"
 
 
 var OptionMap map[string]string
 var OptionMap map[string]string
 var OptionMapRWMutex sync.RWMutex
 var OptionMapRWMutex sync.RWMutex
@@ -102,6 +99,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
 var BatchUpdateEnabled = false
 var BatchUpdateEnabled = false
 var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 
 
+var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
+
 const (
 const (
 	RequestIdKey = "X-Oneapi-Request-Id"
 	RequestIdKey = "X-Oneapi-Request-Id"
 )
 )

+ 6 - 0
common/database.go

@@ -0,0 +1,6 @@
+package common
+
+var UsingSQLite = false
+var UsingPostgreSQL = false
+
+var SQLitePath = "one-api.db"

+ 1 - 0
common/model-ratio.go

@@ -46,6 +46,7 @@ var ModelRatio = map[string]float64{
 	"claude-2":                  5.51,   // $11.02 / 1M tokens
 	"claude-2":                  5.51,   // $11.02 / 1M tokens
 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens
 	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens
 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens
 	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens
+	"ERNIE-Bot-4":               8.572,  // ¥0.12 / 1k tokens
 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens
 	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens
 	"PaLM-2":                    1,
 	"PaLM-2":                    1,
 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens

+ 8 - 0
common/utils.go

@@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
 func MessageWithRequestId(message string, id string) string {
 func MessageWithRequestId(message string, id string) string {
 	return fmt.Sprintf("%s (request id: %s)", message, id)
 	return fmt.Sprintf("%s (request id: %s)", message, id)
 }
 }
+
+func String2Int(str string) int {
+	num, err := strconv.Atoi(str)
+	if err != nil {
+		return 0
+	}
+	return num
+}

+ 2 - 1
controller/channel-test.go

@@ -5,13 +5,14 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/model"
 	"one-api/model"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
 	"time"
 	"time"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
 func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {

+ 9 - 0
controller/model.go

@@ -306,6 +306,15 @@ func init() {
 			Root:       "ERNIE-Bot-turbo",
 			Root:       "ERNIE-Bot-turbo",
 			Parent:     nil,
 			Parent:     nil,
 		},
 		},
+		{
+			Id:         "ERNIE-Bot-4",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "baidu",
+			Permission: permission,
+			Root:       "ERNIE-Bot-4",
+			Parent:     nil,
+		},
 		{
 		{
 			Id:         "Embedding-V1",
 			Id:         "Embedding-V1",
 			Object:     "model",
 			Object:     "model",

+ 2 - 4
controller/relay-audio.go

@@ -6,12 +6,11 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/model"
 	"one-api/model"
-
-	"github.com/gin-gonic/gin"
 )
 )
 
 
 func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -66,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 
 	baseURL := common.ChannelBaseURLs[channelType]
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
 	requestURL := c.Request.URL.String()
-
 	if c.GetString("base_url") != "" {
 	if c.GetString("base_url") != "" {
 		baseURL = c.GetString("base_url")
 		baseURL = c.GetString("base_url")
 	}
 	}
 
 
-	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 	requestBody := c.Request.Body
 	requestBody := c.Request.Body
 
 
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)

+ 2 - 7
controller/relay-image.go

@@ -6,12 +6,11 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/model"
 	"one-api/model"
-
-	"github.com/gin-gonic/gin"
 )
 )
 
 
 func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			isModelMapped = true
 			isModelMapped = true
 		}
 		}
 	}
 	}
-
 	baseURL := common.ChannelBaseURLs[channelType]
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
 	requestURL := c.Request.URL.String()
-
 	if c.GetString("base_url") != "" {
 	if c.GetString("base_url") != "" {
 		baseURL = c.GetString("base_url")
 		baseURL = c.GetString("base_url")
 	}
 	}
-
-	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-
+	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 	var requestBody io.Reader
 	var requestBody io.Reader
 	if isModelMapped {
 	if isModelMapped {
 		jsonStr, err := json.Marshal(imageRequest)
 		jsonStr, err := json.Marshal(imageRequest)

+ 16 - 8
controller/relay-text.go

@@ -6,13 +6,14 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
 	"one-api/model"
 	"one-api/model"
 	"strings"
 	"strings"
 	"time"
 	"time"
+
+	"github.com/gin-gonic/gin"
 )
 )
 
 
 const (
 const (
@@ -31,7 +32,14 @@ var httpClient *http.Client
 var impatientHTTPClient *http.Client
 var impatientHTTPClient *http.Client
 
 
 func init() {
 func init() {
-	httpClient = &http.Client{}
+	if common.RelayTimeout == 0 {
+		httpClient = &http.Client{}
+	} else {
+		httpClient = &http.Client{
+			Timeout: time.Duration(common.RelayTimeout) * time.Second,
+		}
+	}
+
 	impatientHTTPClient = &http.Client{
 	impatientHTTPClient = &http.Client{
 		Timeout: 5 * time.Second,
 		Timeout: 5 * time.Second,
 	}
 	}
@@ -118,12 +126,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	if c.GetString("base_url") != "" {
 	if c.GetString("base_url") != "" {
 		baseURL = c.GetString("base_url")
 		baseURL = c.GetString("base_url")
 	}
 	}
-	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-	if channelType == common.ChannelTypeOpenAI {
-		if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
-			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
-		}
-	}
+	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
 	switch apiType {
 	switch apiType {
 	case APITypeOpenAI:
 	case APITypeOpenAI:
 		if channelType == common.ChannelTypeAzure {
 		if channelType == common.ChannelTypeAzure {
@@ -156,6 +159,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
 		case "ERNIE-Bot-turbo":
 		case "ERNIE-Bot-turbo":
 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
+		case "ERNIE-Bot-4":
+			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
 		case "BLOOMZ-7B":
 		case "BLOOMZ-7B":
 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
 			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
 		case "Embedding-V1":
 		case "Embedding-V1":
@@ -373,6 +378,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 		}
 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+		if isStream && c.Request.Header.Get("Accept") == "" {
+			req.Header.Set("Accept", "text/event-stream")
+		}
 		//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 		//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
 		resp, err = httpClient.Do(req)
 		resp, err = httpClient.Do(req)
 		if err != nil {
 		if err != nil {

+ 10 - 0
controller/relay-utils.go

@@ -187,3 +187,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
 	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
 	return
 	return
 }
 }
+
+func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+	if channelType == common.ChannelTypeOpenAI {
+		if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
+		}
+	}
+	return fullRequestURL
+}

+ 2 - 2
controller/relay-xunfei.go

@@ -298,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
 		common.SysLog("api_version not found, use default: " + apiVersion)
 		common.SysLog("api_version not found, use default: " + apiVersion)
 	}
 	}
 	domain := "general"
 	domain := "general"
-	if apiVersion == "v2.1" {
-		domain = "generalv2"
+	if apiVersion != "v1.1" {
+		domain += strings.Split(apiVersion, ".")[0]
 	}
 	}
 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
 	return domain, authUrl
 	return domain, authUrl

+ 10 - 3
model/ability.go

@@ -15,10 +15,17 @@ type Ability struct {
 
 
 func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 	ability := Ability{}
 	ability := Ability{}
+	groupCol := "`group`"
+	trueVal := "1"
+	if common.UsingPostgreSQL {
+		groupCol = `"group"`
+		trueVal = "true"
+	}
+
 	var err error = nil
 	var err error = nil
-	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
-	channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
-	if common.UsingSQLite {
+	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
+	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+	if common.UsingSQLite || common.UsingPostgreSQL {
 		err = channelQuery.Order("RANDOM()").First(&ability).Error
 		err = channelQuery.Order("RANDOM()").First(&ability).Error
 	} else {
 	} else {
 		err = channelQuery.Order("RAND()").First(&ability).Error
 		err = channelQuery.Order("RAND()").First(&ability).Error

+ 6 - 2
model/cache.go

@@ -21,14 +21,18 @@ var (
 )
 )
 
 
 func CacheGetTokenByKey(key string) (*Token, error) {
 func CacheGetTokenByKey(key string) (*Token, error) {
+	keyCol := "`key`"
+	if common.UsingPostgreSQL {
+		keyCol = `"key"`
+	}
 	var token Token
 	var token Token
 	if !common.RedisEnabled {
 	if !common.RedisEnabled {
-		err := DB.Where("`key` = ?", key).First(&token).Error
+		err := DB.Where(keyCol+" = ?", key).First(&token).Error
 		return &token, err
 		return &token, err
 	}
 	}
 	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
 	tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
 	if err != nil {
 	if err != nil {
-		err := DB.Where("`key` = ?", key).First(&token).Error
+		err := DB.Where(keyCol+" = ?", key).First(&token).Error
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}

+ 5 - 12
model/channel.go

@@ -40,7 +40,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
 }
 }
 
 
 func SearchChannels(keyword string) (channels []*Channel, err error) {
 func SearchChannels(keyword string) (channels []*Channel, err error) {
-	err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
+	keyCol := "`key`"
+	if common.UsingPostgreSQL {
+		keyCol = `"key"`
+	}
+	err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
 	return channels, err
 	return channels, err
 }
 }
 
 
@@ -55,17 +59,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
 	return &channel, err
 	return &channel, err
 }
 }
 
 
-func GetRandomChannel() (*Channel, error) {
-	channel := Channel{}
-	var err error = nil
-	if common.UsingSQLite {
-		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
-	} else {
-		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
-	}
-	return &channel, err
-}
-
 func BatchInsertChannels(channels []Channel) error {
 func BatchInsertChannels(channels []Channel) error {
 	var err error
 	var err error
 	err = DB.Create(&channels).Error
 	err = DB.Create(&channels).Error

+ 1 - 0
model/main.go

@@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
 		if strings.HasPrefix(dsn, "postgres://") {
 		if strings.HasPrefix(dsn, "postgres://") {
 			// Use PostgreSQL
 			// Use PostgreSQL
 			common.SysLog("using PostgreSQL as database")
 			common.SysLog("using PostgreSQL as database")
+			common.UsingPostgreSQL = true
 			return gorm.Open(postgres.New(postgres.Config{
 			return gorm.Open(postgres.New(postgres.Config{
 				DSN:                  dsn,
 				DSN:                  dsn,
 				PreferSimpleProtocol: true, // disables implicit prepared statement usage
 				PreferSimpleProtocol: true, // disables implicit prepared statement usage

+ 6 - 1
model/redemption.go

@@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
 	}
 	}
 	redemption := &Redemption{}
 	redemption := &Redemption{}
 
 
+	keyCol := "`key`"
+	if common.UsingPostgreSQL {
+		keyCol = `"key"`
+	}
+
 	err = DB.Transaction(func(tx *gorm.DB) error {
 	err = DB.Transaction(func(tx *gorm.DB) error {
-		err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
+		err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
 		if err != nil {
 		if err != nil {
 			return errors.New("无效的兑换码")
 			return errors.New("无效的兑换码")
 		}
 		}

+ 6 - 1
model/user.go

@@ -269,7 +269,12 @@ func GetUserEmail(id int) (email string, err error) {
 }
 }
 
 
 func GetUserGroup(id int) (group string, err error) {
 func GetUserGroup(id int) (group string, err error) {
-	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
+	groupCol := "`group`"
+	if common.UsingPostgreSQL {
+		groupCol = `"group"`
+	}
+
+	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
 	return group, err
 	return group, err
 }
 }
 
 

+ 1 - 1
web/src/pages/Channel/EditChannel.js

@@ -70,7 +70,7 @@ const EditChannel = () => {
           localModels = ['PaLM-2'];
           localModels = ['PaLM-2'];
           break;
           break;
         case 15:
         case 15:
-          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
+          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
           break;
           break;
         case 17:
         case 17:
           localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
           localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];