ratio_sync.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "math"
  9. "net"
  10. "net/http"
  11. "net/url"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "time"
  17. "github.com/QuantumNous/new-api/common"
  18. "github.com/QuantumNous/new-api/logger"
  19. "github.com/QuantumNous/new-api/dto"
  20. "github.com/QuantumNous/new-api/model"
  21. "github.com/QuantumNous/new-api/setting/billing_setting"
  22. "github.com/QuantumNous/new-api/setting/ratio_setting"
  23. "github.com/samber/lo"
  24. "github.com/gin-gonic/gin"
  25. )
  26. const (
  27. defaultTimeoutSeconds = 10
  28. defaultEndpoint = "/api/pricing"
  29. maxConcurrentFetches = 8
  30. maxRatioConfigBytes = 10 << 20 // 10MB
  31. floatEpsilon = 1e-9
  32. officialRatioPresetID = -100
  33. officialRatioPresetName = "官方倍率预设"
  34. officialRatioPresetBaseURL = "https://basellm.github.io"
  35. modelsDevPresetID = -101
  36. modelsDevPresetName = "models.dev 价格预设"
  37. modelsDevPresetBaseURL = "https://models.dev"
  38. modelsDevHost = "models.dev"
  39. modelsDevPath = "/api.json"
  40. modelsDevInputCostRatioBase = 1000.0
  41. )
  42. func nearlyEqual(a, b float64) bool {
  43. if a > b {
  44. return a-b < floatEpsilon
  45. }
  46. return b-a < floatEpsilon
  47. }
  48. func valuesEqual(a, b interface{}) bool {
  49. af, aok := a.(float64)
  50. bf, bok := b.(float64)
  51. if aok && bok {
  52. return nearlyEqual(af, bf)
  53. }
  54. return a == b
  55. }
  56. var pricingSyncFields = []string{
  57. "model_ratio",
  58. "completion_ratio",
  59. "cache_ratio",
  60. "create_cache_ratio",
  61. "image_ratio",
  62. "audio_ratio",
  63. "audio_completion_ratio",
  64. "model_price",
  65. billing_setting.BillingModeField,
  66. billing_setting.BillingExprField,
  67. }
  68. var numericPricingSyncFields = map[string]bool{
  69. "model_ratio": true,
  70. "completion_ratio": true,
  71. "cache_ratio": true,
  72. "create_cache_ratio": true,
  73. "image_ratio": true,
  74. "audio_ratio": true,
  75. "audio_completion_ratio": true,
  76. "model_price": true,
  77. }
  78. type upstreamResult struct {
  79. Name string `json:"name"`
  80. Data map[string]any `json:"data,omitempty"`
  81. Err string `json:"err,omitempty"`
  82. }
  83. func valueMap(value any) map[string]any {
  84. switch typed := value.(type) {
  85. case map[string]any:
  86. return typed
  87. case map[string]float64:
  88. return lo.MapValues(typed, func(value float64, _ string) any { return value })
  89. case map[string]string:
  90. return lo.MapValues(typed, func(value string, _ string) any { return value })
  91. default:
  92. return nil
  93. }
  94. }
  95. func asFloat64(value any) (float64, bool) {
  96. switch typed := value.(type) {
  97. case float64:
  98. return typed, true
  99. case float32:
  100. return float64(typed), true
  101. case int:
  102. return float64(typed), true
  103. case int64:
  104. return float64(typed), true
  105. case json.Number:
  106. parsed, err := typed.Float64()
  107. return parsed, err == nil
  108. default:
  109. return 0, false
  110. }
  111. }
  112. func normalizeSyncValue(field string, value any) any {
  113. if numericPricingSyncFields[field] {
  114. if parsed, ok := asFloat64(value); ok {
  115. return parsed
  116. }
  117. }
  118. return value
  119. }
  120. func getLocalPricingSyncData() map[string]any {
  121. data := billing_setting.GetPricingSyncData(map[string]any(ratio_setting.GetExposedData()))
  122. data["image_ratio"] = ratio_setting.GetImageRatioCopy()
  123. data["audio_ratio"] = ratio_setting.GetAudioRatioCopy()
  124. data["audio_completion_ratio"] = ratio_setting.GetAudioCompletionRatioCopy()
  125. return data
  126. }
  127. func FetchUpstreamRatios(c *gin.Context) {
  128. var req dto.UpstreamRequest
  129. if err := c.ShouldBindJSON(&req); err != nil {
  130. common.SysError("failed to bind upstream request: " + err.Error())
  131. c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "请求参数格式错误"})
  132. return
  133. }
  134. if req.Timeout <= 0 {
  135. req.Timeout = defaultTimeoutSeconds
  136. }
  137. var upstreams []dto.UpstreamDTO
  138. if len(req.Upstreams) > 0 {
  139. for _, u := range req.Upstreams {
  140. if strings.HasPrefix(u.BaseURL, "http") {
  141. if u.Endpoint == "" {
  142. u.Endpoint = defaultEndpoint
  143. }
  144. u.BaseURL = strings.TrimRight(u.BaseURL, "/")
  145. upstreams = append(upstreams, u)
  146. }
  147. }
  148. } else if len(req.ChannelIDs) > 0 {
  149. intIds := make([]int, 0, len(req.ChannelIDs))
  150. for _, id64 := range req.ChannelIDs {
  151. intIds = append(intIds, int(id64))
  152. }
  153. dbChannels, err := model.GetChannelsByIds(intIds)
  154. if err != nil {
  155. logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
  156. c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
  157. return
  158. }
  159. for _, ch := range dbChannels {
  160. if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
  161. upstreams = append(upstreams, dto.UpstreamDTO{
  162. ID: ch.Id,
  163. Name: ch.Name,
  164. BaseURL: strings.TrimRight(base, "/"),
  165. Endpoint: "",
  166. })
  167. }
  168. }
  169. }
  170. if len(upstreams) == 0 {
  171. c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
  172. return
  173. }
  174. var wg sync.WaitGroup
  175. ch := make(chan upstreamResult, len(upstreams))
  176. sem := make(chan struct{}, maxConcurrentFetches)
  177. dialer := &net.Dialer{Timeout: 10 * time.Second}
  178. transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
  179. if common.TLSInsecureSkipVerify {
  180. transport.TLSClientConfig = common.InsecureTLSConfig
  181. }
  182. transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  183. host, _, err := net.SplitHostPort(addr)
  184. if err != nil {
  185. host = addr
  186. }
  187. // 对 github.io 优先尝试 IPv4,失败则回退 IPv6
  188. if strings.HasSuffix(host, "github.io") {
  189. if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
  190. return conn, nil
  191. }
  192. return dialer.DialContext(ctx, "tcp6", addr)
  193. }
  194. return dialer.DialContext(ctx, network, addr)
  195. }
  196. client := &http.Client{Transport: transport}
  197. for _, chn := range upstreams {
  198. wg.Add(1)
  199. go func(chItem dto.UpstreamDTO) {
  200. defer wg.Done()
  201. sem <- struct{}{}
  202. defer func() { <-sem }()
  203. isOpenRouter := chItem.Endpoint == "openrouter"
  204. endpoint := chItem.Endpoint
  205. var fullURL string
  206. if isOpenRouter {
  207. fullURL = chItem.BaseURL + "/v1/models"
  208. } else if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
  209. fullURL = endpoint
  210. } else {
  211. if endpoint == "" {
  212. endpoint = defaultEndpoint
  213. } else if !strings.HasPrefix(endpoint, "/") {
  214. endpoint = "/" + endpoint
  215. }
  216. fullURL = chItem.BaseURL + endpoint
  217. }
  218. isModelsDev := isModelsDevAPIEndpoint(fullURL)
  219. uniqueName := chItem.Name
  220. if chItem.ID != 0 {
  221. uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
  222. }
  223. ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
  224. defer cancel()
  225. httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
  226. if err != nil {
  227. logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
  228. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  229. return
  230. }
  231. // OpenRouter requires Bearer token auth
  232. if isOpenRouter && chItem.ID != 0 {
  233. dbCh, err := model.GetChannelById(chItem.ID, true)
  234. if err != nil {
  235. ch <- upstreamResult{Name: uniqueName, Err: "failed to get channel key: " + err.Error()}
  236. return
  237. }
  238. key, _, apiErr := dbCh.GetNextEnabledKey()
  239. if apiErr != nil {
  240. ch <- upstreamResult{Name: uniqueName, Err: "failed to get enabled channel key: " + apiErr.Error()}
  241. return
  242. }
  243. if strings.TrimSpace(key) == "" {
  244. ch <- upstreamResult{Name: uniqueName, Err: "no API key configured for this channel"}
  245. return
  246. }
  247. httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key))
  248. } else if isOpenRouter {
  249. ch <- upstreamResult{Name: uniqueName, Err: "OpenRouter requires a valid channel with API key"}
  250. return
  251. }
  252. // 简单重试:最多 3 次,指数退避
  253. var resp *http.Response
  254. var lastErr error
  255. for attempt := 0; attempt < 3; attempt++ {
  256. resp, lastErr = client.Do(httpReq)
  257. if lastErr == nil {
  258. break
  259. }
  260. time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
  261. }
  262. if lastErr != nil {
  263. logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
  264. ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
  265. return
  266. }
  267. defer resp.Body.Close()
  268. if resp.StatusCode != http.StatusOK {
  269. logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
  270. ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
  271. return
  272. }
  273. // Content-Type 和响应体大小校验
  274. if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
  275. logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
  276. }
  277. limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
  278. bodyBytes, err := io.ReadAll(limited)
  279. if err != nil {
  280. logger.LogWarn(c.Request.Context(), "read response failed from "+chItem.Name+": "+err.Error())
  281. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  282. return
  283. }
  284. // type3: OpenRouter /v1/models -> convert per-token pricing to ratios
  285. if isOpenRouter {
  286. converted, err := convertOpenRouterToRatioData(bytes.NewReader(bodyBytes))
  287. if err != nil {
  288. logger.LogWarn(c.Request.Context(), "OpenRouter parse failed from "+chItem.Name+": "+err.Error())
  289. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  290. return
  291. }
  292. ch <- upstreamResult{Name: uniqueName, Data: converted}
  293. return
  294. }
  295. // type4: models.dev /api.json -> convert provider model pricing to ratios
  296. if isModelsDev {
  297. converted, err := convertModelsDevToRatioData(bytes.NewReader(bodyBytes))
  298. if err != nil {
  299. logger.LogWarn(c.Request.Context(), "models.dev parse failed from "+chItem.Name+": "+err.Error())
  300. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  301. return
  302. }
  303. ch <- upstreamResult{Name: uniqueName, Data: converted}
  304. return
  305. }
  306. // 兼容两种上游接口格式:
  307. // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
  308. // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
  309. var body struct {
  310. Success bool `json:"success"`
  311. Data json.RawMessage `json:"data"`
  312. Message string `json:"message"`
  313. }
  314. if err := common.DecodeJson(bytes.NewReader(bodyBytes), &body); err != nil {
  315. logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
  316. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  317. return
  318. }
  319. if !body.Success {
  320. ch <- upstreamResult{Name: uniqueName, Err: body.Message}
  321. return
  322. }
  323. // 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
  324. // 尝试按 type1 解析
  325. var type1Data map[string]any
  326. if err := common.Unmarshal(body.Data, &type1Data); err == nil {
  327. // 如果包含至少一个 ratioTypes 字段,则认为是 type1
  328. isType1 := false
  329. for _, rt := range pricingSyncFields {
  330. if _, ok := type1Data[rt]; ok {
  331. isType1 = true
  332. break
  333. }
  334. }
  335. if isType1 {
  336. ch <- upstreamResult{Name: uniqueName, Data: type1Data}
  337. return
  338. }
  339. }
  340. // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
  341. var pricingItems []struct {
  342. ModelName string `json:"model_name"`
  343. QuotaType int `json:"quota_type"`
  344. ModelRatio float64 `json:"model_ratio"`
  345. ModelPrice float64 `json:"model_price"`
  346. CompletionRatio float64 `json:"completion_ratio"`
  347. CacheRatio *float64 `json:"cache_ratio"`
  348. CreateCacheRatio *float64 `json:"create_cache_ratio"`
  349. ImageRatio *float64 `json:"image_ratio"`
  350. AudioRatio *float64 `json:"audio_ratio"`
  351. AudioCompletionRatio *float64 `json:"audio_completion_ratio"`
  352. BillingMode string `json:"billing_mode"`
  353. BillingExpr string `json:"billing_expr"`
  354. }
  355. if err := common.Unmarshal(body.Data, &pricingItems); err != nil {
  356. logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
  357. ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
  358. return
  359. }
  360. modelRatioMap := make(map[string]float64)
  361. completionRatioMap := make(map[string]float64)
  362. cacheRatioMap := make(map[string]float64)
  363. createCacheRatioMap := make(map[string]float64)
  364. imageRatioMap := make(map[string]float64)
  365. audioRatioMap := make(map[string]float64)
  366. audioCompletionRatioMap := make(map[string]float64)
  367. modelPriceMap := make(map[string]float64)
  368. billingModeMap := make(map[string]string)
  369. billingExprMap := make(map[string]string)
  370. for _, item := range pricingItems {
  371. if item.ModelName == "" {
  372. continue
  373. }
  374. if item.BillingMode == billing_setting.BillingModeTieredExpr && strings.TrimSpace(item.BillingExpr) != "" {
  375. billingModeMap[item.ModelName] = billing_setting.BillingModeTieredExpr
  376. billingExprMap[item.ModelName] = item.BillingExpr
  377. }
  378. if item.QuotaType == 1 {
  379. modelPriceMap[item.ModelName] = item.ModelPrice
  380. } else {
  381. modelRatioMap[item.ModelName] = item.ModelRatio
  382. // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
  383. completionRatioMap[item.ModelName] = item.CompletionRatio
  384. }
  385. if item.CacheRatio != nil {
  386. cacheRatioMap[item.ModelName] = *item.CacheRatio
  387. }
  388. if item.CreateCacheRatio != nil {
  389. createCacheRatioMap[item.ModelName] = *item.CreateCacheRatio
  390. }
  391. if item.ImageRatio != nil {
  392. imageRatioMap[item.ModelName] = *item.ImageRatio
  393. }
  394. if item.AudioRatio != nil {
  395. audioRatioMap[item.ModelName] = *item.AudioRatio
  396. }
  397. if item.AudioCompletionRatio != nil {
  398. audioCompletionRatioMap[item.ModelName] = *item.AudioCompletionRatio
  399. }
  400. }
  401. converted := make(map[string]any)
  402. if len(modelRatioMap) > 0 {
  403. ratioAny := make(map[string]any, len(modelRatioMap))
  404. for k, v := range modelRatioMap {
  405. ratioAny[k] = v
  406. }
  407. converted["model_ratio"] = ratioAny
  408. }
  409. if len(completionRatioMap) > 0 {
  410. compAny := make(map[string]any, len(completionRatioMap))
  411. for k, v := range completionRatioMap {
  412. compAny[k] = v
  413. }
  414. converted["completion_ratio"] = compAny
  415. }
  416. if len(cacheRatioMap) > 0 {
  417. converted["cache_ratio"] = valueMap(cacheRatioMap)
  418. }
  419. if len(createCacheRatioMap) > 0 {
  420. converted["create_cache_ratio"] = valueMap(createCacheRatioMap)
  421. }
  422. if len(imageRatioMap) > 0 {
  423. converted["image_ratio"] = valueMap(imageRatioMap)
  424. }
  425. if len(audioRatioMap) > 0 {
  426. converted["audio_ratio"] = valueMap(audioRatioMap)
  427. }
  428. if len(audioCompletionRatioMap) > 0 {
  429. converted["audio_completion_ratio"] = valueMap(audioCompletionRatioMap)
  430. }
  431. if len(modelPriceMap) > 0 {
  432. priceAny := make(map[string]any, len(modelPriceMap))
  433. for k, v := range modelPriceMap {
  434. priceAny[k] = v
  435. }
  436. converted["model_price"] = priceAny
  437. }
  438. if len(billingModeMap) > 0 {
  439. converted[billing_setting.BillingModeField] = valueMap(billingModeMap)
  440. }
  441. if len(billingExprMap) > 0 {
  442. converted[billing_setting.BillingExprField] = valueMap(billingExprMap)
  443. }
  444. ch <- upstreamResult{Name: uniqueName, Data: converted}
  445. }(chn)
  446. }
  447. wg.Wait()
  448. close(ch)
  449. localData := getLocalPricingSyncData()
  450. var testResults []dto.TestResult
  451. var successfulChannels []struct {
  452. name string
  453. data map[string]any
  454. }
  455. for r := range ch {
  456. if r.Err != "" {
  457. testResults = append(testResults, dto.TestResult{
  458. Name: r.Name,
  459. Status: "error",
  460. Error: r.Err,
  461. })
  462. } else {
  463. testResults = append(testResults, dto.TestResult{
  464. Name: r.Name,
  465. Status: "success",
  466. })
  467. successfulChannels = append(successfulChannels, struct {
  468. name string
  469. data map[string]any
  470. }{name: r.Name, data: r.Data})
  471. }
  472. }
  473. differences := buildDifferences(localData, successfulChannels)
  474. c.JSON(http.StatusOK, gin.H{
  475. "success": true,
  476. "data": gin.H{
  477. "differences": differences,
  478. "test_results": testResults,
  479. },
  480. })
  481. }
  482. func buildDifferences(localData map[string]any, successfulChannels []struct {
  483. name string
  484. data map[string]any
  485. }) map[string]map[string]dto.DifferenceItem {
  486. differences := make(map[string]map[string]dto.DifferenceItem)
  487. allModels := make(map[string]struct{})
  488. for _, field := range pricingSyncFields {
  489. for modelName := range valueMap(localData[field]) {
  490. allModels[modelName] = struct{}{}
  491. }
  492. }
  493. for _, channel := range successfulChannels {
  494. for _, field := range pricingSyncFields {
  495. for modelName := range valueMap(channel.data[field]) {
  496. allModels[modelName] = struct{}{}
  497. }
  498. }
  499. }
  500. confidenceMap := make(map[string]map[string]bool)
  501. // 预处理阶段:检查pricing接口的可信度
  502. for _, channel := range successfulChannels {
  503. confidenceMap[channel.name] = make(map[string]bool)
  504. modelRatios := valueMap(channel.data["model_ratio"])
  505. completionRatios := valueMap(channel.data["completion_ratio"])
  506. if len(modelRatios) > 0 && len(completionRatios) > 0 {
  507. // 遍历所有模型,检查是否满足不可信条件
  508. for modelName := range allModels {
  509. // 默认为可信
  510. confidenceMap[channel.name][modelName] = true
  511. // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
  512. if modelRatioVal, ok := modelRatios[modelName]; ok {
  513. if completionRatioVal, ok := completionRatios[modelName]; ok {
  514. // 转换为float64进行比较
  515. modelRatioFloat, modelRatioOK := asFloat64(modelRatioVal)
  516. completionRatioFloat, completionRatioOK := asFloat64(completionRatioVal)
  517. if modelRatioOK && completionRatioOK && nearlyEqual(modelRatioFloat, 37.5) && nearlyEqual(completionRatioFloat, 1.0) {
  518. confidenceMap[channel.name][modelName] = false
  519. }
  520. }
  521. }
  522. }
  523. } else {
  524. // 如果不是从pricing接口获取的数据,则全部标记为可信
  525. for modelName := range allModels {
  526. confidenceMap[channel.name][modelName] = true
  527. }
  528. }
  529. }
  530. for modelName := range allModels {
  531. for _, ratioType := range pricingSyncFields {
  532. var localValue interface{} = nil
  533. if val, exists := valueMap(localData[ratioType])[modelName]; exists {
  534. localValue = normalizeSyncValue(ratioType, val)
  535. }
  536. upstreamValues := make(map[string]interface{})
  537. confidenceValues := make(map[string]bool)
  538. hasUpstreamValue := false
  539. hasDifference := false
  540. for _, channel := range successfulChannels {
  541. var upstreamValue interface{} = nil
  542. if val, exists := valueMap(channel.data[ratioType])[modelName]; exists {
  543. upstreamValue = normalizeSyncValue(ratioType, val)
  544. hasUpstreamValue = true
  545. if localValue != nil && !valuesEqual(localValue, upstreamValue) {
  546. hasDifference = true
  547. } else if valuesEqual(localValue, upstreamValue) {
  548. upstreamValue = "same"
  549. }
  550. }
  551. if upstreamValue == nil && localValue == nil {
  552. upstreamValue = "same"
  553. }
  554. if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
  555. hasDifference = true
  556. }
  557. upstreamValues[channel.name] = upstreamValue
  558. confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
  559. }
  560. shouldInclude := false
  561. if localValue != nil {
  562. if hasDifference {
  563. shouldInclude = true
  564. }
  565. } else {
  566. if hasUpstreamValue {
  567. shouldInclude = true
  568. }
  569. }
  570. if shouldInclude {
  571. if differences[modelName] == nil {
  572. differences[modelName] = make(map[string]dto.DifferenceItem)
  573. }
  574. differences[modelName][ratioType] = dto.DifferenceItem{
  575. Current: localValue,
  576. Upstreams: upstreamValues,
  577. Confidence: confidenceValues,
  578. }
  579. }
  580. }
  581. }
  582. channelHasDiff := make(map[string]bool)
  583. for _, ratioMap := range differences {
  584. for _, item := range ratioMap {
  585. for chName, val := range item.Upstreams {
  586. if val != nil && val != "same" {
  587. channelHasDiff[chName] = true
  588. }
  589. }
  590. }
  591. }
  592. for modelName, ratioMap := range differences {
  593. for ratioType, item := range ratioMap {
  594. for chName := range item.Upstreams {
  595. if !channelHasDiff[chName] {
  596. delete(item.Upstreams, chName)
  597. delete(item.Confidence, chName)
  598. }
  599. }
  600. allSame := true
  601. for _, v := range item.Upstreams {
  602. if v != "same" {
  603. allSame = false
  604. break
  605. }
  606. }
  607. if len(item.Upstreams) == 0 || allSame {
  608. delete(ratioMap, ratioType)
  609. } else {
  610. differences[modelName][ratioType] = item
  611. }
  612. }
  613. if len(ratioMap) == 0 {
  614. delete(differences, modelName)
  615. }
  616. }
  617. return differences
  618. }
  619. func roundRatioValue(value float64) float64 {
  620. return math.Round(value*1e6) / 1e6
  621. }
  622. func isModelsDevAPIEndpoint(rawURL string) bool {
  623. parsedURL, err := url.Parse(rawURL)
  624. if err != nil {
  625. return false
  626. }
  627. if strings.ToLower(parsedURL.Hostname()) != modelsDevHost {
  628. return false
  629. }
  630. path := strings.TrimSuffix(parsedURL.Path, "/")
  631. if path == "" {
  632. path = "/"
  633. }
  634. return path == modelsDevPath
  635. }
  636. // convertOpenRouterToRatioData parses OpenRouter's /v1/models response and converts
  637. // per-token USD pricing into the local ratio format.
  638. // model_ratio = prompt_price_per_token * 1_000_000 * (USD / 1000)
  639. //
  640. // since 1 ratio unit = $0.002/1K tokens and USD=500, the factor is 500_000
  641. //
  642. // completion_ratio = completion_price / prompt_price (output/input multiplier)
  643. func convertOpenRouterToRatioData(reader io.Reader) (map[string]any, error) {
  644. var orResp struct {
  645. Data []struct {
  646. ID string `json:"id"`
  647. Pricing struct {
  648. Prompt string `json:"prompt"`
  649. Completion string `json:"completion"`
  650. InputCacheRead string `json:"input_cache_read"`
  651. } `json:"pricing"`
  652. } `json:"data"`
  653. }
  654. if err := common.DecodeJson(reader, &orResp); err != nil {
  655. return nil, fmt.Errorf("failed to decode OpenRouter response: %w", err)
  656. }
  657. modelRatioMap := make(map[string]any)
  658. completionRatioMap := make(map[string]any)
  659. cacheRatioMap := make(map[string]any)
  660. for _, m := range orResp.Data {
  661. promptPrice, promptErr := strconv.ParseFloat(m.Pricing.Prompt, 64)
  662. completionPrice, compErr := strconv.ParseFloat(m.Pricing.Completion, 64)
  663. if promptErr != nil && compErr != nil {
  664. // Both unparseable — skip this model
  665. continue
  666. }
  667. // Treat parse errors as 0
  668. if promptErr != nil {
  669. promptPrice = 0
  670. }
  671. if compErr != nil {
  672. completionPrice = 0
  673. }
  674. // Negative values are sentinel values (e.g., -1 for dynamic/variable pricing) — skip
  675. if promptPrice < 0 || completionPrice < 0 {
  676. continue
  677. }
  678. if promptPrice == 0 && completionPrice == 0 {
  679. // Free model
  680. modelRatioMap[m.ID] = 0.0
  681. continue
  682. }
  683. if promptPrice <= 0 {
  684. // No meaningful prompt baseline, cannot derive ratios safely.
  685. continue
  686. }
  687. // Normal case: promptPrice > 0
  688. ratio := promptPrice * 1000 * ratio_setting.USD
  689. ratio = roundRatioValue(ratio)
  690. modelRatioMap[m.ID] = ratio
  691. compRatio := completionPrice / promptPrice
  692. compRatio = roundRatioValue(compRatio)
  693. completionRatioMap[m.ID] = compRatio
  694. // Convert input_cache_read to cache_ratio (= cache_read_price / prompt_price)
  695. if m.Pricing.InputCacheRead != "" {
  696. if cachePrice, err := strconv.ParseFloat(m.Pricing.InputCacheRead, 64); err == nil && cachePrice >= 0 {
  697. cacheRatio := cachePrice / promptPrice
  698. cacheRatio = roundRatioValue(cacheRatio)
  699. cacheRatioMap[m.ID] = cacheRatio
  700. }
  701. }
  702. }
  703. converted := make(map[string]any)
  704. if len(modelRatioMap) > 0 {
  705. converted["model_ratio"] = modelRatioMap
  706. }
  707. if len(completionRatioMap) > 0 {
  708. converted["completion_ratio"] = completionRatioMap
  709. }
  710. if len(cacheRatioMap) > 0 {
  711. converted["cache_ratio"] = cacheRatioMap
  712. }
  713. return converted, nil
  714. }
  715. type modelsDevProvider struct {
  716. Models map[string]modelsDevModel `json:"models"`
  717. }
  718. type modelsDevModel struct {
  719. Cost modelsDevCost `json:"cost"`
  720. }
  721. type modelsDevCost struct {
  722. Input *float64 `json:"input"`
  723. Output *float64 `json:"output"`
  724. CacheRead *float64 `json:"cache_read"`
  725. }
  726. type modelsDevCandidate struct {
  727. Provider string
  728. Input float64
  729. Output *float64
  730. CacheRead *float64
  731. }
  732. func cloneFloatPtr(v *float64) *float64 {
  733. if v == nil {
  734. return nil
  735. }
  736. out := *v
  737. return &out
  738. }
  739. func isValidNonNegativeCost(v float64) bool {
  740. if math.IsNaN(v) || math.IsInf(v, 0) {
  741. return false
  742. }
  743. return v >= 0
  744. }
  745. func buildModelsDevCandidate(provider string, cost modelsDevCost) (modelsDevCandidate, bool) {
  746. if cost.Input == nil {
  747. return modelsDevCandidate{}, false
  748. }
  749. input := *cost.Input
  750. if !isValidNonNegativeCost(input) {
  751. return modelsDevCandidate{}, false
  752. }
  753. var output *float64
  754. if cost.Output != nil {
  755. if !isValidNonNegativeCost(*cost.Output) {
  756. return modelsDevCandidate{}, false
  757. }
  758. output = cloneFloatPtr(cost.Output)
  759. }
  760. // input=0/output>0 cannot be transformed into local ratio.
  761. if input == 0 && output != nil && *output > 0 {
  762. return modelsDevCandidate{}, false
  763. }
  764. var cacheRead *float64
  765. if cost.CacheRead != nil && isValidNonNegativeCost(*cost.CacheRead) {
  766. cacheRead = cloneFloatPtr(cost.CacheRead)
  767. }
  768. return modelsDevCandidate{
  769. Provider: provider,
  770. Input: input,
  771. Output: output,
  772. CacheRead: cacheRead,
  773. }, true
  774. }
  775. func shouldReplaceModelsDevCandidate(current, next modelsDevCandidate) bool {
  776. currentNonZero := current.Input > 0
  777. nextNonZero := next.Input > 0
  778. if currentNonZero != nextNonZero {
  779. // Prefer non-zero pricing data; this matches "cheapest non-zero" conflict policy.
  780. return nextNonZero
  781. }
  782. if nextNonZero && !nearlyEqual(next.Input, current.Input) {
  783. return next.Input < current.Input
  784. }
  785. // Stable tie-breaker for deterministic result.
  786. return next.Provider < current.Provider
  787. }
  788. // convertModelsDevToRatioData parses models.dev /api.json and converts
  789. // provider pricing metadata into local ratio format.
  790. // models.dev costs are USD per 1M tokens:
  791. //
  792. // model_ratio = input_cost_per_1M / 2
  793. // completion_ratio = output_cost / input_cost
  794. // cache_ratio = cache_read_cost / input_cost
  795. //
  796. // Duplicate model keys across providers are resolved by selecting the
  797. // cheapest non-zero input cost. If only zero-priced candidates exist,
  798. // a zero ratio is kept.
  799. func convertModelsDevToRatioData(reader io.Reader) (map[string]any, error) {
  800. var upstreamData map[string]modelsDevProvider
  801. if err := common.DecodeJson(reader, &upstreamData); err != nil {
  802. return nil, fmt.Errorf("failed to decode models.dev response: %w", err)
  803. }
  804. if len(upstreamData) == 0 {
  805. return nil, fmt.Errorf("empty models.dev response")
  806. }
  807. providers := make([]string, 0, len(upstreamData))
  808. for provider := range upstreamData {
  809. providers = append(providers, provider)
  810. }
  811. sort.Strings(providers)
  812. selectedCandidates := make(map[string]modelsDevCandidate)
  813. for _, provider := range providers {
  814. providerData := upstreamData[provider]
  815. if len(providerData.Models) == 0 {
  816. continue
  817. }
  818. modelNames := make([]string, 0, len(providerData.Models))
  819. for modelName := range providerData.Models {
  820. modelNames = append(modelNames, modelName)
  821. }
  822. sort.Strings(modelNames)
  823. for _, modelName := range modelNames {
  824. candidate, ok := buildModelsDevCandidate(provider, providerData.Models[modelName].Cost)
  825. if !ok {
  826. continue
  827. }
  828. current, exists := selectedCandidates[modelName]
  829. if !exists || shouldReplaceModelsDevCandidate(current, candidate) {
  830. selectedCandidates[modelName] = candidate
  831. }
  832. }
  833. }
  834. if len(selectedCandidates) == 0 {
  835. return nil, fmt.Errorf("no valid models.dev pricing entries found")
  836. }
  837. modelRatioMap := make(map[string]any)
  838. completionRatioMap := make(map[string]any)
  839. cacheRatioMap := make(map[string]any)
  840. for modelName, candidate := range selectedCandidates {
  841. if candidate.Input == 0 {
  842. modelRatioMap[modelName] = 0.0
  843. continue
  844. }
  845. modelRatio := candidate.Input * float64(ratio_setting.USD) / modelsDevInputCostRatioBase
  846. modelRatioMap[modelName] = roundRatioValue(modelRatio)
  847. if candidate.Output != nil {
  848. completionRatio := *candidate.Output / candidate.Input
  849. completionRatioMap[modelName] = roundRatioValue(completionRatio)
  850. }
  851. if candidate.CacheRead != nil {
  852. cacheRatio := *candidate.CacheRead / candidate.Input
  853. cacheRatioMap[modelName] = roundRatioValue(cacheRatio)
  854. }
  855. }
  856. converted := make(map[string]any)
  857. if len(modelRatioMap) > 0 {
  858. converted["model_ratio"] = modelRatioMap
  859. }
  860. if len(completionRatioMap) > 0 {
  861. converted["completion_ratio"] = completionRatioMap
  862. }
  863. if len(cacheRatioMap) > 0 {
  864. converted["cache_ratio"] = cacheRatioMap
  865. }
  866. return converted, nil
  867. }
  868. func GetSyncableChannels(c *gin.Context) {
  869. channels, err := model.GetAllChannels(0, 0, true, false)
  870. if err != nil {
  871. c.JSON(http.StatusOK, gin.H{
  872. "success": false,
  873. "message": err.Error(),
  874. })
  875. return
  876. }
  877. var syncableChannels []dto.SyncableChannel
  878. for _, channel := range channels {
  879. if channel.GetBaseURL() != "" {
  880. syncableChannels = append(syncableChannels, dto.SyncableChannel{
  881. ID: channel.Id,
  882. Name: channel.Name,
  883. BaseURL: channel.GetBaseURL(),
  884. Status: channel.Status,
  885. Type: channel.Type,
  886. })
  887. }
  888. }
  889. syncableChannels = append(syncableChannels, dto.SyncableChannel{
  890. ID: officialRatioPresetID,
  891. Name: officialRatioPresetName,
  892. BaseURL: officialRatioPresetBaseURL,
  893. Status: 1,
  894. })
  895. syncableChannels = append(syncableChannels, dto.SyncableChannel{
  896. ID: modelsDevPresetID,
  897. Name: modelsDevPresetName,
  898. BaseURL: modelsDevPresetBaseURL,
  899. Status: 1,
  900. })
  901. c.JSON(http.StatusOK, gin.H{
  902. "success": true,
  903. "message": "",
  904. "data": syncableChannels,
  905. })
  906. }