| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- package relay
- import (
- "bytes"
- "net/http"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/relay/channel"
- openaichannel "github.com/QuantumNous/new-api/relay/channel/openai"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- relayconstant "github.com/QuantumNous/new-api/relay/constant"
- "github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- )
- func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) {
- if info == nil || request == nil {
- return
- }
- if info.ChannelSetting.SystemPrompt == "" {
- return
- }
- systemRole := request.GetSystemRoleName()
- containSystemPrompt := false
- for _, message := range request.Messages {
- if message.Role == systemRole {
- containSystemPrompt = true
- break
- }
- }
- if !containSystemPrompt {
- systemMessage := dto.Message{
- Role: systemRole,
- Content: info.ChannelSetting.SystemPrompt,
- }
- request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
- return
- }
- if !info.ChannelSetting.SystemPromptOverride {
- return
- }
- common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
- for i, message := range request.Messages {
- if message.Role != systemRole {
- continue
- }
- if message.IsStringContent() {
- request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
- return
- }
- contents := message.ParseContent()
- contents = append([]dto.MediaContent{
- {
- Type: dto.ContentTypeText,
- Text: info.ChannelSetting.SystemPrompt,
- },
- }, contents...)
- request.Messages[i].Content = contents
- return
- }
- }
- func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) {
- overrideCtx := relaycommon.BuildParamOverrideContext(info)
- chatJSON, err := common.Marshal(request)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- if len(info.ParamOverride) > 0 {
- chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- }
- var overriddenChatReq dto.GeneralOpenAIRequest
- if err := common.Unmarshal(chatJSON, &overriddenChatReq); err != nil {
- return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
- }
- responsesReq, err := service.ChatCompletionsRequestToResponsesRequest(&overriddenChatReq)
- if err != nil {
- return nil, types.NewErrorWithStatusCode(err, types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
- }
- info.AppendRequestConversion(types.RelayFormatOpenAIResponses)
- savedRelayMode := info.RelayMode
- savedRequestURLPath := info.RequestURLPath
- defer func() {
- info.RelayMode = savedRelayMode
- info.RequestURLPath = savedRequestURLPath
- }()
- info.RelayMode = relayconstant.RelayModeResponses
- info.RequestURLPath = "/v1/responses"
- convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *responsesReq)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
- jsonData, err := common.Marshal(convertedRequest)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
- }
- var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, info, bytes.NewBuffer(jsonData))
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
- }
- if resp == nil {
- return nil, types.NewOpenAIError(nil, types.ErrorCodeBadResponse, http.StatusInternalServerError)
- }
- statusCodeMappingStr := c.GetString("status_code_mapping")
- httpResp = resp.(*http.Response)
- info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
- if httpResp.StatusCode != http.StatusOK {
- newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
- service.ResetStatusCode(newApiErr, statusCodeMappingStr)
- return nil, newApiErr
- }
- if info.IsStream {
- usage, newApiErr := openaichannel.OaiResponsesToChatStreamHandler(c, info, httpResp)
- if newApiErr != nil {
- service.ResetStatusCode(newApiErr, statusCodeMappingStr)
- return nil, newApiErr
- }
- return usage, nil
- }
- usage, newApiErr := openaichannel.OaiResponsesToChatHandler(c, info, httpResp)
- if newApiErr != nil {
- service.ResetStatusCode(newApiErr, statusCodeMappingStr)
- return nil, newApiErr
- }
- return usage, nil
- }
|