relay-aws.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package aws
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "github.com/jinzhu/copier"
  8. "github.com/pkg/errors"
  9. "io"
  10. "net/http"
  11. "one-api/common"
  12. relaymodel "one-api/dto"
  13. "one-api/relay/channel/claude"
  14. relaycommon "one-api/relay/common"
  15. "strings"
  16. "time"
  17. "github.com/aws/aws-sdk-go-v2/aws"
  18. "github.com/aws/aws-sdk-go-v2/credentials"
  19. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  20. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
  21. )
  22. func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
  23. awsSecret := strings.Split(info.ApiKey, "|")
  24. if len(awsSecret) != 3 {
  25. return nil, errors.New("invalid aws secret key")
  26. }
  27. ak := awsSecret[0]
  28. sk := awsSecret[1]
  29. region := awsSecret[2]
  30. client := bedrockruntime.New(bedrockruntime.Options{
  31. Region: region,
  32. Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
  33. })
  34. return client, nil
  35. }
  36. func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
  37. return &relaymodel.OpenAIErrorWithStatusCode{
  38. StatusCode: http.StatusInternalServerError,
  39. Error: relaymodel.OpenAIError{
  40. Message: fmt.Sprintf("%s", err.Error()),
  41. },
  42. }
  43. }
  44. func awsModelID(requestModel string) (string, error) {
  45. if awsModelID, ok := awsModelIDMap[requestModel]; ok {
  46. return awsModelID, nil
  47. }
  48. return "", errors.Errorf("model %s not found", requestModel)
  49. }
  50. func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  51. awsCli, err := newAwsClient(c, info)
  52. if err != nil {
  53. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  54. }
  55. awsModelId, err := awsModelID(c.GetString("request_model"))
  56. if err != nil {
  57. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  58. }
  59. awsReq := &bedrockruntime.InvokeModelInput{
  60. ModelId: aws.String(awsModelId),
  61. Accept: aws.String("application/json"),
  62. ContentType: aws.String("application/json"),
  63. }
  64. claudeReq_, ok := c.Get("converted_request")
  65. if !ok {
  66. return wrapErr(errors.New("request not found")), nil
  67. }
  68. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  69. awsClaudeReq := &AwsClaudeRequest{
  70. AnthropicVersion: "bedrock-2023-05-31",
  71. }
  72. if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
  73. return wrapErr(errors.Wrap(err, "copy request")), nil
  74. }
  75. awsReq.Body, err = json.Marshal(awsClaudeReq)
  76. if err != nil {
  77. return wrapErr(errors.Wrap(err, "marshal request")), nil
  78. }
  79. awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
  80. if err != nil {
  81. return wrapErr(errors.Wrap(err, "InvokeModel")), nil
  82. }
  83. claudeResponse := new(claude.ClaudeResponse)
  84. err = json.Unmarshal(awsResp.Body, claudeResponse)
  85. if err != nil {
  86. return wrapErr(errors.Wrap(err, "unmarshal response")), nil
  87. }
  88. openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
  89. usage := relaymodel.Usage{
  90. PromptTokens: claudeResponse.Usage.InputTokens,
  91. CompletionTokens: claudeResponse.Usage.OutputTokens,
  92. TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
  93. }
  94. openaiResp.Usage = usage
  95. c.JSON(http.StatusOK, openaiResp)
  96. return nil, &usage
  97. }
  98. func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  99. awsCli, err := newAwsClient(c, info)
  100. if err != nil {
  101. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  102. }
  103. awsModelId, err := awsModelID(c.GetString("request_model"))
  104. if err != nil {
  105. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  106. }
  107. awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
  108. ModelId: aws.String(awsModelId),
  109. Accept: aws.String("application/json"),
  110. ContentType: aws.String("application/json"),
  111. }
  112. claudeReq_, ok := c.Get("converted_request")
  113. if !ok {
  114. return wrapErr(errors.New("request not found")), nil
  115. }
  116. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  117. awsClaudeReq := &AwsClaudeRequest{
  118. AnthropicVersion: "bedrock-2023-05-31",
  119. }
  120. if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
  121. return wrapErr(errors.Wrap(err, "copy request")), nil
  122. }
  123. awsReq.Body, err = json.Marshal(awsClaudeReq)
  124. if err != nil {
  125. return wrapErr(errors.Wrap(err, "marshal request")), nil
  126. }
  127. awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
  128. if err != nil {
  129. return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
  130. }
  131. stream := awsResp.GetStream()
  132. defer stream.Close()
  133. c.Writer.Header().Set("Content-Type", "text/event-stream")
  134. var usage relaymodel.Usage
  135. var id string
  136. var model string
  137. isFirst := true
  138. createdTime := common.GetTimestamp()
  139. c.Stream(func(w io.Writer) bool {
  140. event, ok := <-stream.Events()
  141. if !ok {
  142. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  143. return false
  144. }
  145. switch v := event.(type) {
  146. case *types.ResponseStreamMemberChunk:
  147. if isFirst {
  148. isFirst = false
  149. info.FirstResponseTime = time.Now()
  150. }
  151. claudeResp := new(claude.ClaudeResponse)
  152. err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
  153. if err != nil {
  154. common.SysError("error unmarshalling stream response: " + err.Error())
  155. return false
  156. }
  157. response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
  158. if claudeUsage != nil {
  159. usage.PromptTokens += claudeUsage.InputTokens
  160. usage.CompletionTokens += claudeUsage.OutputTokens
  161. }
  162. if response == nil {
  163. return true
  164. }
  165. if response.Id != "" {
  166. id = response.Id
  167. }
  168. if response.Model != "" {
  169. model = response.Model
  170. }
  171. response.Created = createdTime
  172. response.Id = id
  173. response.Model = model
  174. jsonStr, err := json.Marshal(response)
  175. if err != nil {
  176. common.SysError("error marshalling stream response: " + err.Error())
  177. return true
  178. }
  179. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  180. return true
  181. case *types.UnknownUnionMember:
  182. fmt.Println("unknown tag:", v.Tag)
  183. return false
  184. default:
  185. fmt.Println("union is nil or unknown type")
  186. return false
  187. }
  188. })
  189. return nil, &usage
  190. }