relay_cloudflare.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package cloudflare
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "strings"
  8. "time"
  9. "github.com/QuantumNous/new-api/dto"
  10. "github.com/QuantumNous/new-api/logger"
  11. relaycommon "github.com/QuantumNous/new-api/relay/common"
  12. "github.com/QuantumNous/new-api/relay/helper"
  13. "github.com/QuantumNous/new-api/service"
  14. "github.com/QuantumNous/new-api/types"
  15. "github.com/samber/lo"
  16. "github.com/gin-gonic/gin"
  17. )
  18. func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
  19. p, _ := textRequest.Prompt.(string)
  20. return &CfRequest{
  21. Prompt: p,
  22. MaxTokens: textRequest.GetMaxTokens(),
  23. Stream: lo.FromPtrOr(textRequest.Stream, false),
  24. Temperature: textRequest.Temperature,
  25. }
  26. }
  27. func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  28. scanner := bufio.NewScanner(resp.Body)
  29. scanner.Split(bufio.ScanLines)
  30. helper.SetEventStreamHeaders(c)
  31. id := helper.GetResponseID(c)
  32. var responseText string
  33. isFirst := true
  34. for scanner.Scan() {
  35. data := scanner.Text()
  36. if len(data) < len("data: ") {
  37. continue
  38. }
  39. data = strings.TrimPrefix(data, "data: ")
  40. data = strings.TrimSuffix(data, "\r")
  41. if data == "[DONE]" {
  42. break
  43. }
  44. var response dto.ChatCompletionsStreamResponse
  45. err := json.Unmarshal([]byte(data), &response)
  46. if err != nil {
  47. logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
  48. continue
  49. }
  50. for _, choice := range response.Choices {
  51. choice.Delta.Role = "assistant"
  52. responseText += choice.Delta.GetContentString()
  53. }
  54. response.Id = id
  55. response.Model = info.UpstreamModelName
  56. err = helper.ObjectData(c, response)
  57. if isFirst {
  58. isFirst = false
  59. info.FirstResponseTime = time.Now()
  60. }
  61. if err != nil {
  62. logger.LogError(c, "error_rendering_stream_response: "+err.Error())
  63. }
  64. }
  65. if err := scanner.Err(); err != nil {
  66. logger.LogError(c, "error_scanning_stream_response: "+err.Error())
  67. }
  68. usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
  69. if info.ShouldIncludeUsage {
  70. response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
  71. err := helper.ObjectData(c, response)
  72. if err != nil {
  73. logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
  74. }
  75. }
  76. helper.Done(c)
  77. service.CloseResponseBodyGracefully(resp)
  78. return nil, usage
  79. }
  80. func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  81. responseBody, err := io.ReadAll(resp.Body)
  82. if err != nil {
  83. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  84. }
  85. service.CloseResponseBodyGracefully(resp)
  86. var response dto.TextResponse
  87. err = json.Unmarshal(responseBody, &response)
  88. if err != nil {
  89. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  90. }
  91. response.Model = info.UpstreamModelName
  92. var responseText string
  93. for _, choice := range response.Choices {
  94. responseText += choice.Message.StringContent()
  95. }
  96. usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
  97. response.Usage = *usage
  98. response.Id = helper.GetResponseID(c)
  99. jsonResponse, err := json.Marshal(response)
  100. if err != nil {
  101. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  102. }
  103. c.Writer.Header().Set("Content-Type", "application/json")
  104. c.Writer.WriteHeader(resp.StatusCode)
  105. _, _ = c.Writer.Write(jsonResponse)
  106. return nil, usage
  107. }
  108. func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  109. var cfResp CfAudioResponse
  110. responseBody, err := io.ReadAll(resp.Body)
  111. if err != nil {
  112. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  113. }
  114. service.CloseResponseBodyGracefully(resp)
  115. err = json.Unmarshal(responseBody, &cfResp)
  116. if err != nil {
  117. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  118. }
  119. audioResp := &dto.AudioResponse{
  120. Text: cfResp.Result.Text,
  121. }
  122. jsonResponse, err := json.Marshal(audioResp)
  123. if err != nil {
  124. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  125. }
  126. c.Writer.Header().Set("Content-Type", "application/json")
  127. c.Writer.WriteHeader(resp.StatusCode)
  128. _, _ = c.Writer.Write(jsonResponse)
  129. usage := service.ResponseText2Usage(c, cfResp.Result.Text, info.UpstreamModelName, info.GetEstimatePromptTokens())
  130. return nil, usage
  131. }