Browse Source

feat: 完善函数计费

CaIon 1 năm trước cách đây
mục cha
commit
2841669246

+ 11 - 0
dto/text_request.go

@@ -32,6 +32,17 @@ type GeneralOpenAIRequest struct {
 	TopLogProbs      int             `json:"top_logprobs,omitempty"`
 }
 
+type OpenAITools struct {
+	Type     string         `json:"type"`
+	Function OpenAIFunction `json:"function"`
+}
+
+type OpenAIFunction struct {
+	Description string `json:"description,omitempty"`
+	Name        string `json:"name"`
+	Parameters  any    `json:"parameters,omitempty"`
+}
+
 func (r GeneralOpenAIRequest) ParseInput() []string {
 	if r.Input == nil {
 		return nil

+ 1 - 1
relay/channel/ollama/adaptor.go

@@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
-		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
+		err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 		if info.RelayMode == relayconstant.RelayModeEmbeddings {

+ 3 - 1
relay/channel/openai/adaptor.go

@@ -72,8 +72,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
-		err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
+		var toolCount int
+		err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage.CompletionTokens += toolCount * 7
 	} else {
 		err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 10 - 3
relay/channel/openai/relay-openai.go

@@ -16,9 +16,10 @@ import (
 	"time"
 )
 
-func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
+func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
 	//checkSensitive := constant.ShouldCheckCompletionSensitive()
 	var responseTextBuilder strings.Builder
+	toolCount := 0
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 		if atEOF && len(data) == 0 {
@@ -69,6 +70,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 						for _, choice := range streamResponse.Choices {
 							responseTextBuilder.WriteString(choice.Delta.Content)
 							if choice.Delta.ToolCalls != nil {
+								if len(choice.Delta.ToolCalls) > toolCount {
+									toolCount = len(choice.Delta.ToolCalls)
+								}
 								for _, tool := range choice.Delta.ToolCalls {
 									responseTextBuilder.WriteString(tool.Function.Name)
 									responseTextBuilder.WriteString(tool.Function.Arguments)
@@ -82,6 +86,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 					for _, choice := range streamResponse.Choices {
 						responseTextBuilder.WriteString(choice.Delta.Content)
 						if choice.Delta.ToolCalls != nil {
+							if len(choice.Delta.ToolCalls) > toolCount {
+								toolCount = len(choice.Delta.ToolCalls)
+							}
 							for _, tool := range choice.Delta.ToolCalls {
 								responseTextBuilder.WriteString(tool.Function.Name)
 								responseTextBuilder.WriteString(tool.Function.Arguments)
@@ -135,10 +142,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 	})
 	err := resp.Body.Close()
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
 	}
 	wg.Wait()
-	return nil, responseTextBuilder.String()
+	return nil, responseTextBuilder.String(), toolCount
 }
 
 func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

+ 1 - 1
relay/channel/perplexity/adaptor.go

@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
-		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
+		err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

+ 3 - 1
relay/channel/zhipu_4v/adaptor.go

@@ -47,8 +47,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
-		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
+		var toolCount int
+		err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+		usage.CompletionTokens += toolCount * 7
 	} else {
 		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 1 - 1
relay/relay-text.go

@@ -189,7 +189,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
 	checkSensitive := constant.ShouldCheckPromptSensitive()
 	switch info.RelayMode {
 	case relayconstant.RelayModeChatCompletions:
-		promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
+		promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
 	case relayconstant.RelayModeCompletions:
 		promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
 	case relayconstant.RelayModeModerations:

+ 35 - 0
service/token_counter.go

@@ -116,6 +116,41 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
 	return tiles*170 + 85, nil
 }
 
+func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
+	tkm := 0
+	msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
+	if err != nil {
+		return 0, err, b
+	}
+	tkm += msgTokens
+	if request.Tools != nil {
+		toolsData, _ := json.Marshal(request.Tools)
+		var openaiTools []dto.OpenAITools
+		err := json.Unmarshal(toolsData, &openaiTools)
+		if err != nil {
+			return 0, errors.New(fmt.Sprintf("count tools token fail: %s", err.Error())), false
+		}
+		countStr := ""
+		for _, tool := range openaiTools {
+			countStr = tool.Function.Name
+			if tool.Function.Description != "" {
+				countStr += tool.Function.Description
+			}
+			if tool.Function.Parameters != nil {
+				countStr += fmt.Sprintf("%v", tool.Function.Parameters)
+			}
+		}
+		toolTokens, err, _ := CountTokenInput(countStr, model, false)
+		if err != nil {
+			return 0, err, false
+		}
+		tkm += 8
+		tkm += toolTokens
+	}
+
+	return tkm, nil, false
+}
+
 func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
 	//recover when panic
 	tokenEncoder := getTokenEncoder(model)