Преглед изворни кода

feat: support aiproxy's library

JustSong пре 2 година
родитељ
комит
04acdb1ccb

+ 23 - 21
common/constants.go

@@ -154,27 +154,28 @@ const (
 )
 
 const (
-	ChannelTypeUnknown    = 0
-	ChannelTypeOpenAI     = 1
-	ChannelTypeAPI2D      = 2
-	ChannelTypeAzure      = 3
-	ChannelTypeCloseAI    = 4
-	ChannelTypeOpenAISB   = 5
-	ChannelTypeOpenAIMax  = 6
-	ChannelTypeOhMyGPT    = 7
-	ChannelTypeCustom     = 8
-	ChannelTypeAILS       = 9
-	ChannelTypeAIProxy    = 10
-	ChannelTypePaLM       = 11
-	ChannelTypeAPI2GPT    = 12
-	ChannelTypeAIGC2D     = 13
-	ChannelTypeAnthropic  = 14
-	ChannelTypeBaidu      = 15
-	ChannelTypeZhipu      = 16
-	ChannelTypeAli        = 17
-	ChannelTypeXunfei     = 18
-	ChannelType360        = 19
-	ChannelTypeOpenRouter = 20
+	ChannelTypeUnknown        = 0
+	ChannelTypeOpenAI         = 1
+	ChannelTypeAPI2D          = 2
+	ChannelTypeAzure          = 3
+	ChannelTypeCloseAI        = 4
+	ChannelTypeOpenAISB       = 5
+	ChannelTypeOpenAIMax      = 6
+	ChannelTypeOhMyGPT        = 7
+	ChannelTypeCustom         = 8
+	ChannelTypeAILS           = 9
+	ChannelTypeAIProxy        = 10
+	ChannelTypePaLM           = 11
+	ChannelTypeAPI2GPT        = 12
+	ChannelTypeAIGC2D         = 13
+	ChannelTypeAnthropic      = 14
+	ChannelTypeBaidu          = 15
+	ChannelTypeZhipu          = 16
+	ChannelTypeAli            = 17
+	ChannelTypeXunfei         = 18
+	ChannelType360            = 19
+	ChannelTypeOpenRouter     = 20
+	ChannelTypeAIProxyLibrary = 21
 )
 
 var ChannelBaseURLs = []string{
@@ -199,4 +200,5 @@ var ChannelBaseURLs = []string{
 	"",                               // 18
 	"https://ai.360.cn",              // 19
 	"https://openrouter.ai/api",      // 20
+	"https://api.aiproxy.io",         // 21
 }

+ 220 - 0
controller/relay-aiproxy.go

@@ -0,0 +1,220 @@
+package controller
+
+import (
+	"bufio"
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"strconv"
+	"strings"
+)
+
+// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
+
+type AIProxyLibraryRequest struct {
+	Model     string `json:"model"`
+	Query     string `json:"query"`
+	LibraryId string `json:"libraryId"`
+	Stream    bool   `json:"stream"`
+}
+
+type AIProxyLibraryError struct {
+	ErrCode int    `json:"errCode"`
+	Message string `json:"message"`
+}
+
+type AIProxyLibraryDocument struct {
+	Title string `json:"title"`
+	URL   string `json:"url"`
+}
+
+type AIProxyLibraryResponse struct {
+	Success   bool                     `json:"success"`
+	Answer    string                   `json:"answer"`
+	Documents []AIProxyLibraryDocument `json:"documents"`
+	AIProxyLibraryError
+}
+
+type AIProxyLibraryStreamResponse struct {
+	Content   string                   `json:"content"`
+	Finish    bool                     `json:"finish"`
+	Model     string                   `json:"model"`
+	Documents []AIProxyLibraryDocument `json:"documents"`
+}
+
+func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
+	query := ""
+	if len(request.Messages) != 0 {
+		query = request.Messages[len(request.Messages)-1].Content
+	}
+	return &AIProxyLibraryRequest{
+		Model:  request.Model,
+		Stream: request.Stream,
+		Query:  query,
+	}
+}
+
+func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
+	if len(documents) == 0 {
+		return ""
+	}
+	content := "\n\n参考文档:\n"
+	for i, document := range documents {
+		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
+	}
+	return content
+}
+
+func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
+	content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
+	choice := OpenAITextResponseChoice{
+		Index: 0,
+		Message: Message{
+			Role:    "assistant",
+			Content: content,
+		},
+		FinishReason: "stop",
+	}
+	fullTextResponse := OpenAITextResponse{
+		Id:      common.GetUUID(),
+		Object:  "chat.completion",
+		Created: common.GetTimestamp(),
+		Choices: []OpenAITextResponseChoice{choice},
+	}
+	return &fullTextResponse
+}
+
+func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
+	var choice ChatCompletionsStreamResponseChoice
+	choice.Delta.Content = aiProxyDocuments2Markdown(documents)
+	choice.FinishReason = &stopFinishReason
+	return &ChatCompletionsStreamResponse{
+		Id:      common.GetUUID(),
+		Object:  "chat.completion.chunk",
+		Created: common.GetTimestamp(),
+		Model:   "",
+		Choices: []ChatCompletionsStreamResponseChoice{choice},
+	}
+}
+
+func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
+	var choice ChatCompletionsStreamResponseChoice
+	choice.Delta.Content = response.Content
+	return &ChatCompletionsStreamResponse{
+		Id:      common.GetUUID(),
+		Object:  "chat.completion.chunk",
+		Created: common.GetTimestamp(),
+		Model:   response.Model,
+		Choices: []ChatCompletionsStreamResponseChoice{choice},
+	}
+}
+
+func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var usage Usage
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+		if atEOF && len(data) == 0 {
+			return 0, nil, nil
+		}
+		if i := strings.Index(string(data), "\n"); i >= 0 {
+			return i + 1, data[0:i], nil
+		}
+		if atEOF {
+			return len(data), data, nil
+		}
+		return 0, nil, nil
+	})
+	dataChan := make(chan string)
+	stopChan := make(chan bool)
+	go func() {
+		for scanner.Scan() {
+			data := scanner.Text()
+			if len(data) < 5 { // ignore blank line or wrong format
+				continue
+			}
+			if data[:5] != "data:" {
+				continue
+			}
+			data = data[5:]
+			dataChan <- data
+		}
+		stopChan <- true
+	}()
+	setEventStreamHeaders(c)
+	var documents []AIProxyLibraryDocument
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case data := <-dataChan:
+			var AIProxyLibraryResponse AIProxyLibraryStreamResponse
+			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
+			if err != nil {
+				common.SysError("error unmarshalling stream response: " + err.Error())
+				return true
+			}
+			if len(AIProxyLibraryResponse.Documents) != 0 {
+				documents = AIProxyLibraryResponse.Documents
+			}
+			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			return true
+		case <-stopChan:
+			response := documentsAIProxyLibrary(documents)
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	err := resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	return nil, &usage
+}
+
+func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var AIProxyLibraryResponse AIProxyLibraryResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if AIProxyLibraryResponse.ErrCode != 0 {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: AIProxyLibraryResponse.Message,
+				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode),
+				Code:    AIProxyLibraryResponse.ErrCode,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
+}

+ 35 - 0
controller/relay-text.go

@@ -22,6 +22,7 @@ const (
 	APITypeZhipu
 	APITypeAli
 	APITypeXunfei
+	APITypeAIProxyLibrary
 )
 
 var httpClient *http.Client
@@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		apiType = APITypeAli
 	case common.ChannelTypeXunfei:
 		apiType = APITypeXunfei
+	case common.ChannelTypeAIProxyLibrary:
+		apiType = APITypeAIProxyLibrary
 	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
@@ -171,6 +174,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 	case APITypeAli:
 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
+	case APITypeAIProxyLibrary:
+		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
 	}
 	var promptTokens int
 	var completionTokens int
@@ -263,6 +268,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
+	case APITypeAIProxyLibrary:
+		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
+		aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
+		jsonStr, err := json.Marshal(aiProxyLibraryRequest)
+		if err != nil {
+			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonStr)
 	}
 
 	var req *http.Request
@@ -302,6 +315,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			if textRequest.Stream {
 				req.Header.Set("X-DashScope-SSE", "enable")
 			}
+		default:
+			req.Header.Set("Authorization", "Bearer "+apiKey)
 		}
 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@@ -516,6 +531,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		} else {
 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
 		}
+	case APITypeAIProxyLibrary:
+		if isStream {
+			err, usage := aiProxyLibraryStreamHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			if usage != nil {
+				textResponse.Usage = *usage
+			}
+			return nil
+		} else {
+			err, usage := aiProxyLibraryHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			if usage != nil {
+				textResponse.Usage = *usage
+			}
+			return nil
+		}
 	default:
 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 	}

+ 6 - 1
middleware/distributor.go

@@ -115,8 +115,13 @@ func Distribute() func(c *gin.Context) {
 		c.Set("model_mapping", channel.ModelMapping)
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		c.Set("base_url", channel.BaseURL)
-		if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
+		switch channel.Type {
+		case common.ChannelTypeAzure:
 			c.Set("api_version", channel.Other)
+		case common.ChannelTypeXunfei:
+			c.Set("api_version", channel.Other)
+		case common.ChannelTypeAIProxyLibrary:
+			c.Set("library_id", channel.Other)
 		}
 		c.Next()
 	}

+ 1 - 0
web/src/constants/channel.constants.js

@@ -9,6 +9,7 @@ export const CHANNEL_OPTIONS = [
   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
   { key: 19, text: '360 智脑', value: 19, color: 'blue' },
   { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
+  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
   { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
   { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },

+ 14 - 0
web/src/pages/Channel/EditChannel.js

@@ -295,6 +295,20 @@ const EditChannel = () => {
               </Form.Field>
             )
           }
+          {
+            inputs.type === 21 && (
+              <Form.Field>
+                <Form.Input
+                  label='知识库 ID'
+                  name='other'
+                  placeholder={'请输入知识库 ID,例如:123456'}
+                  onChange={handleInputChange}
+                  value={inputs.other}
+                  autoComplete='new-password'
+                />
+              </Form.Field>
+            )
+          }
           <Form.Field>
             <Form.Dropdown
               label='模型'