Bläddra i källkod

Merge pull request #2973 from RedwindA/feat/modelsdotdev

feat(ratio-sync): support models.dev ratio sync and fix Gemini cache ratios
Seefs 2 veckor sedan
förälder
incheckning
31deb0daac

+ 254 - 15
controller/ratio_sync.go

@@ -1,6 +1,7 @@
 package controller
 
 import (
+	"bytes"
 	"context"
 	"encoding/json"
 	"fmt"
@@ -8,6 +9,8 @@ import (
 	"math"
 	"net"
 	"net/http"
+	"net/url"
+	"sort"
 	"strconv"
 	"strings"
 	"sync"
@@ -24,11 +27,20 @@ import (
 )
 
 const (
-	defaultTimeoutSeconds = 10
-	defaultEndpoint       = "/api/ratio_config"
-	maxConcurrentFetches  = 8
-	maxRatioConfigBytes   = 10 << 20 // 10MB
-	floatEpsilon          = 1e-9
+	defaultTimeoutSeconds       = 10
+	defaultEndpoint             = "/api/ratio_config"
+	maxConcurrentFetches        = 8
+	maxRatioConfigBytes         = 10 << 20 // 10MB
+	floatEpsilon                = 1e-9
+	officialRatioPresetID       = -100
+	officialRatioPresetName     = "官方倍率预设"
+	officialRatioPresetBaseURL  = "https://basellm.github.io"
+	modelsDevPresetID           = -101
+	modelsDevPresetName         = "models.dev 价格预设"
+	modelsDevPresetBaseURL      = "https://models.dev"
+	modelsDevHost               = "models.dev"
+	modelsDevPath               = "/api.json"
+	modelsDevInputCostRatioBase = 1000.0
 )
 
 func nearlyEqual(a, b float64) bool {
@@ -157,6 +169,7 @@ func FetchUpstreamRatios(c *gin.Context) {
 				}
 				fullURL = chItem.BaseURL + endpoint
 			}
+			isModelsDev := isModelsDevAPIEndpoint(fullURL)
 
 			uniqueName := chItem.Name
 			if chItem.ID != 0 {
@@ -222,10 +235,16 @@ func FetchUpstreamRatios(c *gin.Context) {
 				logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
 			}
 			limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
+			bodyBytes, err := io.ReadAll(limited)
+			if err != nil {
+				logger.LogWarn(c.Request.Context(), "read response failed from "+chItem.Name+": "+err.Error())
+				ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+				return
+			}
 
 			// type3: OpenRouter /v1/models -> convert per-token pricing to ratios
 			if isOpenRouter {
-				converted, err := convertOpenRouterToRatioData(limited)
+				converted, err := convertOpenRouterToRatioData(bytes.NewReader(bodyBytes))
 				if err != nil {
 					logger.LogWarn(c.Request.Context(), "OpenRouter parse failed from "+chItem.Name+": "+err.Error())
 					ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
@@ -235,6 +254,18 @@ func FetchUpstreamRatios(c *gin.Context) {
 				return
 			}
 
+			// type4: models.dev /api.json -> convert provider model pricing to ratios
+			if isModelsDev {
+				converted, err := convertModelsDevToRatioData(bytes.NewReader(bodyBytes))
+				if err != nil {
+					logger.LogWarn(c.Request.Context(), "models.dev parse failed from "+chItem.Name+": "+err.Error())
+					ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+					return
+				}
+				ch <- upstreamResult{Name: uniqueName, Data: converted}
+				return
+			}
+
 			// 兼容两种上游接口格式:
 			//  type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
 			//  type2: /api/pricing      -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
@@ -244,7 +275,7 @@ func FetchUpstreamRatios(c *gin.Context) {
 				Message string          `json:"message"`
 			}
 
-			if err := json.NewDecoder(limited).Decode(&body); err != nil {
+			if err := common.DecodeJson(bytes.NewReader(bodyBytes), &body); err != nil {
 				logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
 				ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
 				return
@@ -259,7 +290,7 @@ func FetchUpstreamRatios(c *gin.Context) {
 
 			// 尝试按 type1 解析
 			var type1Data map[string]any
-			if err := json.Unmarshal(body.Data, &type1Data); err == nil {
+			if err := common.Unmarshal(body.Data, &type1Data); err == nil {
 				// 如果包含至少一个 ratioTypes 字段,则认为是 type1
 				isType1 := false
 				for _, rt := range ratioTypes {
@@ -282,7 +313,7 @@ func FetchUpstreamRatios(c *gin.Context) {
 				ModelPrice      float64 `json:"model_price"`
 				CompletionRatio float64 `json:"completion_ratio"`
 			}
-			if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
+			if err := common.Unmarshal(body.Data, &pricingItems); err != nil {
 				logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
 				ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
 				return
@@ -549,6 +580,25 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
 	return differences
 }
 
+func roundRatioValue(value float64) float64 {
+	return math.Round(value*1e6) / 1e6
+}
+
+func isModelsDevAPIEndpoint(rawURL string) bool {
+	parsedURL, err := url.Parse(rawURL)
+	if err != nil {
+		return false
+	}
+	if strings.ToLower(parsedURL.Hostname()) != modelsDevHost {
+		return false
+	}
+	path := strings.TrimSuffix(parsedURL.Path, "/")
+	if path == "" {
+		path = "/"
+	}
+	return path == modelsDevPath
+}
+
 // convertOpenRouterToRatioData parses OpenRouter's /v1/models response and converts
 // per-token USD pricing into the local ratio format.
 // model_ratio = prompt_price_per_token * 1_000_000 * (USD / 1000)
@@ -603,21 +653,25 @@ func convertOpenRouterToRatioData(reader io.Reader) (map[string]any, error) {
 			modelRatioMap[m.ID] = 0.0
 			continue
 		}
+		if promptPrice <= 0 {
+			// No meaningful prompt baseline, cannot derive ratios safely.
+			continue
+		}
 
 		// Normal case: promptPrice > 0
 		ratio := promptPrice * 1000 * ratio_setting.USD
-		ratio = math.Round(ratio*1e6) / 1e6
+		ratio = roundRatioValue(ratio)
 		modelRatioMap[m.ID] = ratio
 
 		compRatio := completionPrice / promptPrice
-		compRatio = math.Round(compRatio*1e6) / 1e6
+		compRatio = roundRatioValue(compRatio)
 		completionRatioMap[m.ID] = compRatio
 
 		// Convert input_cache_read to cache_ratio (= cache_read_price / prompt_price)
 		if m.Pricing.InputCacheRead != "" {
 			if cachePrice, err := strconv.ParseFloat(m.Pricing.InputCacheRead, 64); err == nil && cachePrice >= 0 {
 				cacheRatio := cachePrice / promptPrice
-				cacheRatio = math.Round(cacheRatio*1e6) / 1e6
+				cacheRatio = roundRatioValue(cacheRatio)
 				cacheRatioMap[m.ID] = cacheRatio
 			}
 		}
@@ -637,6 +691,184 @@ func convertOpenRouterToRatioData(reader io.Reader) (map[string]any, error) {
 	return converted, nil
 }
 
+type modelsDevProvider struct {
+	Models map[string]modelsDevModel `json:"models"`
+}
+
+type modelsDevModel struct {
+	Cost modelsDevCost `json:"cost"`
+}
+
+type modelsDevCost struct {
+	Input     *float64 `json:"input"`
+	Output    *float64 `json:"output"`
+	CacheRead *float64 `json:"cache_read"`
+}
+
+type modelsDevCandidate struct {
+	Provider  string
+	Input     float64
+	Output    *float64
+	CacheRead *float64
+}
+
+func cloneFloatPtr(v *float64) *float64 {
+	if v == nil {
+		return nil
+	}
+	out := *v
+	return &out
+}
+
+func isValidNonNegativeCost(v float64) bool {
+	if math.IsNaN(v) || math.IsInf(v, 0) {
+		return false
+	}
+	return v >= 0
+}
+
+func buildModelsDevCandidate(provider string, cost modelsDevCost) (modelsDevCandidate, bool) {
+	if cost.Input == nil {
+		return modelsDevCandidate{}, false
+	}
+
+	input := *cost.Input
+	if !isValidNonNegativeCost(input) {
+		return modelsDevCandidate{}, false
+	}
+
+	var output *float64
+	if cost.Output != nil {
+		if !isValidNonNegativeCost(*cost.Output) {
+			return modelsDevCandidate{}, false
+		}
+		output = cloneFloatPtr(cost.Output)
+	}
+
+	// input=0/output>0 cannot be transformed into local ratio.
+	if input == 0 && output != nil && *output > 0 {
+		return modelsDevCandidate{}, false
+	}
+
+	var cacheRead *float64
+	if cost.CacheRead != nil && isValidNonNegativeCost(*cost.CacheRead) {
+		cacheRead = cloneFloatPtr(cost.CacheRead)
+	}
+
+	return modelsDevCandidate{
+		Provider:  provider,
+		Input:     input,
+		Output:    output,
+		CacheRead: cacheRead,
+	}, true
+}
+
+func shouldReplaceModelsDevCandidate(current, next modelsDevCandidate) bool {
+	currentNonZero := current.Input > 0
+	nextNonZero := next.Input > 0
+	if currentNonZero != nextNonZero {
+		// Prefer non-zero pricing data; this matches "cheapest non-zero" conflict policy.
+		return nextNonZero
+	}
+	if nextNonZero && !nearlyEqual(next.Input, current.Input) {
+		return next.Input < current.Input
+	}
+	// Stable tie-breaker for deterministic result.
+	return next.Provider < current.Provider
+}
+
+// convertModelsDevToRatioData parses models.dev /api.json and converts
+// provider pricing metadata into local ratio format.
+// models.dev costs are USD per 1M tokens:
+//
+//	model_ratio = input_cost_per_1M / 2
+//	completion_ratio = output_cost / input_cost
+//	cache_ratio = cache_read_cost / input_cost
+//
+// Duplicate model keys across providers are resolved by selecting the
+// cheapest non-zero input cost. If only zero-priced candidates exist,
+// a zero ratio is kept.
+func convertModelsDevToRatioData(reader io.Reader) (map[string]any, error) {
+	var upstreamData map[string]modelsDevProvider
+	if err := common.DecodeJson(reader, &upstreamData); err != nil {
+		return nil, fmt.Errorf("failed to decode models.dev response: %w", err)
+	}
+	if len(upstreamData) == 0 {
+		return nil, fmt.Errorf("empty models.dev response")
+	}
+
+	providers := make([]string, 0, len(upstreamData))
+	for provider := range upstreamData {
+		providers = append(providers, provider)
+	}
+	sort.Strings(providers)
+
+	selectedCandidates := make(map[string]modelsDevCandidate)
+	for _, provider := range providers {
+		providerData := upstreamData[provider]
+		if len(providerData.Models) == 0 {
+			continue
+		}
+
+		modelNames := make([]string, 0, len(providerData.Models))
+		for modelName := range providerData.Models {
+			modelNames = append(modelNames, modelName)
+		}
+		sort.Strings(modelNames)
+
+		for _, modelName := range modelNames {
+			candidate, ok := buildModelsDevCandidate(provider, providerData.Models[modelName].Cost)
+			if !ok {
+				continue
+			}
+			current, exists := selectedCandidates[modelName]
+			if !exists || shouldReplaceModelsDevCandidate(current, candidate) {
+				selectedCandidates[modelName] = candidate
+			}
+		}
+	}
+
+	if len(selectedCandidates) == 0 {
+		return nil, fmt.Errorf("no valid models.dev pricing entries found")
+	}
+
+	modelRatioMap := make(map[string]any)
+	completionRatioMap := make(map[string]any)
+	cacheRatioMap := make(map[string]any)
+
+	for modelName, candidate := range selectedCandidates {
+		if candidate.Input == 0 {
+			modelRatioMap[modelName] = 0.0
+			continue
+		}
+
+		modelRatio := candidate.Input * float64(ratio_setting.USD) / modelsDevInputCostRatioBase
+		modelRatioMap[modelName] = roundRatioValue(modelRatio)
+
+		if candidate.Output != nil {
+			completionRatio := *candidate.Output / candidate.Input
+			completionRatioMap[modelName] = roundRatioValue(completionRatio)
+		}
+
+		if candidate.CacheRead != nil {
+			cacheRatio := *candidate.CacheRead / candidate.Input
+			cacheRatioMap[modelName] = roundRatioValue(cacheRatio)
+		}
+	}
+
+	converted := make(map[string]any)
+	if len(modelRatioMap) > 0 {
+		converted["model_ratio"] = modelRatioMap
+	}
+	if len(completionRatioMap) > 0 {
+		converted["completion_ratio"] = completionRatioMap
+	}
+	if len(cacheRatioMap) > 0 {
+		converted["cache_ratio"] = cacheRatioMap
+	}
+	return converted, nil
+}
+
 func GetSyncableChannels(c *gin.Context) {
 	channels, err := model.GetAllChannels(0, 0, true, false)
 	if err != nil {
@@ -661,9 +893,16 @@ func GetSyncableChannels(c *gin.Context) {
 	}
 
 	syncableChannels = append(syncableChannels, dto.SyncableChannel{
-		ID:      -100,
-		Name:    "官方倍率预设",
-		BaseURL: "https://basellm.github.io",
+		ID:      officialRatioPresetID,
+		Name:    officialRatioPresetName,
+		BaseURL: officialRatioPresetBaseURL,
+		Status:  1,
+	})
+
+	syncableChannels = append(syncableChannels, dto.SyncableChannel{
+		ID:      modelsDevPresetID,
+		Name:    modelsDevPresetName,
+		BaseURL: modelsDevPresetBaseURL,
 		Status:  1,
 	})
 

+ 3 - 2
setting/ratio_setting/cache_ratio.go

@@ -5,8 +5,9 @@ import (
 )
 
 var defaultCacheRatio = map[string]float64{
-	"gemini-3-flash-preview":              0.25,
-	"gemini-3-pro-preview":                0.25,
+	"gemini-3-flash-preview":              0.1,
+	"gemini-3-pro-preview":                0.1,
+	"gemini-3.1-pro-preview":              0.1,
 	"gpt-4":                               0.5,
 	"o1":                                  0.5,
 	"o1-2024-12-17":                       0.5,

+ 13 - 3
web/src/components/settings/ChannelSelectorModal.jsx

@@ -35,6 +35,13 @@ import {
 } from '@douyinfe/semi-ui';
 import { IconSearch } from '@douyinfe/semi-icons';
 
+const OFFICIAL_RATIO_PRESET_ID = -100;
+const MODELS_DEV_PRESET_ID = -101;
+const OFFICIAL_RATIO_PRESET_NAME = '官方倍率预设';
+const MODELS_DEV_PRESET_NAME = 'models.dev 价格预设';
+const OFFICIAL_RATIO_PRESET_BASE_URL = 'https://basellm.github.io';
+const MODELS_DEV_PRESET_BASE_URL = 'https://models.dev';
+
 const ChannelSelectorModal = forwardRef(
   (
     {
@@ -70,9 +77,12 @@ const ChannelSelectorModal = forwardRef(
       const base = record?._originalData?.base_url || '';
       const name = record?.label || '';
       return (
-        id === -100 ||
-        base === 'https://basellm.github.io' ||
-        name === '官方倍率预设'
+        id === OFFICIAL_RATIO_PRESET_ID ||
+        id === MODELS_DEV_PRESET_ID ||
+        base === OFFICIAL_RATIO_PRESET_BASE_URL ||
+        base === MODELS_DEV_PRESET_BASE_URL ||
+        name === OFFICIAL_RATIO_PRESET_NAME ||
+        name === MODELS_DEV_PRESET_NAME
       );
     };
 

+ 22 - 6
web/src/pages/Setting/Ratio/UpstreamRatioSync.jsx

@@ -53,6 +53,16 @@ import {
 } from '@douyinfe/semi-illustrations';
 import ChannelSelectorModal from '../../../components/settings/ChannelSelectorModal';
 
+const OFFICIAL_RATIO_PRESET_ID = -100;
+const OFFICIAL_RATIO_PRESET_NAME = '官方倍率预设';
+const OFFICIAL_RATIO_PRESET_BASE_URL = 'https://basellm.github.io';
+const OFFICIAL_RATIO_PRESET_ENDPOINT =
+  '/llm-metadata/api/newapi/ratio_config-v1-base.json';
+const MODELS_DEV_PRESET_ID = -101;
+const MODELS_DEV_PRESET_NAME = 'models.dev 价格预设';
+const MODELS_DEV_PRESET_BASE_URL = 'https://models.dev';
+const MODELS_DEV_PRESET_ENDPOINT = 'https://models.dev/api.json';
+
 function ConflictConfirmModal({ t, visible, items, onOk, onCancel }) {
   const isMobile = useIsMobile();
   const columns = [
@@ -155,14 +165,20 @@ export default function UpstreamRatioSync(props) {
             const base = channel._originalData?.base_url || '';
             const name = channel.label || '';
             const channelType = channel._originalData?.type;
-            const isOfficial =
-              id === -100 ||
-              base === 'https://basellm.github.io' ||
-              name === '官方倍率预设';
+            const isOfficialRatioPreset =
+              id === OFFICIAL_RATIO_PRESET_ID ||
+              base === OFFICIAL_RATIO_PRESET_BASE_URL ||
+              name === OFFICIAL_RATIO_PRESET_NAME;
+            const isModelsDevPreset =
+              id === MODELS_DEV_PRESET_ID ||
+              base === MODELS_DEV_PRESET_BASE_URL ||
+              name === MODELS_DEV_PRESET_NAME;
             const isOpenRouter = channelType === 20;
             if (!merged[id]) {
-              if (isOfficial) {
-                merged[id] = '/llm-metadata/api/newapi/ratio_config-v1-base.json';
+              if (isModelsDevPreset) {
+                merged[id] = MODELS_DEV_PRESET_ENDPOINT;
+              } else if (isOfficialRatioPreset) {
+                merged[id] = OFFICIAL_RATIO_PRESET_ENDPOINT;
               } else if (isOpenRouter) {
                 merged[id] = 'openrouter';
               } else {