Просмотр исходного кода

Merge pull request #4200 from yyhhyyyyyy/fix/vertex-gateway-base-url

fix(vertex): honor custom base_url as gateway prefix
Calcium-Ion 1 неделя назад
Родитель
Сommit
5114ad0677

+ 20 - 28
relay/channel/task/vertex/adaptor.go

@@ -95,20 +95,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
 	if strings.TrimSpace(region) == "" {
 		region = "global"
 	}
-	if region == "global" {
-		return fmt.Sprintf(
-			"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
-			adc.ProjectID,
-			modelName,
-		), nil
-	}
-	return fmt.Sprintf(
-		"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
-		region,
-		adc.ProjectID,
-		region,
-		modelName,
-	), nil
+	return vertexcore.BuildGoogleModelURL(a.baseURL, vertexcore.DefaultAPIVersion, adc.ProjectID, region, modelName, "predictLongRunning"), nil
 }
 
 // BuildRequestHeader sets required headers.
@@ -238,6 +225,22 @@ func (a *TaskAdaptor) GetModelList() []string {
 }
 func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
 
+func buildFetchOperationURL(baseURL, upstreamName string) (string, error) {
+	region := extractRegionFromOperationName(upstreamName)
+	if region == "" {
+		region = "us-central1"
+	}
+	project := extractProjectFromOperationName(upstreamName)
+	modelName := extractModelFromOperationName(upstreamName)
+	if strings.TrimSpace(modelName) == "" {
+		return "", fmt.Errorf("cannot extract model from operation name")
+	}
+	if strings.TrimSpace(project) == "" {
+		return "", fmt.Errorf("cannot extract project from operation name")
+	}
+	return vertexcore.BuildGoogleModelURL(baseURL, vertexcore.DefaultAPIVersion, project, region, modelName, "fetchPredictOperation"), nil
+}
+
 // FetchTask fetch task status
 func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
 	taskID, ok := body["task_id"].(string)
@@ -248,20 +251,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
 	if err != nil {
 		return nil, fmt.Errorf("decode task_id failed: %w", err)
 	}
-	region := extractRegionFromOperationName(upstreamName)
-	if region == "" {
-		region = "us-central1"
-	}
-	project := extractProjectFromOperationName(upstreamName)
-	modelName := extractModelFromOperationName(upstreamName)
-	if project == "" || modelName == "" {
-		return nil, fmt.Errorf("cannot extract project/model from operation name")
-	}
-	var url string
-	if region == "global" {
-		url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
-	} else {
-		url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
+	url, err := buildFetchOperationURL(baseUrl, upstreamName)
+	if err != nil {
+		return nil, err
 	}
 	payload := fetchOperationPayload{OperationName: upstreamName}
 	data, err := common.Marshal(payload)

+ 9 - 48
relay/channel/vertex/adaptor.go

@@ -134,47 +134,11 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s
 		a.AccountCredentials = *adc
 
 		if a.RequestMode == RequestModeGemini {
-			if region == "global" {
-				return fmt.Sprintf(
-					"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
-					adc.ProjectID,
-					modelName,
-					suffix,
-				), nil
-			} else {
-				return fmt.Sprintf(
-					"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
-					region,
-					adc.ProjectID,
-					region,
-					modelName,
-					suffix,
-				), nil
-			}
+			return BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil
 		} else if a.RequestMode == RequestModeClaude {
-			if region == "global" {
-				return fmt.Sprintf(
-					"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
-					adc.ProjectID,
-					modelName,
-					suffix,
-				), nil
-			} else {
-				return fmt.Sprintf(
-					"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
-					region,
-					adc.ProjectID,
-					region,
-					modelName,
-					suffix,
-				), nil
-			}
+			return BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil
 		} else if a.RequestMode == RequestModeOpenSource {
-			return fmt.Sprintf(
-				"https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
-				adc.ProjectID,
-				region,
-			), nil
+			return BuildOpenSourceChatCompletionsURL(info.ChannelBaseUrl, adc.ProjectID, region), nil
 		}
 	} else {
 		var keyPrefix string
@@ -183,20 +147,17 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s
 		} else {
 			keyPrefix = "?"
 		}
-		if region == "global" {
+		if a.RequestMode == RequestModeGemini {
 			return fmt.Sprintf(
-				"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
-				modelName,
-				suffix,
+				"%s%skey=%s",
+				BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix),
 				keyPrefix,
 				info.ApiKey,
 			), nil
-		} else {
+		} else if a.RequestMode == RequestModeClaude {
 			return fmt.Sprintf(
-				"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
-				region,
-				modelName,
-				suffix,
+				"%s%skey=%s",
+				BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix),
 				keyPrefix,
 				info.ApiKey,
 			), nil

+ 86 - 0
relay/channel/vertex/url_builder.go

@@ -0,0 +1,86 @@
+package vertex
+
+import (
+	"fmt"
+	"strings"
+)
+
+const (
+	DefaultAPIVersion    = "v1"
+	OpenSourceAPIVersion = "v1beta1"
+	PublisherGoogle      = "google"
+	PublisherAnthropic   = "anthropic"
+)
+
+func normalizeVertexBaseURL(baseURL string) string {
+	return strings.TrimRight(strings.TrimSpace(baseURL), "/")
+}
+
+func normalizeVertexRegion(region string) string {
+	region = strings.TrimSpace(region)
+	if region == "" {
+		return "global"
+	}
+	return region
+}
+
+func appendVertexAPIVersion(baseURL, version string) string {
+	version = strings.Trim(strings.TrimSpace(version), "/")
+	if version == "" {
+		return baseURL
+	}
+	if strings.HasSuffix(baseURL, "/"+version) {
+		return baseURL
+	}
+	return baseURL + "/" + version
+}
+
+func BuildAPIBaseURL(baseURL, version, projectID, region string) string {
+	if normalized := normalizeVertexBaseURL(baseURL); normalized != "" {
+		normalized = appendVertexAPIVersion(normalized, version)
+
+		region = normalizeVertexRegion(region)
+		if strings.TrimSpace(projectID) != "" {
+			normalized = fmt.Sprintf("%s/projects/%s/locations/%s", normalized, projectID, region)
+		}
+		return normalized
+	}
+
+	region = normalizeVertexRegion(region)
+	if strings.TrimSpace(projectID) == "" {
+		if region == "global" {
+			return fmt.Sprintf("https://aiplatform.googleapis.com/%s", version)
+		}
+		return fmt.Sprintf("https://%s-aiplatform.googleapis.com/%s", region, version)
+	}
+
+	if region == "global" {
+		return fmt.Sprintf("https://aiplatform.googleapis.com/%s/projects/%s/locations/global", version, projectID)
+	}
+	return fmt.Sprintf("https://%s-aiplatform.googleapis.com/%s/projects/%s/locations/%s", region, version, projectID, region)
+}
+
+func BuildPublisherModelURL(baseURL, version, projectID, region, publisher, modelName, action string) string {
+	return fmt.Sprintf(
+		"%s/publishers/%s/models/%s:%s",
+		BuildAPIBaseURL(baseURL, version, projectID, region),
+		publisher,
+		modelName,
+		action,
+	)
+}
+
+func BuildGoogleModelURL(baseURL, version, projectID, region, modelName, action string) string {
+	return BuildPublisherModelURL(baseURL, version, projectID, region, PublisherGoogle, modelName, action)
+}
+
+func BuildAnthropicModelURL(baseURL, version, projectID, region, modelName, action string) string {
+	return BuildPublisherModelURL(baseURL, version, projectID, region, PublisherAnthropic, modelName, action)
+}
+
+func BuildOpenSourceChatCompletionsURL(baseURL, projectID, region string) string {
+	return fmt.Sprintf(
+		"%s/endpoints/openapi/chat/completions",
+		BuildAPIBaseURL(baseURL, OpenSourceAPIVersion, projectID, region),
+	)
+}