|
|
@@ -5,13 +5,13 @@ import (
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"one-api/model"
|
|
|
"strings"
|
|
|
-
|
|
|
- "github.com/gin-gonic/gin"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
@@ -20,12 +20,18 @@ const (
|
|
|
APITypePaLM
|
|
|
APITypeBaidu
|
|
|
APITypeZhipu
|
|
|
+ APITypeAli
|
|
|
+ APITypeXunfei
|
|
|
)
|
|
|
|
|
|
var httpClient *http.Client
|
|
|
+var impatientHTTPClient *http.Client
|
|
|
|
|
|
func init() {
|
|
|
httpClient = &http.Client{}
|
|
|
+ impatientHTTPClient = &http.Client{
|
|
|
+ Timeout: 5 * time.Second,
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
@@ -73,7 +79,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
// map model name
|
|
|
modelMapping := c.GetString("model_mapping")
|
|
|
isModelMapped := false
|
|
|
- if modelMapping != "" {
|
|
|
+ if modelMapping != "" && modelMapping != "{}" {
|
|
|
modelMap := make(map[string]string)
|
|
|
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
|
if err != nil {
|
|
|
@@ -94,34 +100,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
apiType = APITypePaLM
|
|
|
case common.ChannelTypeZhipu:
|
|
|
apiType = APITypeZhipu
|
|
|
+ case common.ChannelTypeAli:
|
|
|
+ apiType = APITypeAli
|
|
|
+ case common.ChannelTypeXunfei:
|
|
|
+ apiType = APITypeXunfei
|
|
|
}
|
|
|
isStable := c.GetBool("stable")
|
|
|
|
|
|
- //if common.NormalPrice == -1 && strings.HasPrefix(textRequest.Model, "gpt-4") {
|
|
|
- // nowUser, err := model.GetUserById(userId, false)
|
|
|
- // if err != nil {
|
|
|
- // return errorWrapper(err, "get_user_info_failed", http.StatusInternalServerError)
|
|
|
- // }
|
|
|
- // if nowUser.StableMode {
|
|
|
- // group = "svip"
|
|
|
- // isStable = true
|
|
|
- // ////stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
|
|
|
- // //userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64)
|
|
|
- // //if userMaxPrice < common.StablePrice {
|
|
|
- // // return errorWrapper(errors.New("当前低价通道不可用,稳定渠道价格为"+strconv.FormatFloat(common.StablePrice, 'f', -1, 64)+"R/刀"), "当前低价通道不可用", http.StatusInternalServerError)
|
|
|
- // //}
|
|
|
- // //
|
|
|
- // ////ratio = stableRatio * groupRatio
|
|
|
- // //channel, err := model.CacheGetRandomSatisfiedChannel("svip", textRequest.Model)
|
|
|
- // //if err != nil {
|
|
|
- // // message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", "svip", textRequest.Model)
|
|
|
- // // return errorWrapper(errors.New(message), "no_available_channel", http.StatusInternalServerError)
|
|
|
- // //}
|
|
|
- // //channelType = channel.Type
|
|
|
- // } else {
|
|
|
- // return errorWrapper(errors.New("当前低价通道不可用,请稍后再试,或者在后台开启稳定模式"), "当前低价通道不可用", http.StatusInternalServerError)
|
|
|
- // }
|
|
|
- //}
|
|
|
baseURL := common.ChannelBaseURLs[channelType]
|
|
|
requestURL := c.Request.URL.String()
|
|
|
if c.GetString("base_url") != "" {
|
|
|
@@ -162,10 +147,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
|
|
case "BLOOMZ-7B":
|
|
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
|
|
+ case "Embedding-V1":
|
|
|
+ fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
|
|
}
|
|
|
apiKey := c.Request.Header.Get("Authorization")
|
|
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
|
- fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
|
|
|
+ var err error
|
|
|
+ if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
|
|
|
+ return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ fullRequestURL += "?access_token=" + apiKey
|
|
|
case APITypePaLM:
|
|
|
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
|
|
if baseURL != "" {
|
|
|
@@ -180,6 +171,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
method = "sse-invoke"
|
|
|
}
|
|
|
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
|
|
+ case APITypeAli:
|
|
|
+ fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
|
|
}
|
|
|
var promptTokens int
|
|
|
var completionTokens int
|
|
|
@@ -195,7 +188,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
if textRequest.MaxTokens != 0 {
|
|
|
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
|
|
}
|
|
|
- //stableRatio := common.GetStableRatio(textRequest.Model)
|
|
|
modelRatio := common.GetModelRatio(textRequest.Model)
|
|
|
stableRatio := modelRatio
|
|
|
groupRatio := common.GetGroupRatio(group)
|
|
|
@@ -209,7 +201,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
if err != nil {
|
|
|
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
- if userQuota > 10*preConsumedQuota {
|
|
|
+ err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ if userQuota > 100*preConsumedQuota {
|
|
|
// in this case, we do not pre-consume quota
|
|
|
// because the user has enough quota
|
|
|
preConsumedQuota = 0
|
|
|
@@ -239,12 +235,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
}
|
|
|
requestBody = bytes.NewBuffer(jsonStr)
|
|
|
case APITypeBaidu:
|
|
|
- baiduRequest := requestOpenAI2Baidu(textRequest)
|
|
|
- jsonStr, err := json.Marshal(baiduRequest)
|
|
|
+ var jsonData []byte
|
|
|
+ var err error
|
|
|
+ switch relayMode {
|
|
|
+ case RelayModeEmbeddings:
|
|
|
+ baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
|
|
|
+ jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
|
|
+ default:
|
|
|
+ baiduRequest := requestOpenAI2Baidu(textRequest)
|
|
|
+ jsonData, err = json.Marshal(baiduRequest)
|
|
|
+ }
|
|
|
if err != nil {
|
|
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
- requestBody = bytes.NewBuffer(jsonStr)
|
|
|
+ requestBody = bytes.NewBuffer(jsonData)
|
|
|
case APITypePaLM:
|
|
|
palmRequest := requestOpenAI2PaLM(textRequest)
|
|
|
jsonStr, err := json.Marshal(palmRequest)
|
|
|
@@ -259,109 +263,114 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
requestBody = bytes.NewBuffer(jsonStr)
|
|
|
+ case APITypeAli:
|
|
|
+ aliRequest := requestOpenAI2Ali(textRequest)
|
|
|
+ jsonStr, err := json.Marshal(aliRequest)
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ requestBody = bytes.NewBuffer(jsonStr)
|
|
|
}
|
|
|
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
|
- if err != nil {
|
|
|
- return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
|
- }
|
|
|
- apiKey := c.Request.Header.Get("Authorization")
|
|
|
- apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
|
- switch apiType {
|
|
|
- case APITypeOpenAI:
|
|
|
- if channelType == common.ChannelTypeAzure {
|
|
|
- req.Header.Set("api-key", apiKey)
|
|
|
- } else {
|
|
|
- req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
|
+
|
|
|
+ var req *http.Request
|
|
|
+ var resp *http.Response
|
|
|
+ isStream := textRequest.Stream
|
|
|
+
|
|
|
+ if apiType != APITypeXunfei { // cause xunfei use websocket
|
|
|
+ req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
- case APITypeClaude:
|
|
|
- req.Header.Set("x-api-key", apiKey)
|
|
|
- anthropicVersion := c.Request.Header.Get("anthropic-version")
|
|
|
- if anthropicVersion == "" {
|
|
|
- anthropicVersion = "2023-06-01"
|
|
|
+ apiKey := c.Request.Header.Get("Authorization")
|
|
|
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
|
|
+ switch apiType {
|
|
|
+ case APITypeOpenAI:
|
|
|
+ if channelType == common.ChannelTypeAzure {
|
|
|
+ req.Header.Set("api-key", apiKey)
|
|
|
+ } else {
|
|
|
+ req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
|
|
+ }
|
|
|
+ case APITypeClaude:
|
|
|
+ req.Header.Set("x-api-key", apiKey)
|
|
|
+ anthropicVersion := c.Request.Header.Get("anthropic-version")
|
|
|
+ if anthropicVersion == "" {
|
|
|
+ anthropicVersion = "2023-06-01"
|
|
|
+ }
|
|
|
+ req.Header.Set("anthropic-version", anthropicVersion)
|
|
|
+ case APITypeZhipu:
|
|
|
+ token := getZhipuToken(apiKey)
|
|
|
+ req.Header.Set("Authorization", token)
|
|
|
+ case APITypeAli:
|
|
|
+ req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
+ if textRequest.Stream {
|
|
|
+ req.Header.Set("X-DashScope-SSE", "enable")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
|
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
|
+ //req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
|
|
+ resp, err = httpClient.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ err = req.Body.Close()
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ err = c.Request.Body.Close()
|
|
|
+ if err != nil {
|
|
|
+ return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
|
|
+
|
|
|
+ if resp.StatusCode != http.StatusOK {
|
|
|
+ return errorWrapper(
|
|
|
+ fmt.Errorf("bad status code: %d", resp.StatusCode), "bad_status_code", resp.StatusCode)
|
|
|
}
|
|
|
- req.Header.Set("anthropic-version", anthropicVersion)
|
|
|
- case APITypeZhipu:
|
|
|
- token := getZhipuToken(apiKey)
|
|
|
- req.Header.Set("Authorization", token)
|
|
|
- }
|
|
|
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
|
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
|
- //req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
|
|
- resp, err := httpClient.Do(req)
|
|
|
- if err != nil {
|
|
|
- return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
|
- }
|
|
|
- err = req.Body.Close()
|
|
|
- if err != nil {
|
|
|
- return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
|
- }
|
|
|
- err = c.Request.Body.Close()
|
|
|
- if err != nil {
|
|
|
- return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
+
|
|
|
var textResponse TextResponse
|
|
|
- isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
|
|
- var streamResponseText string
|
|
|
+ tokenName := c.GetString("token_name")
|
|
|
+ channelId := c.GetInt("channel_id")
|
|
|
|
|
|
defer func() {
|
|
|
- if consumeQuota {
|
|
|
- quota := 0
|
|
|
- completionRatio := 1.0
|
|
|
- if strings.HasPrefix(textRequest.Model, "gpt-3.5") {
|
|
|
- completionRatio = 1.333333
|
|
|
- }
|
|
|
- if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
|
|
- completionRatio = 2
|
|
|
- }
|
|
|
- if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
|
|
|
- completionTokens = countTokenText(streamResponseText, textRequest.Model)
|
|
|
- } else {
|
|
|
+ // c.Writer.Flush()
|
|
|
+ go func() {
|
|
|
+ if consumeQuota {
|
|
|
+ quota := 0
|
|
|
+ completionRatio := common.GetCompletionRatio(textRequest.Model)
|
|
|
promptTokens = textResponse.Usage.PromptTokens
|
|
|
completionTokens = textResponse.Usage.CompletionTokens
|
|
|
- if apiType == APITypeZhipu {
|
|
|
- // zhipu's API does not return prompt tokens & completion tokens
|
|
|
- promptTokens = textResponse.Usage.TotalTokens
|
|
|
+
|
|
|
+ quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
|
|
+ quota = int(float64(quota) * ratio)
|
|
|
+ if ratio != 0 && quota <= 0 {
|
|
|
+ quota = 1
|
|
|
}
|
|
|
- }
|
|
|
- quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
|
|
- quota = int(float64(quota) * ratio)
|
|
|
- if ratio != 0 && quota <= 0 {
|
|
|
- quota = 1
|
|
|
- }
|
|
|
- totalTokens := promptTokens + completionTokens
|
|
|
- if totalTokens == 0 {
|
|
|
- // in this case, must be some error happened
|
|
|
- // we cannot just return, because we may have to return the pre-consumed quota
|
|
|
- quota = 0
|
|
|
- }
|
|
|
- //if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
|
|
- // if quota < 5000 && quota != 0 {
|
|
|
- // quota = 5000
|
|
|
- // }
|
|
|
- //}
|
|
|
- quotaDelta := quota - preConsumedQuota
|
|
|
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
|
- if err != nil {
|
|
|
- common.SysError("error consuming token remain quota: " + err.Error())
|
|
|
- }
|
|
|
- err = model.CacheUpdateUserQuota(userId)
|
|
|
- if err != nil {
|
|
|
- common.SysError("error update user quota cache: " + err.Error())
|
|
|
- }
|
|
|
- if quota != 0 {
|
|
|
- tokenName := c.GetString("token_name")
|
|
|
- var logContent string
|
|
|
- if isStable {
|
|
|
- logContent = fmt.Sprintf("(稳定模式)模型倍率 %.2f,分组倍率 %.2f", stableRatio, groupRatio)
|
|
|
- } else {
|
|
|
- logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
|
+ totalTokens := promptTokens + completionTokens
|
|
|
+ if totalTokens == 0 {
|
|
|
+ // in this case, must be some error happened
|
|
|
+ // we cannot just return, because we may have to return the pre-consumed quota
|
|
|
+ quota = 0
|
|
|
+ }
|
|
|
+ quotaDelta := quota - preConsumedQuota
|
|
|
+ err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error consuming token remain quota: " + err.Error())
|
|
|
+ }
|
|
|
+ err = model.CacheUpdateUserQuota(userId)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error update user quota cache: " + err.Error())
|
|
|
+ }
|
|
|
+ if quota != 0 {
|
|
|
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
|
+ model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
|
|
+ model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
|
+
|
|
|
+ model.UpdateChannelUsedQuota(channelId, quota)
|
|
|
}
|
|
|
- model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
|
|
|
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
|
- channelId := c.GetInt("channel_id")
|
|
|
- model.UpdateChannelUsedQuota(channelId, quota)
|
|
|
}
|
|
|
- }
|
|
|
+ }()
|
|
|
}()
|
|
|
switch apiType {
|
|
|
case APITypeOpenAI:
|
|
|
@@ -370,10 +379,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- streamResponseText = responseText
|
|
|
+ textResponse.Usage.PromptTokens = promptTokens
|
|
|
+ textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
|
return nil
|
|
|
} else {
|
|
|
- err, usage := openaiHandler(c, resp, consumeQuota)
|
|
|
+ err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -388,7 +398,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- streamResponseText = responseText
|
|
|
+ textResponse.Usage.PromptTokens = promptTokens
|
|
|
+ textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
|
return nil
|
|
|
} else {
|
|
|
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
|
|
|
@@ -411,7 +422,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
}
|
|
|
return nil
|
|
|
} else {
|
|
|
- err, usage := baiduHandler(c, resp)
|
|
|
+ var err *OpenAIErrorWithStatusCode
|
|
|
+ var usage *Usage
|
|
|
+ switch relayMode {
|
|
|
+ case RelayModeEmbeddings:
|
|
|
+ err, usage = baiduEmbeddingHandler(c, resp)
|
|
|
+ default:
|
|
|
+ err, usage = baiduHandler(c, resp)
|
|
|
+ }
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -426,7 +444,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- streamResponseText = responseText
|
|
|
+ textResponse.Usage.PromptTokens = promptTokens
|
|
|
+ textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
|
|
return nil
|
|
|
} else {
|
|
|
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
|
|
@@ -447,6 +466,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
if usage != nil {
|
|
|
textResponse.Usage = *usage
|
|
|
}
|
|
|
+ // zhipu's API does not return prompt tokens & completion tokens
|
|
|
+ textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
|
|
return nil
|
|
|
} else {
|
|
|
err, usage := zhipuHandler(c, resp)
|
|
|
@@ -456,8 +477,49 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
if usage != nil {
|
|
|
textResponse.Usage = *usage
|
|
|
}
|
|
|
+ // zhipu's API does not return prompt tokens & completion tokens
|
|
|
+ textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
|
|
return nil
|
|
|
}
|
|
|
+ case APITypeAli:
|
|
|
+ if isStream {
|
|
|
+ err, usage := aliStreamHandler(c, resp)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if usage != nil {
|
|
|
+ textResponse.Usage = *usage
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ } else {
|
|
|
+ err, usage := aliHandler(c, resp)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if usage != nil {
|
|
|
+ textResponse.Usage = *usage
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ case APITypeXunfei:
|
|
|
+ if isStream {
|
|
|
+ auth := c.Request.Header.Get("Authorization")
|
|
|
+ auth = strings.TrimPrefix(auth, "Bearer ")
|
|
|
+ splits := strings.Split(auth, "|")
|
|
|
+ if len(splits) != 3 {
|
|
|
+ return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
|
|
+ }
|
|
|
+ err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if usage != nil {
|
|
|
+ textResponse.Usage = *usage
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ } else {
|
|
|
+ return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
|
|
+ }
|
|
|
default:
|
|
|
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
|
|
}
|