ratio_sync.go 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914
  1. package controller
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "math"
  9. "net"
  10. "net/http"
  11. "net/url"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "time"
  17. "github.com/QuantumNous/new-api/common"
  18. "github.com/QuantumNous/new-api/logger"
  19. "github.com/QuantumNous/new-api/dto"
  20. "github.com/QuantumNous/new-api/model"
  21. "github.com/QuantumNous/new-api/setting/ratio_setting"
  22. "github.com/gin-gonic/gin"
  23. )
  24. const (
  25. defaultTimeoutSeconds = 10
  26. defaultEndpoint = "/api/ratio_config"
  27. maxConcurrentFetches = 8
  28. maxRatioConfigBytes = 10 << 20 // 10MB
  29. floatEpsilon = 1e-9
  30. officialRatioPresetID = -100
  31. officialRatioPresetName = "官方倍率预设"
  32. officialRatioPresetBaseURL = "https://basellm.github.io"
  33. modelsDevPresetID = -101
  34. modelsDevPresetName = "models.dev 价格预设"
  35. modelsDevPresetBaseURL = "https://models.dev"
  36. modelsDevHost = "models.dev"
  37. modelsDevPath = "/api.json"
  38. modelsDevInputCostRatioBase = 1000.0
  39. )
  40. func nearlyEqual(a, b float64) bool {
  41. if a > b {
  42. return a-b < floatEpsilon
  43. }
  44. return b-a < floatEpsilon
  45. }
  46. func valuesEqual(a, b interface{}) bool {
  47. af, aok := a.(float64)
  48. bf, bok := b.(float64)
  49. if aok && bok {
  50. return nearlyEqual(af, bf)
  51. }
  52. return a == b
  53. }
  54. var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
  55. type upstreamResult struct {
  56. Name string `json:"name"`
  57. Data map[string]any `json:"data,omitempty"`
  58. Err string `json:"err,omitempty"`
  59. }
  60. func FetchUpstreamRatios(c *gin.Context) {
  61. var req dto.UpstreamRequest
  62. if err := c.ShouldBindJSON(&req); err != nil {
  63. common.SysError("failed to bind upstream request: " + err.Error())
  64. c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "请求参数格式错误"})
  65. return
  66. }
  67. if req.Timeout <= 0 {
  68. req.Timeout = defaultTimeoutSeconds
  69. }
  70. var upstreams []dto.UpstreamDTO
  71. if len(req.Upstreams) > 0 {
  72. for _, u := range req.Upstreams {
  73. if strings.HasPrefix(u.BaseURL, "http") {
  74. if u.Endpoint == "" {
  75. u.Endpoint = defaultEndpoint
  76. }
  77. u.BaseURL = strings.TrimRight(u.BaseURL, "/")
  78. upstreams = append(upstreams, u)
  79. }
  80. }
  81. } else if len(req.ChannelIDs) > 0 {
  82. intIds := make([]int, 0, len(req.ChannelIDs))
  83. for _, id64 := range req.ChannelIDs {
  84. intIds = append(intIds, int(id64))
  85. }
  86. dbChannels, err := model.GetChannelsByIds(intIds)
  87. if err != nil {
  88. logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
  89. c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
  90. return
  91. }
  92. for _, ch := range dbChannels {
  93. if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
  94. upstreams = append(upstreams, dto.UpstreamDTO{
  95. ID: ch.Id,
  96. Name: ch.Name,
  97. BaseURL: strings.TrimRight(base, "/"),
  98. Endpoint: "",
  99. })
  100. }
  101. }
  102. }
  103. if len(upstreams) == 0 {
  104. c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
  105. return
  106. }
  107. var wg sync.WaitGroup
  108. ch := make(chan upstreamResult, len(upstreams))
  109. sem := make(chan struct{}, maxConcurrentFetches)
  110. dialer := &net.Dialer{Timeout: 10 * time.Second}
  111. transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
  112. if common.TLSInsecureSkipVerify {
  113. transport.TLSClientConfig = common.InsecureTLSConfig
  114. }
  115. transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
  116. host, _, err := net.SplitHostPort(addr)
  117. if err != nil {
  118. host = addr
  119. }
  120. // 对 github.io 优先尝试 IPv4,失败则回退 IPv6
  121. if strings.HasSuffix(host, "github.io") {
  122. if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
  123. return conn, nil
  124. }
  125. return dialer.DialContext(ctx, "tcp6", addr)
  126. }
  127. return dialer.DialContext(ctx, network, addr)
  128. }
  129. client := &http.Client{Transport: transport}
  130. for _, chn := range upstreams {
  131. wg.Add(1)
  132. go func(chItem dto.UpstreamDTO) {
  133. defer wg.Done()
  134. sem <- struct{}{}
  135. defer func() { <-sem }()
  136. isOpenRouter := chItem.Endpoint == "openrouter"
  137. endpoint := chItem.Endpoint
  138. var fullURL string
  139. if isOpenRouter {
  140. fullURL = chItem.BaseURL + "/v1/models"
  141. } else if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
  142. fullURL = endpoint
  143. } else {
  144. if endpoint == "" {
  145. endpoint = defaultEndpoint
  146. } else if !strings.HasPrefix(endpoint, "/") {
  147. endpoint = "/" + endpoint
  148. }
  149. fullURL = chItem.BaseURL + endpoint
  150. }
  151. isModelsDev := isModelsDevAPIEndpoint(fullURL)
  152. uniqueName := chItem.Name
  153. if chItem.ID != 0 {
  154. uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
  155. }
  156. ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
  157. defer cancel()
  158. httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
  159. if err != nil {
  160. logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
  161. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  162. return
  163. }
  164. // OpenRouter requires Bearer token auth
  165. if isOpenRouter && chItem.ID != 0 {
  166. dbCh, err := model.GetChannelById(chItem.ID, true)
  167. if err != nil {
  168. ch <- upstreamResult{Name: uniqueName, Err: "failed to get channel key: " + err.Error()}
  169. return
  170. }
  171. key, _, apiErr := dbCh.GetNextEnabledKey()
  172. if apiErr != nil {
  173. ch <- upstreamResult{Name: uniqueName, Err: "failed to get enabled channel key: " + apiErr.Error()}
  174. return
  175. }
  176. if strings.TrimSpace(key) == "" {
  177. ch <- upstreamResult{Name: uniqueName, Err: "no API key configured for this channel"}
  178. return
  179. }
  180. httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(key))
  181. } else if isOpenRouter {
  182. ch <- upstreamResult{Name: uniqueName, Err: "OpenRouter requires a valid channel with API key"}
  183. return
  184. }
  185. // 简单重试:最多 3 次,指数退避
  186. var resp *http.Response
  187. var lastErr error
  188. for attempt := 0; attempt < 3; attempt++ {
  189. resp, lastErr = client.Do(httpReq)
  190. if lastErr == nil {
  191. break
  192. }
  193. time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
  194. }
  195. if lastErr != nil {
  196. logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
  197. ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
  198. return
  199. }
  200. defer resp.Body.Close()
  201. if resp.StatusCode != http.StatusOK {
  202. logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
  203. ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
  204. return
  205. }
  206. // Content-Type 和响应体大小校验
  207. if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
  208. logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
  209. }
  210. limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
  211. bodyBytes, err := io.ReadAll(limited)
  212. if err != nil {
  213. logger.LogWarn(c.Request.Context(), "read response failed from "+chItem.Name+": "+err.Error())
  214. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  215. return
  216. }
  217. // type3: OpenRouter /v1/models -> convert per-token pricing to ratios
  218. if isOpenRouter {
  219. converted, err := convertOpenRouterToRatioData(bytes.NewReader(bodyBytes))
  220. if err != nil {
  221. logger.LogWarn(c.Request.Context(), "OpenRouter parse failed from "+chItem.Name+": "+err.Error())
  222. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  223. return
  224. }
  225. ch <- upstreamResult{Name: uniqueName, Data: converted}
  226. return
  227. }
  228. // type4: models.dev /api.json -> convert provider model pricing to ratios
  229. if isModelsDev {
  230. converted, err := convertModelsDevToRatioData(bytes.NewReader(bodyBytes))
  231. if err != nil {
  232. logger.LogWarn(c.Request.Context(), "models.dev parse failed from "+chItem.Name+": "+err.Error())
  233. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  234. return
  235. }
  236. ch <- upstreamResult{Name: uniqueName, Data: converted}
  237. return
  238. }
  239. // 兼容两种上游接口格式:
  240. // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
  241. // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
  242. var body struct {
  243. Success bool `json:"success"`
  244. Data json.RawMessage `json:"data"`
  245. Message string `json:"message"`
  246. }
  247. if err := common.DecodeJson(bytes.NewReader(bodyBytes), &body); err != nil {
  248. logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
  249. ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
  250. return
  251. }
  252. if !body.Success {
  253. ch <- upstreamResult{Name: uniqueName, Err: body.Message}
  254. return
  255. }
  256. // 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
  257. // 尝试按 type1 解析
  258. var type1Data map[string]any
  259. if err := common.Unmarshal(body.Data, &type1Data); err == nil {
  260. // 如果包含至少一个 ratioTypes 字段,则认为是 type1
  261. isType1 := false
  262. for _, rt := range ratioTypes {
  263. if _, ok := type1Data[rt]; ok {
  264. isType1 = true
  265. break
  266. }
  267. }
  268. if isType1 {
  269. ch <- upstreamResult{Name: uniqueName, Data: type1Data}
  270. return
  271. }
  272. }
  273. // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
  274. var pricingItems []struct {
  275. ModelName string `json:"model_name"`
  276. QuotaType int `json:"quota_type"`
  277. ModelRatio float64 `json:"model_ratio"`
  278. ModelPrice float64 `json:"model_price"`
  279. CompletionRatio float64 `json:"completion_ratio"`
  280. }
  281. if err := common.Unmarshal(body.Data, &pricingItems); err != nil {
  282. logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
  283. ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
  284. return
  285. }
  286. modelRatioMap := make(map[string]float64)
  287. completionRatioMap := make(map[string]float64)
  288. modelPriceMap := make(map[string]float64)
  289. for _, item := range pricingItems {
  290. if item.QuotaType == 1 {
  291. modelPriceMap[item.ModelName] = item.ModelPrice
  292. } else {
  293. modelRatioMap[item.ModelName] = item.ModelRatio
  294. // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
  295. completionRatioMap[item.ModelName] = item.CompletionRatio
  296. }
  297. }
  298. converted := make(map[string]any)
  299. if len(modelRatioMap) > 0 {
  300. ratioAny := make(map[string]any, len(modelRatioMap))
  301. for k, v := range modelRatioMap {
  302. ratioAny[k] = v
  303. }
  304. converted["model_ratio"] = ratioAny
  305. }
  306. if len(completionRatioMap) > 0 {
  307. compAny := make(map[string]any, len(completionRatioMap))
  308. for k, v := range completionRatioMap {
  309. compAny[k] = v
  310. }
  311. converted["completion_ratio"] = compAny
  312. }
  313. if len(modelPriceMap) > 0 {
  314. priceAny := make(map[string]any, len(modelPriceMap))
  315. for k, v := range modelPriceMap {
  316. priceAny[k] = v
  317. }
  318. converted["model_price"] = priceAny
  319. }
  320. ch <- upstreamResult{Name: uniqueName, Data: converted}
  321. }(chn)
  322. }
  323. wg.Wait()
  324. close(ch)
  325. localData := ratio_setting.GetExposedData()
  326. var testResults []dto.TestResult
  327. var successfulChannels []struct {
  328. name string
  329. data map[string]any
  330. }
  331. for r := range ch {
  332. if r.Err != "" {
  333. testResults = append(testResults, dto.TestResult{
  334. Name: r.Name,
  335. Status: "error",
  336. Error: r.Err,
  337. })
  338. } else {
  339. testResults = append(testResults, dto.TestResult{
  340. Name: r.Name,
  341. Status: "success",
  342. })
  343. successfulChannels = append(successfulChannels, struct {
  344. name string
  345. data map[string]any
  346. }{name: r.Name, data: r.Data})
  347. }
  348. }
  349. differences := buildDifferences(localData, successfulChannels)
  350. c.JSON(http.StatusOK, gin.H{
  351. "success": true,
  352. "data": gin.H{
  353. "differences": differences,
  354. "test_results": testResults,
  355. },
  356. })
  357. }
  358. func buildDifferences(localData map[string]any, successfulChannels []struct {
  359. name string
  360. data map[string]any
  361. }) map[string]map[string]dto.DifferenceItem {
  362. differences := make(map[string]map[string]dto.DifferenceItem)
  363. allModels := make(map[string]struct{})
  364. for _, ratioType := range ratioTypes {
  365. if localRatioAny, ok := localData[ratioType]; ok {
  366. if localRatio, ok := localRatioAny.(map[string]float64); ok {
  367. for modelName := range localRatio {
  368. allModels[modelName] = struct{}{}
  369. }
  370. }
  371. }
  372. }
  373. for _, channel := range successfulChannels {
  374. for _, ratioType := range ratioTypes {
  375. if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
  376. for modelName := range upstreamRatio {
  377. allModels[modelName] = struct{}{}
  378. }
  379. }
  380. }
  381. }
  382. confidenceMap := make(map[string]map[string]bool)
  383. // 预处理阶段:检查pricing接口的可信度
  384. for _, channel := range successfulChannels {
  385. confidenceMap[channel.name] = make(map[string]bool)
  386. modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
  387. completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
  388. if hasModelRatio && hasCompletionRatio {
  389. // 遍历所有模型,检查是否满足不可信条件
  390. for modelName := range allModels {
  391. // 默认为可信
  392. confidenceMap[channel.name][modelName] = true
  393. // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
  394. if modelRatioVal, ok := modelRatios[modelName]; ok {
  395. if completionRatioVal, ok := completionRatios[modelName]; ok {
  396. // 转换为float64进行比较
  397. if modelRatioFloat, ok := modelRatioVal.(float64); ok {
  398. if completionRatioFloat, ok := completionRatioVal.(float64); ok {
  399. if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
  400. confidenceMap[channel.name][modelName] = false
  401. }
  402. }
  403. }
  404. }
  405. }
  406. }
  407. } else {
  408. // 如果不是从pricing接口获取的数据,则全部标记为可信
  409. for modelName := range allModels {
  410. confidenceMap[channel.name][modelName] = true
  411. }
  412. }
  413. }
  414. for modelName := range allModels {
  415. for _, ratioType := range ratioTypes {
  416. var localValue interface{} = nil
  417. if localRatioAny, ok := localData[ratioType]; ok {
  418. if localRatio, ok := localRatioAny.(map[string]float64); ok {
  419. if val, exists := localRatio[modelName]; exists {
  420. localValue = val
  421. }
  422. }
  423. }
  424. upstreamValues := make(map[string]interface{})
  425. confidenceValues := make(map[string]bool)
  426. hasUpstreamValue := false
  427. hasDifference := false
  428. for _, channel := range successfulChannels {
  429. var upstreamValue interface{} = nil
  430. if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
  431. if val, exists := upstreamRatio[modelName]; exists {
  432. upstreamValue = val
  433. hasUpstreamValue = true
  434. if localValue != nil && !valuesEqual(localValue, val) {
  435. hasDifference = true
  436. } else if valuesEqual(localValue, val) {
  437. upstreamValue = "same"
  438. }
  439. }
  440. }
  441. if upstreamValue == nil && localValue == nil {
  442. upstreamValue = "same"
  443. }
  444. if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
  445. hasDifference = true
  446. }
  447. upstreamValues[channel.name] = upstreamValue
  448. confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
  449. }
  450. shouldInclude := false
  451. if localValue != nil {
  452. if hasDifference {
  453. shouldInclude = true
  454. }
  455. } else {
  456. if hasUpstreamValue {
  457. shouldInclude = true
  458. }
  459. }
  460. if shouldInclude {
  461. if differences[modelName] == nil {
  462. differences[modelName] = make(map[string]dto.DifferenceItem)
  463. }
  464. differences[modelName][ratioType] = dto.DifferenceItem{
  465. Current: localValue,
  466. Upstreams: upstreamValues,
  467. Confidence: confidenceValues,
  468. }
  469. }
  470. }
  471. }
  472. channelHasDiff := make(map[string]bool)
  473. for _, ratioMap := range differences {
  474. for _, item := range ratioMap {
  475. for chName, val := range item.Upstreams {
  476. if val != nil && val != "same" {
  477. channelHasDiff[chName] = true
  478. }
  479. }
  480. }
  481. }
  482. for modelName, ratioMap := range differences {
  483. for ratioType, item := range ratioMap {
  484. for chName := range item.Upstreams {
  485. if !channelHasDiff[chName] {
  486. delete(item.Upstreams, chName)
  487. delete(item.Confidence, chName)
  488. }
  489. }
  490. allSame := true
  491. for _, v := range item.Upstreams {
  492. if v != "same" {
  493. allSame = false
  494. break
  495. }
  496. }
  497. if len(item.Upstreams) == 0 || allSame {
  498. delete(ratioMap, ratioType)
  499. } else {
  500. differences[modelName][ratioType] = item
  501. }
  502. }
  503. if len(ratioMap) == 0 {
  504. delete(differences, modelName)
  505. }
  506. }
  507. return differences
  508. }
  509. func roundRatioValue(value float64) float64 {
  510. return math.Round(value*1e6) / 1e6
  511. }
  512. func isModelsDevAPIEndpoint(rawURL string) bool {
  513. parsedURL, err := url.Parse(rawURL)
  514. if err != nil {
  515. return false
  516. }
  517. if strings.ToLower(parsedURL.Hostname()) != modelsDevHost {
  518. return false
  519. }
  520. path := strings.TrimSuffix(parsedURL.Path, "/")
  521. if path == "" {
  522. path = "/"
  523. }
  524. return path == modelsDevPath
  525. }
  526. // convertOpenRouterToRatioData parses OpenRouter's /v1/models response and converts
  527. // per-token USD pricing into the local ratio format.
  528. // model_ratio = prompt_price_per_token * 1_000_000 * (USD / 1000)
  529. //
  530. // since 1 ratio unit = $0.002/1K tokens and USD=500, the factor is 500_000
  531. //
  532. // completion_ratio = completion_price / prompt_price (output/input multiplier)
  533. func convertOpenRouterToRatioData(reader io.Reader) (map[string]any, error) {
  534. var orResp struct {
  535. Data []struct {
  536. ID string `json:"id"`
  537. Pricing struct {
  538. Prompt string `json:"prompt"`
  539. Completion string `json:"completion"`
  540. InputCacheRead string `json:"input_cache_read"`
  541. } `json:"pricing"`
  542. } `json:"data"`
  543. }
  544. if err := common.DecodeJson(reader, &orResp); err != nil {
  545. return nil, fmt.Errorf("failed to decode OpenRouter response: %w", err)
  546. }
  547. modelRatioMap := make(map[string]any)
  548. completionRatioMap := make(map[string]any)
  549. cacheRatioMap := make(map[string]any)
  550. for _, m := range orResp.Data {
  551. promptPrice, promptErr := strconv.ParseFloat(m.Pricing.Prompt, 64)
  552. completionPrice, compErr := strconv.ParseFloat(m.Pricing.Completion, 64)
  553. if promptErr != nil && compErr != nil {
  554. // Both unparseable — skip this model
  555. continue
  556. }
  557. // Treat parse errors as 0
  558. if promptErr != nil {
  559. promptPrice = 0
  560. }
  561. if compErr != nil {
  562. completionPrice = 0
  563. }
  564. // Negative values are sentinel values (e.g., -1 for dynamic/variable pricing) — skip
  565. if promptPrice < 0 || completionPrice < 0 {
  566. continue
  567. }
  568. if promptPrice == 0 && completionPrice == 0 {
  569. // Free model
  570. modelRatioMap[m.ID] = 0.0
  571. continue
  572. }
  573. if promptPrice <= 0 {
  574. // No meaningful prompt baseline, cannot derive ratios safely.
  575. continue
  576. }
  577. // Normal case: promptPrice > 0
  578. ratio := promptPrice * 1000 * ratio_setting.USD
  579. ratio = roundRatioValue(ratio)
  580. modelRatioMap[m.ID] = ratio
  581. compRatio := completionPrice / promptPrice
  582. compRatio = roundRatioValue(compRatio)
  583. completionRatioMap[m.ID] = compRatio
  584. // Convert input_cache_read to cache_ratio (= cache_read_price / prompt_price)
  585. if m.Pricing.InputCacheRead != "" {
  586. if cachePrice, err := strconv.ParseFloat(m.Pricing.InputCacheRead, 64); err == nil && cachePrice >= 0 {
  587. cacheRatio := cachePrice / promptPrice
  588. cacheRatio = roundRatioValue(cacheRatio)
  589. cacheRatioMap[m.ID] = cacheRatio
  590. }
  591. }
  592. }
  593. converted := make(map[string]any)
  594. if len(modelRatioMap) > 0 {
  595. converted["model_ratio"] = modelRatioMap
  596. }
  597. if len(completionRatioMap) > 0 {
  598. converted["completion_ratio"] = completionRatioMap
  599. }
  600. if len(cacheRatioMap) > 0 {
  601. converted["cache_ratio"] = cacheRatioMap
  602. }
  603. return converted, nil
  604. }
  605. type modelsDevProvider struct {
  606. Models map[string]modelsDevModel `json:"models"`
  607. }
  608. type modelsDevModel struct {
  609. Cost modelsDevCost `json:"cost"`
  610. }
  611. type modelsDevCost struct {
  612. Input *float64 `json:"input"`
  613. Output *float64 `json:"output"`
  614. CacheRead *float64 `json:"cache_read"`
  615. }
  616. type modelsDevCandidate struct {
  617. Provider string
  618. Input float64
  619. Output *float64
  620. CacheRead *float64
  621. }
  622. func cloneFloatPtr(v *float64) *float64 {
  623. if v == nil {
  624. return nil
  625. }
  626. out := *v
  627. return &out
  628. }
  629. func isValidNonNegativeCost(v float64) bool {
  630. if math.IsNaN(v) || math.IsInf(v, 0) {
  631. return false
  632. }
  633. return v >= 0
  634. }
  635. func buildModelsDevCandidate(provider string, cost modelsDevCost) (modelsDevCandidate, bool) {
  636. if cost.Input == nil {
  637. return modelsDevCandidate{}, false
  638. }
  639. input := *cost.Input
  640. if !isValidNonNegativeCost(input) {
  641. return modelsDevCandidate{}, false
  642. }
  643. var output *float64
  644. if cost.Output != nil {
  645. if !isValidNonNegativeCost(*cost.Output) {
  646. return modelsDevCandidate{}, false
  647. }
  648. output = cloneFloatPtr(cost.Output)
  649. }
  650. // input=0/output>0 cannot be transformed into local ratio.
  651. if input == 0 && output != nil && *output > 0 {
  652. return modelsDevCandidate{}, false
  653. }
  654. var cacheRead *float64
  655. if cost.CacheRead != nil && isValidNonNegativeCost(*cost.CacheRead) {
  656. cacheRead = cloneFloatPtr(cost.CacheRead)
  657. }
  658. return modelsDevCandidate{
  659. Provider: provider,
  660. Input: input,
  661. Output: output,
  662. CacheRead: cacheRead,
  663. }, true
  664. }
  665. func shouldReplaceModelsDevCandidate(current, next modelsDevCandidate) bool {
  666. currentNonZero := current.Input > 0
  667. nextNonZero := next.Input > 0
  668. if currentNonZero != nextNonZero {
  669. // Prefer non-zero pricing data; this matches "cheapest non-zero" conflict policy.
  670. return nextNonZero
  671. }
  672. if nextNonZero && !nearlyEqual(next.Input, current.Input) {
  673. return next.Input < current.Input
  674. }
  675. // Stable tie-breaker for deterministic result.
  676. return next.Provider < current.Provider
  677. }
  678. // convertModelsDevToRatioData parses models.dev /api.json and converts
  679. // provider pricing metadata into local ratio format.
  680. // models.dev costs are USD per 1M tokens:
  681. //
  682. // model_ratio = input_cost_per_1M / 2
  683. // completion_ratio = output_cost / input_cost
  684. // cache_ratio = cache_read_cost / input_cost
  685. //
  686. // Duplicate model keys across providers are resolved by selecting the
  687. // cheapest non-zero input cost. If only zero-priced candidates exist,
  688. // a zero ratio is kept.
  689. func convertModelsDevToRatioData(reader io.Reader) (map[string]any, error) {
  690. var upstreamData map[string]modelsDevProvider
  691. if err := common.DecodeJson(reader, &upstreamData); err != nil {
  692. return nil, fmt.Errorf("failed to decode models.dev response: %w", err)
  693. }
  694. if len(upstreamData) == 0 {
  695. return nil, fmt.Errorf("empty models.dev response")
  696. }
  697. providers := make([]string, 0, len(upstreamData))
  698. for provider := range upstreamData {
  699. providers = append(providers, provider)
  700. }
  701. sort.Strings(providers)
  702. selectedCandidates := make(map[string]modelsDevCandidate)
  703. for _, provider := range providers {
  704. providerData := upstreamData[provider]
  705. if len(providerData.Models) == 0 {
  706. continue
  707. }
  708. modelNames := make([]string, 0, len(providerData.Models))
  709. for modelName := range providerData.Models {
  710. modelNames = append(modelNames, modelName)
  711. }
  712. sort.Strings(modelNames)
  713. for _, modelName := range modelNames {
  714. candidate, ok := buildModelsDevCandidate(provider, providerData.Models[modelName].Cost)
  715. if !ok {
  716. continue
  717. }
  718. current, exists := selectedCandidates[modelName]
  719. if !exists || shouldReplaceModelsDevCandidate(current, candidate) {
  720. selectedCandidates[modelName] = candidate
  721. }
  722. }
  723. }
  724. if len(selectedCandidates) == 0 {
  725. return nil, fmt.Errorf("no valid models.dev pricing entries found")
  726. }
  727. modelRatioMap := make(map[string]any)
  728. completionRatioMap := make(map[string]any)
  729. cacheRatioMap := make(map[string]any)
  730. for modelName, candidate := range selectedCandidates {
  731. if candidate.Input == 0 {
  732. modelRatioMap[modelName] = 0.0
  733. continue
  734. }
  735. modelRatio := candidate.Input * float64(ratio_setting.USD) / modelsDevInputCostRatioBase
  736. modelRatioMap[modelName] = roundRatioValue(modelRatio)
  737. if candidate.Output != nil {
  738. completionRatio := *candidate.Output / candidate.Input
  739. completionRatioMap[modelName] = roundRatioValue(completionRatio)
  740. }
  741. if candidate.CacheRead != nil {
  742. cacheRatio := *candidate.CacheRead / candidate.Input
  743. cacheRatioMap[modelName] = roundRatioValue(cacheRatio)
  744. }
  745. }
  746. converted := make(map[string]any)
  747. if len(modelRatioMap) > 0 {
  748. converted["model_ratio"] = modelRatioMap
  749. }
  750. if len(completionRatioMap) > 0 {
  751. converted["completion_ratio"] = completionRatioMap
  752. }
  753. if len(cacheRatioMap) > 0 {
  754. converted["cache_ratio"] = cacheRatioMap
  755. }
  756. return converted, nil
  757. }
  758. func GetSyncableChannels(c *gin.Context) {
  759. channels, err := model.GetAllChannels(0, 0, true, false)
  760. if err != nil {
  761. c.JSON(http.StatusOK, gin.H{
  762. "success": false,
  763. "message": err.Error(),
  764. })
  765. return
  766. }
  767. var syncableChannels []dto.SyncableChannel
  768. for _, channel := range channels {
  769. if channel.GetBaseURL() != "" {
  770. syncableChannels = append(syncableChannels, dto.SyncableChannel{
  771. ID: channel.Id,
  772. Name: channel.Name,
  773. BaseURL: channel.GetBaseURL(),
  774. Status: channel.Status,
  775. Type: channel.Type,
  776. })
  777. }
  778. }
  779. syncableChannels = append(syncableChannels, dto.SyncableChannel{
  780. ID: officialRatioPresetID,
  781. Name: officialRatioPresetName,
  782. BaseURL: officialRatioPresetBaseURL,
  783. Status: 1,
  784. })
  785. syncableChannels = append(syncableChannels, dto.SyncableChannel{
  786. ID: modelsDevPresetID,
  787. Name: modelsDevPresetName,
  788. BaseURL: modelsDevPresetBaseURL,
  789. Status: 1,
  790. })
  791. c.JSON(http.StatusOK, gin.H{
  792. "success": true,
  793. "message": "",
  794. "data": syncableChannels,
  795. })
  796. }