Parcourir la source

Merge remote-tracking branch 'origin/alpha' into alpha

# Conflicts:
#	relay/channel/openai/adaptor.go
CaIon il y a 6 mois
Parent
commit
9d6d580cbd

+ 49 - 7
controller/model.go

@@ -16,6 +16,7 @@ import (
 	"one-api/relay/channel/moonshot"
 	relaycommon "one-api/relay/common"
 	"one-api/setting"
+	"time"
 )
 
 // https://platform.openai.com/docs/api-reference/models/list
@@ -102,7 +103,7 @@ func init() {
 	})
 }
 
-func ListModels(c *gin.Context) {
+func ListModels(c *gin.Context, modelType int) {
 	userOpenAiModels := make([]dto.OpenAIModels, 0)
 
 	modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
@@ -171,10 +172,41 @@ func ListModels(c *gin.Context) {
 			}
 		}
 	}
-	c.JSON(200, gin.H{
-		"success": true,
-		"data":    userOpenAiModels,
-	})
+	switch modelType {
+	case constant.ChannelTypeAnthropic:
+		useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
+		for i, model := range userOpenAiModels {
+			useranthropicModels[i] = dto.AnthropicModel{
+				ID:          model.Id,
+				CreatedAt:   time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
+				DisplayName: model.Id,
+				Type:        "model",
+			}
+		}
+		c.JSON(200, gin.H{
+			"data":     useranthropicModels,
+			"first_id": useranthropicModels[0].ID,
+			"has_more": false,
+			"last_id":  useranthropicModels[len(useranthropicModels)-1].ID,
+		})
+	case constant.ChannelTypeGemini:
+		userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
+		for i, model := range userOpenAiModels {
+			userGeminiModels[i] = dto.GeminiModel{
+				Name:        model.Id,
+				DisplayName: model.Id,
+			}
+		}
+		c.JSON(200, gin.H{
+			"models":        userGeminiModels,
+			"nextPageToken": nil,
+		})
+	default:
+		c.JSON(200, gin.H{
+			"success": true,
+			"data":    userOpenAiModels,
+		})
+	}
 }
 
 func ChannelListModels(c *gin.Context) {
@@ -198,10 +230,20 @@ func EnabledListModels(c *gin.Context) {
 	})
 }
 
-func RetrieveModel(c *gin.Context) {
+func RetrieveModel(c *gin.Context, modelType int) {
 	modelId := c.Param("model")
 	if aiModel, ok := openAIModelsMap[modelId]; ok {
-		c.JSON(200, aiModel)
+		switch modelType {
+		case constant.ChannelTypeAnthropic:
+			c.JSON(200, dto.AnthropicModel{
+				ID:          aiModel.Id,
+				CreatedAt:   time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
+				DisplayName: aiModel.Id,
+				Type:        "model",
+			})
+		default:
+			c.JSON(200, aiModel)
+		}
 	} else {
 		openAIError := dto.OpenAIError{
 			Message: fmt.Sprintf("The model '%s' does not exist", modelId),

+ 109 - 8
controller/model_meta.go

@@ -3,8 +3,10 @@ package controller
 import (
 	"encoding/json"
 	"strconv"
+	"strings"
 
 	"one-api/common"
+	"one-api/constant"
 	"one-api/model"
 
 	"github.com/gin-gonic/gin"
@@ -162,17 +164,116 @@ func DeleteModelMeta(c *gin.Context) {
 
 // 辅助函数:填充 Endpoints 和 BoundChannels 和 EnableGroups
 func fillModelExtra(m *model.Model) {
-	if m.Endpoints == "" {
-		eps := model.GetModelSupportEndpointTypes(m.ModelName)
+	// 若为精确匹配,保持原有逻辑
+	if m.NameRule == model.NameRuleExact {
+		if m.Endpoints == "" {
+			eps := model.GetModelSupportEndpointTypes(m.ModelName)
+			if b, err := json.Marshal(eps); err == nil {
+				m.Endpoints = string(b)
+			}
+		}
+		if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
+			m.BoundChannels = channels
+		}
+		m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
+		m.QuotaType = model.GetModelQuotaType(m.ModelName)
+		return
+	}
+
+	// 非精确匹配:计算并集
+	pricings := model.GetPricing()
+
+	// 匹配到的模型名称集合
+	matchedNames := make([]string, 0)
+
+	// 端点去重集合
+	endpointSet := make(map[constant.EndpointType]struct{})
+
+	// 已绑定渠道去重集合
+	channelSet := make(map[string]model.BoundChannel)
+	// 分组去重集合
+	groupSet := make(map[string]struct{})
+	// 计费类型(若有任意模型为 1,则返回 1)
+	quotaTypeSet := make(map[int]struct{})
+
+	for _, p := range pricings {
+		var matched bool
+		switch m.NameRule {
+		case model.NameRulePrefix:
+			matched = strings.HasPrefix(p.ModelName, m.ModelName)
+		case model.NameRuleSuffix:
+			matched = strings.HasSuffix(p.ModelName, m.ModelName)
+		case model.NameRuleContains:
+			matched = strings.Contains(p.ModelName, m.ModelName)
+		}
+		if !matched {
+			continue
+		}
+
+		// 记录匹配到的模型名称
+		matchedNames = append(matchedNames, p.ModelName)
+
+		// 收集端点
+		for _, et := range p.SupportedEndpointTypes {
+			endpointSet[et] = struct{}{}
+		}
+
+		// 收集分组
+		for _, g := range p.EnableGroup {
+			groupSet[g] = struct{}{}
+		}
+
+		// 收集计费类型
+		quotaTypeSet[p.QuotaType] = struct{}{}
+	}
+
+	// 序列化端点
+	if len(endpointSet) > 0 && m.Endpoints == "" {
+		eps := make([]constant.EndpointType, 0, len(endpointSet))
+		for et := range endpointSet {
+			eps = append(eps, et)
+		}
 		if b, err := json.Marshal(eps); err == nil {
 			m.Endpoints = string(b)
 		}
 	}
-	if channels, err := model.GetBoundChannels(m.ModelName); err == nil {
-		m.BoundChannels = channels
+
+	// 序列化分组
+	if len(groupSet) > 0 {
+		groups := make([]string, 0, len(groupSet))
+		for g := range groupSet {
+			groups = append(groups, g)
+		}
+		m.EnableGroups = groups
 	}
-	// 填充启用分组
-	m.EnableGroups = model.GetModelEnableGroups(m.ModelName)
-	// 填充计费类型
-	m.QuotaType = model.GetModelQuotaType(m.ModelName)
+
+	// 确定计费类型:仅当所有匹配模型计费类型一致时才返回该类型,否则返回 -1 表示未知/不确定
+	if len(quotaTypeSet) == 1 {
+		for k := range quotaTypeSet {
+			m.QuotaType = k
+		}
+	} else {
+		m.QuotaType = -1
+	}
+
+	// 批量查询并序列化渠道
+	if len(matchedNames) > 0 {
+		if channels, err := model.GetBoundChannelsForModels(matchedNames); err == nil {
+			for _, ch := range channels {
+				key := ch.Name + "_" + strconv.Itoa(ch.Type)
+				channelSet[key] = ch
+			}
+		}
+		if len(channelSet) > 0 {
+			chs := make([]model.BoundChannel, 0, len(channelSet))
+			for _, ch := range channelSet {
+				chs = append(chs, ch)
+			}
+			m.BoundChannels = chs
+		}
+	}
+
+	// 设置匹配信息
+	m.MatchedModels = matchedNames
+	m.MatchedCount = len(matchedNames)
 }

+ 24 - 0
dto/pricing.go

@@ -2,6 +2,7 @@ package dto
 
 import "one-api/constant"
 
+// 这里不好动就不动了,本来想独立出来的(
 type OpenAIModels struct {
 	Id                     string                  `json:"id"`
 	Object                 string                  `json:"object"`
@@ -9,3 +10,26 @@ type OpenAIModels struct {
 	OwnedBy                string                  `json:"owned_by"`
 	SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
 }
+
+type AnthropicModel struct {
+	ID          string `json:"id"`
+	CreatedAt   string `json:"created_at"`
+	DisplayName string `json:"display_name"`
+	Type        string `json:"type"`
+}
+
+type GeminiModel struct {
+	Name                       interface{}   `json:"name"`
+	BaseModelId                interface{}   `json:"baseModelId"`
+	Version                    interface{}   `json:"version"`
+	DisplayName                interface{}   `json:"displayName"`
+	Description                interface{}   `json:"description"`
+	InputTokenLimit            interface{}   `json:"inputTokenLimit"`
+	OutputTokenLimit           interface{}   `json:"outputTokenLimit"`
+	SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
+	Thinking                   interface{}   `json:"thinking"`
+	Temperature                interface{}   `json:"temperature"`
+	MaxTemperature             interface{}   `json:"maxTemperature"`
+	TopP                       interface{}   `json:"topP"`
+	TopK                       interface{}   `json:"topK"`
+}

+ 9 - 7
middleware/auth.go

@@ -192,16 +192,18 @@ func TokenAuth() func(c *gin.Context) {
 			}
 			c.Request.Header.Set("Authorization", "Bearer "+key)
 		}
+		anthropicKey := c.Request.Header.Get("x-api-key")
 		// 检查path包含/v1/messages
-		if strings.Contains(c.Request.URL.Path, "/v1/messages") {
-			// 从x-api-key中获取key
-			key := c.Request.Header.Get("x-api-key")
-			if key != "" {
-				c.Request.Header.Set("Authorization", "Bearer "+key)
-			}
+		// 或者是否 x-api-key 不为空且存在anthropic-version
+		// 谁知道有多少不符合规范没写anthropic-version的
+		// 所以就这样随它去吧(
+		if strings.Contains(c.Request.URL.Path, "/v1/messages") || (anthropicKey != "" && c.Request.Header.Get("anthropic-version") != "") {
+			c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
 		}
 		// gemini api 从query中获取key
-		if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
+		if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
+			strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
+			strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
 			skKey := c.Query("key")
 			if skKey != "" {
 				c.Request.Header.Set("Authorization", "Bearer "+skKey)

+ 23 - 15
model/main.go

@@ -66,18 +66,18 @@ var LOG_DB *gorm.DB
 
 // dropIndexIfExists drops a MySQL index only if it exists to avoid noisy 1091 errors
 func dropIndexIfExists(tableName string, indexName string) {
-    if !common.UsingMySQL {
-        return
-    }
-    var count int64
-    // Check index existence via information_schema
-    err := DB.Raw(
-        "SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
-        tableName, indexName,
-    ).Scan(&count).Error
-    if err == nil && count > 0 {
-        _ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
-    }
+	if !common.UsingMySQL {
+		return
+	}
+	var count int64
+	// Check index existence via information_schema
+	err := DB.Raw(
+		"SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
+		tableName, indexName,
+	).Scan(&count).Error
+	if err == nil && count > 0 {
+		_ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
+	}
 }
 
 func createRootAccountIfNeed() error {
@@ -252,8 +252,12 @@ func InitLogDB() (err error) {
 
 func migrateDB() error {
 	// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
-	dropIndexIfExists("models", "uk_model_name")
-	dropIndexIfExists("vendors", "uk_vendor_name")
+	// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引 (model_name, deleted_at) 冲突
+	dropIndexIfExists("models", "uk_model_name") // 新版复合索引名称(若已存在)
+	dropIndexIfExists("models", "model_name")    // 旧版列级唯一索引名称
+
+	dropIndexIfExists("vendors", "uk_vendor_name") // 新版复合索引名称(若已存在)
+	dropIndexIfExists("vendors", "name")           // 旧版列级唯一索引名称
 	if !common.UsingPostgreSQL {
 		return migrateDBFast()
 	}
@@ -284,8 +288,12 @@ func migrateDB() error {
 
 func migrateDBFast() error {
 	// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
+	// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引冲突
 	dropIndexIfExists("models", "uk_model_name")
+	dropIndexIfExists("models", "model_name")
+
 	dropIndexIfExists("vendors", "uk_vendor_name")
+	dropIndexIfExists("vendors", "name")
 
 	var wg sync.WaitGroup
 
@@ -305,7 +313,7 @@ func migrateDBFast() error {
 		{&QuotaData{}, "QuotaData"},
 		{&Task{}, "Task"},
 		{&Model{}, "Model"},
-        {&Vendor{}, "Vendor"},
+		{&Vendor{}, "Vendor"},
 		{&PrefillGroup{}, "PrefillGroup"},
 		{&Setup{}, "Setup"},
 		{&TwoFA{}, "TwoFA"},

+ 18 - 0
model/model_meta.go

@@ -51,6 +51,9 @@ type Model struct {
 	EnableGroups  []string       `json:"enable_groups,omitempty" gorm:"-"`
 	QuotaType     int            `json:"quota_type" gorm:"-"`
 	NameRule      int            `json:"name_rule" gorm:"default:0"`
+
+	MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
+	MatchedCount  int      `json:"matched_count,omitempty" gorm:"-"`
 }
 
 // Insert 创建新的模型元数据记录
@@ -136,6 +139,21 @@ func GetBoundChannels(modelName string) ([]BoundChannel, error) {
 	return channels, err
 }
 
+// GetBoundChannelsForModels 批量查询多模型的绑定渠道并去重返回
+func GetBoundChannelsForModels(modelNames []string) ([]BoundChannel, error) {
+	if len(modelNames) == 0 {
+		return make([]BoundChannel, 0), nil
+	}
+	var channels []BoundChannel
+	err := DB.Table("channels").
+		Select("channels.name, channels.type").
+		Joins("join abilities on abilities.channel_id = channels.id").
+		Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
+		Group("channels.id").
+		Scan(&channels).Error
+	return channels, err
+}
+
 // FindModelByNameWithRule 根据模型名称和匹配规则查找模型元数据,优先级:精确 > 前缀 > 后缀 > 包含
 func FindModelByNameWithRule(name string) (*Model, error) {
 	// 1. 精确匹配

+ 38 - 4
router/relay-router.go

@@ -1,11 +1,11 @@
 package router
 
 import (
+	"github.com/gin-gonic/gin"
+	"one-api/constant"
 	"one-api/controller"
 	"one-api/middleware"
 	"one-api/relay"
-
-	"github.com/gin-gonic/gin"
 )
 
 func SetRelayRouter(router *gin.Engine) {
@@ -16,9 +16,43 @@ func SetRelayRouter(router *gin.Engine) {
 	modelsRouter := router.Group("/v1/models")
 	modelsRouter.Use(middleware.TokenAuth())
 	{
-		modelsRouter.GET("", controller.ListModels)
-		modelsRouter.GET("/:model", controller.RetrieveModel)
+		modelsRouter.GET("", func(c *gin.Context) {
+			switch {
+			case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
+				controller.ListModels(c, constant.ChannelTypeAnthropic)
+			case c.GetHeader("x-goog-api-key") != "" || c.Query("key") != "": // 单独的适配
+				controller.RetrieveModel(c, constant.ChannelTypeGemini)
+			default:
+				controller.ListModels(c, constant.ChannelTypeOpenAI)
+			}
+		})
+
+		modelsRouter.GET("/:model", func(c *gin.Context) {
+			switch {
+			case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
+				controller.RetrieveModel(c, constant.ChannelTypeAnthropic)
+			default:
+				controller.RetrieveModel(c, constant.ChannelTypeOpenAI)
+			}
+		})
 	}
+
+	geminiRouter := router.Group("/v1beta/models")
+	geminiRouter.Use(middleware.TokenAuth())
+	{
+		geminiRouter.GET("", func(c *gin.Context) {
+			controller.ListModels(c, constant.ChannelTypeGemini)
+		})
+	}
+
+	geminiCompatibleRouter := router.Group("/v1beta/openai/models")
+	geminiCompatibleRouter.Use(middleware.TokenAuth())
+	{
+		geminiCompatibleRouter.GET("", func(c *gin.Context) {
+			controller.ListModels(c, constant.ChannelTypeOpenAI)
+		})
+	}
+
 	playgroundRouter := router.Group("/pg")
 	playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
 	{

+ 11 - 6
web/src/components/table/model-pricing/modal/components/ModelPricingTable.jsx

@@ -63,7 +63,7 @@ const ModelPricingTable = ({
         key: group,
         group: group,
         ratio: groupRatioValue,
-        billingType: modelData?.quota_type === 0 ? t('按量计费') : t('按次计费'),
+        billingType: modelData?.quota_type === 0 ? t('按量计费') : (modelData?.quota_type === 1 ? t('按次计费') : '-'),
         inputPrice: modelData?.quota_type === 0 ? priceData.inputPrice : '-',
         outputPrice: modelData?.quota_type === 0 ? (priceData.completionPrice || priceData.outputPrice) : '-',
         fixedPrice: modelData?.quota_type === 1 ? priceData.price : '-',
@@ -100,11 +100,16 @@ const ModelPricingTable = ({
     columns.push({
       title: t('计费类型'),
       dataIndex: 'billingType',
-      render: (text) => (
-        <Tag color={text === t('按量计费') ? 'violet' : 'teal'} size="small" shape="circle">
-          {text}
-        </Tag>
-      ),
+      render: (text) => {
+        let color = 'white';
+        if (text === t('按量计费')) color = 'violet';
+        else if (text === t('按次计费')) color = 'teal';
+        return (
+          <Tag color={color} size="small" shape="circle">
+            {text || '-'}
+          </Tag>
+        );
+      },
     });
 
     // 根据计费类型添加价格列

+ 16 - 5
web/src/components/table/model-pricing/view/card/PricingCardView.jsx

@@ -144,13 +144,24 @@ const PricingCardView = ({
   // 渲染标签
   const renderTags = (record) => {
     // 计费类型标签(左边)
-    const billingType = record.quota_type === 1 ? 'teal' : 'violet';
-    const billingText = record.quota_type === 1 ? t('按次计费') : t('按量计费');
-    const billingTag = (
-      <Tag key="billing" shape='circle' color={billingType} size='small'>
-        {billingText}
+    let billingTag = (
+      <Tag key="billing" shape='circle' color='white' size='small'>
+        -
       </Tag>
     );
+    if (record.quota_type === 1) {
+      billingTag = (
+        <Tag key="billing" shape='circle' color='teal' size='small'>
+          {t('按次计费')}
+        </Tag>
+      );
+    } else if (record.quota_type === 0) {
+      billingTag = (
+        <Tag key="billing" shape='circle' color='violet' size='small'>
+          {t('按量计费')}
+        </Tag>
+      );
+    }
 
     // 自定义标签(右边)
     const customTags = [];

+ 24 - 7
web/src/components/table/models/ModelsColumnDefs.js

@@ -18,7 +18,7 @@ For commercial licensing, please contact support@quantumnous.com
 */
 
 import React from 'react';
-import { Button, Space, Tag, Typography, Modal } from '@douyinfe/semi-ui';
+import { Button, Space, Tag, Typography, Modal, Tooltip } from '@douyinfe/semi-ui';
 import {
   timestamp2string,
   getLobeHubIcon,
@@ -137,7 +137,8 @@ const renderQuotaType = (qt, t) => {
       </Tag>
     );
   }
-  return qt ?? '-';
+  // 未知
+  return '-';
 };
 
 // Render bound channels
@@ -207,8 +208,8 @@ const renderOperations = (text, record, setEditingModel, setShowEdit, manageMode
   );
 };
 
-// 名称匹配类型渲染
-const renderNameRule = (rule, t) => {
+// 名称匹配类型渲染(带匹配数量 Tooltip)
+const renderNameRule = (rule, record, t) => {
   const map = {
     0: { color: 'green', label: t('精确') },
     1: { color: 'blue', label: t('前缀') },
@@ -217,11 +218,27 @@ const renderNameRule = (rule, t) => {
   };
   const cfg = map[rule];
   if (!cfg) return '-';
-  return (
+
+  let label = cfg.label;
+  if (rule !== 0 && record.matched_count) {
+    label = `${cfg.label} ${record.matched_count}${t('个模型')}`;
+  }
+
+  const tagElement = (
     <Tag color={cfg.color} size="small" shape='circle'>
-      {cfg.label}
+      {label}
     </Tag>
   );
+
+  if (rule === 0 || !record.matched_models || record.matched_models.length === 0) {
+    return tagElement;
+  }
+
+  return (
+    <Tooltip content={record.matched_models.join(', ')} showArrow>
+      {tagElement}
+    </Tooltip>
+  );
 };
 
 export const getModelsColumns = ({
@@ -252,7 +269,7 @@ export const getModelsColumns = ({
     {
       title: t('匹配类型'),
       dataIndex: 'name_rule',
-      render: (val) => renderNameRule(val, t),
+      render: (val, record) => renderNameRule(val, record, t),
     },
     {
       title: t('描述'),

+ 14 - 4
web/src/helpers/utils.js

@@ -632,12 +632,22 @@ export const calculateModelPrice = ({
     };
   }
 
-  // 按次计费
-  const priceUSD = parseFloat(record.model_price) * usedGroupRatio;
-  const displayVal = displayPrice(priceUSD);
+  if (record.quota_type === 1) {
+    // 按次计费
+    const priceUSD = parseFloat(record.model_price) * usedGroupRatio;
+    const displayVal = displayPrice(priceUSD);
 
+    return {
+      price: displayVal,
+      isPerToken: false,
+      usedGroup,
+      usedGroupRatio,
+    };
+  }
+
+  // 未知计费类型,返回占位信息
   return {
-    price: displayVal,
+    price: '-',
     isPerToken: false,
     usedGroup,
     usedGroupRatio,