Просмотр исходного кода

🚀 chore(controller, dto): elevate ratio-sync feature to production readiness

WHAT’S NEW
• controller/ratio_sync.go
  – Deleted unused local structs (TestResult, DifferenceItem, SyncableChannel).
  – Centralised config with constants: defaultTimeoutSeconds, defaultEndpoint, maxConcurrentFetches, ratioTypes.
  – Replaced magic numbers; added semaphore-based concurrency limit and shared http.Client (with TLS & Expect-Continue timeouts).
  – Added comprehensive error handling and context-aware logging via common.Log* helpers.
  – Checked DB errors from GetChannelsByIds; early-return on failures or empty upstream list.
  – Removed custom-channel support; logic now relies solely on ChannelIDs.
  – Minor clean-ups: import grouping, string trimming, endpoint normalisation.

• dto/ratio_sync.go
  – Simplified UpstreamRequest: dropped unused CustomChannels field.

WHY
These improvements harden the ratio-sync endpoint for production use by preventing silent failures, controlling resource usage, and making behaviour configurable and observable.

HOW
No business logic change—only structural refactor, logging, and safeguards—so existing API contracts (aside from removed custom_channels) remain intact.
Apple\Apple 8 месяцев назад
Родитель
Сommit
150c506ece
2 измененных файлов с 64 добавлено и 36 удалено
  1. 62 33
      controller/ratio_sync.go
  2. 2 3
      dto/ratio_sync.go

+ 62 - 33
controller/ratio_sync.go

@@ -1,41 +1,35 @@
 package controller
 package controller
 
 
 import (
 import (
+    "context"
     "encoding/json"
     "encoding/json"
     "net/http"
     "net/http"
-    "one-api/model"
-    "one-api/setting/ratio_setting"
-    "one-api/dto"
+    "strings"
     "sync"
     "sync"
     "time"
     "time"
 
 
+    "one-api/common"
+    "one-api/dto"
+    "one-api/model"
+    "one-api/setting/ratio_setting"
+
     "github.com/gin-gonic/gin"
     "github.com/gin-gonic/gin"
 )
 )
 
 
+const (
+    defaultTimeoutSeconds  = 10
+    defaultEndpoint        = "/api/ratio_config"
+    maxConcurrentFetches   = 8
+)
+
+var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
+
 type upstreamResult struct {
 type upstreamResult struct {
     Name string                 `json:"name"`
     Name string                 `json:"name"`
     Data map[string]any         `json:"data,omitempty"`
     Data map[string]any         `json:"data,omitempty"`
     Err  string                 `json:"err,omitempty"`
     Err  string                 `json:"err,omitempty"`
 }
 }
 
 
-type TestResult struct {
-    Name   string `json:"name"`
-    Status string `json:"status"`
-    Error  string `json:"error,omitempty"`
-}
-
-type DifferenceItem struct {
-    Current   interface{}            `json:"current"`   // 当前本地值,可能为null
-    Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null
-}
-
-type SyncableChannel struct {
-    ID      int    `json:"id"`
-    Name    string `json:"name"`
-    BaseURL string `json:"base_url"`
-    Status  int    `json:"status"`
-}
-
 func FetchUpstreamRatios(c *gin.Context) {
 func FetchUpstreamRatios(c *gin.Context) {
     var req dto.UpstreamRequest
     var req dto.UpstreamRequest
     if err := c.ShouldBindJSON(&req); err != nil {
     if err := c.ShouldBindJSON(&req); err != nil {
@@ -44,45 +38,80 @@ func FetchUpstreamRatios(c *gin.Context) {
     }
     }
 
 
     if req.Timeout <= 0 {
     if req.Timeout <= 0 {
-        req.Timeout = 10
+        req.Timeout = defaultTimeoutSeconds
     }
     }
 
 
     var upstreams []dto.UpstreamDTO
     var upstreams []dto.UpstreamDTO
+
     if len(req.ChannelIDs) > 0 {
     if len(req.ChannelIDs) > 0 {
         intIds := make([]int, 0, len(req.ChannelIDs))
         intIds := make([]int, 0, len(req.ChannelIDs))
         for _, id64 := range req.ChannelIDs {
         for _, id64 := range req.ChannelIDs {
             intIds = append(intIds, int(id64))
             intIds = append(intIds, int(id64))
         }
         }
-        dbChannels, _ := model.GetChannelsByIds(intIds)
+        dbChannels, err := model.GetChannelsByIds(intIds)
+        if err != nil {
+            common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
+            c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
+            return
+        }
         for _, ch := range dbChannels {
         for _, ch := range dbChannels {
-            upstreams = append(upstreams, dto.UpstreamDTO{
-                Name:     ch.Name,
-                BaseURL:  ch.GetBaseURL(),
-                Endpoint: "",
-            })
+            if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
+                upstreams = append(upstreams, dto.UpstreamDTO{
+                    Name:     ch.Name,
+                    BaseURL:  strings.TrimRight(base, "/"),
+                    Endpoint: "",
+                })
+            }
         }
         }
     }
     }
 
 
+    if len(upstreams) == 0 {
+        c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
+        return
+    }
+
     var wg sync.WaitGroup
     var wg sync.WaitGroup
     ch := make(chan upstreamResult, len(upstreams))
     ch := make(chan upstreamResult, len(upstreams))
 
 
+    sem := make(chan struct{}, maxConcurrentFetches)
+
+    client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
+
     for _, chn := range upstreams {
     for _, chn := range upstreams {
         wg.Add(1)
         wg.Add(1)
         go func(chItem dto.UpstreamDTO) {
         go func(chItem dto.UpstreamDTO) {
             defer wg.Done()
             defer wg.Done()
+
+            sem <- struct{}{}
+            defer func() { <-sem }()
+
             endpoint := chItem.Endpoint
             endpoint := chItem.Endpoint
             if endpoint == "" {
             if endpoint == "" {
-                endpoint = "/api/ratio_config"
+                endpoint = defaultEndpoint
+            } else if !strings.HasPrefix(endpoint, "/") {
+                endpoint = "/" + endpoint
+            }
+            fullURL := chItem.BaseURL + endpoint
+
+            ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
+            defer cancel()
+
+            httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
+            if err != nil {
+                common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
+                ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
+                return
             }
             }
-            url := chItem.BaseURL + endpoint
-            client := http.Client{Timeout: time.Duration(req.Timeout) * time.Second}
-            resp, err := client.Get(url)
+
+            resp, err := client.Do(httpReq)
             if err != nil {
             if err != nil {
+                common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
                 ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
                 ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
                 return
                 return
             }
             }
             defer resp.Body.Close()
             defer resp.Body.Close()
             if resp.StatusCode != http.StatusOK {
             if resp.StatusCode != http.StatusOK {
+                common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
                 ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
                 ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
                 return
                 return
             }
             }
@@ -92,6 +121,7 @@ func FetchUpstreamRatios(c *gin.Context) {
                 Message string                 `json:"message"`
                 Message string                 `json:"message"`
             }
             }
             if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
             if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+                common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
                 ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
                 ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
                 return
                 return
             }
             }
@@ -149,7 +179,6 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
     data map[string]any
     data map[string]any
 }) map[string]map[string]dto.DifferenceItem {
 }) map[string]map[string]dto.DifferenceItem {
     differences := make(map[string]map[string]dto.DifferenceItem)
     differences := make(map[string]map[string]dto.DifferenceItem)
-    ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
 
 
     allModels := make(map[string]struct{})
     allModels := make(map[string]struct{})
     
     

+ 2 - 3
dto/ratio_sync.go

@@ -19,9 +19,8 @@ type UpstreamDTO struct {
 }
 }
 
 
 type UpstreamRequest struct {
 type UpstreamRequest struct {
-    ChannelIDs     []int64       `json:"channel_ids"`
-    CustomChannels []UpstreamDTO `json:"custom_channels"`
-    Timeout        int           `json:"timeout"`
+    ChannelIDs []int64 `json:"channel_ids"`
+    Timeout    int     `json:"timeout"`
 }
 }
 
 
 // TestResult 上游测试连通性结果
 // TestResult 上游测试连通性结果