relay-aws.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. package aws
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "github.com/jinzhu/copier"
  8. "github.com/pkg/errors"
  9. "io"
  10. "net/http"
  11. "one-api/common"
  12. relaymodel "one-api/dto"
  13. "one-api/relay/channel/claude"
  14. relaycommon "one-api/relay/common"
  15. "one-api/service"
  16. "strings"
  17. "time"
  18. "github.com/aws/aws-sdk-go-v2/aws"
  19. "github.com/aws/aws-sdk-go-v2/credentials"
  20. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  21. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
  22. )
  23. func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
  24. awsSecret := strings.Split(info.ApiKey, "|")
  25. if len(awsSecret) != 3 {
  26. return nil, errors.New("invalid aws secret key")
  27. }
  28. ak := awsSecret[0]
  29. sk := awsSecret[1]
  30. region := awsSecret[2]
  31. client := bedrockruntime.New(bedrockruntime.Options{
  32. Region: region,
  33. Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
  34. })
  35. return client, nil
  36. }
  37. func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
  38. return &relaymodel.OpenAIErrorWithStatusCode{
  39. StatusCode: http.StatusInternalServerError,
  40. Error: relaymodel.OpenAIError{
  41. Message: fmt.Sprintf("%s", err.Error()),
  42. },
  43. }
  44. }
  45. func awsModelID(requestModel string) (string, error) {
  46. if awsModelID, ok := awsModelIDMap[requestModel]; ok {
  47. return awsModelID, nil
  48. }
  49. return "", errors.Errorf("model %s not found", requestModel)
  50. }
  51. func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  52. awsCli, err := newAwsClient(c, info)
  53. if err != nil {
  54. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  55. }
  56. awsModelId, err := awsModelID(c.GetString("request_model"))
  57. if err != nil {
  58. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  59. }
  60. awsReq := &bedrockruntime.InvokeModelInput{
  61. ModelId: aws.String(awsModelId),
  62. Accept: aws.String("application/json"),
  63. ContentType: aws.String("application/json"),
  64. }
  65. claudeReq_, ok := c.Get("converted_request")
  66. if !ok {
  67. return wrapErr(errors.New("request not found")), nil
  68. }
  69. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  70. awsClaudeReq := &AwsClaudeRequest{
  71. AnthropicVersion: "bedrock-2023-05-31",
  72. }
  73. if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
  74. return wrapErr(errors.Wrap(err, "copy request")), nil
  75. }
  76. awsReq.Body, err = json.Marshal(awsClaudeReq)
  77. if err != nil {
  78. return wrapErr(errors.Wrap(err, "marshal request")), nil
  79. }
  80. awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
  81. if err != nil {
  82. return wrapErr(errors.Wrap(err, "InvokeModel")), nil
  83. }
  84. claudeResponse := new(claude.ClaudeResponse)
  85. err = json.Unmarshal(awsResp.Body, claudeResponse)
  86. if err != nil {
  87. return wrapErr(errors.Wrap(err, "unmarshal response")), nil
  88. }
  89. openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
  90. usage := relaymodel.Usage{
  91. PromptTokens: claudeResponse.Usage.InputTokens,
  92. CompletionTokens: claudeResponse.Usage.OutputTokens,
  93. TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
  94. }
  95. openaiResp.Usage = usage
  96. c.JSON(http.StatusOK, openaiResp)
  97. return nil, &usage
  98. }
  99. func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  100. awsCli, err := newAwsClient(c, info)
  101. if err != nil {
  102. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  103. }
  104. awsModelId, err := awsModelID(c.GetString("request_model"))
  105. if err != nil {
  106. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  107. }
  108. awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
  109. ModelId: aws.String(awsModelId),
  110. Accept: aws.String("application/json"),
  111. ContentType: aws.String("application/json"),
  112. }
  113. claudeReq_, ok := c.Get("converted_request")
  114. if !ok {
  115. return wrapErr(errors.New("request not found")), nil
  116. }
  117. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  118. awsClaudeReq := &AwsClaudeRequest{
  119. AnthropicVersion: "bedrock-2023-05-31",
  120. }
  121. if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
  122. return wrapErr(errors.Wrap(err, "copy request")), nil
  123. }
  124. awsReq.Body, err = json.Marshal(awsClaudeReq)
  125. if err != nil {
  126. return wrapErr(errors.Wrap(err, "marshal request")), nil
  127. }
  128. awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
  129. if err != nil {
  130. return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
  131. }
  132. stream := awsResp.GetStream()
  133. defer stream.Close()
  134. c.Writer.Header().Set("Content-Type", "text/event-stream")
  135. var usage relaymodel.Usage
  136. var id string
  137. var model string
  138. isFirst := true
  139. createdTime := common.GetTimestamp()
  140. c.Stream(func(w io.Writer) bool {
  141. event, ok := <-stream.Events()
  142. if !ok {
  143. return false
  144. }
  145. switch v := event.(type) {
  146. case *types.ResponseStreamMemberChunk:
  147. if isFirst {
  148. isFirst = false
  149. info.FirstResponseTime = time.Now()
  150. }
  151. claudeResp := new(claude.ClaudeResponse)
  152. err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
  153. if err != nil {
  154. common.SysError("error unmarshalling stream response: " + err.Error())
  155. return false
  156. }
  157. response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
  158. if claudeUsage != nil {
  159. usage.PromptTokens += claudeUsage.InputTokens
  160. usage.CompletionTokens += claudeUsage.OutputTokens
  161. }
  162. if response == nil {
  163. return true
  164. }
  165. if response.Id != "" {
  166. id = response.Id
  167. }
  168. if response.Model != "" {
  169. model = response.Model
  170. }
  171. response.Created = createdTime
  172. response.Id = id
  173. response.Model = model
  174. jsonStr, err := json.Marshal(response)
  175. if err != nil {
  176. common.SysError("error marshalling stream response: " + err.Error())
  177. return true
  178. }
  179. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  180. return true
  181. case *types.UnknownUnionMember:
  182. fmt.Println("unknown tag:", v.Tag)
  183. return false
  184. default:
  185. fmt.Println("union is nil or unknown type")
  186. return false
  187. }
  188. })
  189. if info.ShouldIncludeUsage {
  190. response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
  191. err := service.ObjectData(c, response)
  192. if err != nil {
  193. common.SysError("send final response failed: " + err.Error())
  194. }
  195. }
  196. service.Done(c)
  197. err = resp.Body.Close()
  198. if err != nil {
  199. return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
  200. }
  201. return nil, &usage
  202. }