Selaa lähdekoodia

Merge pull request #2550 from shikaiwei1/patch-2

Seefs 2 kuukautta sitten
vanhempi
commit
1519e97bc6
1 muutettua tiedostoa jossa 45 lisäystä ja 2 poistoa
  1. 45 2
      relay/channel/openai/relay-openai.go

+ 45 - 2
relay/channel/openai/relay-openai.go

@@ -186,7 +186,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
 		usage.CompletionTokens += toolCount * 7
 	}
 
-	applyUsagePostProcessing(info, usage, nil)
+	applyUsagePostProcessing(info, usage, common.StringToByteSlice(lastStreamData))
 
 	HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
 
@@ -596,7 +596,8 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
 		if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
 			usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
 		}
-	case constant.ChannelTypeZhipu_v4, constant.ChannelTypeMoonshot:
+	case constant.ChannelTypeZhipu_v4:
+		// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
 		if usage.PromptTokensDetails.CachedTokens == 0 {
 			if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
 				usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
@@ -606,6 +607,19 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
 				usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
 			}
 		}
+	case constant.ChannelTypeMoonshot:
+		// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
+		if usage.PromptTokensDetails.CachedTokens == 0 {
+			if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
+				usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
+			} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
+				usage.PromptTokensDetails.CachedTokens = cachedTokens
+			} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
+				usage.PromptTokensDetails.CachedTokens = cachedTokens
+			} else if usage.PromptCacheHitTokens > 0 {
+				usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
+			}
+		}
 	}
 }
 
@@ -639,3 +653,32 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
 	}
 	return 0, false
 }
+
+// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
+// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
+func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
+	if len(body) == 0 {
+		return 0, false
+	}
+
+	var payload struct {
+		Choices []struct {
+			Usage struct {
+				CachedTokens *int `json:"cached_tokens"`
+			} `json:"usage"`
+		} `json:"choices"`
+	}
+
+	if err := common.Unmarshal(body, &payload); err != nil {
+		return 0, false
+	}
+
+	// 遍历choices查找cached_tokens
+	for _, choice := range payload.Choices {
+		if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
+			return *choice.Usage.CachedTokens, true
+		}
+	}
+
+	return 0, false
+}