Browse Source

feat: support model mapping chain

#1033
joey 10 months ago
parent
commit
65ccfd0848
1 changed files with 33 additions and 4 deletions
  1. 33 4
      relay/helper/model_mapped.go

+ 33 - 4
relay/helper/model_mapped.go

@@ -2,9 +2,11 @@ package helper
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"one-api/relay/common"
+
+	"github.com/gin-gonic/gin"
 )
 
 func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
@@ -16,9 +18,36 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
 		if err != nil {
 			return fmt.Errorf("unmarshal_model_mapping_failed")
 		}
-		if modelMap[info.OriginModelName] != "" {
-			info.UpstreamModelName = modelMap[info.OriginModelName]
-			info.IsModelMapped = true
+
+		// 支持链式模型重定向,最终使用链尾的模型
+		currentModel := info.OriginModelName
+		visitedModels := map[string]bool{
+			currentModel: true,
+		}
+		for {
+			if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" {
+				// 模型重定向循环检测,避免无限循环
+				if visitedModels[mappedModel] {
+					if mappedModel == currentModel {
+						if currentModel == info.OriginModelName {
+							info.IsModelMapped = false
+							return nil
+						} else {
+							info.IsModelMapped = true
+							break
+						}
+					}
+					return errors.New("model_mapping_contains_cycle")
+				}
+				visitedModels[mappedModel] = true
+				currentModel = mappedModel
+				info.IsModelMapped = true
+			} else {
+				break
+			}
+		}
+		if info.IsModelMapped {
+			info.UpstreamModelName = currentModel
 		}
 	}
 	return nil