adaptor.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. package vertex
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. "one-api/relay/channel"
  11. "one-api/relay/channel/claude"
  12. "one-api/relay/channel/gemini"
  13. "one-api/relay/channel/openai"
  14. relaycommon "one-api/relay/common"
  15. "one-api/relay/constant"
  16. "one-api/setting/model_setting"
  17. "one-api/types"
  18. "strings"
  19. "github.com/gin-gonic/gin"
  20. )
  21. const (
  22. RequestModeClaude = 1
  23. RequestModeGemini = 2
  24. RequestModeLlama = 3
  25. )
  26. var claudeModelMap = map[string]string{
  27. "claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
  28. "claude-3-opus-20240229": "claude-3-opus@20240229",
  29. "claude-3-haiku-20240307": "claude-3-haiku@20240307",
  30. "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
  31. "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
  32. "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
  33. "claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
  34. "claude-opus-4-20250514": "claude-opus-4@20250514",
  35. "claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
  36. }
  37. const anthropicVersion = "vertex-2023-10-16"
  38. type Adaptor struct {
  39. RequestMode int
  40. AccountCredentials Credentials
  41. }
  42. func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
  43. geminiAdaptor := gemini.Adaptor{}
  44. return geminiAdaptor.ConvertGeminiRequest(c, info, request)
  45. }
  46. func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
  47. if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
  48. c.Set("request_model", v)
  49. } else {
  50. c.Set("request_model", request.Model)
  51. }
  52. vertexClaudeReq := copyRequest(request, anthropicVersion)
  53. return vertexClaudeReq, nil
  54. }
  55. func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
  56. //TODO implement me
  57. return nil, errors.New("not implemented")
  58. }
  59. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
  60. geminiAdaptor := gemini.Adaptor{}
  61. return geminiAdaptor.ConvertImageRequest(c, info, request)
  62. }
  63. func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
  64. if strings.HasPrefix(info.UpstreamModelName, "claude") {
  65. a.RequestMode = RequestModeClaude
  66. } else if strings.Contains(info.UpstreamModelName, "llama") {
  67. a.RequestMode = RequestModeLlama
  68. } else {
  69. a.RequestMode = RequestModeGemini
  70. }
  71. }
  72. func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
  73. region := GetModelRegion(info.ApiVersion, info.OriginModelName)
  74. if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
  75. adc := &Credentials{}
  76. if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
  77. return "", fmt.Errorf("failed to decode credentials file: %w", err)
  78. }
  79. a.AccountCredentials = *adc
  80. if a.RequestMode == RequestModeLlama {
  81. return fmt.Sprintf(
  82. "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
  83. region,
  84. adc.ProjectID,
  85. region,
  86. ), nil
  87. }
  88. if region == "global" {
  89. return fmt.Sprintf(
  90. "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
  91. adc.ProjectID,
  92. modelName,
  93. suffix,
  94. ), nil
  95. } else {
  96. return fmt.Sprintf(
  97. "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
  98. region,
  99. adc.ProjectID,
  100. region,
  101. modelName,
  102. suffix,
  103. ), nil
  104. }
  105. } else {
  106. if region == "global" {
  107. return fmt.Sprintf(
  108. "https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
  109. modelName,
  110. suffix,
  111. info.ApiKey,
  112. ), nil
  113. } else {
  114. return fmt.Sprintf(
  115. "https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
  116. region,
  117. modelName,
  118. suffix,
  119. info.ApiKey,
  120. ), nil
  121. }
  122. }
  123. }
  124. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  125. suffix := ""
  126. if a.RequestMode == RequestModeGemini {
  127. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
  128. // 新增逻辑:处理 -thinking-<budget> 格式
  129. if strings.Contains(info.UpstreamModelName, "-thinking-") {
  130. parts := strings.Split(info.UpstreamModelName, "-thinking-")
  131. info.UpstreamModelName = parts[0]
  132. } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
  133. info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
  134. } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  135. info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
  136. }
  137. }
  138. if info.IsStream {
  139. suffix = "streamGenerateContent?alt=sse"
  140. } else {
  141. suffix = "generateContent"
  142. }
  143. if strings.HasPrefix(info.UpstreamModelName, "imagen") {
  144. suffix = "predict"
  145. }
  146. return a.getRequestUrl(info, info.UpstreamModelName, suffix)
  147. } else if a.RequestMode == RequestModeClaude {
  148. if info.IsStream {
  149. suffix = "streamRawPredict?alt=sse"
  150. } else {
  151. suffix = "rawPredict"
  152. }
  153. model := info.UpstreamModelName
  154. if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
  155. model = v
  156. }
  157. return a.getRequestUrl(info, model, suffix)
  158. } else if a.RequestMode == RequestModeLlama {
  159. return a.getRequestUrl(info, "", "")
  160. }
  161. return "", errors.New("unsupported request mode")
  162. }
  163. func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
  164. channel.SetupApiRequestHeader(info, c, req)
  165. if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
  166. accessToken, err := getAccessToken(a, info)
  167. if err != nil {
  168. return err
  169. }
  170. req.Set("Authorization", "Bearer "+accessToken)
  171. }
  172. if a.AccountCredentials.ProjectID != "" {
  173. req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
  174. }
  175. return nil
  176. }
  177. func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
  178. if request == nil {
  179. return nil, errors.New("request is nil")
  180. }
  181. if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") {
  182. prompt := ""
  183. for _, m := range request.Messages {
  184. if m.Role == "user" {
  185. prompt = m.StringContent()
  186. if prompt != "" {
  187. break
  188. }
  189. }
  190. }
  191. if prompt == "" {
  192. if p, ok := request.Prompt.(string); ok {
  193. prompt = p
  194. }
  195. }
  196. if prompt == "" {
  197. return nil, errors.New("prompt is required for image generation")
  198. }
  199. imgReq := dto.ImageRequest{
  200. Model: request.Model,
  201. Prompt: prompt,
  202. N: 1,
  203. Size: "1024x1024",
  204. }
  205. if request.N > 0 {
  206. imgReq.N = uint(request.N)
  207. }
  208. if request.Size != "" {
  209. imgReq.Size = request.Size
  210. }
  211. if len(request.ExtraBody) > 0 {
  212. var extra map[string]any
  213. if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
  214. if n, ok := extra["n"].(float64); ok && n > 0 {
  215. imgReq.N = uint(n)
  216. }
  217. if size, ok := extra["size"].(string); ok {
  218. imgReq.Size = size
  219. }
  220. // accept aspectRatio in extra body (top-level or under parameters)
  221. if ar, ok := extra["aspectRatio"].(string); ok && ar != "" {
  222. imgReq.Size = ar
  223. }
  224. if params, ok := extra["parameters"].(map[string]any); ok {
  225. if ar, ok := params["aspectRatio"].(string); ok && ar != "" {
  226. imgReq.Size = ar
  227. }
  228. }
  229. }
  230. }
  231. c.Set("request_model", request.Model)
  232. return a.ConvertImageRequest(c, info, imgReq)
  233. }
  234. if a.RequestMode == RequestModeClaude {
  235. claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
  236. if err != nil {
  237. return nil, err
  238. }
  239. vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
  240. c.Set("request_model", claudeReq.Model)
  241. info.UpstreamModelName = claudeReq.Model
  242. return vertexClaudeReq, nil
  243. } else if a.RequestMode == RequestModeGemini {
  244. geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
  245. if err != nil {
  246. return nil, err
  247. }
  248. c.Set("request_model", request.Model)
  249. return geminiRequest, nil
  250. } else if a.RequestMode == RequestModeLlama {
  251. return request, nil
  252. }
  253. return nil, errors.New("unsupported request mode")
  254. }
  255. func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
  256. return nil, nil
  257. }
  258. func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
  259. //TODO implement me
  260. return nil, errors.New("not implemented")
  261. }
  262. func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
  263. // TODO implement me
  264. return nil, errors.New("not implemented")
  265. }
  266. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
  267. return channel.DoApiRequest(a, c, info, requestBody)
  268. }
  269. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
  270. if info.IsStream {
  271. switch a.RequestMode {
  272. case RequestModeClaude:
  273. return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
  274. case RequestModeGemini:
  275. if info.RelayMode == constant.RelayModeGemini {
  276. return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
  277. } else {
  278. return gemini.GeminiChatStreamHandler(c, info, resp)
  279. }
  280. case RequestModeLlama:
  281. return openai.OaiStreamHandler(c, info, resp)
  282. }
  283. } else {
  284. switch a.RequestMode {
  285. case RequestModeClaude:
  286. return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
  287. case RequestModeGemini:
  288. if info.RelayMode == constant.RelayModeGemini {
  289. return gemini.GeminiTextGenerationHandler(c, info, resp)
  290. } else {
  291. if strings.HasPrefix(info.UpstreamModelName, "imagen") {
  292. return gemini.GeminiImageHandler(c, info, resp)
  293. }
  294. return gemini.GeminiChatHandler(c, info, resp)
  295. }
  296. case RequestModeLlama:
  297. return openai.OpenaiHandler(c, info, resp)
  298. }
  299. }
  300. return
  301. }
  302. func (a *Adaptor) GetModelList() []string {
  303. var modelList []string
  304. for i, s := range ModelList {
  305. modelList = append(modelList, s)
  306. ModelList[i] = s
  307. }
  308. for i, s := range claude.ModelList {
  309. modelList = append(modelList, s)
  310. claude.ModelList[i] = s
  311. }
  312. for i, s := range gemini.ModelList {
  313. modelList = append(modelList, s)
  314. gemini.ModelList[i] = s
  315. }
  316. return modelList
  317. }
  318. func (a *Adaptor) GetChannelName() string {
  319. return ChannelName
  320. }