relay-aws.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package aws
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "github.com/pkg/errors"
  8. "io"
  9. "net/http"
  10. "one-api/common"
  11. relaymodel "one-api/dto"
  12. "one-api/relay/channel/claude"
  13. relaycommon "one-api/relay/common"
  14. "one-api/relay/helper"
  15. "one-api/service"
  16. "strings"
  17. "time"
  18. "github.com/aws/aws-sdk-go-v2/aws"
  19. "github.com/aws/aws-sdk-go-v2/credentials"
  20. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
  21. "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
  22. )
  23. func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
  24. awsSecret := strings.Split(info.ApiKey, "|")
  25. if len(awsSecret) != 3 {
  26. return nil, errors.New("invalid aws secret key")
  27. }
  28. ak := awsSecret[0]
  29. sk := awsSecret[1]
  30. region := awsSecret[2]
  31. client := bedrockruntime.New(bedrockruntime.Options{
  32. Region: region,
  33. Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
  34. })
  35. return client, nil
  36. }
  37. func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
  38. return &relaymodel.OpenAIErrorWithStatusCode{
  39. StatusCode: http.StatusInternalServerError,
  40. Error: relaymodel.OpenAIError{
  41. Message: fmt.Sprintf("%s", err.Error()),
  42. },
  43. }
  44. }
  45. func awsModelID(requestModel string) (string, error) {
  46. if awsModelID, ok := awsModelIDMap[requestModel]; ok {
  47. return awsModelID, nil
  48. }
  49. return requestModel, nil
  50. }
  51. func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  52. awsCli, err := newAwsClient(c, info)
  53. if err != nil {
  54. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  55. }
  56. awsModelId, err := awsModelID(c.GetString("request_model"))
  57. if err != nil {
  58. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  59. }
  60. awsReq := &bedrockruntime.InvokeModelInput{
  61. ModelId: aws.String(awsModelId),
  62. Accept: aws.String("application/json"),
  63. ContentType: aws.String("application/json"),
  64. }
  65. claudeReq_, ok := c.Get("converted_request")
  66. if !ok {
  67. return wrapErr(errors.New("request not found")), nil
  68. }
  69. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  70. awsClaudeReq := copyRequest(claudeReq)
  71. awsReq.Body, err = json.Marshal(awsClaudeReq)
  72. if err != nil {
  73. return wrapErr(errors.Wrap(err, "marshal request")), nil
  74. }
  75. awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
  76. if err != nil {
  77. return wrapErr(errors.Wrap(err, "InvokeModel")), nil
  78. }
  79. claudeResponse := new(claude.ClaudeResponse)
  80. err = json.Unmarshal(awsResp.Body, claudeResponse)
  81. if err != nil {
  82. return wrapErr(errors.Wrap(err, "unmarshal response")), nil
  83. }
  84. openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
  85. usage := relaymodel.Usage{
  86. PromptTokens: claudeResponse.Usage.InputTokens,
  87. CompletionTokens: claudeResponse.Usage.OutputTokens,
  88. TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
  89. }
  90. openaiResp.Usage = usage
  91. c.JSON(http.StatusOK, openaiResp)
  92. return nil, &usage
  93. }
  94. func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
  95. awsCli, err := newAwsClient(c, info)
  96. if err != nil {
  97. return wrapErr(errors.Wrap(err, "newAwsClient")), nil
  98. }
  99. awsModelId, err := awsModelID(c.GetString("request_model"))
  100. if err != nil {
  101. return wrapErr(errors.Wrap(err, "awsModelID")), nil
  102. }
  103. awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
  104. ModelId: aws.String(awsModelId),
  105. Accept: aws.String("application/json"),
  106. ContentType: aws.String("application/json"),
  107. }
  108. claudeReq_, ok := c.Get("converted_request")
  109. if !ok {
  110. return wrapErr(errors.New("request not found")), nil
  111. }
  112. claudeReq := claudeReq_.(*claude.ClaudeRequest)
  113. awsClaudeReq := copyRequest(claudeReq)
  114. awsReq.Body, err = json.Marshal(awsClaudeReq)
  115. if err != nil {
  116. return wrapErr(errors.Wrap(err, "marshal request")), nil
  117. }
  118. awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
  119. if err != nil {
  120. return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
  121. }
  122. stream := awsResp.GetStream()
  123. defer stream.Close()
  124. c.Writer.Header().Set("Content-Type", "text/event-stream")
  125. claudeInfo := &claude.ClaudeResponseInfo{
  126. ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  127. Created: common.GetTimestamp(),
  128. Model: info.UpstreamModelName,
  129. ResponseText: strings.Builder{},
  130. Usage: &relaymodel.Usage{},
  131. }
  132. isFirst := true
  133. c.Stream(func(w io.Writer) bool {
  134. event, ok := <-stream.Events()
  135. if !ok {
  136. return false
  137. }
  138. switch v := event.(type) {
  139. case *types.ResponseStreamMemberChunk:
  140. if isFirst {
  141. isFirst = false
  142. info.FirstResponseTime = time.Now()
  143. }
  144. claudeResponse := new(claude.ClaudeResponse)
  145. err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse)
  146. if err != nil {
  147. common.SysError("error unmarshalling stream response: " + err.Error())
  148. return false
  149. }
  150. response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse)
  151. if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) {
  152. return true
  153. }
  154. jsonStr, err := json.Marshal(response)
  155. if err != nil {
  156. common.SysError("error marshalling stream response: " + err.Error())
  157. return true
  158. }
  159. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  160. return true
  161. case *types.UnknownUnionMember:
  162. fmt.Println("unknown tag:", v.Tag)
  163. return false
  164. default:
  165. fmt.Println("union is nil or unknown type")
  166. return false
  167. }
  168. })
  169. if claudeInfo.Usage.PromptTokens == 0 {
  170. //上游出错
  171. }
  172. if claudeInfo.Usage.CompletionTokens == 0 {
  173. claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
  174. }
  175. if info.ShouldIncludeUsage {
  176. response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
  177. err := helper.ObjectData(c, response)
  178. if err != nil {
  179. common.SysError("send final response failed: " + err.Error())
  180. }
  181. }
  182. helper.Done(c)
  183. if resp != nil {
  184. err = resp.Body.Close()
  185. if err != nil {
  186. return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
  187. }
  188. }
  189. return nil, claudeInfo.Usage
  190. }