adaptor.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package aws
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "strings"
  7. "github.com/QuantumNous/new-api/dto"
  8. "github.com/QuantumNous/new-api/relay/channel"
  9. "github.com/QuantumNous/new-api/relay/channel/claude"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/service"
  12. "github.com/QuantumNous/new-api/types"
  13. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  14. "github.com/pkg/errors"
  15. "github.com/gin-gonic/gin"
  16. )
  17. type ClientMode int
  18. const (
  19. ClientModeApiKey ClientMode = iota + 1
  20. ClientModeAKSK
  21. )
  22. type Adaptor struct {
  23. ClientMode ClientMode
  24. AwsClient *bedrockruntime.Client
  25. AwsModelId string
  26. AwsReq any
  27. IsNova bool
  28. }
  29. func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
  30. //TODO implement me
  31. return nil, errors.New("not implemented")
  32. }
  33. func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
  34. for i, message := range request.Messages {
  35. updated := false
  36. if !message.IsStringContent() {
  37. content, err := message.ParseContent()
  38. if err != nil {
  39. return nil, errors.Wrap(err, "failed to parse message content")
  40. }
  41. for i2, mediaMessage := range content {
  42. if mediaMessage.Source != nil {
  43. if mediaMessage.Source.Type == "url" {
  44. // 使用统一的文件服务获取图片数据
  45. source := types.NewURLFileSource(mediaMessage.Source.Url)
  46. base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
  47. if err != nil {
  48. return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
  49. }
  50. mediaMessage.Source.MediaType = mimeType
  51. mediaMessage.Source.Data = base64Data
  52. mediaMessage.Source.Url = ""
  53. mediaMessage.Source.Type = "base64"
  54. content[i2] = mediaMessage
  55. updated = true
  56. }
  57. }
  58. }
  59. if updated {
  60. message.SetContent(content)
  61. }
  62. }
  63. if updated {
  64. request.Messages[i] = message
  65. }
  66. }
  67. return request, nil
  68. }
  69. func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
  70. //TODO implement me
  71. return nil, errors.New("not implemented")
  72. }
  73. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
  74. //TODO implement me
  75. return nil, errors.New("not implemented")
  76. }
  77. func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
  78. }
  79. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  80. if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
  81. awsModelId := getAwsModelID(info.UpstreamModelName)
  82. a.ClientMode = ClientModeApiKey
  83. awsSecret := strings.Split(info.ApiKey, "|")
  84. if len(awsSecret) != 2 {
  85. return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
  86. }
  87. return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
  88. } else {
  89. a.ClientMode = ClientModeAKSK
  90. return "", nil
  91. }
  92. }
  93. func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
  94. claude.CommonClaudeHeadersOperation(c, req, info)
  95. if a.ClientMode == ClientModeApiKey {
  96. req.Set("Authorization", "Bearer "+info.ApiKey)
  97. }
  98. return nil
  99. }
  100. func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
  101. if request == nil {
  102. return nil, errors.New("request is nil")
  103. }
  104. // 检查是否为Nova模型
  105. if isNovaModel(request.Model) {
  106. novaReq := convertToNovaRequest(request)
  107. a.IsNova = true
  108. return novaReq, nil
  109. }
  110. // 原有的Claude模型处理逻辑
  111. claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
  112. if err != nil {
  113. return nil, errors.Wrap(err, "failed to convert openai request to claude request")
  114. }
  115. info.UpstreamModelName = claudeReq.Model
  116. return claudeReq, err
  117. }
  118. func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
  119. return nil, nil
  120. }
  121. func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
  122. //TODO implement me
  123. return nil, errors.New("not implemented")
  124. }
  125. func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
  126. // TODO implement me
  127. return nil, errors.New("not implemented")
  128. }
  129. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
  130. if a.ClientMode == ClientModeApiKey {
  131. return channel.DoApiRequest(a, c, info, requestBody)
  132. } else {
  133. return doAwsClientRequest(c, info, a, requestBody)
  134. }
  135. }
  136. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
  137. if a.ClientMode == ClientModeApiKey {
  138. claudeAdaptor := claude.Adaptor{}
  139. usage, err = claudeAdaptor.DoResponse(c, resp, info)
  140. } else {
  141. if a.IsNova {
  142. err, usage = handleNovaRequest(c, info, a)
  143. } else {
  144. if info.IsStream {
  145. err, usage = awsStreamHandler(c, info, a)
  146. } else {
  147. err, usage = awsHandler(c, info, a)
  148. }
  149. }
  150. }
  151. return
  152. }
  153. func (a *Adaptor) GetModelList() (models []string) {
  154. for n := range awsModelIDMap {
  155. models = append(models, n)
  156. }
  157. return
  158. }
  159. func (a *Adaptor) GetChannelName() string {
  160. return ChannelName
  161. }