Преглед изворни кода

feat: support model remap now

JustSong пре 2 година
родитељ
комит
0941e294bf
5 измењених фајлова са 57 додато и 12 уклоњено
  1. 25 1
      controller/relay-text.go
  2. 10 10
      controller/relay.go
  3. 1 0
      middleware/distributor.go
  4. 1 0
      model/channel.go
  5. 20 1
      web/src/pages/Channel/EditChannel.js

+ 25 - 1
controller/relay-text.go

@@ -53,6 +53,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
 		}
 	}
+	// map model name
+	modelMapping := c.GetString("model_mapping")
+	isModelMapped := false
+	if modelMapping != "" {
+		modelMap := make(map[string]string)
+		err := json.Unmarshal([]byte(modelMapping), &modelMap)
+		if err != nil {
+			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+		}
+		if modelMap[textRequest.Model] != "" {
+			textRequest.Model = modelMap[textRequest.Model]
+			isModelMapped = true
+		}
+	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
 	if c.GetString("base_url") != "" {
@@ -114,7 +128,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 		}
 	}
-	req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
+	var requestBody io.Reader
+	if isModelMapped {
+		jsonStr, err := json.Marshal(textRequest)
+		if err != nil {
+			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonStr)
+	} else {
+		requestBody = c.Request.Body
+	}
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	if err != nil {
 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 	}

+ 10 - 10
controller/relay.go

@@ -27,16 +27,16 @@ const (
 // https://platform.openai.com/docs/api-reference/chat
 
 type GeneralOpenAIRequest struct {
-	Model       string    `json:"model"`
-	Messages    []Message `json:"messages"`
-	Prompt      any       `json:"prompt"`
-	Stream      bool      `json:"stream"`
-	MaxTokens   int       `json:"max_tokens"`
-	Temperature float64   `json:"temperature"`
-	TopP        float64   `json:"top_p"`
-	N           int       `json:"n"`
-	Input       any       `json:"input"`
-	Instruction string    `json:"instruction"`
+	Model       string    `json:"model,omitempty"`
+	Messages    []Message `json:"messages,omitempty"`
+	Prompt      any       `json:"prompt,omitempty"`
+	Stream      bool      `json:"stream,omitempty"`
+	MaxTokens   int       `json:"max_tokens,omitempty"`
+	Temperature float64   `json:"temperature,omitempty"`
+	TopP        float64   `json:"top_p,omitempty"`
+	N           int       `json:"n,omitempty"`
+	Input       any       `json:"input,omitempty"`
+	Instruction string    `json:"instruction,omitempty"`
 }
 
 type ChatRequest struct {

+ 1 - 0
middleware/distributor.go

@@ -88,6 +88,7 @@ func Distribute() func(c *gin.Context) {
 		c.Set("channel", channel.Type)
 		c.Set("channel_id", channel.Id)
 		c.Set("channel_name", channel.Name)
+		c.Set("model_mapping", channel.ModelMapping)
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		c.Set("base_url", channel.BaseURL)
 		if channel.Type == common.ChannelTypeAzure {

+ 1 - 0
model/channel.go

@@ -22,6 +22,7 @@ type Channel struct {
 	Models             string  `json:"models"`
 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
+	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 }
 
 func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

+ 20 - 1
web/src/pages/Channel/EditChannel.js

@@ -1,7 +1,7 @@
 import React, { useEffect, useState } from 'react';
 import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
 import { useParams } from 'react-router-dom';
-import { API, showError, showInfo, showSuccess } from '../../helpers';
+import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
 import { CHANNEL_OPTIONS } from '../../constants';
 
 const EditChannel = () => {
@@ -15,6 +15,7 @@ const EditChannel = () => {
     key: '',
     base_url: '',
     other: '',
+    model_mapping:'',
     models: [],
     groups: ['default']
   };
@@ -42,6 +43,9 @@ const EditChannel = () => {
       } else {
         data.groups = data.group.split(',');
       }
+      if (data.model_mapping !== '') {
+        data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
+      }
       setInputs(data);
     } else {
       showError(message);
@@ -94,6 +98,10 @@ const EditChannel = () => {
       showInfo('请至少选择一个模型!');
       return;
     }
+    if (inputs.model_mapping !== "" && !verifyJSON(inputs.model_mapping)) {
+      showInfo('模型映射必须是合法的 JSON 格式!');
+      return;
+    }
     let localInputs = inputs;
     if (localInputs.base_url.endsWith('/')) {
       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
@@ -246,6 +254,17 @@ const EditChannel = () => {
               handleInputChange(null, { name: 'models', value: [] });
             }}>清除所有模型</Button>
           </div>
+          <Form.Field>
+            <Form.TextArea
+              label='模型映射'
+              placeholder={'为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称'}
+              name='model_mapping'
+              onChange={handleInputChange}
+              value={inputs.model_mapping}
+              style={{ minHeight: 100, fontFamily: 'JetBrains Mono, Consolas' }}
+              autoComplete='new-password'
+            />
+          </Form.Field>
           {
             batch ? <Form.Field>
               <Form.TextArea