Jelajahi Sumber

feat: support cloudflare worker ai

CalciumIon 1 tahun lalu
induk
melakukan
7b36a2b885

+ 2 - 0
common/constants.go

@@ -212,6 +212,7 @@ const (
 	ChannelTypeSunoAPI        = 36
 	ChannelTypeDify           = 37
 	ChannelTypeJina           = 38
+	ChannelCloudflare         = 39
 
 	ChannelTypeDummy // this one is only for count, do not add any channel after this
 
@@ -257,4 +258,5 @@ var ChannelBaseURLs = []string{
 	"",                                          //36
 	"",                                          //37
 	"https://api.jina.ai",                       //38
+	"https://api.cloudflare.com",                //39
 }

+ 15 - 22
controller/channel-test.go

@@ -12,6 +12,7 @@ import (
 	"net/url"
 	"one-api/common"
 	"one-api/dto"
+	"one-api/middleware"
 	"one-api/model"
 	"one-api/relay"
 	relaycommon "one-api/relay/common"
@@ -40,29 +41,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 		Body:   nil,
 		Header: make(http.Header),
 	}
-	c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
-	c.Request.Header.Set("Content-Type", "application/json")
-	c.Set("channel", channel.Type)
-	c.Set("base_url", channel.GetBaseURL())
-	switch channel.Type {
-	case common.ChannelTypeAzure:
-		c.Set("api_version", channel.Other)
-	case common.ChannelTypeXunfei:
-		c.Set("api_version", channel.Other)
-	//case common.ChannelTypeAIProxyLibrary:
-	//	c.Set("library_id", channel.Other)
-	case common.ChannelTypeGemini:
-		c.Set("api_version", channel.Other)
-	case common.ChannelTypeAli:
-		c.Set("plugin", channel.Other)
-	}
 
-	meta := relaycommon.GenRelayInfo(c)
-	apiType, _ := constant.ChannelType2APIType(channel.Type)
-	adaptor := relay.GetAdaptor(apiType)
-	if adaptor == nil {
-		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
-	}
 	if testModel == "" {
 		if channel.TestModel != nil && *channel.TestModel != "" {
 			testModel = *channel.TestModel
@@ -88,6 +67,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 		}
 	}
 
+	c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+	c.Request.Header.Set("Content-Type", "application/json")
+	c.Set("channel", channel.Type)
+	c.Set("base_url", channel.GetBaseURL())
+
+	middleware.SetupContextForSelectedChannel(c, channel, testModel)
+
+	meta := relaycommon.GenRelayInfo(c)
+	apiType, _ := constant.ChannelType2APIType(channel.Type)
+	adaptor := relay.GetAdaptor(apiType)
+	if adaptor == nil {
+		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
+	}
+
 	request := buildTestRequest()
 	request.Model = testModel
 	meta.UpstreamModelName = testModel

+ 2 - 2
dto/text_request.go

@@ -48,8 +48,8 @@ type StreamOptions struct {
 	IncludeUsage bool `json:"include_usage,omitempty"`
 }
 
-func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
-	return int64(r.MaxTokens)
+func (r GeneralOpenAIRequest) GetMaxTokens() int {
+	return int(r.MaxTokens)
 }
 
 func (r GeneralOpenAIRequest) ParseInput() []string {

+ 2 - 2
middleware/distributor.go

@@ -198,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 		c.Set("api_version", channel.Other)
 	case common.ChannelTypeXunfei:
 		c.Set("api_version", channel.Other)
-	//case common.ChannelTypeAIProxyLibrary:
-	//	c.Set("library_id", channel.Other)
 	case common.ChannelTypeGemini:
 		c.Set("api_version", channel.Other)
 	case common.ChannelTypeAli:
 		c.Set("plugin", channel.Other)
+	case common.ChannelCloudflare:
+		c.Set("api_version", channel.Other)
 	}
 }

+ 76 - 0
relay/channel/cloudflare/adaptor.go

@@ -0,0 +1,76 @@
+package cloudflare
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	switch info.RelayMode {
+	case constant.RelayModeChatCompletions:
+		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
+	case constant.RelayModeEmbeddings:
+		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
+	default:
+		return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
+	}
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, req)
+	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	switch relayMode {
+	case constant.RelayModeCompletions:
+		return convertCf2CompletionsRequest(*request), nil
+	default:
+		return request, nil
+	}
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return request, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		err, usage = cfStreamHandler(c, resp, info)
+	} else {
+		err, usage = cfHandler(c, resp, info)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 38 - 0
relay/channel/cloudflare/constant.go

@@ -0,0 +1,38 @@
+package cloudflare
+
+var ModelList = []string{
+	"@cf/meta/llama-2-7b-chat-fp16",
+	"@cf/meta/llama-2-7b-chat-int8",
+	"@cf/mistral/mistral-7b-instruct-v0.1",
+	"@hf/thebloke/deepseek-coder-6.7b-base-awq",
+	"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
+	"@cf/deepseek-ai/deepseek-math-7b-base",
+	"@cf/deepseek-ai/deepseek-math-7b-instruct",
+	"@cf/thebloke/discolm-german-7b-v1-awq",
+	"@cf/tiiuae/falcon-7b-instruct",
+	"@cf/google/gemma-2b-it-lora",
+	"@hf/google/gemma-7b-it",
+	"@cf/google/gemma-7b-it-lora",
+	"@hf/nousresearch/hermes-2-pro-mistral-7b",
+	"@hf/thebloke/llama-2-13b-chat-awq",
+	"@cf/meta-llama/llama-2-7b-chat-hf-lora",
+	"@cf/meta/llama-3-8b-instruct",
+	"@hf/thebloke/llamaguard-7b-awq",
+	"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
+	"@hf/mistralai/mistral-7b-instruct-v0.2",
+	"@cf/mistral/mistral-7b-instruct-v0.2-lora",
+	"@hf/thebloke/neural-chat-7b-v3-1-awq",
+	"@cf/openchat/openchat-3.5-0106",
+	"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
+	"@cf/microsoft/phi-2",
+	"@cf/qwen/qwen1.5-0.5b-chat",
+	"@cf/qwen/qwen1.5-1.8b-chat",
+	"@cf/qwen/qwen1.5-14b-chat-awq",
+	"@cf/qwen/qwen1.5-7b-chat-awq",
+	"@cf/defog/sqlcoder-7b-2",
+	"@hf/nexusflow/starling-lm-7b-beta",
+	"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
+	"@hf/thebloke/zephyr-7b-beta-awq",
+}
+
+var ChannelName = "cloudflare"

+ 13 - 0
relay/channel/cloudflare/model.go

@@ -0,0 +1,13 @@
+package cloudflare
+
+import "one-api/dto"
+
+type CfRequest struct {
+	Messages    []dto.Message `json:"messages,omitempty"`
+	Lora        string        `json:"lora,omitempty"`
+	MaxTokens   int           `json:"max_tokens,omitempty"`
+	Prompt      string        `json:"prompt,omitempty"`
+	Raw         bool          `json:"raw,omitempty"`
+	Stream      bool          `json:"stream,omitempty"`
+	Temperature float64       `json:"temperature,omitempty"`
+}

+ 115 - 0
relay/channel/cloudflare/relay_cloudflare.go

@@ -0,0 +1,115 @@
+package cloudflare
+
+import (
+	"bufio"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+	"strings"
+)
+
+func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
+	p, _ := textRequest.Prompt.(string)
+	return &CfRequest{
+		Prompt:      p,
+		MaxTokens:   textRequest.GetMaxTokens(),
+		Stream:      textRequest.Stream,
+		Temperature: textRequest.Temperature,
+	}
+}
+
+func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(bufio.ScanLines)
+
+	service.SetEventStreamHeaders(c)
+	id := service.GetResponseID(c)
+	var responseText string
+
+	for scanner.Scan() {
+		data := scanner.Text()
+		if len(data) < len("data: ") {
+			continue
+		}
+		data = strings.TrimPrefix(data, "data: ")
+		data = strings.TrimSuffix(data, "\r")
+
+		if data == "[DONE]" {
+			break
+		}
+
+		var response dto.ChatCompletionsStreamResponse
+		err := json.Unmarshal([]byte(data), &response)
+		if err != nil {
+			common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
+			continue
+		}
+		for _, choice := range response.Choices {
+			choice.Delta.Role = "assistant"
+			responseText += choice.Delta.GetContentString()
+		}
+		response.Id = id
+		response.Model = info.UpstreamModelName
+		err = service.ObjectData(c, response)
+		if err != nil {
+			common.LogError(c, "error_rendering_stream_response: "+err.Error())
+		}
+	}
+
+	if err := scanner.Err(); err != nil {
+		common.LogError(c, "error_scanning_stream_response: "+err.Error())
+	}
+	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	if info.ShouldIncludeUsage {
+		response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
+		err := service.ObjectData(c, response)
+		if err != nil {
+			common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
+		}
+	}
+	service.Done(c)
+
+	err := resp.Body.Close()
+	if err != nil {
+		common.LogError(c, "close_response_body_failed: "+err.Error())
+	}
+
+	return nil, usage
+}
+
+func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	var response dto.TextResponse
+	err = json.Unmarshal(responseBody, &response)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	response.Model = info.UpstreamModelName
+	var responseText string
+	for _, choice := range response.Choices {
+		responseText += choice.Message.StringContent()
+	}
+	usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+	response.Usage = *usage
+	response.Id = service.GetResponseID(c)
+	jsonResponse, err := json.Marshal(response)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, _ = c.Writer.Write(jsonResponse)
+	return nil, usage
+}

+ 1 - 1
relay/channel/cohere/dto.go

@@ -7,7 +7,7 @@ type CohereRequest struct {
 	ChatHistory []ChatHistory `json:"chat_history"`
 	Message     string        `json:"message"`
 	Stream      bool          `json:"stream"`
-	MaxTokens   int64         `json:"max_tokens"`
+	MaxTokens   int           `json:"max_tokens"`
 }
 
 type ChatHistory struct {

+ 2 - 1
relay/common/relay_info.go

@@ -68,7 +68,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 		info.ApiVersion = GetAPIVersion(c)
 	}
 	if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
-		info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
+		info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
+		info.ChannelType == common.ChannelCloudflare {
 		info.SupportStreamOptions = true
 	}
 	return info

+ 3 - 0
relay/constant/api_type.go

@@ -22,6 +22,7 @@ const (
 	APITypeCohere
 	APITypeDify
 	APITypeJina
+	APITypeCloudflare
 
 	APITypeDummy // this one is only for count, do not add any channel after this
 )
@@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
 		apiType = APITypeDify
 	case common.ChannelTypeJina:
 		apiType = APITypeJina
+	case common.ChannelCloudflare:
+		apiType = APITypeCloudflare
 	}
 	if apiType == -1 {
 		return APITypeOpenAI, false

+ 3 - 0
relay/relay_adaptor.go

@@ -7,6 +7,7 @@ import (
 	"one-api/relay/channel/aws"
 	"one-api/relay/channel/baidu"
 	"one-api/relay/channel/claude"
+	"one-api/relay/channel/cloudflare"
 	"one-api/relay/channel/cohere"
 	"one-api/relay/channel/dify"
 	"one-api/relay/channel/gemini"
@@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 		return &dify.Adaptor{}
 	case constant.APITypeJina:
 		return &jina.Adaptor{}
+	case constant.APITypeCloudflare:
+		return &cloudflare.Adaptor{}
 	}
 	return nil
 }

+ 5 - 0
service/sse.go → service/relay.go

@@ -35,3 +35,8 @@ func ObjectData(c *gin.Context, object interface{}) error {
 func Done(c *gin.Context) {
 	StringData(c, "[DONE]")
 }
+
+func GetResponseID(c *gin.Context) string {
+	logID := c.GetString("X-Oneapi-Request-Id")
+	return fmt.Sprintf("chatcmpl-%s", logID)
+}

+ 1 - 0
web/src/constants/channel.constants.js

@@ -99,6 +99,7 @@ export const CHANNEL_OPTIONS = [
     color: 'orange',
     label: 'Google PaLM2',
   },
+  { key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
   { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
   { key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
   { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },

+ 18 - 0
web/src/pages/Channel/EditChannel.js

@@ -605,6 +605,24 @@ const EditChannel = (props) => {
               />
             </>
           )}
+          {inputs.type === 39 && (
+            <>
+              <div style={{ marginTop: 10 }}>
+                <Typography.Text strong>Account ID:</Typography.Text>
+              </div>
+              <Input
+                name='other'
+                placeholder={
+                  '请输入Account ID,例如:d6b5da8hk1awo8nap34ube6gh'
+                }
+                onChange={(value) => {
+                  handleInputChange('other', value);
+                }}
+                value={inputs.other}
+                autoComplete='new-password'
+              />
+            </>
+          )}
           <div style={{ marginTop: 10 }}>
             <Typography.Text strong>模型:</Typography.Text>
           </div>