|
@@ -19,6 +19,13 @@ type Message struct {
|
|
|
Name *string `json:"name,omitempty"`
|
|
Name *string `json:"name,omitempty"`
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+const (
|
|
|
|
|
+ RelayModeUnknown = iota
|
|
|
|
|
+ RelayModeChatCompletions
|
|
|
|
|
+ RelayModeCompletions
|
|
|
|
|
+ RelayModeEmbeddings
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
// https://platform.openai.com/docs/api-reference/chat
|
|
// https://platform.openai.com/docs/api-reference/chat
|
|
|
|
|
|
|
|
type GeneralOpenAIRequest struct {
|
|
type GeneralOpenAIRequest struct {
|
|
@@ -69,7 +76,7 @@ type TextResponse struct {
|
|
|
Error OpenAIError `json:"error"`
|
|
Error OpenAIError `json:"error"`
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-type StreamResponse struct {
|
|
|
|
|
|
|
+type ChatCompletionsStreamResponse struct {
|
|
|
Choices []struct {
|
|
Choices []struct {
|
|
|
Delta struct {
|
|
Delta struct {
|
|
|
Content string `json:"content"`
|
|
Content string `json:"content"`
|
|
@@ -78,8 +85,23 @@ type StreamResponse struct {
|
|
|
} `json:"choices"`
|
|
} `json:"choices"`
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+type CompletionsStreamResponse struct {
|
|
|
|
|
+ Choices []struct {
|
|
|
|
|
+ Text string `json:"text"`
|
|
|
|
|
+ FinishReason string `json:"finish_reason"`
|
|
|
|
|
+ } `json:"choices"`
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func Relay(c *gin.Context) {
|
|
func Relay(c *gin.Context) {
|
|
|
- err := relayHelper(c)
|
|
|
|
|
|
|
+ relayMode := RelayModeUnknown
|
|
|
|
|
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
|
|
|
|
+ relayMode = RelayModeChatCompletions
|
|
|
|
|
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
|
|
|
|
+ relayMode = RelayModeCompletions
|
|
|
|
|
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
|
|
|
|
+ relayMode = RelayModeEmbeddings
|
|
|
|
|
+ }
|
|
|
|
|
+ err := relayHelper(c, relayMode)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
if err.StatusCode == http.StatusTooManyRequests {
|
|
if err.StatusCode == http.StatusTooManyRequests {
|
|
|
err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
|
|
err.OpenAIError.Message = "负载已满,请稍后再试,或升级账户以提升服务质量。"
|
|
@@ -110,7 +132,7 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|
|
|
|
|
|
+func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
channelType := c.GetInt("channel")
|
|
channelType := c.GetInt("channel")
|
|
|
tokenId := c.GetInt("token_id")
|
|
tokenId := c.GetInt("token_id")
|
|
|
consumeQuota := c.GetBool("consume_quota")
|
|
consumeQuota := c.GetBool("consume_quota")
|
|
@@ -148,8 +170,13 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|
|
err := relayPaLM(textRequest, c)
|
|
err := relayPaLM(textRequest, c)
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
|
|
|
|
|
|
|
+ var promptTokens int
|
|
|
|
|
+ switch relayMode {
|
|
|
|
|
+ case RelayModeChatCompletions:
|
|
|
|
|
+ promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
|
|
|
|
|
+ case RelayModeCompletions:
|
|
|
|
|
+ promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
|
|
|
|
|
+ }
|
|
|
preConsumedTokens := common.PreConsumedQuota
|
|
preConsumedTokens := common.PreConsumedQuota
|
|
|
if textRequest.MaxTokens != 0 {
|
|
if textRequest.MaxTokens != 0 {
|
|
|
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
|
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
|
@@ -245,14 +272,27 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|
|
dataChan <- data
|
|
dataChan <- data
|
|
|
data = data[6:]
|
|
data = data[6:]
|
|
|
if !strings.HasPrefix(data, "[DONE]") {
|
|
if !strings.HasPrefix(data, "[DONE]") {
|
|
|
- var streamResponse StreamResponse
|
|
|
|
|
- err = json.Unmarshal([]byte(data), &streamResponse)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- common.SysError("Error unmarshalling stream response: " + err.Error())
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
- for _, choice := range streamResponse.Choices {
|
|
|
|
|
- streamResponseText += choice.Delta.Content
|
|
|
|
|
|
|
+ switch relayMode {
|
|
|
|
|
+ case RelayModeChatCompletions:
|
|
|
|
|
+ var streamResponse ChatCompletionsStreamResponse
|
|
|
|
|
+ err = json.Unmarshal([]byte(data), &streamResponse)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ common.SysError("Error unmarshalling stream response: " + err.Error())
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ for _, choice := range streamResponse.Choices {
|
|
|
|
|
+ streamResponseText += choice.Delta.Content
|
|
|
|
|
+ }
|
|
|
|
|
+ case RelayModeCompletions:
|
|
|
|
|
+ var streamResponse CompletionsStreamResponse
|
|
|
|
|
+ err = json.Unmarshal([]byte(data), &streamResponse)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ common.SysError("Error unmarshalling stream response: " + err.Error())
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ for _, choice := range streamResponse.Choices {
|
|
|
|
|
+ streamResponseText += choice.Text
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|