Jelajahi Sumber

feat: 错误内容脱敏

CaIon 7 bulan lalu
induk
melakukan
f7b284ad73
3 mengubah file dengan 119 tambahan dan 8 penghapusan
  1. 95 0
      common/str.go
  2. 1 2
      controller/relay.go
  3. 23 6
      types/error.go

+ 95 - 0
common/str.go

@@ -4,7 +4,10 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"math/rand"
+	"net/url"
+	"regexp"
 	"strconv"
+	"strings"
 	"unsafe"
 )
 
@@ -95,3 +98,95 @@ func GetJsonString(data any) string {
 	b, _ := json.Marshal(data)
 	return string(b)
 }
+
+// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
+// Example:
+// http://example.com -> http://***.com
+// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
+// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
+// 192.168.1.1 -> ***.***.***.***
+func MaskSensitiveInfo(str string) string {
+	// Mask URLs
+	urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
+	str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
+		u, err := url.Parse(urlStr)
+		if err != nil {
+			return urlStr
+		}
+
+		host := u.Host
+		if host == "" {
+			return urlStr
+		}
+
+		// Split host by dots
+		parts := strings.Split(host, ".")
+		if len(parts) < 2 {
+			// If less than 2 parts, just mask the whole host
+			return u.Scheme + "://***" + u.Path
+		}
+
+		// Keep the TLD (Top Level Domain) and mask the rest
+		var maskedHost string
+		if len(parts) == 2 {
+			// example.com -> ***.com
+			maskedHost = "***." + parts[len(parts)-1]
+		} else {
+			// Handle cases like sub.domain.co.uk or api.example.com
+			// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
+			lastPart := parts[len(parts)-1]
+			secondLastPart := parts[len(parts)-2]
+
+			if len(lastPart) == 2 && len(secondLastPart) <= 3 {
+				// Likely country code TLD like co.uk, com.cn
+				maskedHost = "***." + secondLastPart + "." + lastPart
+			} else {
+				// Regular TLD like .com, .org
+				maskedHost = "***." + lastPart
+			}
+		}
+
+		result := u.Scheme + "://" + maskedHost
+
+		// Mask path
+		if u.Path != "" && u.Path != "/" {
+			pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
+			maskedPathParts := make([]string, len(pathParts))
+			for i := range pathParts {
+				if pathParts[i] != "" {
+					maskedPathParts[i] = "***"
+				}
+			}
+			if len(maskedPathParts) > 0 {
+				result += "/" + strings.Join(maskedPathParts, "/")
+			}
+		} else if u.Path == "/" {
+			result += "/"
+		}
+
+		// Mask query parameters
+		if u.RawQuery != "" {
+			values, err := url.ParseQuery(u.RawQuery)
+			if err != nil {
+				// If can't parse query, just mask the whole query string
+				result += "?***"
+			} else {
+				maskedParams := make([]string, 0, len(values))
+				for key := range values {
+					maskedParams = append(maskedParams, key+"=***")
+				}
+				if len(maskedParams) > 0 {
+					result += "?" + strings.Join(maskedParams, "&")
+				}
+			}
+		}
+
+		return result
+	})
+
+	// Mask IP addresses
+	ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
+	str = ipPattern.ReplaceAllString(str, "***.***.***.***")
+
+	return str
+}

+ 1 - 2
controller/relay.go

@@ -62,8 +62,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
 		other["channel_id"] = channelId
 		other["channel_name"] = c.GetString("channel_name")
 		other["channel_type"] = c.GetInt("channel_type")
-
-		model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
+		model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
 	}
 
 	return err

+ 23 - 6
types/error.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"one-api/common"
 	"strings"
 )
 
@@ -107,19 +108,30 @@ func (e *NewAPIError) Error() string {
 	return e.Err.Error()
 }
 
+func (e *NewAPIError) MaskSensitiveError() string {
+	if e == nil {
+		return ""
+	}
+	if e.Err == nil {
+		return string(e.errorCode)
+	}
+	return common.MaskSensitiveInfo(e.Err.Error())
+}
+
 func (e *NewAPIError) SetMessage(message string) {
 	e.Err = errors.New(message)
 }
 
 func (e *NewAPIError) ToOpenAIError() OpenAIError {
+	var result OpenAIError
 	switch e.errorType {
 	case ErrorTypeOpenAIError:
 		if openAIError, ok := e.RelayError.(OpenAIError); ok {
-			return openAIError
+			result = openAIError
 		}
 	case ErrorTypeClaudeError:
 		if claudeError, ok := e.RelayError.(ClaudeError); ok {
-			return OpenAIError{
+			result = OpenAIError{
 				Message: e.Error(),
 				Type:    claudeError.Type,
 				Param:   "",
@@ -127,30 +139,35 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
 			}
 		}
 	}
-	return OpenAIError{
+	result = OpenAIError{
 		Message: e.Error(),
 		Type:    string(e.errorType),
 		Param:   "",
 		Code:    e.errorCode,
 	}
+	result.Message = common.MaskSensitiveInfo(result.Message)
+	return result
 }
 
 func (e *NewAPIError) ToClaudeError() ClaudeError {
+	var result ClaudeError
 	switch e.errorType {
 	case ErrorTypeOpenAIError:
 		openAIError := e.RelayError.(OpenAIError)
-		return ClaudeError{
+		result = ClaudeError{
 			Message: e.Error(),
 			Type:    fmt.Sprintf("%v", openAIError.Code),
 		}
 	case ErrorTypeClaudeError:
-		return e.RelayError.(ClaudeError)
+		result = e.RelayError.(ClaudeError)
 	default:
-		return ClaudeError{
+		result = ClaudeError{
 			Message: e.Error(),
 			Type:    string(e.errorType),
 		}
 	}
+	result.Message = common.MaskSensitiveInfo(result.Message)
+	return result
 }
 
 func NewError(err error, errorCode ErrorCode) *NewAPIError {