adaptor.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package aws
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "strings"
  7. "github.com/QuantumNous/new-api/dto"
  8. "github.com/QuantumNous/new-api/relay/channel"
  9. "github.com/QuantumNous/new-api/relay/channel/claude"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/service"
  12. "github.com/QuantumNous/new-api/types"
  13. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  14. "github.com/pkg/errors"
  15. "github.com/gin-gonic/gin"
  16. )
  17. type ClientMode int
  18. const (
  19. ClientModeApiKey ClientMode = iota + 1
  20. ClientModeAKSK
  21. )
  22. type Adaptor struct {
  23. ClientMode ClientMode
  24. AwsClient *bedrockruntime.Client
  25. AwsModelId string
  26. AwsReq any
  27. IsNova bool
  28. }
  29. func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
  30. //TODO implement me
  31. return nil, errors.New("not implemented")
  32. }
  33. func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
  34. for i, message := range request.Messages {
  35. updated := false
  36. if !message.IsStringContent() {
  37. content, err := message.ParseContent()
  38. if err != nil {
  39. return nil, errors.Wrap(err, "failed to parse message content")
  40. }
  41. for i2, mediaMessage := range content {
  42. if mediaMessage.Source != nil {
  43. if mediaMessage.Source.Type == "url" {
  44. fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
  45. if err != nil {
  46. return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
  47. }
  48. mediaMessage.Source.MediaType = fileData.MimeType
  49. mediaMessage.Source.Data = fileData.Base64Data
  50. mediaMessage.Source.Url = ""
  51. mediaMessage.Source.Type = "base64"
  52. content[i2] = mediaMessage
  53. updated = true
  54. }
  55. }
  56. }
  57. if updated {
  58. message.SetContent(content)
  59. }
  60. }
  61. if updated {
  62. request.Messages[i] = message
  63. }
  64. }
  65. return request, nil
  66. }
  67. func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
  68. //TODO implement me
  69. return nil, errors.New("not implemented")
  70. }
  71. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
  72. //TODO implement me
  73. return nil, errors.New("not implemented")
  74. }
  75. func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
  76. }
  77. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  78. if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
  79. awsModelId := getAwsModelID(info.UpstreamModelName)
  80. a.ClientMode = ClientModeApiKey
  81. awsSecret := strings.Split(info.ApiKey, "|")
  82. if len(awsSecret) != 2 {
  83. return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
  84. }
  85. return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
  86. } else {
  87. a.ClientMode = ClientModeAKSK
  88. return "", nil
  89. }
  90. }
  91. func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
  92. claude.CommonClaudeHeadersOperation(c, req, info)
  93. if a.ClientMode == ClientModeApiKey {
  94. req.Set("Authorization", "Bearer "+info.ApiKey)
  95. }
  96. return nil
  97. }
  98. func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
  99. if request == nil {
  100. return nil, errors.New("request is nil")
  101. }
  102. // 检查是否为Nova模型
  103. if isNovaModel(request.Model) {
  104. novaReq := convertToNovaRequest(request)
  105. a.IsNova = true
  106. return novaReq, nil
  107. }
  108. // 原有的Claude模型处理逻辑
  109. claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
  110. if err != nil {
  111. return nil, errors.Wrap(err, "failed to convert openai request to claude request")
  112. }
  113. info.UpstreamModelName = claudeReq.Model
  114. return claudeReq, err
  115. }
  116. func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
  117. return nil, nil
  118. }
  119. func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
  120. //TODO implement me
  121. return nil, errors.New("not implemented")
  122. }
  123. func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
  124. // TODO implement me
  125. return nil, errors.New("not implemented")
  126. }
  127. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
  128. if a.ClientMode == ClientModeApiKey {
  129. return channel.DoApiRequest(a, c, info, requestBody)
  130. } else {
  131. return doAwsClientRequest(c, info, a, requestBody)
  132. }
  133. }
  134. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
  135. if a.ClientMode == ClientModeApiKey {
  136. claudeAdaptor := claude.Adaptor{}
  137. usage, err = claudeAdaptor.DoResponse(c, resp, info)
  138. } else {
  139. if a.IsNova {
  140. err, usage = handleNovaRequest(c, info, a)
  141. } else {
  142. if info.IsStream {
  143. err, usage = awsStreamHandler(c, info, a)
  144. } else {
  145. err, usage = awsHandler(c, info, a)
  146. }
  147. }
  148. }
  149. return
  150. }
  151. func (a *Adaptor) GetModelList() (models []string) {
  152. for n := range awsModelIDMap {
  153. models = append(models, n)
  154. }
  155. return
  156. }
  157. func (a *Adaptor) GetChannelName() string {
  158. return ChannelName
  159. }