Explorar o código

perf: validate the request first before send to OpenAI's server

JustSong %!s(int64=2) %!d(string=hai) anos
pai
achega
f6eb4e5628
Modificáronse 1 ficheiros con 20 adicións e 0 borrados
  1. 20 0
      controller/relay-text.go

+ 20 - 0
controller/relay-text.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
@@ -29,6 +30,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	if relayMode == RelayModeModeration && textRequest.Model == "" {
 		textRequest.Model = "text-moderation-latest"
 	}
+	// request validation
+	if textRequest.Model == "" {
+		return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
+	}
+	switch relayMode {
+	case RelayModeCompletions:
+		if textRequest.Prompt == "" {
+			return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
+		}
+	case RelayModeChatCompletions:
+		if len(textRequest.Messages) == 0 {
+			return errorWrapper(errors.New("messages is required"), "required_field_missing", http.StatusBadRequest)
+		}
+	case RelayModeEmbeddings:
+	case RelayModeModeration:
+		if textRequest.Input == "" {
+			return errorWrapper(errors.New("input is required"), "required_field_missing", http.StatusBadRequest)
+		}
+	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
 	if c.GetString("base_url") != "" {