Просмотр исходного кода

Merge pull request #1817 from wzxjohn/hotfix/relay_vertex_claude

fix(relay): wrong URL for claude model in GCP Vertex AI
Seefs 5 месяцев назад
Родитель
Сommit
aab82f22fa
1 измененных файлов с 48 добавлено и 21 удалено
  1. 48 21
      relay/channel/vertex/adaptor.go

+ 48 - 21
relay/channel/vertex/adaptor.go

@@ -91,7 +91,43 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s
 		}
 		a.AccountCredentials = *adc
 
-		if a.RequestMode == RequestModeLlama {
+		if a.RequestMode == RequestModeGemini {
+			if region == "global" {
+				return fmt.Sprintf(
+					"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
+					adc.ProjectID,
+					modelName,
+					suffix,
+				), nil
+			} else {
+				return fmt.Sprintf(
+					"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
+					region,
+					adc.ProjectID,
+					region,
+					modelName,
+					suffix,
+				), nil
+			}
+		} else if a.RequestMode == RequestModeClaude {
+			if region == "global" {
+				return fmt.Sprintf(
+					"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
+					adc.ProjectID,
+					modelName,
+					suffix,
+				), nil
+			} else {
+				return fmt.Sprintf(
+					"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
+					region,
+					adc.ProjectID,
+					region,
+					modelName,
+					suffix,
+				), nil
+			}
+		} else if a.RequestMode == RequestModeLlama {
 			return fmt.Sprintf(
 				"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
 				region,
@@ -99,42 +135,33 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s
 				region,
 			), nil
 		}
-
-		if region == "global" {
-			return fmt.Sprintf(
-				"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
-				adc.ProjectID,
-				modelName,
-				suffix,
-			), nil
+	} else {
+		var keyPrefix string
+		if strings.HasSuffix(suffix, "?alt=sse") {
+			keyPrefix = "&"
 		} else {
-			return fmt.Sprintf(
-				"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
-				region,
-				adc.ProjectID,
-				region,
-				modelName,
-				suffix,
-			), nil
+			keyPrefix = "?"
 		}
-	} else {
 		if region == "global" {
 			return fmt.Sprintf(
-				"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
+				"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
 				modelName,
 				suffix,
+				keyPrefix,
 				info.ApiKey,
 			), nil
 		} else {
 			return fmt.Sprintf(
-				"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
+				"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
 				region,
 				modelName,
 				suffix,
+				keyPrefix,
 				info.ApiKey,
 			), nil
 		}
 	}
+	return "", errors.New("unsupported request mode")
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -188,7 +215,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
 		}
 		req.Set("Authorization", "Bearer "+accessToken)
 	}
-  if a.AccountCredentials.ProjectID != "" {
+	if a.AccountCredentials.ProjectID != "" {
 		req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
 	}
 	return nil