ratio_sync.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. package controller
  2. import (
  3. "encoding/json"
  4. "net/http"
  5. "one-api/model"
  6. "one-api/setting/ratio_setting"
  7. "one-api/dto"
  8. "sync"
  9. "time"
  10. "github.com/gin-gonic/gin"
  11. )
  12. type upstreamResult struct {
  13. Name string `json:"name"`
  14. Data map[string]any `json:"data,omitempty"`
  15. Err string `json:"err,omitempty"`
  16. }
  17. type TestResult struct {
  18. Name string `json:"name"`
  19. Status string `json:"status"`
  20. Error string `json:"error,omitempty"`
  21. }
  22. type DifferenceItem struct {
  23. Current interface{} `json:"current"` // 当前本地值,可能为null
  24. Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null
  25. }
  26. // SyncableChannel 可同步的渠道信息
  27. type SyncableChannel struct {
  28. ID int `json:"id"`
  29. Name string `json:"name"`
  30. BaseURL string `json:"base_url"`
  31. Status int `json:"status"`
  32. }
  33. // FetchUpstreamRatios 后端并发拉取上游倍率
  34. func FetchUpstreamRatios(c *gin.Context) {
  35. var req dto.UpstreamRequest
  36. if err := c.ShouldBindJSON(&req); err != nil {
  37. c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
  38. return
  39. }
  40. if req.Timeout <= 0 {
  41. req.Timeout = 10
  42. }
  43. // build upstream list from ids
  44. var upstreams []dto.UpstreamDTO
  45. if len(req.ChannelIDs) > 0 {
  46. // convert []int64 -> []int for model function
  47. intIds := make([]int, 0, len(req.ChannelIDs))
  48. for _, id64 := range req.ChannelIDs {
  49. intIds = append(intIds, int(id64))
  50. }
  51. dbChannels, _ := model.GetChannelsByIds(intIds)
  52. for _, ch := range dbChannels {
  53. upstreams = append(upstreams, dto.UpstreamDTO{
  54. Name: ch.Name,
  55. BaseURL: ch.GetBaseURL(),
  56. Endpoint: "", // assume default endpoint
  57. })
  58. }
  59. }
  60. var wg sync.WaitGroup
  61. ch := make(chan upstreamResult, len(upstreams))
  62. for _, chn := range upstreams {
  63. wg.Add(1)
  64. go func(chItem dto.UpstreamDTO) {
  65. defer wg.Done()
  66. endpoint := chItem.Endpoint
  67. if endpoint == "" {
  68. endpoint = "/api/ratio_config"
  69. }
  70. url := chItem.BaseURL + endpoint
  71. client := http.Client{Timeout: time.Duration(req.Timeout) * time.Second}
  72. resp, err := client.Get(url)
  73. if err != nil {
  74. ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
  75. return
  76. }
  77. defer resp.Body.Close()
  78. if resp.StatusCode != http.StatusOK {
  79. ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
  80. return
  81. }
  82. var body struct {
  83. Success bool `json:"success"`
  84. Data map[string]any `json:"data"`
  85. Message string `json:"message"`
  86. }
  87. if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
  88. ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
  89. return
  90. }
  91. if !body.Success {
  92. ch <- upstreamResult{Name: chItem.Name, Err: body.Message}
  93. return
  94. }
  95. ch <- upstreamResult{Name: chItem.Name, Data: body.Data}
  96. }(chn)
  97. }
  98. wg.Wait()
  99. close(ch)
  100. // 本地倍率配置
  101. localData := ratio_setting.GetExposedData()
  102. var testResults []dto.TestResult
  103. var successfulChannels []struct {
  104. name string
  105. data map[string]any
  106. }
  107. for r := range ch {
  108. if r.Err != "" {
  109. testResults = append(testResults, dto.TestResult{
  110. Name: r.Name,
  111. Status: "error",
  112. Error: r.Err,
  113. })
  114. } else {
  115. testResults = append(testResults, dto.TestResult{
  116. Name: r.Name,
  117. Status: "success",
  118. })
  119. successfulChannels = append(successfulChannels, struct {
  120. name string
  121. data map[string]any
  122. }{name: r.Name, data: r.Data})
  123. }
  124. }
  125. // 构建差异化数据
  126. differences := buildDifferences(localData, successfulChannels)
  127. c.JSON(http.StatusOK, gin.H{
  128. "success": true,
  129. "data": gin.H{
  130. "differences": differences,
  131. "test_results": testResults,
  132. },
  133. })
  134. }
  135. // buildDifferences 构建差异化数据,只返回有意义的差异
  136. func buildDifferences(localData map[string]any, successfulChannels []struct {
  137. name string
  138. data map[string]any
  139. }) map[string]map[string]dto.DifferenceItem {
  140. differences := make(map[string]map[string]dto.DifferenceItem)
  141. ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
  142. // 收集所有模型名称
  143. allModels := make(map[string]struct{})
  144. // 从本地数据收集模型名称
  145. for _, ratioType := range ratioTypes {
  146. if localRatioAny, ok := localData[ratioType]; ok {
  147. if localRatio, ok := localRatioAny.(map[string]float64); ok {
  148. for modelName := range localRatio {
  149. allModels[modelName] = struct{}{}
  150. }
  151. }
  152. }
  153. }
  154. // 从上游数据收集模型名称
  155. for _, channel := range successfulChannels {
  156. for _, ratioType := range ratioTypes {
  157. if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
  158. for modelName := range upstreamRatio {
  159. allModels[modelName] = struct{}{}
  160. }
  161. }
  162. }
  163. }
  164. // 对每个模型和每个比率类型进行分析
  165. for modelName := range allModels {
  166. for _, ratioType := range ratioTypes {
  167. // 获取本地值
  168. var localValue interface{} = nil
  169. if localRatioAny, ok := localData[ratioType]; ok {
  170. if localRatio, ok := localRatioAny.(map[string]float64); ok {
  171. if val, exists := localRatio[modelName]; exists {
  172. localValue = val
  173. }
  174. }
  175. }
  176. // 收集上游值
  177. upstreamValues := make(map[string]interface{})
  178. hasUpstreamValue := false
  179. hasDifference := false
  180. for _, channel := range successfulChannels {
  181. var upstreamValue interface{} = nil
  182. if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
  183. if val, exists := upstreamRatio[modelName]; exists {
  184. upstreamValue = val
  185. hasUpstreamValue = true
  186. // 检查是否与本地值不同
  187. if localValue != nil && localValue != val {
  188. hasDifference = true
  189. } else if localValue == val {
  190. upstreamValue = "same"
  191. }
  192. }
  193. }
  194. // 如果本地值为空但上游有值,这也是差异
  195. if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
  196. hasDifference = true
  197. }
  198. upstreamValues[channel.name] = upstreamValue
  199. }
  200. // 应用过滤逻辑
  201. shouldInclude := false
  202. if localValue != nil {
  203. // 规则1: 本地值存在,至少有一个上游与本地值不同
  204. if hasDifference {
  205. shouldInclude = true
  206. }
  207. // 规则2: 本地值存在,但所有上游都未设置 - 不包含
  208. } else {
  209. // 规则3: 本地值不存在,至少有一个上游设置了值
  210. if hasUpstreamValue {
  211. shouldInclude = true
  212. }
  213. }
  214. if shouldInclude {
  215. if differences[modelName] == nil {
  216. differences[modelName] = make(map[string]dto.DifferenceItem)
  217. }
  218. differences[modelName][ratioType] = dto.DifferenceItem{
  219. Current: localValue,
  220. Upstreams: upstreamValues,
  221. }
  222. }
  223. }
  224. }
  225. return differences
  226. }
  227. // GetSyncableChannels 获取可用于倍率同步的渠道(base_url 不为空的渠道)
  228. func GetSyncableChannels(c *gin.Context) {
  229. channels, err := model.GetAllChannels(0, 0, true, false)
  230. if err != nil {
  231. c.JSON(http.StatusOK, gin.H{
  232. "success": false,
  233. "message": err.Error(),
  234. })
  235. return
  236. }
  237. var syncableChannels []dto.SyncableChannel
  238. for _, channel := range channels {
  239. // 只返回 base_url 不为空的渠道
  240. if channel.GetBaseURL() != "" {
  241. syncableChannels = append(syncableChannels, dto.SyncableChannel{
  242. ID: channel.Id,
  243. Name: channel.Name,
  244. BaseURL: channel.GetBaseURL(),
  245. Status: channel.Status,
  246. })
  247. }
  248. }
  249. c.JSON(http.StatusOK, gin.H{
  250. "success": true,
  251. "message": "",
  252. "data": syncableChannels,
  253. })
  254. }