text.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package xai
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. "one-api/relay/channel/openai"
  11. relaycommon "one-api/relay/common"
  12. "one-api/relay/helper"
  13. "one-api/service"
  14. "strings"
  15. )
  16. func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
  17. if xAIResp == nil {
  18. return nil
  19. }
  20. if xAIResp.Usage != nil {
  21. xAIResp.Usage.CompletionTokens = usage.CompletionTokens
  22. }
  23. openAIResp := &dto.ChatCompletionsStreamResponse{
  24. Id: xAIResp.Id,
  25. Object: xAIResp.Object,
  26. Created: xAIResp.Created,
  27. Model: xAIResp.Model,
  28. Choices: xAIResp.Choices,
  29. Usage: xAIResp.Usage,
  30. }
  31. return openAIResp
  32. }
  33. func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  34. usage := &dto.Usage{}
  35. var responseTextBuilder strings.Builder
  36. var toolCount int
  37. var containStreamUsage bool
  38. helper.SetEventStreamHeaders(c)
  39. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  40. var xAIResp *dto.ChatCompletionsStreamResponse
  41. err := json.Unmarshal([]byte(data), &xAIResp)
  42. if err != nil {
  43. common.SysError("error unmarshalling stream response: " + err.Error())
  44. return true
  45. }
  46. // 把 xAI 的usage转换为 OpenAI 的usage
  47. if xAIResp.Usage != nil {
  48. containStreamUsage = true
  49. usage.PromptTokens = xAIResp.Usage.PromptTokens
  50. usage.TotalTokens = xAIResp.Usage.TotalTokens
  51. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  52. }
  53. openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
  54. _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
  55. err = helper.ObjectData(c, openaiResponse)
  56. if err != nil {
  57. common.SysError(err.Error())
  58. }
  59. return true
  60. })
  61. if !containStreamUsage {
  62. usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  63. usage.CompletionTokens += toolCount * 7
  64. }
  65. helper.Done(c)
  66. common.CloseResponseBodyGracefully(resp)
  67. return nil, usage
  68. }
  69. func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  70. responseBody, err := io.ReadAll(resp.Body)
  71. var response *dto.TextResponse
  72. err = common.DecodeJson(responseBody, &response)
  73. if err != nil {
  74. common.SysError("error unmarshalling stream response: " + err.Error())
  75. return nil, nil
  76. }
  77. response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
  78. response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
  79. // new body
  80. encodeJson, err := common.EncodeJson(response)
  81. if err != nil {
  82. common.SysError("error marshalling stream response: " + err.Error())
  83. return nil, nil
  84. }
  85. // set new body
  86. resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
  87. for k, v := range resp.Header {
  88. c.Writer.Header().Set(k, v[0])
  89. }
  90. c.Writer.WriteHeader(resp.StatusCode)
  91. _, err = io.Copy(c.Writer, resp.Body)
  92. if err != nil {
  93. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  94. }
  95. common.CloseResponseBodyGracefully(resp)
  96. return nil, &response.Usage
  97. }