adaptor.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. package vertex
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "github.com/QuantumNous/new-api/common"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/relay/channel"
  12. "github.com/QuantumNous/new-api/relay/channel/claude"
  13. "github.com/QuantumNous/new-api/relay/channel/gemini"
  14. "github.com/QuantumNous/new-api/relay/channel/openai"
  15. relaycommon "github.com/QuantumNous/new-api/relay/common"
  16. "github.com/QuantumNous/new-api/relay/constant"
  17. "github.com/QuantumNous/new-api/setting/model_setting"
  18. "github.com/QuantumNous/new-api/setting/reasoning"
  19. "github.com/QuantumNous/new-api/types"
  20. "github.com/gin-gonic/gin"
  21. "github.com/samber/lo"
  22. )
  23. const (
  24. RequestModeClaude = 1
  25. RequestModeGemini = 2
  26. RequestModeOpenSource = 3
  27. )
  28. var claudeModelMap = map[string]string{
  29. "claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
  30. "claude-3-opus-20240229": "claude-3-opus@20240229",
  31. "claude-3-haiku-20240307": "claude-3-haiku@20240307",
  32. "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
  33. "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
  34. "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
  35. "claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
  36. "claude-opus-4-20250514": "claude-opus-4@20250514",
  37. "claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
  38. "claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
  39. "claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
  40. "claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
  41. "claude-opus-4-6": "claude-opus-4-6",
  42. "claude-opus-4-7": "claude-opus-4-7",
  43. }
  44. const anthropicVersion = "vertex-2023-10-16"
  45. type Adaptor struct {
  46. RequestMode int
  47. AccountCredentials Credentials
  48. }
  49. func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
  50. // Vertex AI does not support functionResponse.id; keep it stripped here for consistency.
  51. if model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
  52. removeFunctionResponseID(request)
  53. }
  54. geminiAdaptor := gemini.Adaptor{}
  55. return geminiAdaptor.ConvertGeminiRequest(c, info, request)
  56. }
  57. func removeFunctionResponseID(request *dto.GeminiChatRequest) {
  58. if request == nil {
  59. return
  60. }
  61. if len(request.Contents) > 0 {
  62. for i := range request.Contents {
  63. if len(request.Contents[i].Parts) == 0 {
  64. continue
  65. }
  66. for j := range request.Contents[i].Parts {
  67. part := &request.Contents[i].Parts[j]
  68. if part.FunctionResponse == nil {
  69. continue
  70. }
  71. if len(part.FunctionResponse.ID) > 0 {
  72. part.FunctionResponse.ID = nil
  73. }
  74. }
  75. }
  76. }
  77. if len(request.Requests) > 0 {
  78. for i := range request.Requests {
  79. removeFunctionResponseID(&request.Requests[i])
  80. }
  81. }
  82. }
  83. func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
  84. if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
  85. c.Set("request_model", v)
  86. } else {
  87. c.Set("request_model", request.Model)
  88. }
  89. vertexClaudeReq := copyRequest(request, anthropicVersion)
  90. return vertexClaudeReq, nil
  91. }
  92. func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
  93. //TODO implement me
  94. return nil, errors.New("not implemented")
  95. }
  96. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
  97. geminiAdaptor := gemini.Adaptor{}
  98. return geminiAdaptor.ConvertImageRequest(c, info, request)
  99. }
  100. func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
  101. if strings.HasPrefix(info.UpstreamModelName, "claude") {
  102. a.RequestMode = RequestModeClaude
  103. } else if strings.Contains(info.UpstreamModelName, "llama") ||
  104. // open source models
  105. strings.Contains(info.UpstreamModelName, "-maas") {
  106. a.RequestMode = RequestModeOpenSource
  107. } else {
  108. a.RequestMode = RequestModeGemini
  109. }
  110. }
  111. func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
  112. region := GetModelRegion(info.ApiVersion, info.OriginModelName)
  113. if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
  114. adc := &Credentials{}
  115. if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
  116. return "", fmt.Errorf("failed to decode credentials file: %w", err)
  117. }
  118. a.AccountCredentials = *adc
  119. if a.RequestMode == RequestModeGemini {
  120. return BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil
  121. } else if a.RequestMode == RequestModeClaude {
  122. return BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil
  123. } else if a.RequestMode == RequestModeOpenSource {
  124. return BuildOpenSourceChatCompletionsURL(info.ChannelBaseUrl, adc.ProjectID, region), nil
  125. }
  126. } else {
  127. var keyPrefix string
  128. if strings.HasSuffix(suffix, "?alt=sse") {
  129. keyPrefix = "&"
  130. } else {
  131. keyPrefix = "?"
  132. }
  133. if a.RequestMode == RequestModeGemini {
  134. return fmt.Sprintf(
  135. "%s%skey=%s",
  136. BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix),
  137. keyPrefix,
  138. info.ApiKey,
  139. ), nil
  140. } else if a.RequestMode == RequestModeClaude {
  141. return fmt.Sprintf(
  142. "%s%skey=%s",
  143. BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix),
  144. keyPrefix,
  145. info.ApiKey,
  146. ), nil
  147. }
  148. }
  149. return "", errors.New("unsupported request mode")
  150. }
  151. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  152. suffix := ""
  153. if a.RequestMode == RequestModeGemini {
  154. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled &&
  155. !model_setting.ShouldPreserveThinkingSuffix(info.OriginModelName) {
  156. // 新增逻辑:处理 -thinking-<budget> 格式
  157. if strings.Contains(info.UpstreamModelName, "-thinking-") {
  158. parts := strings.Split(info.UpstreamModelName, "-thinking-")
  159. info.UpstreamModelName = parts[0]
  160. } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
  161. info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
  162. } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  163. info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
  164. } else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
  165. info.UpstreamModelName = baseModel
  166. }
  167. }
  168. if info.IsStream {
  169. suffix = "streamGenerateContent?alt=sse"
  170. } else {
  171. suffix = "generateContent"
  172. }
  173. if strings.HasPrefix(info.UpstreamModelName, "imagen") {
  174. suffix = "predict"
  175. }
  176. return a.getRequestUrl(info, info.UpstreamModelName, suffix)
  177. } else if a.RequestMode == RequestModeClaude {
  178. if info.IsStream {
  179. suffix = "streamRawPredict?alt=sse"
  180. } else {
  181. suffix = "rawPredict"
  182. }
  183. model := info.UpstreamModelName
  184. if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
  185. model = v
  186. }
  187. return a.getRequestUrl(info, model, suffix)
  188. } else if a.RequestMode == RequestModeOpenSource {
  189. return a.getRequestUrl(info, "", "")
  190. }
  191. return "", errors.New("unsupported request mode")
  192. }
  193. func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
  194. channel.SetupApiRequestHeader(info, c, req)
  195. if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
  196. accessToken, err := getAccessToken(a, info)
  197. if err != nil {
  198. return err
  199. }
  200. req.Set("Authorization", "Bearer "+accessToken)
  201. }
  202. if a.AccountCredentials.ProjectID != "" {
  203. req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
  204. }
  205. if strings.Contains(info.UpstreamModelName, "claude") {
  206. claude.CommonClaudeHeadersOperation(c, req, info)
  207. }
  208. return nil
  209. }
  210. func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
  211. if request == nil {
  212. return nil, errors.New("request is nil")
  213. }
  214. if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") {
  215. prompt := ""
  216. for _, m := range request.Messages {
  217. if m.Role == "user" {
  218. prompt = m.StringContent()
  219. if prompt != "" {
  220. break
  221. }
  222. }
  223. }
  224. if prompt == "" {
  225. if p, ok := request.Prompt.(string); ok {
  226. prompt = p
  227. }
  228. }
  229. if prompt == "" {
  230. return nil, errors.New("prompt is required for image generation")
  231. }
  232. imgReq := dto.ImageRequest{
  233. Model: request.Model,
  234. Prompt: prompt,
  235. N: lo.ToPtr(uint(1)),
  236. Size: "1024x1024",
  237. }
  238. if request.N != nil && *request.N > 0 {
  239. imgReq.N = lo.ToPtr(uint(*request.N))
  240. }
  241. if request.Size != "" {
  242. imgReq.Size = request.Size
  243. }
  244. if len(request.ExtraBody) > 0 {
  245. var extra map[string]any
  246. if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
  247. if n, ok := extra["n"].(float64); ok && n > 0 {
  248. imgReq.N = lo.ToPtr(uint(n))
  249. }
  250. if size, ok := extra["size"].(string); ok {
  251. imgReq.Size = size
  252. }
  253. // accept aspectRatio in extra body (top-level or under parameters)
  254. if ar, ok := extra["aspectRatio"].(string); ok && ar != "" {
  255. imgReq.Size = ar
  256. }
  257. if params, ok := extra["parameters"].(map[string]any); ok {
  258. if ar, ok := params["aspectRatio"].(string); ok && ar != "" {
  259. imgReq.Size = ar
  260. }
  261. }
  262. }
  263. }
  264. c.Set("request_model", request.Model)
  265. return a.ConvertImageRequest(c, info, imgReq)
  266. }
  267. if a.RequestMode == RequestModeClaude {
  268. claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
  269. if err != nil {
  270. return nil, err
  271. }
  272. vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
  273. c.Set("request_model", claudeReq.Model)
  274. info.UpstreamModelName = claudeReq.Model
  275. return vertexClaudeReq, nil
  276. } else if a.RequestMode == RequestModeGemini {
  277. geminiRequest, err := gemini.CovertOpenAI2Gemini(c, *request, info)
  278. if err != nil {
  279. return nil, err
  280. }
  281. c.Set("request_model", request.Model)
  282. return geminiRequest, nil
  283. } else if a.RequestMode == RequestModeOpenSource {
  284. return request, nil
  285. }
  286. return nil, errors.New("unsupported request mode")
  287. }
  288. func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
  289. return nil, nil
  290. }
  291. func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
  292. //TODO implement me
  293. return nil, errors.New("not implemented")
  294. }
  295. func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
  296. // TODO implement me
  297. return nil, errors.New("not implemented")
  298. }
  299. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
  300. return channel.DoApiRequest(a, c, info, requestBody)
  301. }
  302. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
  303. claudeAdaptor := claude.Adaptor{}
  304. if info.IsStream {
  305. switch a.RequestMode {
  306. case RequestModeClaude:
  307. return claudeAdaptor.DoResponse(c, resp, info)
  308. case RequestModeGemini:
  309. if info.RelayMode == constant.RelayModeGemini {
  310. return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
  311. } else {
  312. return gemini.GeminiChatStreamHandler(c, info, resp)
  313. }
  314. case RequestModeOpenSource:
  315. return openai.OaiStreamHandler(c, info, resp)
  316. }
  317. } else {
  318. switch a.RequestMode {
  319. case RequestModeClaude:
  320. return claudeAdaptor.DoResponse(c, resp, info)
  321. case RequestModeGemini:
  322. if info.RelayMode == constant.RelayModeGemini {
  323. return gemini.GeminiTextGenerationHandler(c, info, resp)
  324. } else {
  325. if strings.HasPrefix(info.UpstreamModelName, "imagen") {
  326. return gemini.GeminiImageHandler(c, info, resp)
  327. }
  328. return gemini.GeminiChatHandler(c, info, resp)
  329. }
  330. case RequestModeOpenSource:
  331. return openai.OpenaiHandler(c, info, resp)
  332. }
  333. }
  334. return
  335. }
  336. func (a *Adaptor) GetModelList() []string {
  337. var modelList []string
  338. for i, s := range ModelList {
  339. modelList = append(modelList, s)
  340. ModelList[i] = s
  341. }
  342. for i, s := range claude.ModelList {
  343. modelList = append(modelList, s)
  344. claude.ModelList[i] = s
  345. }
  346. for i, s := range gemini.ModelList {
  347. modelList = append(modelList, s)
  348. gemini.ModelList[i] = s
  349. }
  350. return modelList
  351. }
  352. func (a *Adaptor) GetChannelName() string {
  353. return ChannelName
  354. }