|
|
@@ -19,6 +19,19 @@ type Message struct {
|
|
|
Name *string `json:"name,omitempty"`
|
|
|
}
|
|
|
|
|
|
+// https://platform.openai.com/docs/api-reference/chat
|
|
|
+
|
|
|
+type GeneralOpenAIRequest struct {
|
|
|
+ Model string `json:"model"`
|
|
|
+ Messages []Message `json:"messages"`
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
+ Stream bool `json:"stream"`
|
|
|
+ MaxTokens int `json:"max_tokens"`
|
|
|
+ Temperature float64 `json:"temperature"`
|
|
|
+ TopP float64 `json:"top_p"`
|
|
|
+ N int `json:"n"`
|
|
|
+}
|
|
|
+
|
|
|
type ChatRequest struct {
|
|
|
Model string `json:"model"`
|
|
|
Messages []Message `json:"messages"`
|
|
|
@@ -101,8 +114,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|
|
channelType := c.GetInt("channel")
|
|
|
tokenId := c.GetInt("token_id")
|
|
|
consumeQuota := c.GetBool("consume_quota")
|
|
|
- var textRequest TextRequest
|
|
|
- if consumeQuota || channelType == common.ChannelTypeAzure {
|
|
|
+ var textRequest GeneralOpenAIRequest
|
|
|
+ if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
|
|
|
requestBody, err := io.ReadAll(c.Request.Body)
|
|
|
if err != nil {
|
|
|
return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
|
|
|
@@ -141,6 +154,9 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|
|
model_ = strings.TrimSuffix(model_, "-0301")
|
|
|
model_ = strings.TrimSuffix(model_, "-0314")
|
|
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
|
|
+ } else if channelType == common.ChannelTypePaLM {
|
|
|
+ err := relayPaLM(textRequest, c)
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
|