|
@@ -95,20 +95,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
|
|
if strings.TrimSpace(region) == "" {
|
|
if strings.TrimSpace(region) == "" {
|
|
|
region = "global"
|
|
region = "global"
|
|
|
}
|
|
}
|
|
|
- if region == "global" {
|
|
|
|
|
- return fmt.Sprintf(
|
|
|
|
|
- "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
|
|
|
|
|
- adc.ProjectID,
|
|
|
|
|
- modelName,
|
|
|
|
|
- ), nil
|
|
|
|
|
- }
|
|
|
|
|
- return fmt.Sprintf(
|
|
|
|
|
- "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
|
|
|
|
|
- region,
|
|
|
|
|
- adc.ProjectID,
|
|
|
|
|
- region,
|
|
|
|
|
- modelName,
|
|
|
|
|
- ), nil
|
|
|
|
|
|
|
+ return vertexcore.BuildGoogleModelURL(a.baseURL, vertexcore.DefaultAPIVersion, adc.ProjectID, region, modelName, "predictLongRunning"), nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// BuildRequestHeader sets required headers.
|
|
// BuildRequestHeader sets required headers.
|
|
@@ -238,6 +225,22 @@ func (a *TaskAdaptor) GetModelList() []string {
|
|
|
}
|
|
}
|
|
|
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
|
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
|
|
|
|
|
|
|
|
|
+func buildFetchOperationURL(baseURL, upstreamName string) (string, error) {
|
|
|
|
|
+ region := extractRegionFromOperationName(upstreamName)
|
|
|
|
|
+ if region == "" {
|
|
|
|
|
+ region = "us-central1"
|
|
|
|
|
+ }
|
|
|
|
|
+ project := extractProjectFromOperationName(upstreamName)
|
|
|
|
|
+ modelName := extractModelFromOperationName(upstreamName)
|
|
|
|
|
+ if strings.TrimSpace(modelName) == "" {
|
|
|
|
|
+ return "", fmt.Errorf("cannot extract model from operation name")
|
|
|
|
|
+ }
|
|
|
|
|
+ if strings.TrimSpace(project) == "" {
|
|
|
|
|
+ return "", fmt.Errorf("cannot extract project from operation name")
|
|
|
|
|
+ }
|
|
|
|
|
+ return vertexcore.BuildGoogleModelURL(baseURL, vertexcore.DefaultAPIVersion, project, region, modelName, "fetchPredictOperation"), nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// FetchTask fetch task status
|
|
// FetchTask fetch task status
|
|
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
|
|
taskID, ok := body["task_id"].(string)
|
|
taskID, ok := body["task_id"].(string)
|
|
@@ -248,20 +251,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
|
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
|
|
}
|
|
}
|
|
|
- region := extractRegionFromOperationName(upstreamName)
|
|
|
|
|
- if region == "" {
|
|
|
|
|
- region = "us-central1"
|
|
|
|
|
- }
|
|
|
|
|
- project := extractProjectFromOperationName(upstreamName)
|
|
|
|
|
- modelName := extractModelFromOperationName(upstreamName)
|
|
|
|
|
- if project == "" || modelName == "" {
|
|
|
|
|
- return nil, fmt.Errorf("cannot extract project/model from operation name")
|
|
|
|
|
- }
|
|
|
|
|
- var url string
|
|
|
|
|
- if region == "global" {
|
|
|
|
|
- url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
|
|
|
|
|
- } else {
|
|
|
|
|
- url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
|
|
|
|
|
|
+ url, err := buildFetchOperationURL(baseUrl, upstreamName)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
}
|
|
}
|
|
|
payload := fetchOperationPayload{OperationName: upstreamName}
|
|
payload := fetchOperationPayload{OperationName: upstreamName}
|
|
|
data, err := common.Marshal(payload)
|
|
data, err := common.Marshal(payload)
|