فهرست منبع

refactor: replace json.Marshal with common.Marshal for consistency and error handling

CaIon 6 ماه پیش
والد
کامیت
621d2b0b6a
2فایلهای تغییر یافته به همراه13 افزوده شده و 8 حذف شده
  1. 11 5
      relay/channel/openai/relay_responses.go
  2. 2 3
      relay/helper/common.go

+ 11 - 5
relay/channel/openai/relay_responses.go

@@ -73,9 +73,15 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
 			switch streamResponse.Type {
 			case "response.completed":
 				if streamResponse.Response.Usage != nil {
-					usage.PromptTokens = streamResponse.Response.Usage.InputTokens
-					usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
-					usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+					if streamResponse.Response.Usage.InputTokens != 0 {
+						usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+					}
+					if streamResponse.Response.Usage.OutputTokens != 0 {
+						usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
+					}
+					if streamResponse.Response.Usage.TotalTokens != 0 {
+						usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+					}
 					if streamResponse.Response.Usage.InputTokensDetails != nil {
 						usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
 					}
@@ -110,9 +116,9 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
 
 	if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
 		usage.PromptTokens = info.PromptTokens
-	} else {
-		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	}
 
+	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+
 	return usage, nil
 }

+ 2 - 3
relay/helper/common.go

@@ -1,7 +1,6 @@
 package helper
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"net/http"
@@ -42,7 +41,7 @@ func SetEventStreamHeaders(c *gin.Context) {
 }
 
 func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
-	jsonData, err := json.Marshal(resp)
+	jsonData, err := common.Marshal(resp)
 	if err != nil {
 		common.SysError("error marshalling stream response: " + err.Error())
 	} else {
@@ -104,7 +103,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
 }
 
 func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
-	jsonData, err := json.Marshal(object)
+	jsonData, err := common.Marshal(object)
 	if err != nil {
 		return fmt.Errorf("error marshalling object: %w", err)
 	}