Bladeren bron

add coze request

creamlike1024 9 maanden geleden
bovenliggende
commit
b2cad22952

+ 2 - 0
common/constants.go

@@ -240,6 +240,7 @@ const (
 	ChannelTypeBaiduV2        = 46
 	ChannelTypeXinference     = 47
 	ChannelTypeXai            = 48
+	ChannelTypeCoze           = 49
 	ChannelTypeDummy          // this one is only for count, do not add any channel after this
 
 )
@@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{
 	"https://qianfan.baidubce.com",              //46
 	"",                                          //47
 	"https://api.x.ai",                          //48
+	"https://api.coze.cn",                       //49
 }

+ 125 - 0
relay/channel/coze/adaptor.go

@@ -0,0 +1,125 @@
+package coze
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel"
+	"one-api/relay/common"
+	"time"
+
+	"github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+// ConvertAudioRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertClaudeRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertEmbeddingRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertImageRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertOpenAIRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return convertCozeChatRequest(*request), nil
+}
+
+// ConvertOpenAIResponsesRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// ConvertRerankRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, errors.New("not implemented")
+}
+
+// DoRequest implements channel.Adaptor.
+func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
+	// 首先发送创建消息请求,成功后再发送获取消息请求
+	// 发送创建消息请求
+	resp, err := channel.DoApiRequest(a, c, info, requestBody)
+	if err != nil {
+		return nil, err
+	}
+	// 解析 resp
+	var cozeResponse CozeChatResponse
+	respBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, err
+	}
+	err = json.Unmarshal(respBody, &cozeResponse)
+	if cozeResponse.Code != 0 {
+		return nil, errors.New(cozeResponse.Msg)
+	}
+	c.Set("coze_conversation_id", cozeResponse.Data.ConversationId)
+	c.Set("coze_chat_id", cozeResponse.Data.Id)
+	// 轮询检查消息是否完成
+	for {
+		err, isComplete := checkIfChatComplete(a, c, info)
+		if err != nil {
+			return nil, err
+		} else {
+			if isComplete {
+				break
+			}
+		}
+		time.Sleep(time.Second * 1)
+	}
+	// 发送获取消息请求
+	return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse implements channel.Adaptor.
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	err, usage = cozeChatHandler(c, resp, info)
+	return
+}
+
+// GetChannelName implements channel.Adaptor.
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}
+
+// GetModelList implements channel.Adaptor.
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+// GetRequestURL implements channel.Adaptor.
+func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl), nil
+}
+
+// Init implements channel.Adaptor.
+func (a *Adaptor) Init(info *common.RelayInfo) {
+
+}
+
+// SetupRequestHeader implements channel.Adaptor.
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, req)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
+	return nil
+}

+ 8 - 0
relay/channel/coze/constants.go

@@ -0,0 +1,8 @@
+package coze
+
+var ModelList = []string{
+	// TODO: 完整列表
+	"deepseek-v3",
+}
+
+var ChannelName = "coze"

+ 81 - 0
relay/channel/coze/dto.go

@@ -0,0 +1,81 @@
+package coze
+
+import "encoding/json"
+
+// type CozeResponse struct {
+// 	Code    int                  `json:"code"`
+// 	Message string               `json:"message"`
+// 	Data    CozeConversationData `json:"data"`
+// 	Detail  CozeConversationData `json:"detail"`
+// }
+
+// type CozeConversationData struct {
+// 	Id            string          `json:"id"`
+// 	CreatedAt     int64           `json:"created_at"`
+// 	MetaData      json.RawMessage `json:"meta_data"`
+// 	LastSectionId string          `json:"last_section_id"`
+// }
+
+// type CozeResponseDetail struct {
+// 	Logid string `json:"logid"`
+// }
+
+type CozeError struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+}
+
+// type CozeErrorWithStatusCode struct {
+// 	Error      CozeError `json:"error"`
+// 	StatusCode int
+// 	LocalError bool
+// }
+
+type CozeRequest struct {
+	BotId    string             `json:"bot_id,omitempty"`
+	MetaData json.RawMessage    `json:"meta_data,omitempty"`
+	Messages []CozeEnterMessage `json:"messages,omitempty"`
+}
+
+type CozeEnterMessage struct {
+	Role        string          `json:"role"`
+	Type        string          `json:"type,omitempty"`
+	Content     json.RawMessage `json:"content,omitempty"`
+	MetaData    json.RawMessage `json:"meta_data,omitempty"`
+	ContentType string          `json:"content_type,omitempty"`
+}
+
+type CozeChatRequest struct {
+	BotId              string             `json:"bot_id"`
+	UserId             string             `json:"user_id"`
+	AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
+	Stream             bool               `json:"stream,omitempty"`
+	CustomVariables    json.RawMessage    `json:"custom_variables,omitempty"`
+	AutoSaveHistory    bool               `json:"auto_save_history,omitempty"`
+	MetaData           json.RawMessage    `json:"meta_data,omitempty"`
+	ExtraParams        json.RawMessage    `json:"extra_params,omitempty"`
+	ShortcutCommand    json.RawMessage    `json:"shortcut_command,omitempty"`
+	Parameters         json.RawMessage    `json:"parameters,omitempty"`
+}
+
+type CozeChatResponse struct {
+	Code int                  `json:"code"`
+	Msg  string               `json:"msg"`
+	Data CozeChatResponseData `json:"data"`
+}
+
+type CozeChatResponseData struct {
+	Id             string        `json:"id"`
+	ConversationId string        `json:"conversation_id"`
+	BotId          string        `json:"bot_id"`
+	CreatedAt      int64         `json:"created_at"`
+	LastError      CozeError     `json:"last_error"`
+	Status         string        `json:"status"`
+	Usage          CozeChatUsage `json:"usage"`
+}
+
+type CozeChatUsage struct {
+	TokenCount  int `json:"token_count"`
+	OutputCount int `json:"output_count"`
+	InputCount  int `json:"input_count"`
+}

+ 121 - 0
relay/channel/coze/relay-coze.go

@@ -0,0 +1,121 @@
+package coze
+
+import (
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/common"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+
+	"github.com/gin-gonic/gin"
+)
+
+func convertCozeChatRequest(request dto.GeneralOpenAIRequest) *CozeRequest {
+	var messages []CozeEnterMessage
+	// 将 request的messages的role为user的content转换为CozeMessage
+	for _, message := range request.Messages {
+		if message.Role == "user" {
+			messages = append(messages, CozeEnterMessage{
+				Role:    "user",
+				Content: message.Content,
+				// TODO: support more content type
+				ContentType: "text",
+			})
+		}
+	}
+	cozeRequest := &CozeRequest{
+		// TODO: model to botid
+		BotId:    "1",
+		Messages: messages,
+	}
+	return cozeRequest
+}
+
+func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	// convert coze response to openai response
+	var response dto.TextResponse
+	var cozeResponse CozeChatResponse
+	err = json.Unmarshal(responseBody, &cozeResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	response.Model = info.UpstreamModelName
+	// TODO: 处理 cozeResponse
+	return nil, nil
+}
+
+func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
+	requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
+
+	requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
+	// 将 conversationId和chatId作为参数发送get请求
+	req, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return err, false
+	}
+	err = a.SetupRequestHeader(c, &req.Header, info)
+	if err != nil {
+		return err, false
+	}
+
+	resp, err := doRequest(req, info) // 调用 doRequest
+	if err != nil {
+		return err, false
+	}
+	if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
+		return fmt.Errorf("resp is nil"), false
+	}
+	defer resp.Body.Close() // 确保响应体被关闭
+
+	// 解析 resp 到 CozeChatResponse
+	var cozeResponse CozeChatResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("read response body failed: %w", err), false
+	}
+	err = json.Unmarshal(responseBody, &cozeResponse)
+	if err != nil {
+		return fmt.Errorf("unmarshal response body failed: %w", err), false
+	}
+	if cozeResponse.Data.Status == "completed" {
+		// 在上下文设置 usage
+		c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
+		c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
+		c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
+		return nil, true
+	} else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
+		return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
+	} else {
+		return nil, false
+	}
+}
+
+func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) {
+	var client *http.Client
+	var err error // 声明 err 变量
+	if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
+		client, err = service.NewProxyHttpClient(proxyURL.(string))
+		if err != nil {
+			return nil, fmt.Errorf("new proxy http client failed: %w", err)
+		}
+	} else {
+		client = service.GetHttpClient()
+	}
+	resp, err := client.Do(req)
+	if err != nil { // 增加对 client.Do(req) 返回错误的检查
+		return nil, fmt.Errorf("client.Do failed: %w", err)
+	}
+	_ = resp.Body.Close()
+	return resp, nil
+}

+ 3 - 0
relay/constant/api_type.go

@@ -33,6 +33,7 @@ const (
 	APITypeOpenRouter
 	APITypeXinference
 	APITypeXai
+	APITypeCoze
 	APITypeDummy // this one is only for count, do not add any channel after this
 )
 
@@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
 		apiType = APITypeXinference
 	case common.ChannelTypeXai:
 		apiType = APITypeXai
+	case common.ChannelTypeCoze:
+		apiType = APITypeCoze
 	}
 	if apiType == -1 {
 		return APITypeOpenAI, false

+ 3 - 0
relay/relay_adaptor.go

@@ -10,6 +10,7 @@ import (
 	"one-api/relay/channel/claude"
 	"one-api/relay/channel/cloudflare"
 	"one-api/relay/channel/cohere"
+	"one-api/relay/channel/coze"
 	"one-api/relay/channel/deepseek"
 	"one-api/relay/channel/dify"
 	"one-api/relay/channel/gemini"
@@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 		return &openai.Adaptor{}
 	case constant.APITypeXai:
 		return &xai.Adaptor{}
+	case constant.APITypeCoze:
+		return &coze.Adaptor{}
 	}
 	return nil
 }