relay-aws.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. package aws
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "github.com/jinzhu/copier"
  8. "github.com/pkg/errors"
  9. "io"
  10. "net/http"
  11. "one-api/common"
  12. relaymodel "one-api/dto"
  13. "one-api/relay/channel/claude"
  14. relaycommon "one-api/relay/common"
  15. "strings"
  16. "github.com/aws/aws-sdk-go-v2/aws"
  17. "github.com/aws/aws-sdk-go-v2/credentials"
  18. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  19. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
  20. )
  21. func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
  22. awsSecret := strings.Split(info.ApiKey, "|")
  23. if len(awsSecret) != 3 {
  24. return nil, errors.New("invalid aws secret key")
  25. }
  26. ak := awsSecret[0]
  27. sk := awsSecret[1]
  28. region := awsSecret[2]
  29. client := bedrockruntime.New(bedrockruntime.Options{
  30. Region: region,
  31. Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
  32. })
  33. return client, nil
  34. }
  35. func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
  36. return &relaymodel.OpenAIErrorWithStatusCode{
  37. StatusCode: http.StatusInternalServerError,
  38. Error: relaymodel.OpenAIError{
  39. Message: fmt.Sprintf("%s", err.Error()),
  40. },
  41. }
  42. }
  43. func awsModelID(requestModel string) (string, error) {
  44. if awsModelID, ok := awsModelIDMap[requestModel]; ok {
  45. return awsModelID, nil
  46. }
  47. return "", errors.Errorf("model %s not found", requestModel)
  48. }
  49. func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  50. awsCli, err := newAwsClient(c, info)
  51. if err != nil {
  52. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  53. }
  54. awsModelId, err := awsModelID(c.GetString("request_model"))
  55. if err != nil {
  56. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  57. }
  58. awsReq := &bedrockruntime.InvokeModelInput{
  59. ModelId: aws.String(awsModelId),
  60. Accept: aws.String("application/json"),
  61. ContentType: aws.String("application/json"),
  62. }
  63. claudeReq_, ok := c.Get("converted_request")
  64. if !ok {
  65. return wrapErr(errors.New("request not found")), nil
  66. }
  67. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  68. awsClaudeReq := &AwsClaudeRequest{
  69. AnthropicVersion: "bedrock-2023-05-31",
  70. }
  71. if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
  72. return wrapErr(errors.Wrap(err, "copy request")), nil
  73. }
  74. awsReq.Body, err = json.Marshal(awsClaudeReq)
  75. if err != nil {
  76. return wrapErr(errors.Wrap(err, "marshal request")), nil
  77. }
  78. awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
  79. if err != nil {
  80. return wrapErr(errors.Wrap(err, "InvokeModel")), nil
  81. }
  82. claudeResponse := new(claude.ClaudeResponse)
  83. err = json.Unmarshal(awsResp.Body, claudeResponse)
  84. if err != nil {
  85. return wrapErr(errors.Wrap(err, "unmarshal response")), nil
  86. }
  87. openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
  88. usage := relaymodel.Usage{
  89. PromptTokens: claudeResponse.Usage.InputTokens,
  90. CompletionTokens: claudeResponse.Usage.OutputTokens,
  91. TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
  92. }
  93. openaiResp.Usage = usage
  94. c.JSON(http.StatusOK, openaiResp)
  95. return nil, &usage
  96. }
  97. func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  98. awsCli, err := newAwsClient(c, info)
  99. if err != nil {
  100. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  101. }
  102. awsModelId, err := awsModelID(c.GetString("request_model"))
  103. if err != nil {
  104. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  105. }
  106. awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
  107. ModelId: aws.String(awsModelId),
  108. Accept: aws.String("application/json"),
  109. ContentType: aws.String("application/json"),
  110. }
  111. claudeReq_, ok := c.Get("converted_request")
  112. if !ok {
  113. return wrapErr(errors.New("request not found")), nil
  114. }
  115. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  116. awsClaudeReq := &AwsClaudeRequest{
  117. AnthropicVersion: "bedrock-2023-05-31",
  118. }
  119. if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
  120. return wrapErr(errors.Wrap(err, "copy request")), nil
  121. }
  122. awsReq.Body, err = json.Marshal(awsClaudeReq)
  123. if err != nil {
  124. return wrapErr(errors.Wrap(err, "marshal request")), nil
  125. }
  126. awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
  127. if err != nil {
  128. return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
  129. }
  130. stream := awsResp.GetStream()
  131. defer stream.Close()
  132. c.Writer.Header().Set("Content-Type", "text/event-stream")
  133. var usage relaymodel.Usage
  134. var id string
  135. var model string
  136. createdTime := common.GetTimestamp()
  137. c.Stream(func(w io.Writer) bool {
  138. event, ok := <-stream.Events()
  139. if !ok {
  140. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  141. return false
  142. }
  143. switch v := event.(type) {
  144. case *types.ResponseStreamMemberChunk:
  145. claudeResp := new(claude.ClaudeResponse)
  146. err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
  147. if err != nil {
  148. common.SysError("error unmarshalling stream response: " + err.Error())
  149. return false
  150. }
  151. response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
  152. if claudeUsage != nil {
  153. usage.PromptTokens += claudeUsage.InputTokens
  154. usage.CompletionTokens += claudeUsage.OutputTokens
  155. }
  156. if response == nil {
  157. return true
  158. }
  159. if response.Id != "" {
  160. id = response.Id
  161. }
  162. if response.Model != "" {
  163. model = response.Model
  164. }
  165. response.Created = createdTime
  166. response.Id = id
  167. response.Model = model
  168. jsonStr, err := json.Marshal(response)
  169. if err != nil {
  170. common.SysError("error marshalling stream response: " + err.Error())
  171. return true
  172. }
  173. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  174. return true
  175. case *types.UnknownUnionMember:
  176. fmt.Println("unknown tag:", v.Tag)
  177. return false
  178. default:
  179. fmt.Println("union is nil or unknown type")
  180. return false
  181. }
  182. })
  183. return nil, &usage
  184. }