relay_responses.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package openai
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. relaycommon "one-api/relay/common"
  10. "one-api/relay/helper"
  11. "one-api/service"
  12. "strings"
  13. "github.com/gin-gonic/gin"
  14. )
  15. func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  16. // read response body
  17. var responsesResponse dto.OpenAIResponsesResponse
  18. responseBody, err := io.ReadAll(resp.Body)
  19. if err != nil {
  20. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  21. }
  22. err = resp.Body.Close()
  23. if err != nil {
  24. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  25. }
  26. err = common.DecodeJson(responseBody, &responsesResponse)
  27. if err != nil {
  28. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  29. }
  30. if responsesResponse.Error != nil {
  31. return &dto.OpenAIErrorWithStatusCode{
  32. Error: dto.OpenAIError{
  33. Message: responsesResponse.Error.Message,
  34. Type: "openai_error",
  35. Code: responsesResponse.Error.Code,
  36. },
  37. StatusCode: resp.StatusCode,
  38. }, nil
  39. }
  40. // reset response body
  41. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  42. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  43. // And then we will have to send an error response, but in this case, the header has already been set.
  44. // So the httpClient will be confused by the response.
  45. // For example, Postman will report error, and we cannot check the response at all.
  46. for k, v := range resp.Header {
  47. c.Writer.Header().Set(k, v[0])
  48. }
  49. c.Writer.WriteHeader(resp.StatusCode)
  50. // copy response body
  51. _, err = io.Copy(c.Writer, resp.Body)
  52. if err != nil {
  53. common.SysError("error copying response body: " + err.Error())
  54. }
  55. resp.Body.Close()
  56. // compute usage
  57. usage := dto.Usage{}
  58. usage.PromptTokens = responsesResponse.Usage.InputTokens
  59. usage.CompletionTokens = responsesResponse.Usage.OutputTokens
  60. usage.TotalTokens = responsesResponse.Usage.TotalTokens
  61. // 解析 Tools 用量
  62. for _, tool := range responsesResponse.Tools {
  63. info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
  64. }
  65. return nil, &usage
  66. }
  67. func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  68. if resp == nil || resp.Body == nil {
  69. common.LogError(c, "invalid response or response body")
  70. return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
  71. }
  72. var usage = &dto.Usage{}
  73. var responseTextBuilder strings.Builder
  74. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  75. // 检查当前数据是否包含 completed 状态和 usage 信息
  76. var streamResponse dto.ResponsesStreamResponse
  77. if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
  78. sendResponsesStreamData(c, streamResponse, data)
  79. switch streamResponse.Type {
  80. case "response.completed":
  81. usage.PromptTokens = streamResponse.Response.Usage.InputTokens
  82. usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
  83. usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
  84. case "response.output_text.delta":
  85. // 处理输出文本
  86. responseTextBuilder.WriteString(streamResponse.Delta)
  87. case dto.ResponsesOutputTypeItemDone:
  88. // 函数调用处理
  89. if streamResponse.Item != nil {
  90. switch streamResponse.Item.Type {
  91. case dto.BuildInCallWebSearchCall:
  92. info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
  93. }
  94. }
  95. }
  96. }
  97. return true
  98. })
  99. if usage.CompletionTokens == 0 {
  100. // 计算输出文本的 token 数量
  101. tempStr := responseTextBuilder.String()
  102. if len(tempStr) > 0 {
  103. // 非正常结束,使用输出文本的 token 数量
  104. completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
  105. usage.CompletionTokens = completionTokens
  106. }
  107. }
  108. return nil, usage
  109. }