relay-aws.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. package aws
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/dto"
  12. "github.com/QuantumNous/new-api/relay/channel/claude"
  13. relaycommon "github.com/QuantumNous/new-api/relay/common"
  14. "github.com/QuantumNous/new-api/relay/helper"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/QuantumNous/new-api/types"
  17. "github.com/gin-gonic/gin"
  18. "github.com/pkg/errors"
  19. "github.com/QuantumNous/new-api/setting/model_setting"
  20. "github.com/aws/aws-sdk-go-v2/aws"
  21. "github.com/aws/aws-sdk-go-v2/credentials"
  22. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  23. bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
  24. "github.com/aws/smithy-go/auth/bearer"
  25. )
  26. // getAwsErrorStatusCode extracts HTTP status code from AWS SDK error
  27. func getAwsErrorStatusCode(err error) int {
  28. // Check for HTTP response error which contains status code
  29. var httpErr interface{ HTTPStatusCode() int }
  30. if errors.As(err, &httpErr) {
  31. return httpErr.HTTPStatusCode()
  32. }
  33. // Default to 500 if we can't determine the status code
  34. return http.StatusInternalServerError
  35. }
  36. func newAwsInvokeContext() (context.Context, context.CancelFunc) {
  37. if common.RelayTimeout <= 0 {
  38. return context.Background(), func() {}
  39. }
  40. return context.WithTimeout(context.Background(), time.Duration(common.RelayTimeout)*time.Second)
  41. }
  42. func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
  43. var (
  44. httpClient *http.Client
  45. err error
  46. )
  47. if info.ChannelSetting.Proxy != "" {
  48. httpClient, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
  49. if err != nil {
  50. return nil, fmt.Errorf("new proxy http client failed: %w", err)
  51. }
  52. } else {
  53. httpClient = service.GetHttpClient()
  54. }
  55. awsSecret := strings.Split(info.ApiKey, "|")
  56. var client *bedrockruntime.Client
  57. switch len(awsSecret) {
  58. case 2:
  59. apiKey := awsSecret[0]
  60. region := awsSecret[1]
  61. client = bedrockruntime.New(bedrockruntime.Options{
  62. Region: region,
  63. BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
  64. HTTPClient: httpClient,
  65. })
  66. case 3:
  67. ak := awsSecret[0]
  68. sk := awsSecret[1]
  69. region := awsSecret[2]
  70. client = bedrockruntime.New(bedrockruntime.Options{
  71. Region: region,
  72. Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
  73. HTTPClient: httpClient,
  74. })
  75. default:
  76. return nil, errors.New("invalid aws secret key")
  77. }
  78. return client, nil
  79. }
  80. func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
  81. awsCli, err := newAwsClient(c, info)
  82. if err != nil {
  83. return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
  84. }
  85. a.AwsClient = awsCli
  86. // 获取对应的AWS模型ID
  87. awsModelId := getAwsModelID(info.UpstreamModelName)
  88. awsRegionPrefix := getAwsRegionPrefix(awsCli.Options().Region)
  89. canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
  90. if canCrossRegion {
  91. awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
  92. }
  93. // init empty request.header
  94. requestHeader := http.Header{}
  95. a.SetupRequestHeader(c, &requestHeader, info)
  96. if isNovaModel(awsModelId) {
  97. var novaReq *NovaRequest
  98. err = common.DecodeJson(requestBody, &novaReq)
  99. if err != nil {
  100. return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
  101. }
  102. // 使用InvokeModel API,但使用Nova格式的请求体
  103. awsReq := &bedrockruntime.InvokeModelInput{
  104. ModelId: aws.String(awsModelId),
  105. Accept: aws.String("application/json"),
  106. ContentType: aws.String("application/json"),
  107. }
  108. reqBody, err := common.Marshal(novaReq)
  109. if err != nil {
  110. return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
  111. }
  112. awsReq.Body = reqBody
  113. a.AwsReq = awsReq
  114. return nil, nil
  115. } else {
  116. awsClaudeReq, err := formatRequest(requestBody, requestHeader)
  117. if err != nil {
  118. return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
  119. }
  120. if info.IsStream {
  121. awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
  122. ModelId: aws.String(awsModelId),
  123. Accept: aws.String("application/json"),
  124. ContentType: aws.String("application/json"),
  125. }
  126. awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
  127. if err != nil {
  128. return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
  129. }
  130. a.AwsReq = awsReq
  131. return nil, nil
  132. } else {
  133. awsReq := &bedrockruntime.InvokeModelInput{
  134. ModelId: aws.String(awsModelId),
  135. Accept: aws.String("application/json"),
  136. ContentType: aws.String("application/json"),
  137. }
  138. awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
  139. if err != nil {
  140. return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
  141. }
  142. a.AwsReq = awsReq
  143. return nil, nil
  144. }
  145. }
  146. }
  147. // buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
  148. func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
  149. if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
  150. body, err := common.GetRequestBody(c)
  151. if err != nil {
  152. return nil, errors.Wrap(err, "get request body for pass-through fail")
  153. }
  154. var data map[string]interface{}
  155. if err := common.Unmarshal(body, &data); err != nil {
  156. return nil, errors.Wrap(err, "pass-through unmarshal request body fail")
  157. }
  158. delete(data, "model")
  159. delete(data, "stream")
  160. return common.Marshal(data)
  161. }
  162. return common.Marshal(awsClaudeReq)
  163. }
  164. func getAwsRegionPrefix(awsRegionId string) string {
  165. parts := strings.Split(awsRegionId, "-")
  166. regionPrefix := ""
  167. if len(parts) > 0 {
  168. regionPrefix = parts[0]
  169. }
  170. return regionPrefix
  171. }
  172. func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
  173. regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
  174. return exists && regionSet[awsRegionPrefix]
  175. }
  176. func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
  177. modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
  178. if !find {
  179. return awsModelId
  180. }
  181. return modelPrefix + "." + awsModelId
  182. }
  183. func getAwsModelID(requestModel string) string {
  184. if awsModelIDName, ok := awsModelIDMap[requestModel]; ok {
  185. return awsModelIDName
  186. }
  187. return requestModel
  188. }
  189. func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
  190. ctx, cancel := newAwsInvokeContext()
  191. defer cancel()
  192. awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
  193. if err != nil {
  194. statusCode := getAwsErrorStatusCode(err)
  195. return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
  196. }
  197. claudeInfo := &claude.ClaudeResponseInfo{
  198. ResponseId: helper.GetResponseID(c),
  199. Created: common.GetTimestamp(),
  200. Model: info.UpstreamModelName,
  201. ResponseText: strings.Builder{},
  202. Usage: &dto.Usage{},
  203. }
  204. // 复制上游 Content-Type 到客户端响应头
  205. if awsResp.ContentType != nil && *awsResp.ContentType != "" {
  206. c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
  207. }
  208. handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body)
  209. if handlerErr != nil {
  210. return handlerErr, nil
  211. }
  212. return nil, claudeInfo.Usage
  213. }
  214. func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
  215. ctx, cancel := newAwsInvokeContext()
  216. defer cancel()
  217. awsResp, err := a.AwsClient.InvokeModelWithResponseStream(ctx, a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
  218. if err != nil {
  219. statusCode := getAwsErrorStatusCode(err)
  220. return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
  221. }
  222. stream := awsResp.GetStream()
  223. defer stream.Close()
  224. claudeInfo := &claude.ClaudeResponseInfo{
  225. ResponseId: helper.GetResponseID(c),
  226. Created: common.GetTimestamp(),
  227. Model: info.UpstreamModelName,
  228. ResponseText: strings.Builder{},
  229. Usage: &dto.Usage{},
  230. }
  231. for event := range stream.Events() {
  232. switch v := event.(type) {
  233. case *bedrockruntimeTypes.ResponseStreamMemberChunk:
  234. info.SetFirstResponseTime()
  235. respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes))
  236. if respErr != nil {
  237. return respErr, nil
  238. }
  239. case *bedrockruntimeTypes.UnknownUnionMember:
  240. fmt.Println("unknown tag:", v.Tag)
  241. return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
  242. default:
  243. fmt.Println("union is nil or unknown type")
  244. return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
  245. }
  246. }
  247. claude.HandleStreamFinalResponse(c, info, claudeInfo)
  248. return nil, claudeInfo.Usage
  249. }
  250. // Nova模型处理函数
  251. func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
  252. ctx, cancel := newAwsInvokeContext()
  253. defer cancel()
  254. awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
  255. if err != nil {
  256. statusCode := getAwsErrorStatusCode(err)
  257. return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
  258. }
  259. // 解析Nova响应
  260. var novaResp struct {
  261. Output struct {
  262. Message struct {
  263. Content []struct {
  264. Text string `json:"text"`
  265. } `json:"content"`
  266. } `json:"message"`
  267. } `json:"output"`
  268. Usage struct {
  269. InputTokens int `json:"inputTokens"`
  270. OutputTokens int `json:"outputTokens"`
  271. TotalTokens int `json:"totalTokens"`
  272. } `json:"usage"`
  273. }
  274. if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
  275. return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
  276. }
  277. // 构造OpenAI格式响应
  278. response := dto.OpenAITextResponse{
  279. Id: helper.GetResponseID(c),
  280. Object: "chat.completion",
  281. Created: common.GetTimestamp(),
  282. Model: info.UpstreamModelName,
  283. Choices: []dto.OpenAITextResponseChoice{{
  284. Index: 0,
  285. Message: dto.Message{
  286. Role: "assistant",
  287. Content: novaResp.Output.Message.Content[0].Text,
  288. },
  289. FinishReason: "stop",
  290. }},
  291. Usage: dto.Usage{
  292. PromptTokens: novaResp.Usage.InputTokens,
  293. CompletionTokens: novaResp.Usage.OutputTokens,
  294. TotalTokens: novaResp.Usage.TotalTokens,
  295. },
  296. }
  297. c.JSON(http.StatusOK, response)
  298. return nil, &response.Usage
  299. }