Ver Fonte

feat: add OpenRouter pricing support to upstream ratio sync

RedwindA há 3 semanas atrás
pai
commit
b1ef7d1517

+ 131 - 1
controller/ratio_sync.go

@@ -5,8 +5,10 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"math"
 	"net"
 	"net/http"
+	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -139,9 +141,13 @@ func FetchUpstreamRatios(c *gin.Context) {
 			sem <- struct{}{}
 			defer func() { <-sem }()
 
+			isOpenRouter := chItem.Endpoint == "openrouter"
+
 			endpoint := chItem.Endpoint
 			var fullURL string
-			if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
+			if isOpenRouter {
+				fullURL = chItem.BaseURL + "/v1/models"
+			} else if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
 				fullURL = endpoint
 			} else {
 				if endpoint == "" {
@@ -167,6 +173,28 @@ func FetchUpstreamRatios(c *gin.Context) {
 				return
 			}
 
+			// OpenRouter requires Bearer token auth
+			if isOpenRouter && chItem.ID != 0 {
+				dbCh, err := model.GetChannelById(chItem.ID, true)
+				if err != nil {
+					ch <- upstreamResult{Name: uniqueName, Err: "failed to get channel key: " + err.Error()}
+					return
+				}
+				key, _, apiErr := dbCh.GetNextEnabledKey()
+				if apiErr != nil {
+					ch <- upstreamResult{Name: uniqueName, Err: "failed to get enabled channel key: " + apiErr.Error()}
+					return
+				}
+				if strings.TrimSpace(key) == "" {
+					ch <- upstreamResult{Name: uniqueName, Err: "no API key configured for this channel"}
+					return
+				}
+				httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key))
+			} else if isOpenRouter {
+				ch <- upstreamResult{Name: uniqueName, Err: "OpenRouter requires a valid channel with API key"}
+				return
+			}
+
 			// 简单重试:最多 3 次,指数退避
 			var resp *http.Response
 			var lastErr error
@@ -194,6 +222,19 @@ func FetchUpstreamRatios(c *gin.Context) {
 				logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
 			}
 			limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
+
+			// type3: OpenRouter /v1/models -> convert per-token pricing to ratios
+			if isOpenRouter {
+				converted, err := convertOpenRouterToRatioData(limited)
+				if err != nil {
+					logger.LogWarn(c.Request.Context(), "OpenRouter parse failed from "+chItem.Name+": "+err.Error())
+					ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+					return
+				}
+				ch <- upstreamResult{Name: uniqueName, Data: converted}
+				return
+			}
+
 			// 兼容两种上游接口格式:
 			//  type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
 			//  type2: /api/pricing      -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
@@ -508,6 +549,94 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
 	return differences
 }
 
+// convertOpenRouterToRatioData parses OpenRouter's /v1/models response and converts
+// per-token USD pricing into the local ratio format.
+// model_ratio = prompt_price_per_token * 1_000_000 * (USD / 1000)
+//
+//	since 1 ratio unit = $0.002/1K tokens and USD=500, the factor is 500_000
+//
+// completion_ratio = completion_price / prompt_price (output/input multiplier)
+func convertOpenRouterToRatioData(reader io.Reader) (map[string]any, error) {
+	var orResp struct {
+		Data []struct {
+			ID      string `json:"id"`
+			Pricing struct {
+				Prompt         string `json:"prompt"`
+				Completion     string `json:"completion"`
+				InputCacheRead string `json:"input_cache_read"`
+			} `json:"pricing"`
+		} `json:"data"`
+	}
+
+	if err := common.DecodeJson(reader, &orResp); err != nil {
+		return nil, fmt.Errorf("failed to decode OpenRouter response: %w", err)
+	}
+
+	modelRatioMap := make(map[string]any)
+	completionRatioMap := make(map[string]any)
+	cacheRatioMap := make(map[string]any)
+
+	for _, m := range orResp.Data {
+		promptPrice, promptErr := strconv.ParseFloat(m.Pricing.Prompt, 64)
+		completionPrice, compErr := strconv.ParseFloat(m.Pricing.Completion, 64)
+
+		if promptErr != nil && compErr != nil {
+			// Both unparseable — skip this model
+			continue
+		}
+
+		// Treat parse errors as 0
+		if promptErr != nil {
+			promptPrice = 0
+		}
+		if compErr != nil {
+			completionPrice = 0
+		}
+
+		// Negative values are sentinel values (e.g., -1 for dynamic/variable pricing) — skip
+		if promptPrice < 0 || completionPrice < 0 {
+			continue
+		}
+
+		if promptPrice == 0 && completionPrice == 0 {
+			// Free model
+			modelRatioMap[m.ID] = 0.0
+			continue
+		}
+
+		// Normal case: promptPrice > 0
+		ratio := promptPrice * 1000 * ratio_setting.USD
+		ratio = math.Round(ratio*1e6) / 1e6
+		modelRatioMap[m.ID] = ratio
+
+		compRatio := completionPrice / promptPrice
+		compRatio = math.Round(compRatio*1e6) / 1e6
+		completionRatioMap[m.ID] = compRatio
+
+		// Convert input_cache_read to cache_ratio (= cache_read_price / prompt_price)
+		if m.Pricing.InputCacheRead != "" {
+			if cachePrice, err := strconv.ParseFloat(m.Pricing.InputCacheRead, 64); err == nil && cachePrice >= 0 {
+				cacheRatio := cachePrice / promptPrice
+				cacheRatio = math.Round(cacheRatio*1e6) / 1e6
+				cacheRatioMap[m.ID] = cacheRatio
+			}
+		}
+	}
+
+	converted := make(map[string]any)
+	if len(modelRatioMap) > 0 {
+		converted["model_ratio"] = modelRatioMap
+	}
+	if len(completionRatioMap) > 0 {
+		converted["completion_ratio"] = completionRatioMap
+	}
+	if len(cacheRatioMap) > 0 {
+		converted["cache_ratio"] = cacheRatioMap
+	}
+
+	return converted, nil
+}
+
 func GetSyncableChannels(c *gin.Context) {
 	channels, err := model.GetAllChannels(0, 0, true, false)
 	if err != nil {
@@ -526,6 +655,7 @@ func GetSyncableChannels(c *gin.Context) {
 				Name:    channel.Name,
 				BaseURL: channel.GetBaseURL(),
 				Status:  channel.Status,
+				Type:    channel.Type,
 			})
 		}
 	}

+ 1 - 0
dto/ratio_sync.go

@@ -35,4 +35,5 @@ type SyncableChannel struct {
 	Name    string `json:"name"`
 	BaseURL string `json:"base_url"`
 	Status  int    `json:"status"`
+	Type    int    `json:"type"`
 }

+ 4 - 0
web/src/components/settings/ChannelSelectorModal.jsx

@@ -117,6 +117,7 @@ const ChannelSelectorModal = forwardRef(
       const getEndpointType = (ep) => {
         if (ep === '/api/ratio_config') return 'ratio_config';
         if (ep === '/api/pricing') return 'pricing';
+        if (ep === 'openrouter') return 'openrouter';
         return 'custom';
       };
 
@@ -127,6 +128,8 @@ const ChannelSelectorModal = forwardRef(
           updateEndpoint(channelId, '/api/ratio_config');
         } else if (val === 'pricing') {
           updateEndpoint(channelId, '/api/pricing');
+        } else if (val === 'openrouter') {
+          updateEndpoint(channelId, 'openrouter');
         } else {
           if (currentType !== 'custom') {
             updateEndpoint(channelId, '');
@@ -144,6 +147,7 @@ const ChannelSelectorModal = forwardRef(
             optionList={[
               { label: 'ratio_config', value: 'ratio_config' },
               { label: 'pricing', value: 'pricing' },
+              { label: 'OpenRouter', value: 'openrouter' },
               { label: 'custom', value: 'custom' },
             ]}
           />

+ 11 - 5
web/src/pages/Setting/Ratio/UpstreamRatioSync.jsx

@@ -154,14 +154,20 @@ export default function UpstreamRatioSync(props) {
             const id = channel.key;
             const base = channel._originalData?.base_url || '';
             const name = channel.label || '';
+            const channelType = channel._originalData?.type;
             const isOfficial =
               id === -100 ||
               base === 'https://basellm.github.io' ||
               name === '官方倍率预设';
+            const isOpenRouter = channelType === 20;
             if (!merged[id]) {
-              merged[id] = isOfficial
-                ? '/llm-metadata/api/newapi/ratio_config-v1-base.json'
-                : DEFAULT_ENDPOINT;
+              if (isOfficial) {
+                merged[id] = '/llm-metadata/api/newapi/ratio_config-v1-base.json';
+              } else if (isOpenRouter) {
+                merged[id] = 'openrouter';
+              } else {
+                merged[id] = DEFAULT_ENDPOINT;
+              }
             }
           });
           return merged;
@@ -652,7 +658,7 @@ export default function UpstreamRatioSync(props) {
             color={text !== null && text !== undefined ? 'blue' : 'default'}
             shape='circle'
           >
-            {text !== null && text !== undefined ? text : t('未设置')}
+            {text !== null && text !== undefined ? String(text) : t('未设置')}
           </Tag>
         ),
       },
@@ -774,7 +780,7 @@ export default function UpstreamRatioSync(props) {
                     }
                   }}
                 >
-                  {upstreamVal}
+                  {String(upstreamVal)}
                 </Checkbox>
                 {!isConfident && (
                   <Tooltip