|
@@ -1,41 +1,35 @@
|
|
|
package controller
|
|
package controller
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
|
+ "context"
|
|
|
"encoding/json"
|
|
"encoding/json"
|
|
|
"net/http"
|
|
"net/http"
|
|
|
- "one-api/model"
|
|
|
|
|
- "one-api/setting/ratio_setting"
|
|
|
|
|
- "one-api/dto"
|
|
|
|
|
|
|
+ "strings"
|
|
|
"sync"
|
|
"sync"
|
|
|
"time"
|
|
"time"
|
|
|
|
|
|
|
|
|
|
+ "one-api/common"
|
|
|
|
|
+ "one-api/dto"
|
|
|
|
|
+ "one-api/model"
|
|
|
|
|
+ "one-api/setting/ratio_setting"
|
|
|
|
|
+
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+const (
|
|
|
|
|
+ defaultTimeoutSeconds = 10
|
|
|
|
|
+ defaultEndpoint = "/api/ratio_config"
|
|
|
|
|
+ maxConcurrentFetches = 8
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
|
|
|
|
+
|
|
|
type upstreamResult struct {
|
|
type upstreamResult struct {
|
|
|
Name string `json:"name"`
|
|
Name string `json:"name"`
|
|
|
Data map[string]any `json:"data,omitempty"`
|
|
Data map[string]any `json:"data,omitempty"`
|
|
|
Err string `json:"err,omitempty"`
|
|
Err string `json:"err,omitempty"`
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-type TestResult struct {
|
|
|
|
|
- Name string `json:"name"`
|
|
|
|
|
- Status string `json:"status"`
|
|
|
|
|
- Error string `json:"error,omitempty"`
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-type DifferenceItem struct {
|
|
|
|
|
- Current interface{} `json:"current"` // 当前本地值,可能为null
|
|
|
|
|
- Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-type SyncableChannel struct {
|
|
|
|
|
- ID int `json:"id"`
|
|
|
|
|
- Name string `json:"name"`
|
|
|
|
|
- BaseURL string `json:"base_url"`
|
|
|
|
|
- Status int `json:"status"`
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
func FetchUpstreamRatios(c *gin.Context) {
|
|
func FetchUpstreamRatios(c *gin.Context) {
|
|
|
var req dto.UpstreamRequest
|
|
var req dto.UpstreamRequest
|
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
@@ -44,45 +38,80 @@ func FetchUpstreamRatios(c *gin.Context) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if req.Timeout <= 0 {
|
|
if req.Timeout <= 0 {
|
|
|
- req.Timeout = 10
|
|
|
|
|
|
|
+ req.Timeout = defaultTimeoutSeconds
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
var upstreams []dto.UpstreamDTO
|
|
var upstreams []dto.UpstreamDTO
|
|
|
|
|
+
|
|
|
if len(req.ChannelIDs) > 0 {
|
|
if len(req.ChannelIDs) > 0 {
|
|
|
intIds := make([]int, 0, len(req.ChannelIDs))
|
|
intIds := make([]int, 0, len(req.ChannelIDs))
|
|
|
for _, id64 := range req.ChannelIDs {
|
|
for _, id64 := range req.ChannelIDs {
|
|
|
intIds = append(intIds, int(id64))
|
|
intIds = append(intIds, int(id64))
|
|
|
}
|
|
}
|
|
|
- dbChannels, _ := model.GetChannelsByIds(intIds)
|
|
|
|
|
|
|
+ dbChannels, err := model.GetChannelsByIds(intIds)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
for _, ch := range dbChannels {
|
|
for _, ch := range dbChannels {
|
|
|
- upstreams = append(upstreams, dto.UpstreamDTO{
|
|
|
|
|
- Name: ch.Name,
|
|
|
|
|
- BaseURL: ch.GetBaseURL(),
|
|
|
|
|
- Endpoint: "",
|
|
|
|
|
- })
|
|
|
|
|
|
|
+ if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
|
|
|
|
+ upstreams = append(upstreams, dto.UpstreamDTO{
|
|
|
|
|
+ Name: ch.Name,
|
|
|
|
|
+ BaseURL: strings.TrimRight(base, "/"),
|
|
|
|
|
+ Endpoint: "",
|
|
|
|
|
+ })
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if len(upstreams) == 0 {
|
|
|
|
|
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
var wg sync.WaitGroup
|
|
var wg sync.WaitGroup
|
|
|
ch := make(chan upstreamResult, len(upstreams))
|
|
ch := make(chan upstreamResult, len(upstreams))
|
|
|
|
|
|
|
|
|
|
+ sem := make(chan struct{}, maxConcurrentFetches)
|
|
|
|
|
+
|
|
|
|
|
+ client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
|
|
|
|
+
|
|
|
for _, chn := range upstreams {
|
|
for _, chn := range upstreams {
|
|
|
wg.Add(1)
|
|
wg.Add(1)
|
|
|
go func(chItem dto.UpstreamDTO) {
|
|
go func(chItem dto.UpstreamDTO) {
|
|
|
defer wg.Done()
|
|
defer wg.Done()
|
|
|
|
|
+
|
|
|
|
|
+ sem <- struct{}{}
|
|
|
|
|
+ defer func() { <-sem }()
|
|
|
|
|
+
|
|
|
endpoint := chItem.Endpoint
|
|
endpoint := chItem.Endpoint
|
|
|
if endpoint == "" {
|
|
if endpoint == "" {
|
|
|
- endpoint = "/api/ratio_config"
|
|
|
|
|
|
|
+ endpoint = defaultEndpoint
|
|
|
|
|
+ } else if !strings.HasPrefix(endpoint, "/") {
|
|
|
|
|
+ endpoint = "/" + endpoint
|
|
|
|
|
+ }
|
|
|
|
|
+ fullURL := chItem.BaseURL + endpoint
|
|
|
|
|
+
|
|
|
|
|
+ ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
|
|
|
|
+ defer cancel()
|
|
|
|
|
+
|
|
|
|
|
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
|
|
|
|
+ ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
|
|
|
|
+ return
|
|
|
}
|
|
}
|
|
|
- url := chItem.BaseURL + endpoint
|
|
|
|
|
- client := http.Client{Timeout: time.Duration(req.Timeout) * time.Second}
|
|
|
|
|
- resp, err := client.Get(url)
|
|
|
|
|
|
|
+
|
|
|
|
|
+ resp, err := client.Do(httpReq)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
|
|
+ common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
|
|
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
|
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
defer resp.Body.Close()
|
|
defer resp.Body.Close()
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
|
+ common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
|
|
ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
|
|
ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
@@ -92,6 +121,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
|
|
Message string `json:"message"`
|
|
Message string `json:"message"`
|
|
|
}
|
|
}
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
|
|
|
|
+ common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
|
|
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
|
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
@@ -149,7 +179,6 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
|
|
data map[string]any
|
|
data map[string]any
|
|
|
}) map[string]map[string]dto.DifferenceItem {
|
|
}) map[string]map[string]dto.DifferenceItem {
|
|
|
differences := make(map[string]map[string]dto.DifferenceItem)
|
|
differences := make(map[string]map[string]dto.DifferenceItem)
|
|
|
- ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
|
|
|
|
|
|
|
|
|
allModels := make(map[string]struct{})
|
|
allModels := make(map[string]struct{})
|
|
|
|
|
|