Sfoglia il codice sorgente

Merge remote-tracking branch 'origin/alpha' into refactor/model-pricing

t0ng7u 7 mesi fa
parent
commit
d61a862fa2
3 ha cambiato i file con 10 aggiunte e 6 eliminazioni
  1. 4 0
      controller/channel.go
  2. 3 3
      model/channel.go
  3. 3 3
      relay/channel/openai/adaptor.go

+ 4 - 0
controller/channel.go

@@ -1107,6 +1107,10 @@ func ManageMultiKeys(c *gin.Context) {
 		return
 		return
 	}
 	}
 
 
+	lock := model.GetChannelPollingLock(channel.Id)
+	lock.Lock()
+	defer lock.Unlock()
+
 	switch request.Action {
 	switch request.Action {
 	case "get_key_status":
 	case "get_key_status":
 		keys := channel.GetKeys()
 		keys := channel.GetKeys()

+ 3 - 3
model/channel.go

@@ -141,7 +141,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
 		return keys[selectedIdx], selectedIdx, nil
 		return keys[selectedIdx], selectedIdx, nil
 	case constant.MultiKeyModePolling:
 	case constant.MultiKeyModePolling:
 		// Use channel-specific lock to ensure thread-safe polling
 		// Use channel-specific lock to ensure thread-safe polling
-		lock := getChannelPollingLock(channel.Id)
+		lock := GetChannelPollingLock(channel.Id)
 		lock.Lock()
 		lock.Lock()
 		defer lock.Unlock()
 		defer lock.Unlock()
 
 
@@ -500,8 +500,8 @@ var channelStatusLock sync.Mutex
 // channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
 // channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
 var channelPollingLocks sync.Map
 var channelPollingLocks sync.Map
 
 
-// getChannelPollingLock returns or creates a mutex for the given channel ID
-func getChannelPollingLock(channelId int) *sync.Mutex {
+// GetChannelPollingLock returns or creates a mutex for the given channel ID
+func GetChannelPollingLock(channelId int) *sync.Mutex {
 	if lock, exists := channelPollingLocks.Load(channelId); exists {
 	if lock, exists := channelPollingLocks.Load(channelId); exists {
 		return lock.(*sync.Mutex)
 		return lock.(*sync.Mutex)
 	}
 	}

+ 3 - 3
relay/channel/openai/adaptor.go

@@ -73,9 +73,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 }
 
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
-		return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
-	}
 	if info.RelayMode == relayconstant.RelayModeRealtime {
 	if info.RelayMode == relayconstant.RelayModeRealtime {
 		if strings.HasPrefix(info.BaseUrl, "https://") {
 		if strings.HasPrefix(info.BaseUrl, "https://") {
 			baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
 			baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
@@ -122,6 +119,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
 		url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
 		return url, nil
 		return url, nil
 	default:
 	default:
+		if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
+			return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+		}
 		return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
 		return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
 	}
 	}
 }
 }