Seefs 1 місяць тому
батько
коміт
fd25b60e7a
1 змінених файлів з 21 додано та 0 видалено
  1. 21 0
      relay/channel/openai/relay-openai.go

+ 21 - 0
relay/channel/openai/relay-openai.go

@@ -1,6 +1,7 @@
 package openai
 
 import (
+	"bytes"
 	"fmt"
 	"io"
 	"net/http"
@@ -22,6 +23,19 @@ import (
 	"github.com/gorilla/websocket"
 )
 
+const xaiCSAMSafetyCheckType = "SAFETY_CHECK_TYPE_CSAM"
+
+func maybeMarkXaiCSAMRefusal(c *gin.Context, info *relaycommon.RelayInfo, responseBody []byte) bool {
+	if c == nil || info == nil || len(responseBody) == 0 {
+		return false
+	}
+	if !bytes.Contains(responseBody, []byte(xaiCSAMSafetyCheckType)) {
+		return false
+	}
+	common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "grok_safety_check_type=csam")
+	return true
+}
+
 func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
 	if data == "" {
 		return nil
@@ -201,6 +215,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	if err != nil {
 		return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
 	}
+	isXaiCSAMRefusal := maybeMarkXaiCSAMRefusal(c, info, responseBody)
 	if common.DebugEnabled {
 		println("upstream response body:", string(responseBody))
 	}
@@ -222,10 +237,16 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 
 	err = common.Unmarshal(responseBody, &simpleResponse)
 	if err != nil {
+		if isXaiCSAMRefusal {
+			return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError, types.ErrOptionWithSkipRetry())
+		}
 		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
 	}
 
 	if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
+		if isXaiCSAMRefusal {
+			return nil, types.WithOpenAIError(*oaiError, resp.StatusCode, types.ErrOptionWithSkipRetry())
+		}
 		return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
 	}