Explorar el Código

feat: support Azure dall-e

CaIon hace 2 años
padre
commit
75b6327f4f
Se han modificado 4 ficheros con 55 adiciones y 8 borrados
  1. 17 1
      controller/relay-audio.go
  2. 15 3
      controller/relay-image.go
  3. 9 0
      controller/relay-utils.go
  4. 14 4
      controller/relay.go

+ 17 - 1
controller/relay-audio.go

@@ -106,13 +106,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	}
 
 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
+	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
+		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
+		apiVersion := GetAPIVersion(c)
+		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
+	}
+
 	requestBody := c.Request.Body
 
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	if err != nil {
 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 	}
-	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+
+	if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
+		// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
+		apiKey := c.Request.Header.Get("Authorization")
+		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+		req.Header.Set("api-key", apiKey)
+		req.ContentLength = c.Request.ContentLength
+	} else {
+		req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+	}
+
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 

+ 15 - 3
controller/relay-image.go

@@ -31,7 +31,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	}
 
 	if imageRequest.Model == "" {
-		imageRequest.Model = "dall-e"
+		imageRequest.Model = "dall-e-2"
 	}
 	if imageRequest.Size == "" {
 		imageRequest.Size = "1024x1024"
@@ -86,8 +86,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 		baseURL = c.GetString("base_url")
 	}
 	fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
+	if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
+		// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
+		apiVersion := GetAPIVersion(c)
+		// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
+		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
+	}
 	var requestBody io.Reader
-	if isModelMapped {
+	if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
 		jsonStr, err := json.Marshal(imageRequest)
 		if err != nil {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
@@ -132,8 +138,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	if err != nil {
 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 	}
-	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 
+	token := c.Request.Header.Get("Authorization")
+	if channelType == common.ChannelTypeAzure { // Azure authentication
+		token = strings.TrimPrefix(token, "Bearer ")
+		req.Header.Set("api-key", token)
+	} else {
+		req.Header.Set("Authorization", token)
+	}
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 

+ 9 - 0
controller/relay-utils.go

@@ -301,3 +301,12 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
 	}
 	return fullRequestURL
 }
+
+func GetAPIVersion(c *gin.Context) string {
+	query := c.Request.URL.Query()
+	apiVersion := query.Get("api-version")
+	if apiVersion == "" {
+		apiVersion = c.GetString("api_version")
+	}
+	return apiVersion
+}

+ 14 - 4
controller/relay.go

@@ -99,7 +99,9 @@ const (
 	RelayModeMidjourneyNotify
 	RelayModeMidjourneyTaskFetch
 	RelayModeMidjourneyTaskFetchByCondition
-	RelayModeAudio
+	RelayModeAudioSpeech
+	RelayModeAudioTranscription
+	RelayModeAudioTranslation
 )
 
 // https://platform.openai.com/docs/api-reference/chat
@@ -291,14 +293,22 @@ func Relay(c *gin.Context) {
 		relayMode = RelayModeImagesGenerations
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 		relayMode = RelayModeEdits
-	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
-		relayMode = RelayModeAudio
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+		relayMode = RelayModeAudioSpeech
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
+		relayMode = RelayModeAudioTranscription
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
+		relayMode = RelayModeAudioTranslation
 	}
 	var err *OpenAIErrorWithStatusCode
 	switch relayMode {
 	case RelayModeImagesGenerations:
 		err = relayImageHelper(c, relayMode)
-	case RelayModeAudio:
+	case RelayModeAudioSpeech:
+		fallthrough
+	case RelayModeAudioTranslation:
+		fallthrough
+	case RelayModeAudioTranscription:
 		err = relayAudioHelper(c, relayMode)
 	default:
 		err = relayTextHelper(c, relayMode)