Просмотр исходного кода

fix: fix baidu & ali's quota calculation (#444)

* 修复阿里计费问题

* 修复百度计费问题
glzjin 2 лет назад
Родитель
Сommit
dfaa0183b7
2 измененных файлов с 10 добавлено и 6 удалено
  1. 5 3
      controller/relay-ali.go
  2. 5 3
      controller/relay-baidu.go

+ 5 - 3
controller/relay-ali.go

@@ -177,9 +177,11 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
-			usage.PromptTokens += aliResponse.Usage.InputTokens
-			usage.CompletionTokens += aliResponse.Usage.OutputTokens
-			usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
+			if aliResponse.Usage.OutputTokens != 0 {
+				usage.PromptTokens = aliResponse.Usage.InputTokens
+				usage.CompletionTokens = aliResponse.Usage.OutputTokens
+				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
+			}
 			response := streamResponseAli2OpenAI(&aliResponse)
 			response := streamResponseAli2OpenAI(&aliResponse)
 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
 			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
 			lastResponseText = aliResponse.Output.Text
 			lastResponseText = aliResponse.Output.Text

+ 5 - 3
controller/relay-baidu.go

@@ -215,9 +215,11 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				common.SysError("error unmarshalling stream response: " + err.Error())
 				return true
 				return true
 			}
 			}
-			usage.PromptTokens += baiduResponse.Usage.PromptTokens
-			usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
-			usage.TotalTokens += baiduResponse.Usage.TotalTokens
+			if baiduResponse.Usage.TotalTokens != 0 {
+				usage.TotalTokens = baiduResponse.Usage.TotalTokens
+				usage.PromptTokens = baiduResponse.Usage.PromptTokens
+				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
+			}
 			response := streamResponseBaidu2OpenAI(&baiduResponse)
 			response := streamResponseBaidu2OpenAI(&baiduResponse)
 			jsonResponse, err := json.Marshal(response)
 			jsonResponse, err := json.Marshal(response)
 			if err != nil {
 			if err != nil {