Просмотр исходного кода

feat: support Google PaLM2 (close #105)

JustSong 2 лет назад
Родитель
Сommit
8f721d67a5
6 измененных файлов с 221 добавлено и 32 удалено
  1. 1 0
      README.md
  2. 1 0
      common/model-ratio.go
  3. 9 0
      controller/model.go
  4. 179 32
      controller/relay-palm.go
  5. 30 0
      controller/relay-text.go
  6. 1 0
      web/src/constants/channel.constants.js

+ 1 - 0
README.md

@@ -62,6 +62,7 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
    + [x] OpenAI 官方通道(支持配置镜像)
    + [x] **Azure OpenAI API**
    + [x] [Anthropic Claude 系列模型](https://anthropic.com)
+   + [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
    + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
    + [x] [OpenAI-SB](https://openai-sb.com)

+ 1 - 0
common/model-ratio.go

@@ -41,6 +41,7 @@ var ModelRatio = map[string]float64{
 	"claude-2":                30,
 	"ERNIE-Bot":               1,    // 0.012元/千tokens
 	"ERNIE-Bot-turbo":         0.67, // 0.008元/千tokens
+	"PaLM-2":                  1,
 }
 
 func ModelRatio2JSONString() string {

+ 9 - 0
controller/model.go

@@ -306,6 +306,15 @@ func init() {
 			Root:       "ERNIE-Bot-turbo",
 			Parent:     nil,
 		},
+		{
+			Id:         "PaLM-2",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "google",
+			Permission: permission,
+			Root:       "PaLM-2",
+			Parent:     nil,
+		},
 	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {

+ 179 - 32
controller/relay-palm.go

@@ -1,10 +1,17 @@
 package controller
 
 import (
+	"encoding/json"
 	"fmt"
 	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
 )
 
+// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
+// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
+
 type PaLMChatMessage struct {
 	Author  string `json:"author"`
 	Content string `json:"content"`
@@ -15,45 +22,185 @@ type PaLMFilter struct {
 	Message string `json:"message"`
 }
 
-// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
+type PaLMPrompt struct {
+	Messages []PaLMChatMessage `json:"messages"`
+}
+
 type PaLMChatRequest struct {
-	Prompt         []Message `json:"prompt"`
-	Temperature    float64   `json:"temperature"`
-	CandidateCount int       `json:"candidateCount"`
-	TopP           float64   `json:"topP"`
-	TopK           int       `json:"topK"`
+	Prompt         PaLMPrompt `json:"prompt"`
+	Temperature    float64    `json:"temperature,omitempty"`
+	CandidateCount int        `json:"candidateCount,omitempty"`
+	TopP           float64    `json:"topP,omitempty"`
+	TopK           int        `json:"topK,omitempty"`
+}
+
+type PaLMError struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+	Status  string `json:"status"`
 }
 
-// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
 type PaLMChatResponse struct {
-	Candidates []Message    `json:"candidates"`
-	Messages   []Message    `json:"messages"`
-	Filters    []PaLMFilter `json:"filters"`
+	Candidates []PaLMChatMessage `json:"candidates"`
+	Messages   []Message         `json:"messages"`
+	Filters    []PaLMFilter      `json:"filters"`
+	Error      PaLMError         `json:"error"`
 }
 
-func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
-	// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
-	messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
-	for _, message := range openAIRequest.Messages {
-		var author string
+func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
+	palmRequest := PaLMChatRequest{
+		Prompt: PaLMPrompt{
+			Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
+		},
+		Temperature:    textRequest.Temperature,
+		CandidateCount: textRequest.N,
+		TopP:           textRequest.TopP,
+		TopK:           textRequest.MaxTokens,
+	}
+	for _, message := range textRequest.Messages {
+		palmMessage := PaLMChatMessage{
+			Content: message.Content,
+		}
 		if message.Role == "user" {
-			author = "0"
+			palmMessage.Author = "0"
 		} else {
-			author = "1"
+			palmMessage.Author = "1"
 		}
-		messages = append(messages, PaLMChatMessage{
-			Author:  author,
-			Content: message.Content,
-		})
-	}
-	request := PaLMChatRequest{
-		Prompt:         nil,
-		Temperature:    openAIRequest.Temperature,
-		CandidateCount: openAIRequest.N,
-		TopP:           openAIRequest.TopP,
-		TopK:           openAIRequest.MaxTokens,
-	}
-	// TODO: forward request to PaLM & convert response
-	fmt.Print(request)
-	return nil
+		palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
+	}
+	return &palmRequest
+}
+
+func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
+	fullTextResponse := OpenAITextResponse{
+		Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
+	}
+	for i, candidate := range response.Candidates {
+		choice := OpenAITextResponseChoice{
+			Index: i,
+			Message: Message{
+				Role:    "assistant",
+				Content: candidate.Content,
+			},
+			FinishReason: "stop",
+		}
+		fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
+	}
+	return &fullTextResponse
+}
+
+func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
+	var choice ChatCompletionsStreamResponseChoice
+	if len(palmResponse.Candidates) > 0 {
+		choice.Delta.Content = palmResponse.Candidates[0].Content
+	}
+	choice.FinishReason = "stop"
+	var response ChatCompletionsStreamResponse
+	response.Object = "chat.completion.chunk"
+	response.Model = "palm2"
+	response.Choices = []ChatCompletionsStreamResponseChoice{choice}
+	return &response
+}
+
+func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+	responseText := ""
+	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+	createdTime := common.GetTimestamp()
+	dataChan := make(chan string)
+	stopChan := make(chan bool)
+	go func() {
+		responseBody, err := io.ReadAll(resp.Body)
+		if err != nil {
+			common.SysError("error reading stream response: " + err.Error())
+			stopChan <- true
+			return
+		}
+		err = resp.Body.Close()
+		if err != nil {
+			common.SysError("error closing stream response: " + err.Error())
+			stopChan <- true
+			return
+		}
+		var palmResponse PaLMChatResponse
+		err = json.Unmarshal(responseBody, &palmResponse)
+		if err != nil {
+			common.SysError("error unmarshalling stream response: " + err.Error())
+			stopChan <- true
+			return
+		}
+		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
+		fullTextResponse.Id = responseId
+		fullTextResponse.Created = createdTime
+		jsonResponse, err := json.Marshal(fullTextResponse)
+		if err != nil {
+			common.SysError("error marshalling stream response: " + err.Error())
+			stopChan <- true
+			return
+		}
+		dataChan <- string(jsonResponse)
+		stopChan <- true
+	}()
+	c.Writer.Header().Set("Content-Type", "text/event-stream")
+	c.Writer.Header().Set("Cache-Control", "no-cache")
+	c.Writer.Header().Set("Connection", "keep-alive")
+	c.Writer.Header().Set("Transfer-Encoding", "chunked")
+	c.Writer.Header().Set("X-Accel-Buffering", "no")
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case data := <-dataChan:
+			c.Render(-1, common.CustomEvent{Data: "data: " + data})
+			return true
+		case <-stopChan:
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	err := resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+	}
+	return nil, responseText
+}
+
+func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	var palmResponse PaLMChatResponse
+	err = json.Unmarshal(responseBody, &palmResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: palmResponse.Error.Message,
+				Type:    palmResponse.Error.Status,
+				Param:   "",
+				Code:    palmResponse.Error.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
+	completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
+	usage := Usage{
+		PromptTokens:     promptTokens,
+		CompletionTokens: completionTokens,
+		TotalTokens:      promptTokens + completionTokens,
+	}
+	fullTextResponse.Usage = usage
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &usage
 }

+ 30 - 0
controller/relay-text.go

@@ -82,6 +82,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		apiType = APITypeClaude
 	} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
 		apiType = APITypeBaidu
+	} else if strings.HasPrefix(textRequest.Model, "PaLM") {
+		apiType = APITypePaLM
 	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
@@ -127,6 +129,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		apiKey := c.Request.Header.Get("Authorization")
 		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
+	case APITypePaLM:
+		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
+		apiKey := c.Request.Header.Get("Authorization")
+		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+		fullRequestURL += "?key=" + apiKey
 	}
 	var promptTokens int
 	var completionTokens int
@@ -186,6 +193,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
+	case APITypePaLM:
+		palmRequest := requestOpenAI2PaLM(textRequest)
+		jsonStr, err := json.Marshal(palmRequest)
+		if err != nil {
+			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonStr)
 	}
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	if err != nil {
@@ -323,6 +337,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			textResponse.Usage = *usage
 			return nil
 		}
+	case APITypePaLM:
+		if textRequest.Stream { // PaLM2 API does not support stream
+			err, responseText := palmStreamHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			streamResponseText = responseText
+			return nil
+		} else {
+			err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
+			if err != nil {
+				return err
+			}
+			textResponse.Usage = *usage
+			return nil
+		}
 	default:
 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 	}

+ 1 - 0
web/src/constants/channel.constants.js

@@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [
   { key: 14, text: 'Anthropic', value: 14, color: 'black' },
   { key: 8, text: '自定义', value: 8, color: 'pink' },
   { key: 3, text: 'Azure', value: 3, color: 'olive' },
+  { key: 11, text: 'PaLM', value: 11, color: 'orange' },
   { key: 15, text: 'Baidu', value: 15, color: 'blue' },
   { key: 2, text: 'API2D', value: 2, color: 'blue' },
   { key: 4, text: 'CloseAI', value: 4, color: 'teal' },