relay_responses.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package openai
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/dto"
  8. relaycommon "one-api/relay/common"
  9. "one-api/relay/helper"
  10. "one-api/service"
  11. "strings"
  12. "github.com/gin-gonic/gin"
  13. )
  14. func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  15. defer common.CloseResponseBodyGracefully(resp)
  16. // read response body
  17. var responsesResponse dto.OpenAIResponsesResponse
  18. responseBody, err := io.ReadAll(resp.Body)
  19. if err != nil {
  20. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  21. }
  22. err = common.UnmarshalJson(responseBody, &responsesResponse)
  23. if err != nil {
  24. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  25. }
  26. if responsesResponse.Error != nil {
  27. return &dto.OpenAIErrorWithStatusCode{
  28. Error: dto.OpenAIError{
  29. Message: responsesResponse.Error.Message,
  30. Type: "openai_error",
  31. Code: responsesResponse.Error.Code,
  32. },
  33. StatusCode: resp.StatusCode,
  34. }, nil
  35. }
  36. // 写入新的 response body
  37. common.IOCopyBytesGracefully(c, resp, responseBody)
  38. // compute usage
  39. usage := dto.Usage{}
  40. usage.PromptTokens = responsesResponse.Usage.InputTokens
  41. usage.CompletionTokens = responsesResponse.Usage.OutputTokens
  42. usage.TotalTokens = responsesResponse.Usage.TotalTokens
  43. // 解析 Tools 用量
  44. for _, tool := range responsesResponse.Tools {
  45. info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
  46. }
  47. return nil, &usage
  48. }
  49. func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  50. if resp == nil || resp.Body == nil {
  51. common.LogError(c, "invalid response or response body")
  52. return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
  53. }
  54. var usage = &dto.Usage{}
  55. var responseTextBuilder strings.Builder
  56. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  57. // 检查当前数据是否包含 completed 状态和 usage 信息
  58. var streamResponse dto.ResponsesStreamResponse
  59. if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
  60. sendResponsesStreamData(c, streamResponse, data)
  61. switch streamResponse.Type {
  62. case "response.completed":
  63. usage.PromptTokens = streamResponse.Response.Usage.InputTokens
  64. usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
  65. usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
  66. case "response.output_text.delta":
  67. // 处理输出文本
  68. responseTextBuilder.WriteString(streamResponse.Delta)
  69. case dto.ResponsesOutputTypeItemDone:
  70. // 函数调用处理
  71. if streamResponse.Item != nil {
  72. switch streamResponse.Item.Type {
  73. case dto.BuildInCallWebSearchCall:
  74. info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
  75. }
  76. }
  77. }
  78. }
  79. return true
  80. })
  81. if usage.CompletionTokens == 0 {
  82. // 计算输出文本的 token 数量
  83. tempStr := responseTextBuilder.String()
  84. if len(tempStr) > 0 {
  85. // 非正常结束,使用输出文本的 token 数量
  86. completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
  87. usage.CompletionTokens = completionTokens
  88. }
  89. }
  90. return nil, usage
  91. }