CaIon 1 год назад
Родитель
Сommit
6b71db7ce2
5 измененных файлов с 70 добавлено и 1 удалено
  1. 1 0
      middleware/distributor.go
  2. 8 0
      model/channel.go
  3. 7 1
      relay/relay-text.go
  4. 19 0
      service/error.go
  5. 35 0
      web/src/pages/Channel/EditChannel.js

+ 1 - 0
middleware/distributor.go

@@ -177,6 +177,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 	}
 	c.Set("auto_ban", ban)
 	c.Set("model_mapping", channel.GetModelMapping())
+	c.Set("status_code_mapping", channel.GetStatusCodeMapping())
 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 	c.Set("base_url", channel.GetBaseURL())
 	// TODO: api_version统一

+ 8 - 0
model/channel.go

@@ -25,6 +25,7 @@ type Channel struct {
 	Group              string  `json:"group" gorm:"type:varchar(64);default:'default'"`
 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
+	StatusCodeMapping  *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
 	AutoBan            *int    `json:"auto_ban" gorm:"default:1"`
 }
@@ -153,6 +154,13 @@ func (channel *Channel) GetModelMapping() string {
 	return *channel.ModelMapping
 }
 
+func (channel *Channel) GetStatusCodeMapping() string {
+	if channel.StatusCodeMapping == nil {
+		return ""
+	}
+	return *channel.StatusCodeMapping
+}
+
 func (channel *Channel) Insert() error {
 	var err error
 	err = DB.Create(channel).Error

+ 7 - 1
relay/relay-text.go

@@ -154,6 +154,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 		requestBody = bytes.NewBuffer(jsonData)
 	}
 
+	statusCodeMappingStr := c.GetString("status_code_mapping")
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
@@ -162,12 +163,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 
 	if resp.StatusCode != http.StatusOK {
 		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
-		return service.RelayErrorHandler(resp)
+		openaiErr := service.RelayErrorHandler(resp)
+		// reset status code 重置状态码
+		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+		return openaiErr
 	}
 
 	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 	if openaiErr != nil {
 		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+		// reset status code 重置状态码
+		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
 		return openaiErr
 	}
 	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)

+ 19 - 0
service/error.go

@@ -86,3 +86,22 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
 	}
 	return
 }
+
+func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) {
+	if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
+		return
+	}
+	statusCodeMapping := make(map[string]string)
+	err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
+	if err != nil {
+		return
+	}
+	if openaiErr.StatusCode == http.StatusOK {
+		return
+	}
+	codeStr := strconv.Itoa(openaiErr.StatusCode)
+	if _, ok := statusCodeMapping[codeStr]; ok {
+		intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
+		openaiErr.StatusCode = intCode
+	}
+}

+ 35 - 0
web/src/pages/Channel/EditChannel.js

@@ -29,6 +29,10 @@ const MODEL_MAPPING_EXAMPLE = {
   'gpt-4-32k-0314': 'gpt-4-32k',
 };
 
+const STATUS_CODE_MAPPING_EXAMPLE = {
+  400: '500',
+};
+
 function type2secretPrompt(type) {
   // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
   switch (type) {
@@ -61,6 +65,7 @@ const EditChannel = (props) => {
     base_url: '',
     other: '',
     model_mapping: '',
+    status_code_mapping: '',
     models: [],
     auto_ban: 1,
     test_model: '',
@@ -629,6 +634,36 @@ const EditChannel = (props) => {
           >
             填入模板
           </Typography.Text>
+          <div style={{ marginTop: 10 }}>
+            <Typography.Text strong>
+              状态码复写(仅影响本地判断,不修改返回到上游的状态码):
+            </Typography.Text>
+          </div>
+          <TextArea
+            placeholder={`此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:\n${JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2)}`}
+            name='status_code_mapping'
+            onChange={(value) => {
+              handleInputChange('status_code_mapping', value);
+            }}
+            autosize
+            value={inputs.status_code_mapping}
+            autoComplete='new-password'
+          />
+          <Typography.Text
+            style={{
+              color: 'rgba(var(--semi-blue-5), 1)',
+              userSelect: 'none',
+              cursor: 'pointer',
+            }}
+            onClick={() => {
+              handleInputChange(
+                'status_code_mapping',
+                JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2),
+              );
+            }}
+          >
+            填入模板
+          </Typography.Text>
           <div style={{ marginTop: 10 }}>
             <Typography.Text strong>密钥:</Typography.Text>
           </div>