Procházet zdrojové kódy

Merge pull request #1564 from feitianbubu/pr/init-default-vendor

feat: init default vendor
IcedTangerine před 6 měsíci
rodič
revize
c05dc07666
2 změnil soubory, kde provedl 94 přidání a 2 odebrání
  1. 5 2
      model/pricing.go
  2. 89 0
      model/pricing_default.go

+ 5 - 2
model/pricing.go

@@ -155,9 +155,12 @@ func updatePricing() {
 		vendorMap[vendors[i].Id] = &vendors[i]
 	}
 
+	// 初始化默认供应商映射
+	initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
+
 	// 构建对前端友好的供应商列表
-	vendorsList = make([]PricingVendor, 0, len(vendors))
-	for _, v := range vendors {
+	vendorsList = make([]PricingVendor, 0, len(vendorMap))
+	for _, v := range vendorMap {
 		vendorsList = append(vendorsList, PricingVendor{
 			ID:          v.Id,
 			Name:        v.Name,

+ 89 - 0
model/pricing_default.go

@@ -0,0 +1,89 @@
+package model
+
+import (
+	"strings"
+)
+
+// 简化的供应商映射规则
+var defaultVendorRules = map[string]string{
+	"gpt":      "OpenAI",
+	"dall-e":   "OpenAI",
+	"whisper":  "OpenAI",
+	"o1":       "OpenAI",
+	"o3":       "OpenAI",
+	"claude":   "Anthropic",
+	"gemini":   "Google",
+	"moonshot": "Moonshot",
+	"kimi":     "Moonshot",
+	"chatglm":  "智谱",
+	"glm-":     "智谱",
+	"qwen":     "阿里巴巴",
+	"deepseek": "DeepSeek",
+	"abab":     "MiniMax",
+	"ernie":    "百度",
+	"spark":    "讯飞",
+	"hunyuan":  "腾讯",
+	"command":  "Cohere",
+	"@cf/":     "Cloudflare",
+	"360":      "360",
+	"yi":       "零一万物",
+	"jina":     "Jina",
+	"mistral":  "Mistral",
+	"grok":     "xAI",
+	"llama":    "Meta",
+	"doubao":   "字节跳动",
+	"kling":    "快手",
+	"jimeng":   "即梦",
+	"vidu":     "Vidu",
+}
+
+// initDefaultVendorMapping 简化的默认供应商映射
+func initDefaultVendorMapping(metaMap map[string]*Model, vendorMap map[int]*Vendor, enableAbilities []AbilityWithChannel) {
+	for _, ability := range enableAbilities {
+		modelName := ability.Model
+		if _, exists := metaMap[modelName]; exists {
+			continue
+		}
+
+		// 匹配供应商
+		vendorID := 0
+		modelLower := strings.ToLower(modelName)
+		for pattern, vendorName := range defaultVendorRules {
+			if strings.Contains(modelLower, pattern) {
+				vendorID = getOrCreateVendor(vendorName, vendorMap)
+				break
+			}
+		}
+
+		// 创建模型元数据
+		metaMap[modelName] = &Model{
+			ModelName: modelName,
+			VendorID:  vendorID,
+			Status:    1,
+			NameRule:  NameRuleExact,
+		}
+	}
+}
+
+// 查找或创建供应商
+func getOrCreateVendor(vendorName string, vendorMap map[int]*Vendor) int {
+	// 查找现有供应商
+	for id, vendor := range vendorMap {
+		if vendor.Name == vendorName {
+			return id
+		}
+	}
+
+	// 创建新供应商
+	newVendor := &Vendor{
+		Name:   vendorName,
+		Status: 1,
+	}
+
+	if err := newVendor.Insert(); err != nil {
+		return 0
+	}
+
+	vendorMap[newVendor.Id] = newVendor
+	return newVendor.Id
+}