Ver código fonte

refactor: enhance error handling and masking for model not found scenarios

CaIon 6 meses atrás
pai
commit
44e9b02b3f
4 arquivos alterados com 49 adições e 5 exclusões
  1. 29 1
      common/str.go
  2. 2 2
      middleware/distributor.go
  3. 6 1
      middleware/utils.go
  4. 12 1
      types/error.go

+ 29 - 1
common/str.go

@@ -99,12 +99,15 @@ func GetJsonString(data any) string {
 	return string(b)
 }
 
-// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
+// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
 // Example:
 // http://example.com -> http://***.com
 // https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
 // https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
 // 192.168.1.1 -> ***.***.***.***
+// openai.com -> ***.com
+// www.openai.com -> ***.***.com
+// api.openai.com -> ***.***.com
 func MaskSensitiveInfo(str string) string {
 	// Mask URLs
 	urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
@@ -184,6 +187,31 @@ func MaskSensitiveInfo(str string) string {
 		return result
 	})
 
+	// Mask domain names without protocol (like openai.com, www.openai.com)
+	domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
+	str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
+		// Skip if it's already been processed as part of a URL
+		if strings.Contains(str, "://"+domain) {
+			return domain
+		}
+
+		parts := strings.Split(domain, ".")
+		if len(parts) < 2 {
+			return domain
+		}
+
+		// Handle different domain patterns
+		if len(parts) == 2 {
+			// openai.com -> ***.com
+			return "***." + parts[1]
+		} else {
+			// www.openai.com -> ***.***.com
+			// api.openai.com -> ***.***.com
+			lastPart := parts[len(parts)-1]
+			return "***.***." + lastPart
+		}
+	})
+
 	// Mask IP addresses
 	ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
 	str = ipPattern.ReplaceAllString(str, "***.***.***.***")

+ 2 - 2
middleware/distributor.go

@@ -107,11 +107,11 @@ func Distribute() func(c *gin.Context) {
 					//	common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 					//	message = "数据库一致性已被破坏,请联系管理员"
 					//}
-					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
+					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
 					return
 				}
 				if channel == nil {
-					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model))
+					abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
 					return
 				}
 			}

+ 6 - 1
middleware/utils.go

@@ -7,12 +7,17 @@ import (
 	"one-api/logger"
 )
 
-func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
+func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
+	codeStr := ""
+	if len(code) > 0 {
+		codeStr = code[0]
+	}
 	userId := c.GetInt("id")
 	c.JSON(statusCode, gin.H{
 		"error": gin.H{
 			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
 			"type":    "new_api_error",
+			"code":    codeStr,
 		},
 	})
 	c.Abort()

+ 12 - 1
types/error.go

@@ -67,6 +67,7 @@ const (
 	ErrorCodeBadResponseBody        ErrorCode = "bad_response_body"
 	ErrorCodeEmptyResponse          ErrorCode = "empty_response"
 	ErrorCodeAwsInvokeError         ErrorCode = "aws_invoke_error"
+	ErrorCodeModelNotFound          ErrorCode = "model_not_found"
 
 	// sql error
 	ErrorCodeQueryDataError  ErrorCode = "query_data_error"
@@ -119,7 +120,17 @@ func (e *NewAPIError) MaskSensitiveError() string {
 	if e.Err == nil {
 		return string(e.errorCode)
 	}
-	return common.MaskSensitiveInfo(e.Err.Error())
+	errStr := e.Err.Error()
+	if e.StatusCode == http.StatusServiceUnavailable {
+		if e.errorCode == ErrorCodeModelNotFound {
+			errStr = "上游分组模型服务不可用,请稍后再试"
+		} else {
+			if strings.Contains(errStr, "分组") || strings.Contains(errStr, "渠道") {
+				errStr = "上游分组模型服务不可用,请稍后再试"
+			}
+		}
+	}
+	return common.MaskSensitiveInfo(errStr)
 }
 
 func (e *NewAPIError) SetMessage(message string) {