Jelajahi Sumber

feat: add xai channel

feat: add xai channel

feat: add xai channel
HynoR 11 bulan lalu
induk
melakukan
f500eb17a8

+ 2 - 0
common/constants.go

@@ -235,6 +235,7 @@ const (
 	ChannelTypeVolcEngine     = 45
 	ChannelTypeBaiduV2        = 46
 	ChannelTypeXinference     = 47
+	ChannelTypeXai            = 48
 	ChannelTypeDummy          // this one is only for count, do not add any channel after this
 
 )
@@ -288,4 +289,5 @@ var ChannelBaseURLs = []string{
 	"https://ark.cn-beijing.volces.com",         //45
 	"https://qianfan.baidubce.com",              //46
 	"",                                          //47
+	"https://api.x.ai",                          //48
 }

+ 87 - 0
relay/channel/xai/adaptor.go

@@ -0,0 +1,87 @@
+package xai
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel"
+	"one-api/relay/channel/openai"
+	relaycommon "one-api/relay/common"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+	//TODO implement me
+	panic("implement me")
+	return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, req)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
+	return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	request.StreamOptions = nil
+	return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+	return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		err, usage = openai.OaiStreamHandler(c, resp, info)
+	} else {
+		err, usage = openai.OpenaiHandler(c, resp, info)
+	}
+	if _, ok := usage.(*dto.Usage); ok && usage != nil {
+		usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
+	}
+
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 9 - 0
relay/channel/xai/constants.go

@@ -0,0 +1,9 @@
+package xai
+
+var ModelList = []string{
+	// grok-2,grok-2-vision,grok-3-mini-beta,grok-3-beta
+	"grok-3-beta", "grok-3-mini-beta", "grok-2", "grok-2-vision", "grok-3", "grok-beta", "grok-vision-beta",
+	"grok-3-fast-beta", "grok-3-mini-fast-beta",
+}
+
+var ChannelName = "xai"

+ 3 - 0
relay/constant/api_type.go

@@ -32,6 +32,7 @@ const (
 	APITypeBaiduV2
 	APITypeOpenRouter
 	APITypeXinference
+	APITypeXai
 	APITypeDummy // this one is only for count, do not add any channel after this
 )
 
@@ -92,6 +93,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
 		apiType = APITypeOpenRouter
 	case common.ChannelTypeXinference:
 		apiType = APITypeXinference
+	case common.ChannelTypeXai:
+		apiType = APITypeXai
 	}
 	if apiType == -1 {
 		return APITypeOpenAI, false

+ 3 - 0
relay/relay_adaptor.go

@@ -25,6 +25,7 @@ import (
 	"one-api/relay/channel/tencent"
 	"one-api/relay/channel/vertex"
 	"one-api/relay/channel/volcengine"
+	"one-api/relay/channel/xai"
 	"one-api/relay/channel/xunfei"
 	"one-api/relay/channel/zhipu"
 	"one-api/relay/channel/zhipu_4v"
@@ -85,6 +86,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 		return &openai.Adaptor{}
 	case constant.APITypeXinference:
 		return &openai.Adaptor{}
+	case constant.APITypeXai:
+		return &xai.Adaptor{}
 	}
 	return nil
 }

+ 9 - 0
setting/operation_setting/model-ratio.go

@@ -199,6 +199,15 @@ var defaultModelRatio = map[string]float64{
 	"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
 	"llama-3-sonar-large-32k-chat":   1 / 1000 * USD,
 	"llama-3-sonar-large-32k-online": 1 / 1000 * USD,
+	// grok
+	"grok-3-beta":           1.5,
+	"grok-3-mini-beta":      0.15,
+	"grok-2":                1,
+	"grok-2-vision":         1,
+	"grok-beta":             2.5,
+	"grok-vision-beta":      2.5,
+	"grok-3-fast-beta":      2.5,
+	"grok-3-mini-fast-beta": 0.3,
 }
 
 var defaultModelPrice = map[string]float64{

+ 5 - 0
web/src/constants/channel.constants.js

@@ -115,4 +115,9 @@ export const CHANNEL_OPTIONS = [
     color: 'blue',
     label: '字节火山方舟、豆包、DeepSeek通用'
   },
+  {
+    value: 48,
+    color: 'blue',
+    label: 'xAI'
+  }
 ];