Browse Source

feat: support /v1/completions (close #115)

JustSong 2 years ago
parent
commit
4b6adaec0b
4 changed files with 100 additions and 15 deletions
  1. 1 1
      common/model-ratio.go
  2. 45 0
      controller/model.go
  3. 53 13
      controller/relay.go
  4. 1 1
      router/relay-router.go

+ 1 - 1
common/model-ratio.go

@@ -10,7 +10,7 @@ var ModelRatio = map[string]float64{
 	"gpt-4-0314":              15,
 	"gpt-4-32k":               30,
 	"gpt-4-32k-0314":          30,
-	"gpt-3.5-turbo":           1,
+	"gpt-3.5-turbo":           1, // $0.002 / 1K tokens
 	"gpt-3.5-turbo-0301":      1,
 	"text-ada-001":            0.2,
 	"text-babbage-001":        0.25,

+ 45 - 0
controller/model.go

@@ -116,6 +116,51 @@ func init() {
 			Root:       "text-embedding-ada-002",
 			Parent:     nil,
 		},
+		{
+			Id:         "text-davinci-003",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "text-davinci-003",
+			Parent:     nil,
+		},
+		{
+			Id:         "text-davinci-002",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "text-davinci-002",
+			Parent:     nil,
+		},
+		{
+			Id:         "text-curie-001",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "text-curie-001",
+			Parent:     nil,
+		},
+		{
+			Id:         "text-babbage-001",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "text-babbage-001",
+			Parent:     nil,
+		},
+		{
+			Id:         "text-ada-001",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "text-ada-001",
+			Parent:     nil,
+		},
 	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {

+ 53 - 13
controller/relay.go

@@ -19,6 +19,13 @@ type Message struct {
 	Name    *string `json:"name,omitempty"`
 }
 
+const (
+	RelayModeUnknown = iota
+	RelayModeChatCompletions
+	RelayModeCompletions
+	RelayModeEmbeddings
+)
+
 // https://platform.openai.com/docs/api-reference/chat
 
 type GeneralOpenAIRequest struct {
@@ -69,7 +76,7 @@ type TextResponse struct {
 	Error OpenAIError `json:"error"`
 }
 
-type StreamResponse struct {
+type ChatCompletionsStreamResponse struct {
 	Choices []struct {
 		Delta struct {
 			Content string `json:"content"`
@@ -78,8 +85,23 @@ type StreamResponse struct {
 	} `json:"choices"`
 }
 
+type CompletionsStreamResponse struct {
+	Choices []struct {
+		Text         string `json:"text"`
+		FinishReason string `json:"finish_reason"`
+	} `json:"choices"`
+}
+
 func Relay(c *gin.Context) {
-	err := relayHelper(c)
+	relayMode := RelayModeUnknown
+	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
+		relayMode = RelayModeChatCompletions
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
+		relayMode = RelayModeCompletions
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
+		relayMode = RelayModeEmbeddings
+	}
+	err := relayHelper(c, relayMode)
 	if err != nil {
 		if err.StatusCode == http.StatusTooManyRequests {
 			err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
@@ -110,7 +132,7 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
 	}
 }
 
-func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
+func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	channelType := c.GetInt("channel")
 	tokenId := c.GetInt("token_id")
 	consumeQuota := c.GetBool("consume_quota")
@@ -148,8 +170,13 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 		err := relayPaLM(textRequest, c)
 		return err
 	}
-
-	promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
+	var promptTokens int
+	switch relayMode {
+	case RelayModeChatCompletions:
+		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
+	case RelayModeCompletions:
+		promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
+	}
 	preConsumedTokens := common.PreConsumedQuota
 	if textRequest.MaxTokens != 0 {
 		preConsumedTokens = promptTokens + textRequest.MaxTokens
@@ -245,14 +272,27 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 				dataChan <- data
 				data = data[6:]
 				if !strings.HasPrefix(data, "[DONE]") {
-					var streamResponse StreamResponse
-					err = json.Unmarshal([]byte(data), &streamResponse)
-					if err != nil {
-						common.SysError("Error unmarshalling stream response: " + err.Error())
-						return
-					}
-					for _, choice := range streamResponse.Choices {
-						streamResponseText += choice.Delta.Content
+					switch relayMode {
+					case RelayModeChatCompletions:
+						var streamResponse ChatCompletionsStreamResponse
+						err = json.Unmarshal([]byte(data), &streamResponse)
+						if err != nil {
+							common.SysError("Error unmarshalling stream response: " + err.Error())
+							return
+						}
+						for _, choice := range streamResponse.Choices {
+							streamResponseText += choice.Delta.Content
+						}
+					case RelayModeCompletions:
+						var streamResponse CompletionsStreamResponse
+						err = json.Unmarshal([]byte(data), &streamResponse)
+						if err != nil {
+							common.SysError("Error unmarshalling stream response: " + err.Error())
+							return
+						}
+						for _, choice := range streamResponse.Choices {
+							streamResponseText += choice.Text
+						}
 					}
 				}
 			}

+ 1 - 1
router/relay-router.go

@@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
 	relayV1Router := router.Group("/v1")
 	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
-		relayV1Router.POST("/completions", controller.RelayNotImplemented)
+		relayV1Router.POST("/completions", controller.Relay)
 		relayV1Router.POST("/chat/completions", controller.Relay)
 		relayV1Router.POST("/edits", controller.RelayNotImplemented)
 		relayV1Router.POST("/images/generations", controller.RelayNotImplemented)