relay-aws.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package aws
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "github.com/pkg/errors"
  8. "io"
  9. "net/http"
  10. "one-api/common"
  11. relaymodel "one-api/dto"
  12. "one-api/relay/channel/claude"
  13. relaycommon "one-api/relay/common"
  14. "one-api/service"
  15. "strings"
  16. "time"
  17. "github.com/aws/aws-sdk-go-v2/aws"
  18. "github.com/aws/aws-sdk-go-v2/credentials"
  19. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  20. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
  21. )
  22. func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
  23. awsSecret := strings.Split(info.ApiKey, "|")
  24. if len(awsSecret) != 3 {
  25. return nil, errors.New("invalid aws secret key")
  26. }
  27. ak := awsSecret[0]
  28. sk := awsSecret[1]
  29. region := awsSecret[2]
  30. client := bedrockruntime.New(bedrockruntime.Options{
  31. Region: region,
  32. Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
  33. })
  34. return client, nil
  35. }
  36. func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
  37. return &relaymodel.OpenAIErrorWithStatusCode{
  38. StatusCode: http.StatusInternalServerError,
  39. Error: relaymodel.OpenAIError{
  40. Message: fmt.Sprintf("%s", err.Error()),
  41. },
  42. }
  43. }
  44. func awsModelID(requestModel string) (string, error) {
  45. if awsModelID, ok := awsModelIDMap[requestModel]; ok {
  46. return awsModelID, nil
  47. }
  48. return requestModel, nil
  49. }
  50. func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  51. awsCli, err := newAwsClient(c, info)
  52. if err != nil {
  53. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  54. }
  55. awsModelId, err := awsModelID(c.GetString("request_model"))
  56. if err != nil {
  57. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  58. }
  59. awsReq := &bedrockruntime.InvokeModelInput{
  60. ModelId: aws.String(awsModelId),
  61. Accept: aws.String("application/json"),
  62. ContentType: aws.String("application/json"),
  63. }
  64. claudeReq_, ok := c.Get("converted_request")
  65. if !ok {
  66. return wrapErr(errors.New("request not found")), nil
  67. }
  68. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  69. awsClaudeReq := copyRequest(claudeReq)
  70. awsReq.Body, err = json.Marshal(awsClaudeReq)
  71. if err != nil {
  72. return wrapErr(errors.Wrap(err, "marshal request")), nil
  73. }
  74. awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
  75. if err != nil {
  76. return wrapErr(errors.Wrap(err, "InvokeModel")), nil
  77. }
  78. claudeResponse := new(claude.ClaudeResponse)
  79. err = json.Unmarshal(awsResp.Body, claudeResponse)
  80. if err != nil {
  81. return wrapErr(errors.Wrap(err, "unmarshal response")), nil
  82. }
  83. openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
  84. usage := relaymodel.Usage{
  85. PromptTokens: claudeResponse.Usage.InputTokens,
  86. CompletionTokens: claudeResponse.Usage.OutputTokens,
  87. TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
  88. }
  89. openaiResp.Usage = usage
  90. c.JSON(http.StatusOK, openaiResp)
  91. return nil, &usage
  92. }
  93. func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  94. awsCli, err := newAwsClient(c, info)
  95. if err != nil {
  96. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  97. }
  98. awsModelId, err := awsModelID(c.GetString("request_model"))
  99. if err != nil {
  100. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  101. }
  102. awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
  103. ModelId: aws.String(awsModelId),
  104. Accept: aws.String("application/json"),
  105. ContentType: aws.String("application/json"),
  106. }
  107. claudeReq_, ok := c.Get("converted_request")
  108. if !ok {
  109. return wrapErr(errors.New("request not found")), nil
  110. }
  111. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  112. awsClaudeReq := copyRequest(claudeReq)
  113. awsReq.Body, err = json.Marshal(awsClaudeReq)
  114. if err != nil {
  115. return wrapErr(errors.Wrap(err, "marshal request")), nil
  116. }
  117. awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
  118. if err != nil {
  119. return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
  120. }
  121. stream := awsResp.GetStream()
  122. defer stream.Close()
  123. c.Writer.Header().Set("Content-Type", "text/event-stream")
  124. var usage relaymodel.Usage
  125. var id string
  126. var model string
  127. isFirst := true
  128. createdTime := common.GetTimestamp()
  129. c.Stream(func(w io.Writer) bool {
  130. event, ok := <-stream.Events()
  131. if !ok {
  132. return false
  133. }
  134. switch v := event.(type) {
  135. case *types.ResponseStreamMemberChunk:
  136. if isFirst {
  137. isFirst = false
  138. info.FirstResponseTime = time.Now()
  139. }
  140. claudeResp := new(claude.ClaudeResponse)
  141. err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
  142. if err != nil {
  143. common.SysError("error unmarshalling stream response: " + err.Error())
  144. return false
  145. }
  146. response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
  147. if claudeUsage != nil {
  148. usage.PromptTokens += claudeUsage.InputTokens
  149. usage.CompletionTokens += claudeUsage.OutputTokens
  150. }
  151. if response == nil {
  152. return true
  153. }
  154. if response.Id != "" {
  155. id = response.Id
  156. }
  157. if response.Model != "" {
  158. model = response.Model
  159. }
  160. response.Created = createdTime
  161. response.Id = id
  162. response.Model = model
  163. jsonStr, err := json.Marshal(response)
  164. if err != nil {
  165. common.SysError("error marshalling stream response: " + err.Error())
  166. return true
  167. }
  168. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  169. return true
  170. case *types.UnknownUnionMember:
  171. fmt.Println("unknown tag:", v.Tag)
  172. return false
  173. default:
  174. fmt.Println("union is nil or unknown type")
  175. return false
  176. }
  177. })
  178. if info.ShouldIncludeUsage {
  179. response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
  180. err := service.ObjectData(c, response)
  181. if err != nil {
  182. common.SysError("send final response failed: " + err.Error())
  183. }
  184. }
  185. service.Done(c)
  186. if resp != nil {
  187. err = resp.Body.Close()
  188. if err != nil {
  189. return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
  190. }
  191. }
  192. return nil, &usage
  193. }