Browse Source

feat: support SiliconFlow (close #437, close #403)

CalciumIon 1 year ago
parent
commit
7c4d9d225e

+ 2 - 0
common/constants.go

@@ -213,6 +213,7 @@ const (
 	ChannelTypeDify           = 37
 	ChannelTypeJina           = 38
 	ChannelCloudflare         = 39
+	ChannelTypeSiliconFlow    = 40
 
 	ChannelTypeDummy // this one is only for count, do not add any channel after this
 
@@ -259,4 +260,5 @@ var ChannelBaseURLs = []string{
 	"",                                          //37
 	"https://api.jina.ai",                       //38
 	"https://api.cloudflare.com",                //39
+	"https://api.siliconflow.cn",                //40
 }

+ 8 - 5
dto/rerank.go

@@ -1,14 +1,17 @@
 package dto
 
 type RerankRequest struct {
-	Documents []any  `json:"documents"`
-	Query     string `json:"query"`
-	Model     string `json:"model"`
-	TopN      int    `json:"top_n"`
+	Documents       []any  `json:"documents"`
+	Query           string `json:"query"`
+	Model           string `json:"model"`
+	TopN            int    `json:"top_n"`
+	ReturnDocuments bool   `json:"return_documents,omitempty"`
+	MaxChunkPerDoc  int    `json:"max_chunk_per_doc,omitempty"`
+	OverLapTokens   int    `json:"overlap_tokens,omitempty"`
 }
 
 type RerankResponseDocument struct {
-	Document       any     `json:"document"`
+	Document       any     `json:"document,omitempty"`
 	Index          int     `json:"index"`
 	RelevanceScore float64 `json:"relevance_score"`
 }

+ 80 - 0
relay/channel/siliconflow/adaptor.go

@@ -0,0 +1,80 @@
+package siliconflow
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel"
+	"one-api/relay/channel/openai"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	if info.RelayMode == constant.RelayModeRerank {
+		return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
+	} else if info.RelayMode == constant.RelayModeEmbeddings {
+		return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
+	} else if info.RelayMode == constant.RelayModeChatCompletions {
+		return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+	}
+	return "", errors.New("invalid relay mode")
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, req)
+	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+	return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return request, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.RelayMode == constant.RelayModeRerank {
+		err, usage = siliconflowRerankHandler(c, resp)
+	} else if info.RelayMode == constant.RelayModeChatCompletions {
+		if info.IsStream {
+			err, usage = openai.OaiStreamHandler(c, resp, info)
+		} else {
+			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		}
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 51 - 0
relay/channel/siliconflow/constant.go

@@ -0,0 +1,51 @@
+package siliconflow
+
+var ModelList = []string{
+	"THUDM/glm-4-9b-chat",
+	//"stabilityai/stable-diffusion-xl-base-1.0",
+	//"TencentARC/PhotoMaker",
+	"InstantX/InstantID",
+	//"stabilityai/stable-diffusion-2-1",
+	//"stabilityai/sd-turbo",
+	//"stabilityai/sdxl-turbo",
+	"ByteDance/SDXL-Lightning",
+	"deepseek-ai/deepseek-llm-67b-chat",
+	"Qwen/Qwen1.5-14B-Chat",
+	"Qwen/Qwen1.5-7B-Chat",
+	"Qwen/Qwen1.5-110B-Chat",
+	"Qwen/Qwen1.5-32B-Chat",
+	"01-ai/Yi-1.5-6B-Chat",
+	"01-ai/Yi-1.5-9B-Chat-16K",
+	"01-ai/Yi-1.5-34B-Chat-16K",
+	"THUDM/chatglm3-6b",
+	"deepseek-ai/DeepSeek-V2-Chat",
+	"Qwen/Qwen2-72B-Instruct",
+	"Qwen/Qwen2-7B-Instruct",
+	"Qwen/Qwen2-57B-A14B-Instruct",
+	//"stabilityai/stable-diffusion-3-medium",
+	"deepseek-ai/DeepSeek-Coder-V2-Instruct",
+	"Qwen/Qwen2-1.5B-Instruct",
+	"internlm/internlm2_5-7b-chat",
+	"BAAI/bge-large-en-v1.5",
+	"BAAI/bge-large-zh-v1.5",
+	"Pro/Qwen/Qwen2-7B-Instruct",
+	"Pro/Qwen/Qwen2-1.5B-Instruct",
+	"Pro/Qwen/Qwen1.5-7B-Chat",
+	"Pro/THUDM/glm-4-9b-chat",
+	"Pro/THUDM/chatglm3-6b",
+	"Pro/01-ai/Yi-1.5-9B-Chat-16K",
+	"Pro/01-ai/Yi-1.5-6B-Chat",
+	"Pro/google/gemma-2-9b-it",
+	"Pro/internlm/internlm2_5-7b-chat",
+	"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
+	"Pro/mistralai/Mistral-7B-Instruct-v0.2",
+	"black-forest-labs/FLUX.1-schnell",
+	"iic/SenseVoiceSmall",
+	"netease-youdao/bce-embedding-base_v1",
+	"BAAI/bge-m3",
+	"internlm/internlm2_5-20b-chat",
+	"Qwen/Qwen2-Math-72B-Instruct",
+	"netease-youdao/bce-reranker-base_v1",
+	"BAAI/bge-reranker-v2-m3",
+}
+var ChannelName = "siliconflow"

+ 17 - 0
relay/channel/siliconflow/dto.go

@@ -0,0 +1,17 @@
+package siliconflow
+
+import "one-api/dto"
+
+type SFTokens struct {
+	InputTokens  int `json:"input_tokens"`
+	OutputTokens int `json:"output_tokens"`
+}
+
+type SFMeta struct {
+	Tokens SFTokens `json:"tokens"`
+}
+
+type SFRerankResponse struct {
+	Results []dto.RerankResponseDocument `json:"results"`
+	Meta    SFMeta                       `json:"meta"`
+}

+ 44 - 0
relay/channel/siliconflow/relay-siliconflow.go

@@ -0,0 +1,44 @@
+package siliconflow
+
+import (
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/service"
+)
+
+func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	var siliconflowResp SFRerankResponse
+	err = json.Unmarshal(responseBody, &siliconflowResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	usage := &dto.Usage{
+		PromptTokens:     siliconflowResp.Meta.Tokens.InputTokens,
+		CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
+		TotalTokens:      siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
+	}
+	rerankResp := &dto.RerankResponse{
+		Results: siliconflowResp.Results,
+		Usage:   *usage,
+	}
+
+	jsonResponse, err := json.Marshal(rerankResp)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, usage
+}

+ 3 - 0
relay/constant/api_type.go

@@ -23,6 +23,7 @@ const (
 	APITypeDify
 	APITypeJina
 	APITypeCloudflare
+	APITypeSiliconFlow
 
 	APITypeDummy // this one is only for count, do not add any channel after this
 )
@@ -66,6 +67,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
 		apiType = APITypeJina
 	case common.ChannelCloudflare:
 		apiType = APITypeCloudflare
+	case common.ChannelTypeSiliconFlow:
+		apiType = APITypeSiliconFlow
 	}
 	if apiType == -1 {
 		return APITypeOpenAI, false

+ 1 - 1
relay/relay-text.go

@@ -317,7 +317,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
 	totalTokens := promptTokens + completionTokens
 	var logContent string
 	if !usePrice {
-		logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
+		logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
 	} else {
 		logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
 	}

+ 3 - 0
relay/relay_adaptor.go

@@ -16,6 +16,7 @@ import (
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/palm"
 	"one-api/relay/channel/perplexity"
+	"one-api/relay/channel/siliconflow"
 	"one-api/relay/channel/task/suno"
 	"one-api/relay/channel/tencent"
 	"one-api/relay/channel/xunfei"
@@ -62,6 +63,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 		return &jina.Adaptor{}
 	case constant.APITypeCloudflare:
 		return &cloudflare.Adaptor{}
+	case constant.APITypeSiliconFlow:
+		return &siliconflow.Adaptor{}
 	}
 	return nil
 }

+ 18 - 17
web/src/constants/channel.constants.js

@@ -5,21 +5,21 @@ export const CHANNEL_OPTIONS = [
     text: 'Midjourney Proxy',
     value: 2,
     color: 'light-blue',
-    label: 'Midjourney Proxy',
+    label: 'Midjourney Proxy'
   },
   {
     key: 5,
     text: 'Midjourney Proxy Plus',
     value: 5,
     color: 'blue',
-    label: 'Midjourney Proxy Plus',
+    label: 'Midjourney Proxy Plus'
   },
   {
     key: 36,
     text: 'Suno API',
     value: 36,
     color: 'purple',
-    label: 'Suno API',
+    label: 'Suno API'
   },
   { key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' },
   {
@@ -27,77 +27,77 @@ export const CHANNEL_OPTIONS = [
     text: 'Anthropic Claude',
     value: 14,
     color: 'indigo',
-    label: 'Anthropic Claude',
+    label: 'Anthropic Claude'
   },
   {
     key: 33,
     text: 'AWS Claude',
     value: 33,
     color: 'indigo',
-    label: 'AWS Claude',
+    label: 'AWS Claude'
   },
   {
     key: 3,
     text: 'Azure OpenAI',
     value: 3,
     color: 'teal',
-    label: 'Azure OpenAI',
+    label: 'Azure OpenAI'
   },
   {
     key: 24,
     text: 'Google Gemini',
     value: 24,
     color: 'orange',
-    label: 'Google Gemini',
+    label: 'Google Gemini'
   },
   {
     key: 34,
     text: 'Cohere',
     value: 34,
     color: 'purple',
-    label: 'Cohere',
+    label: 'Cohere'
   },
   {
     key: 15,
     text: '百度文心千帆',
     value: 15,
     color: 'blue',
-    label: '百度文心千帆',
+    label: '百度文心千帆'
   },
   {
     key: 17,
     text: '阿里通义千问',
     value: 17,
     color: 'orange',
-    label: '阿里通义千问',
+    label: '阿里通义千问'
   },
   {
     key: 18,
     text: '讯飞星火认知',
     value: 18,
     color: 'blue',
-    label: '讯飞星火认知',
+    label: '讯飞星火认知'
   },
   {
     key: 16,
     text: '智谱 ChatGLM',
     value: 16,
     color: 'violet',
-    label: '智谱 ChatGLM',
+    label: '智谱 ChatGLM'
   },
   {
     key: 26,
     text: '智谱 GLM-4V',
     value: 26,
     color: 'purple',
-    label: '智谱 GLM-4V',
+    label: '智谱 GLM-4V'
   },
   {
     key: 11,
     text: 'Google PaLM2',
     value: 11,
     color: 'orange',
-    label: 'Google PaLM2',
+    label: 'Google PaLM2'
   },
   { key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
   { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
@@ -107,19 +107,20 @@ export const CHANNEL_OPTIONS = [
   { key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' },
   { key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' },
   { key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' },
+  { key: 40, text: 'SiliconCloud', value: 40, color: 'purple', label: 'SiliconCloud' },
   { key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
   {
     key: 22,
     text: '知识库:FastGPT',
     value: 22,
     color: 'blue',
-    label: '知识库:FastGPT',
+    label: '知识库:FastGPT'
   },
   {
     key: 21,
     text: '知识库:AI Proxy',
     value: 21,
     color: 'purple',
-    label: '知识库:AI Proxy',
-  },
+    label: '知识库:AI Proxy'
+  }
 ];