text.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package xai
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. "one-api/relay/channel/openai"
  11. relaycommon "one-api/relay/common"
  12. "one-api/relay/helper"
  13. "one-api/service"
  14. "strings"
  15. )
  16. func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
  17. if xAIResp == nil {
  18. return nil
  19. }
  20. if xAIResp.Usage != nil {
  21. xAIResp.Usage.CompletionTokens = usage.CompletionTokens
  22. }
  23. openAIResp := &dto.ChatCompletionsStreamResponse{
  24. Id: xAIResp.Id,
  25. Object: xAIResp.Object,
  26. Created: xAIResp.Created,
  27. Model: xAIResp.Model,
  28. Choices: xAIResp.Choices,
  29. Usage: xAIResp.Usage,
  30. }
  31. return openAIResp
  32. }
  33. func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  34. usage := &dto.Usage{}
  35. var responseTextBuilder strings.Builder
  36. var toolCount int
  37. var containStreamUsage bool
  38. helper.SetEventStreamHeaders(c)
  39. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  40. var xAIResp *dto.ChatCompletionsStreamResponse
  41. err := json.Unmarshal([]byte(data), &xAIResp)
  42. if err != nil {
  43. common.SysError("error unmarshalling stream response: " + err.Error())
  44. return true
  45. }
  46. // 把 xAI 的usage转换为 OpenAI 的usage
  47. if xAIResp.Usage != nil {
  48. containStreamUsage = true
  49. usage.PromptTokens = xAIResp.Usage.PromptTokens
  50. usage.TotalTokens = xAIResp.Usage.TotalTokens
  51. usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
  52. }
  53. openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
  54. _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
  55. err = helper.ObjectData(c, openaiResponse)
  56. if err != nil {
  57. common.SysError(err.Error())
  58. }
  59. return true
  60. })
  61. if !containStreamUsage {
  62. usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  63. usage.CompletionTokens += toolCount * 7
  64. }
  65. helper.Done(c)
  66. err := resp.Body.Close()
  67. if err != nil {
  68. //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  69. common.SysError("close_response_body_failed: " + err.Error())
  70. }
  71. return nil, usage
  72. }
  73. func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  74. responseBody, err := io.ReadAll(resp.Body)
  75. var response *dto.TextResponse
  76. err = common.DecodeJson(responseBody, &response)
  77. if err != nil {
  78. common.SysError("error unmarshalling stream response: " + err.Error())
  79. return nil, nil
  80. }
  81. response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
  82. response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
  83. // new body
  84. encodeJson, err := common.EncodeJson(response)
  85. if err != nil {
  86. common.SysError("error marshalling stream response: " + err.Error())
  87. return nil, nil
  88. }
  89. // set new body
  90. resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
  91. for k, v := range resp.Header {
  92. c.Writer.Header().Set(k, v[0])
  93. }
  94. c.Writer.WriteHeader(resp.StatusCode)
  95. _, err = io.Copy(c.Writer, resp.Body)
  96. if err != nil {
  97. return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
  98. }
  99. err = resp.Body.Close()
  100. if err != nil {
  101. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  102. }
  103. return nil, &response.Usage
  104. }