text.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package xai
  2. import (
  3. "io"
  4. "net/http"
  5. "strings"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/dto"
  8. "github.com/QuantumNous/new-api/relay/channel/openai"
  9. relaycommon "github.com/QuantumNous/new-api/relay/common"
  10. "github.com/QuantumNous/new-api/relay/helper"
  11. "github.com/QuantumNous/new-api/service"
  12. "github.com/QuantumNous/new-api/types"
  13. "github.com/gin-gonic/gin"
  14. )
  15. func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
  16. if xAIResp == nil {
  17. return nil
  18. }
  19. if xAIResp.Usage != nil {
  20. xAIResp.Usage.CompletionTokens = usage.CompletionTokens
  21. }
  22. openAIResp := &dto.ChatCompletionsStreamResponse{
  23. Id: xAIResp.Id,
  24. Object: xAIResp.Object,
  25. Created: xAIResp.Created,
  26. Model: xAIResp.Model,
  27. Choices: xAIResp.Choices,
  28. Usage: xAIResp.Usage,
  29. }
  30. return openAIResp
  31. }
  32. func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  33. usage := &dto.Usage{}
  34. var responseTextBuilder strings.Builder
  35. var toolCount int
  36. var containStreamUsage bool
  37. helper.SetEventStreamHeaders(c)
  38. helper.StreamScannerHandler(c, resp, info, func(data string, sr *helper.StreamResult) {
  39. var xAIResp *dto.ChatCompletionsStreamResponse
  40. if err := common.UnmarshalJsonStr(data, &xAIResp); err != nil {
  41. common.SysLog("error unmarshalling stream response: " + err.Error())
  42. sr.Error(err)
  43. return
  44. }
  45. // 把 xAI 的usage转换为 OpenAI 的usage
  46. if xAIResp.Usage != nil {
  47. containStreamUsage = true
  48. usage.PromptTokens = xAIResp.Usage.PromptTokens
  49. usage.TotalTokens = xAIResp.Usage.TotalTokens
  50. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  51. }
  52. openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
  53. _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
  54. if err := helper.ObjectData(c, openaiResponse); err != nil {
  55. common.SysLog(err.Error())
  56. sr.Error(err)
  57. }
  58. })
  59. if !containStreamUsage {
  60. usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
  61. usage.CompletionTokens += toolCount * 7
  62. }
  63. helper.Done(c)
  64. service.CloseResponseBodyGracefully(resp)
  65. return usage, nil
  66. }
  67. func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  68. defer service.CloseResponseBodyGracefully(resp)
  69. responseBody, err := io.ReadAll(resp.Body)
  70. if err != nil {
  71. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  72. }
  73. var xaiResponse ChatCompletionResponse
  74. err = common.Unmarshal(responseBody, &xaiResponse)
  75. if err != nil {
  76. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  77. }
  78. if xaiResponse.Usage != nil {
  79. xaiResponse.Usage.CompletionTokens = xaiResponse.Usage.TotalTokens - xaiResponse.Usage.PromptTokens
  80. xaiResponse.Usage.CompletionTokenDetails.TextTokens = xaiResponse.Usage.CompletionTokens - xaiResponse.Usage.CompletionTokenDetails.ReasoningTokens
  81. }
  82. // new body
  83. encodeJson, err := common.Marshal(xaiResponse)
  84. if err != nil {
  85. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  86. }
  87. service.IOCopyBytesGracefully(c, resp, encodeJson)
  88. return xaiResponse.Usage, nil
  89. }