Просмотр исходного кода

feat: PaLM support is WIP (#105)

JustSong 2 лет назад
Родитель
Сommit
bcca0cc0bc
3 измененных файлов с 79 добавлено и 2 удалено
  1. 2 0
      common/constants.go
  2. 59 0
      controller/relay-palm.go
  3. 18 2
      controller/relay.go

+ 2 - 0
common/constants.go

@@ -129,6 +129,7 @@ const (
 	ChannelTypeCustom    = 8
 	ChannelTypeCustom    = 8
 	ChannelTypeAILS      = 9
 	ChannelTypeAILS      = 9
 	ChannelTypeAIProxy   = 10
 	ChannelTypeAIProxy   = 10
+	ChannelTypePaLM      = 11
 )
 )
 
 
 var ChannelBaseURLs = []string{
 var ChannelBaseURLs = []string{
@@ -143,4 +144,5 @@ var ChannelBaseURLs = []string{
 	"",                            // 8
 	"",                            // 8
 	"https://api.caipacity.com",   // 9
 	"https://api.caipacity.com",   // 9
 	"https://api.aiproxy.io",      // 10
 	"https://api.aiproxy.io",      // 10
+	"",                            // 11
 }
 }

+ 59 - 0
controller/relay-palm.go

@@ -0,0 +1,59 @@
+package controller
+
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+)
+
+type PaLMChatMessage struct {
+	Author  string `json:"author"`
+	Content string `json:"content"`
+}
+
+type PaLMFilter struct {
+	Reason  string `json:"reason"`
+	Message string `json:"message"`
+}
+
+// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
+type PaLMChatRequest struct {
+	Prompt         []Message `json:"prompt"`
+	Temperature    float64   `json:"temperature"`
+	CandidateCount int       `json:"candidateCount"`
+	TopP           float64   `json:"topP"`
+	TopK           int       `json:"topK"`
+}
+
+// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
+type PaLMChatResponse struct {
+	Candidates []Message    `json:"candidates"`
+	Messages   []Message    `json:"messages"`
+	Filters    []PaLMFilter `json:"filters"`
+}
+
+func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
+	// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
+	messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
+	for _, message := range openAIRequest.Messages {
+		var author string
+		if message.Role == "user" {
+			author = "0"
+		} else {
+			author = "1"
+		}
+		messages = append(messages, PaLMChatMessage{
+			Author:  author,
+			Content: message.Content,
+		})
+	}
+	request := PaLMChatRequest{
+		Prompt:         nil,
+		Temperature:    openAIRequest.Temperature,
+		CandidateCount: openAIRequest.N,
+		TopP:           openAIRequest.TopP,
+		TopK:           openAIRequest.MaxTokens,
+	}
+	// TODO: forward request to PaLM & convert response
+	fmt.Print(request)
+	return nil
+}

+ 18 - 2
controller/relay.go

@@ -19,6 +19,19 @@ type Message struct {
 	Name    *string `json:"name,omitempty"`
 	Name    *string `json:"name,omitempty"`
 }
 }
 
 
+// https://platform.openai.com/docs/api-reference/chat
+
+type GeneralOpenAIRequest struct {
+	Model       string    `json:"model"`
+	Messages    []Message `json:"messages"`
+	Prompt      string    `json:"prompt"`
+	Stream      bool      `json:"stream"`
+	MaxTokens   int       `json:"max_tokens"`
+	Temperature float64   `json:"temperature"`
+	TopP        float64   `json:"top_p"`
+	N           int       `json:"n"`
+}
+
 type ChatRequest struct {
 type ChatRequest struct {
 	Model     string    `json:"model"`
 	Model     string    `json:"model"`
 	Messages  []Message `json:"messages"`
 	Messages  []Message `json:"messages"`
@@ -101,8 +114,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 	channelType := c.GetInt("channel")
 	channelType := c.GetInt("channel")
 	tokenId := c.GetInt("token_id")
 	tokenId := c.GetInt("token_id")
 	consumeQuota := c.GetBool("consume_quota")
 	consumeQuota := c.GetBool("consume_quota")
-	var textRequest TextRequest
-	if consumeQuota || channelType == common.ChannelTypeAzure {
+	var textRequest GeneralOpenAIRequest
+	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
 		requestBody, err := io.ReadAll(c.Request.Body)
 		requestBody, err := io.ReadAll(c.Request.Body)
 		if err != nil {
 		if err != nil {
 			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
 			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
@@ -141,6 +154,9 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 		model_ = strings.TrimSuffix(model_, "-0301")
 		model_ = strings.TrimSuffix(model_, "-0301")
 		model_ = strings.TrimSuffix(model_, "-0314")
 		model_ = strings.TrimSuffix(model_, "-0314")
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
+	} else if channelType == common.ChannelTypePaLM {
+		err := relayPaLM(textRequest, c)
+		return err
 	}
 	}
 
 
 	promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
 	promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)