Просмотр исходного кода

Merge pull request #1067 from QuantumNous/coze

Coze 渠道
Calcium-Ion 9 месяцев назад
Родитель
Сommit
7171a69512

+ 2 - 0
common/constants.go

@@ -240,6 +240,7 @@ const (
 	ChannelTypeBaiduV2        = 46
 	ChannelTypeXinference     = 47
 	ChannelTypeXai            = 48
+	ChannelTypeCoze           = 49
 	ChannelTypeDummy          // this one is only for count, do not add any channel after this
 
 )
@@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{
 	"https://qianfan.baidubce.com",              //46
 	"",                                          //47
 	"https://api.x.ai",                          //48
+	"https://api.coze.cn",                       //49
 }

+ 2 - 0
middleware/distributor.go

@@ -240,5 +240,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 		c.Set("api_version", channel.Other)
 	case common.ChannelTypeMokaAI:
 		c.Set("api_version", channel.Other)
+	case common.ChannelTypeCoze:
+		c.Set("bot_id", channel.Other)
 	}
 }

+ 132 - 0
relay/channel/coze/adaptor.go

@@ -0,0 +1,132 @@
+package coze
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel"
+	"one-api/relay/common"
+	"time"
+
+	"github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+// ConvertAudioRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertClaudeRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertEmbeddingRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertImageRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertOpenAIRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return convertCozeChatRequest(c, *request), nil
+}
+
+// ConvertOpenAIResponsesRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertRerankRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// DoRequest implements channel.Adaptor.
+func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
+	if info.IsStream {
+		return channel.DoApiRequest(a, c, info, requestBody)
+	}
+	// 首先发送创建消息请求,成功后再发送获取消息请求
+	// 发送创建消息请求
+	resp, err := channel.DoApiRequest(a, c, info, requestBody)
+	if err != nil {
+		return nil, err
+	}
+	// 解析 resp
+	var cozeResponse CozeChatResponse
+	respBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, err
+	}
+	err = json.Unmarshal(respBody, &cozeResponse)
+	if cozeResponse.Code != 0 {
+		return nil, errors.New(cozeResponse.Msg)
+	}
+	c.Set("coze_conversation_id", cozeResponse.Data.ConversationId)
+	c.Set("coze_chat_id", cozeResponse.Data.Id)
+	// 轮询检查消息是否完成
+	for {
+		err, isComplete := checkIfChatComplete(a, c, info)
+		if err != nil {
+			return nil, err
+		} else {
+			if isComplete {
+				break
+			}
+		}
+		time.Sleep(time.Second * 1)
+	}
+	// 发送获取消息请求
+	return getChatDetail(a, c, info)
+}
+
+// DoResponse implements channel.Adaptor.
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		err, usage = cozeChatStreamHandler(c, resp, info)
+	} else {
+		err, usage = cozeChatHandler(c, resp, info)
+	}
+	return
+}
+
+// GetChannelName implements channel.Adaptor.
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}
+
+// GetModelList implements channel.Adaptor.
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+// GetRequestURL implements channel.Adaptor.
+func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
+}
+
+// Init implements channel.Adaptor.
+func (a *Adaptor) Init(info *common.RelayInfo) {
+
+}
+
+// SetupRequestHeader implements channel.Adaptor.
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, req)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
+	return nil
+}

+ 30 - 0
relay/channel/coze/constants.go

@@ -0,0 +1,30 @@
+package coze
+
+var ModelList = []string{
+	"moonshot-v1-8k",
+	"moonshot-v1-32k",
+	"moonshot-v1-128k",
+	"Baichuan4",
+	"abab6.5s-chat-pro",
+	"glm-4-0520",
+	"qwen-max",
+	"deepseek-r1",
+	"deepseek-v3",
+	"deepseek-r1-distill-qwen-32b",
+	"deepseek-r1-distill-qwen-7b",
+	"step-1v-8k",
+	"step-1.5v-mini",
+	"Doubao-pro-32k",
+	"Doubao-pro-256k",
+	"Doubao-lite-128k",
+	"Doubao-lite-32k",
+	"Doubao-vision-lite-32k",
+	"Doubao-vision-pro-32k",
+	"Doubao-1.5-pro-vision-32k",
+	"Doubao-1.5-lite-32k",
+	"Doubao-1.5-pro-32k",
+	"Doubao-1.5-thinking-pro",
+	"Doubao-1.5-pro-256k",
+}
+
+var ChannelName = "coze"

+ 78 - 0
relay/channel/coze/dto.go

@@ -0,0 +1,78 @@
+package coze
+
+import "encoding/json"
+
+type CozeError struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+}
+
+type CozeEnterMessage struct {
+	Role        string          `json:"role"`
+	Type        string          `json:"type,omitempty"`
+	Content     json.RawMessage `json:"content,omitempty"`
+	MetaData    json.RawMessage `json:"meta_data,omitempty"`
+	ContentType string          `json:"content_type,omitempty"`
+}
+
+type CozeChatRequest struct {
+	BotId              string             `json:"bot_id"`
+	UserId             string             `json:"user_id"`
+	AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
+	Stream             bool               `json:"stream,omitempty"`
+	CustomVariables    json.RawMessage    `json:"custom_variables,omitempty"`
+	AutoSaveHistory    bool               `json:"auto_save_history,omitempty"`
+	MetaData           json.RawMessage    `json:"meta_data,omitempty"`
+	ExtraParams        json.RawMessage    `json:"extra_params,omitempty"`
+	ShortcutCommand    json.RawMessage    `json:"shortcut_command,omitempty"`
+	Parameters         json.RawMessage    `json:"parameters,omitempty"`
+}
+
+type CozeChatResponse struct {
+	Code int                  `json:"code"`
+	Msg  string               `json:"msg"`
+	Data CozeChatResponseData `json:"data"`
+}
+
+type CozeChatResponseData struct {
+	Id             string        `json:"id"`
+	ConversationId string        `json:"conversation_id"`
+	BotId          string        `json:"bot_id"`
+	CreatedAt      int64         `json:"created_at"`
+	LastError      CozeError     `json:"last_error"`
+	Status         string        `json:"status"`
+	Usage          CozeChatUsage `json:"usage"`
+}
+
+type CozeChatUsage struct {
+	TokenCount  int `json:"token_count"`
+	OutputCount int `json:"output_count"`
+	InputCount  int `json:"input_count"`
+}
+
+type CozeChatDetailResponse struct {
+	Data   []CozeChatV3MessageDetail `json:"data"`
+	Code   int                       `json:"code"`
+	Msg    string                    `json:"msg"`
+	Detail CozeResponseDetail        `json:"detail"`
+}
+
+type CozeChatV3MessageDetail struct {
+	Id               string          `json:"id"`
+	Role             string          `json:"role"`
+	Type             string          `json:"type"`
+	BotId            string          `json:"bot_id"`
+	ChatId           string          `json:"chat_id"`
+	Content          json.RawMessage `json:"content"`
+	MetaData         json.RawMessage `json:"meta_data"`
+	CreatedAt        int64           `json:"created_at"`
+	SectionId        string          `json:"section_id"`
+	UpdatedAt        int64           `json:"updated_at"`
+	ContentType      string          `json:"content_type"`
+	ConversationId   string          `json:"conversation_id"`
+	ReasoningContent string          `json:"reasoning_content"`
+}
+
+type CozeResponseDetail struct {
+	Logid string `json:"logid"`
+}

+ 300 - 0
relay/channel/coze/relay-coze.go

@@ -0,0 +1,300 @@
+package coze
+
+import (
+	"bufio"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
+	"one-api/service"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
+
+func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
+	var messages []CozeEnterMessage
+	// 将 request的messages的role为user的content转换为CozeMessage
+	for _, message := range request.Messages {
+		if message.Role == "user" {
+			messages = append(messages, CozeEnterMessage{
+				Role:    "user",
+				Content: message.Content,
+				// TODO: support more content type
+				ContentType: "text",
+			})
+		}
+	}
+	user := request.User
+	if user == "" {
+		user = helper.GetResponseID(c)
+	}
+	cozeRequest := &CozeChatRequest{
+		BotId:              c.GetString("bot_id"),
+		UserId:             user,
+		AdditionalMessages: messages,
+		Stream:             request.Stream,
+	}
+	return cozeRequest
+}
+
+func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	// convert coze response to openai response
+	var response dto.TextResponse
+	var cozeResponse CozeChatDetailResponse
+	response.Model = info.UpstreamModelName
+	err = json.Unmarshal(responseBody, &cozeResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if cozeResponse.Code != 0 {
+		return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
+	}
+	// 从上下文获取 usage
+	var usage dto.Usage
+	usage.PromptTokens = c.GetInt("coze_input_count")
+	usage.CompletionTokens = c.GetInt("coze_output_count")
+	usage.TotalTokens = c.GetInt("coze_token_count")
+	response.Usage = usage
+	response.Id = helper.GetResponseID(c)
+
+	var responseContent json.RawMessage
+	for _, data := range cozeResponse.Data {
+		if data.Type == "answer" {
+			responseContent = data.Content
+			response.Created = data.CreatedAt
+		}
+	}
+	// 添加 response.Choices
+	response.Choices = []dto.OpenAITextResponseChoice{
+		{
+			Index:        0,
+			Message:      dto.Message{Role: "assistant", Content: responseContent},
+			FinishReason: "stop",
+		},
+	}
+	jsonResponse, err := json.Marshal(response)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, _ = c.Writer.Write(jsonResponse)
+
+	return nil, &usage
+}
+
+func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(bufio.ScanLines)
+	helper.SetEventStreamHeaders(c)
+	id := helper.GetResponseID(c)
+	var responseText string
+
+	var currentEvent string
+	var currentData string
+	var usage dto.Usage
+
+	for scanner.Scan() {
+		line := scanner.Text()
+
+		if line == "" {
+			if currentEvent != "" && currentData != "" {
+				// handle last event
+				handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+				currentEvent = ""
+				currentData = ""
+			}
+			continue
+		}
+
+		if strings.HasPrefix(line, "event:") {
+			currentEvent = strings.TrimSpace(line[6:])
+			continue
+		}
+
+		if strings.HasPrefix(line, "data:") {
+			currentData = strings.TrimSpace(line[5:])
+			continue
+		}
+	}
+
+	// Last event
+	if currentEvent != "" && currentData != "" {
+		handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+	}
+
+	if err := scanner.Err(); err != nil {
+		return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
+	}
+	helper.Done(c)
+
+	if usage.TotalTokens == 0 {
+		usage.PromptTokens = info.PromptTokens
+		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
+		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+	}
+
+	return nil, &usage
+}
+
+func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
+	switch event {
+	case "conversation.chat.completed":
+		// 将 data 解析为 CozeChatResponseData
+		var chatData CozeChatResponseData
+		err := json.Unmarshal([]byte(data), &chatData)
+		if err != nil {
+			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			return
+		}
+
+		usage.PromptTokens = chatData.Usage.InputCount
+		usage.CompletionTokens = chatData.Usage.OutputCount
+		usage.TotalTokens = chatData.Usage.TokenCount
+
+		finishReason := "stop"
+		stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
+		helper.ObjectData(c, stopResponse)
+
+	case "conversation.message.delta":
+		// 将 data 解析为 CozeChatV3MessageDetail
+		var messageData CozeChatV3MessageDetail
+		err := json.Unmarshal([]byte(data), &messageData)
+		if err != nil {
+			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			return
+		}
+
+		var content string
+		err = json.Unmarshal(messageData.Content, &content)
+		if err != nil {
+			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			return
+		}
+
+		*responseText += content
+
+		openaiResponse := dto.ChatCompletionsStreamResponse{
+			Id:      id,
+			Object:  "chat.completion.chunk",
+			Created: common.GetTimestamp(),
+			Model:   info.UpstreamModelName,
+		}
+
+		choice := dto.ChatCompletionsStreamResponseChoice{
+			Index: 0,
+		}
+		choice.Delta.SetContentString(content)
+		openaiResponse.Choices = append(openaiResponse.Choices, choice)
+
+		helper.ObjectData(c, openaiResponse)
+
+	case "error":
+		var errorData CozeError
+		err := json.Unmarshal([]byte(data), &errorData)
+		if err != nil {
+			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			return
+		}
+
+		common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
+	}
+}
+
+func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
+	requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
+
+	requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
+	// 将 conversationId和chatId作为参数发送get请求
+	req, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return err, false
+	}
+	err = a.SetupRequestHeader(c, &req.Header, info)
+	if err != nil {
+		return err, false
+	}
+
+	resp, err := doRequest(req, info) // 调用 doRequest
+	if err != nil {
+		return err, false
+	}
+	if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
+		return fmt.Errorf("resp is nil"), false
+	}
+	defer resp.Body.Close() // 确保响应体被关闭
+
+	// 解析 resp 到 CozeChatResponse
+	var cozeResponse CozeChatResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("read response body failed: %w", err), false
+	}
+	err = json.Unmarshal(responseBody, &cozeResponse)
+	if err != nil {
+		return fmt.Errorf("unmarshal response body failed: %w", err), false
+	}
+	if cozeResponse.Data.Status == "completed" {
+		// 在上下文设置 usage
+		c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
+		c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
+		c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
+		return nil, true
+	} else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
+		return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
+	} else {
+		return nil, false
+	}
+}
+
+func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
+	requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
+
+	requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
+	req, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return nil, fmt.Errorf("new request failed: %w", err)
+	}
+	err = a.SetupRequestHeader(c, &req.Header, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	resp, err := doRequest(req, info)
+	if err != nil {
+		return nil, fmt.Errorf("do request failed: %w", err)
+	}
+	return resp, nil
+}
+
+func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
+	var client *http.Client
+	var err error // 声明 err 变量
+	if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
+		client, err = service.NewProxyHttpClient(proxyURL.(string))
+		if err != nil {
+			return nil, fmt.Errorf("new proxy http client failed: %w", err)
+		}
+	} else {
+		client = service.GetHttpClient()
+	}
+	resp, err := client.Do(req)
+	if err != nil { // 增加对 client.Do(req) 返回错误的检查
+		return nil, fmt.Errorf("client.Do failed: %w", err)
+	}
+	// _ = resp.Body.Close()
+	return resp, nil
+}

+ 3 - 0
relay/constant/api_type.go

@@ -33,6 +33,7 @@ const (
 	APITypeOpenRouter
 	APITypeXinference
 	APITypeXai
+	APITypeCoze
 	APITypeDummy // this one is only for count, do not add any channel after this
 )
 
@@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
 		apiType = APITypeXinference
 	case common.ChannelTypeXai:
 		apiType = APITypeXai
+	case common.ChannelTypeCoze:
+		apiType = APITypeCoze
 	}
 	if apiType == -1 {
 		return APITypeOpenAI, false

+ 3 - 0
relay/relay_adaptor.go

@@ -10,6 +10,7 @@ import (
 	"one-api/relay/channel/claude"
 	"one-api/relay/channel/cloudflare"
 	"one-api/relay/channel/cohere"
+	"one-api/relay/channel/coze"
 	"one-api/relay/channel/deepseek"
 	"one-api/relay/channel/dify"
 	"one-api/relay/channel/gemini"
@@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 		return &openai.Adaptor{}
 	case constant.APITypeXai:
 		return &xai.Adaptor{}
+	case constant.APITypeCoze:
+		return &coze.Adaptor{}
 	}
 	return nil
 }

+ 7 - 2
web/src/constants/channel.constants.js

@@ -118,6 +118,11 @@ export const CHANNEL_OPTIONS = [
   {
     value: 48,
     color: 'blue',
-    label: 'xAI'
-  }
+    label: 'xAI',
+  },
+  {
+    value: 49,
+    color: 'blue',
+    label: 'Coze',
+  },
 ];

+ 16 - 0
web/src/pages/Channel/EditChannel.js

@@ -838,6 +838,22 @@ const EditChannel = (props) => {
               />
             </>
           )}
+          {inputs.type === 49 && (
+            <>
+              <div style={{ marginTop: 10 }}>
+                <Typography.Text strong>智能体ID:</Typography.Text>
+              </div>
+              <Input
+                name='other'
+                placeholder={'请输入智能体ID,例如:7342866812345'}
+                onChange={(value) => {
+                  handleInputChange('other', value);
+                }}
+                value={inputs.other}
+                autoComplete='new-password'
+              />
+            </>
+          )}
           <div style={{ marginTop: 10 }}>
             <Typography.Text strong>{t('模型')}:</Typography.Text>
           </div>