Browse Source

fix: support AI Gateway, close #52

CaIon 2 years ago
parent
commit
4274b925da
3 changed files with 9 additions and 4 deletions
  1. 1 1
      controller/channel-test.go
  2. 2 1
      controller/relay-text.go
  3. 6 2
      controller/relay-utils.go

+ 1 - 1
controller/channel-test.go

@@ -45,7 +45,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
 	}
 	requestURL := common.ChannelBaseURLs[channel.Type]
 	if channel.Type == common.ChannelTypeAzure {
-		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
+		requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
 	} else {
 		if channel.GetBaseURL() != "" {
 			requestURL = channel.GetBaseURL()

+ 2 - 1
controller/relay-text.go

@@ -149,7 +149,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			model_ = strings.TrimSuffix(model_, "-0301")
 			model_ = strings.TrimSuffix(model_, "-0314")
 			model_ = strings.TrimSuffix(model_, "-0613")
-			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
+			requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+			fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
 		}
 	case APITypeClaude:
 		fullRequestURL = "https://api.anthropic.com/v1/complete"

+ 6 - 2
controller/relay-utils.go

@@ -295,9 +295,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
 
 func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-	if channelType == common.ChannelTypeOpenAI {
-		if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+
+	if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+		switch channelType {
+		case common.ChannelTypeOpenAI:
 			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
+		case common.ChannelTypeAzure:
+			fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
 		}
 	}
 	return fullRequestURL