|
|
@@ -0,0 +1,234 @@
|
|
|
+package openai
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/QuantumNous/new-api/common"
|
|
|
+ "github.com/QuantumNous/new-api/dto"
|
|
|
+ "github.com/QuantumNous/new-api/logger"
|
|
|
+ relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
|
+ "github.com/QuantumNous/new-api/relay/helper"
|
|
|
+ "github.com/QuantumNous/new-api/service"
|
|
|
+ "github.com/QuantumNous/new-api/types"
|
|
|
+
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+)
|
|
|
+
|
|
|
+func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
|
+ if resp == nil || resp.Body == nil {
|
|
|
+ return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+
|
|
|
+ defer service.CloseResponseBodyGracefully(resp)
|
|
|
+
|
|
|
+ var responsesResp dto.OpenAIResponsesResponse
|
|
|
+ const maxResponseBodyBytes = 10 << 20 // 10MB
|
|
|
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes+1))
|
|
|
+ if err != nil {
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ if int64(len(body)) > maxResponseBodyBytes {
|
|
|
+ return nil, types.NewOpenAIError(fmt.Errorf("response body exceeds %d bytes", maxResponseBodyBytes), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := common.Unmarshal(body, &responsesResp); err != nil {
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+
|
|
|
+ if oaiError := responsesResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
|
|
+ return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
|
|
+ }
|
|
|
+
|
|
|
+ chatId := helper.GetResponseID(c)
|
|
|
+ chatResp, usage, err := service.ResponsesResponseToChatCompletionsResponse(&responsesResp, chatId)
|
|
|
+ if err != nil {
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+
|
|
|
+ if usage == nil || usage.TotalTokens == 0 {
|
|
|
+ text := service.ExtractOutputTextFromResponses(&responsesResp)
|
|
|
+ usage = service.ResponseText2Usage(c, text, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
|
|
+ chatResp.Usage = *usage
|
|
|
+ }
|
|
|
+
|
|
|
+ chatBody, err := common.Marshal(chatResp)
|
|
|
+ if err != nil {
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+
|
|
|
+ service.IOCopyBytesGracefully(c, resp, chatBody)
|
|
|
+ return usage, nil
|
|
|
+}
|
|
|
+
|
|
|
+func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
|
|
+ if resp == nil || resp.Body == nil {
|
|
|
+ return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+
|
|
|
+ defer service.CloseResponseBodyGracefully(resp)
|
|
|
+
|
|
|
+ responseId := helper.GetResponseID(c)
|
|
|
+ createAt := time.Now().Unix()
|
|
|
+ model := info.UpstreamModelName
|
|
|
+
|
|
|
+ var (
|
|
|
+ usage = &dto.Usage{}
|
|
|
+ textBuilder strings.Builder
|
|
|
+ sentStart bool
|
|
|
+ sentStop bool
|
|
|
+ streamErr *types.NewAPIError
|
|
|
+ )
|
|
|
+
|
|
|
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ if streamErr != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ var streamResp dto.ResponsesStreamResponse
|
|
|
+ if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
|
|
|
+ logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ switch streamResp.Type {
|
|
|
+ case "response.created":
|
|
|
+ if streamResp.Response != nil {
|
|
|
+ if streamResp.Response.Model != "" {
|
|
|
+ model = streamResp.Response.Model
|
|
|
+ }
|
|
|
+ if streamResp.Response.CreatedAt != 0 {
|
|
|
+ createAt = int64(streamResp.Response.CreatedAt)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ case "response.output_text.delta":
|
|
|
+ if !sentStart {
|
|
|
+ if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
|
|
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ sentStart = true
|
|
|
+ }
|
|
|
+
|
|
|
+ if streamResp.Delta != "" {
|
|
|
+ textBuilder.WriteString(streamResp.Delta)
|
|
|
+ delta := streamResp.Delta
|
|
|
+ chunk := &dto.ChatCompletionsStreamResponse{
|
|
|
+ Id: responseId,
|
|
|
+ Object: "chat.completion.chunk",
|
|
|
+ Created: createAt,
|
|
|
+ Model: model,
|
|
|
+ Choices: []dto.ChatCompletionsStreamResponseChoice{
|
|
|
+ {
|
|
|
+ Index: 0,
|
|
|
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
|
|
+ Content: &delta,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ if err := helper.ObjectData(c, chunk); err != nil {
|
|
|
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ case "response.completed":
|
|
|
+ if streamResp.Response != nil {
|
|
|
+ if streamResp.Response.Model != "" {
|
|
|
+ model = streamResp.Response.Model
|
|
|
+ }
|
|
|
+ if streamResp.Response.CreatedAt != 0 {
|
|
|
+ createAt = int64(streamResp.Response.CreatedAt)
|
|
|
+ }
|
|
|
+ if streamResp.Response.Usage != nil {
|
|
|
+ if streamResp.Response.Usage.InputTokens != 0 {
|
|
|
+ usage.PromptTokens = streamResp.Response.Usage.InputTokens
|
|
|
+ usage.InputTokens = streamResp.Response.Usage.InputTokens
|
|
|
+ }
|
|
|
+ if streamResp.Response.Usage.OutputTokens != 0 {
|
|
|
+ usage.CompletionTokens = streamResp.Response.Usage.OutputTokens
|
|
|
+ usage.OutputTokens = streamResp.Response.Usage.OutputTokens
|
|
|
+ }
|
|
|
+ if streamResp.Response.Usage.TotalTokens != 0 {
|
|
|
+ usage.TotalTokens = streamResp.Response.Usage.TotalTokens
|
|
|
+ } else {
|
|
|
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
|
+ }
|
|
|
+ if streamResp.Response.Usage.InputTokensDetails != nil {
|
|
|
+ usage.PromptTokensDetails.CachedTokens = streamResp.Response.Usage.InputTokensDetails.CachedTokens
|
|
|
+ usage.PromptTokensDetails.ImageTokens = streamResp.Response.Usage.InputTokensDetails.ImageTokens
|
|
|
+ usage.PromptTokensDetails.AudioTokens = streamResp.Response.Usage.InputTokensDetails.AudioTokens
|
|
|
+ }
|
|
|
+ if streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens != 0 {
|
|
|
+ usage.CompletionTokenDetails.ReasoningTokens = streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if !sentStart {
|
|
|
+ if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
|
|
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ sentStart = true
|
|
|
+ }
|
|
|
+ if !sentStop {
|
|
|
+ stop := helper.GenerateStopResponse(responseId, createAt, model, "stop")
|
|
|
+ if err := helper.ObjectData(c, stop); err != nil {
|
|
|
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ sentStop = true
|
|
|
+ }
|
|
|
+
|
|
|
+ case "response.error", "response.failed":
|
|
|
+ if streamResp.Response != nil {
|
|
|
+ if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
|
|
|
+ streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+ streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ return false
|
|
|
+
|
|
|
+ case "response.output_item.added", "response.output_item.done":
|
|
|
+
|
|
|
+ default:
|
|
|
+ }
|
|
|
+
|
|
|
+ return true
|
|
|
+ })
|
|
|
+
|
|
|
+ if streamErr != nil {
|
|
|
+ return nil, streamErr
|
|
|
+ }
|
|
|
+
|
|
|
+ if usage.TotalTokens == 0 {
|
|
|
+ usage = service.ResponseText2Usage(c, textBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
|
|
+ }
|
|
|
+
|
|
|
+ if !sentStart {
|
|
|
+ if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if !sentStop {
|
|
|
+ stop := helper.GenerateStopResponse(responseId, createAt, model, "stop")
|
|
|
+ if err := helper.ObjectData(c, stop); err != nil {
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if info.ShouldIncludeUsage && usage != nil {
|
|
|
+ if err := helper.ObjectData(c, helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)); err != nil {
|
|
|
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ helper.Done(c)
|
|
|
+ return usage, nil
|
|
|
+}
|