Kaynağa Gözat

Merge pull request #1780 from ShibaInu64/feature/support-amazon-nova

feat: support amazon nova model
Calcium-Ion 5 ay önce
ebeveyn
işleme
5d76e16324

+ 10 - 0
relay/channel/aws/adaptor.go

@@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 	if request == nil {
 		return nil, errors.New("request is nil")
 		return nil, errors.New("request is nil")
 	}
 	}
+	// 检查是否为Nova模型
+	if isNovaModel(request.Model) {
+		novaReq := convertToNovaRequest(request)
+		c.Set("request_model", request.Model)
+		c.Set("converted_request", novaReq)
+		c.Set("is_nova_model", true)
+		return novaReq, nil
+	}
 
 
+	// 原有的Claude模型处理逻辑
 	var claudeReq *dto.ClaudeRequest
 	var claudeReq *dto.ClaudeRequest
 	var err error
 	var err error
 	claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
 	claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
@@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	}
 	}
 	c.Set("request_model", claudeReq.Model)
 	c.Set("request_model", claudeReq.Model)
 	c.Set("converted_request", claudeReq)
 	c.Set("converted_request", claudeReq)
+	c.Set("is_nova_model", false)
 	return claudeReq, err
 	return claudeReq, err
 }
 }
 
 

+ 33 - 1
relay/channel/aws/constants.go

@@ -1,5 +1,7 @@
 package aws
 package aws
 
 
+import "strings"
+
 var awsModelIDMap = map[string]string{
 var awsModelIDMap = map[string]string{
 	"claude-instant-1.2":         "anthropic.claude-instant-v1",
 	"claude-instant-1.2":         "anthropic.claude-instant-v1",
 	"claude-2.0":                 "anthropic.claude-v2",
 	"claude-2.0":                 "anthropic.claude-v2",
@@ -14,6 +16,11 @@ var awsModelIDMap = map[string]string{
 	"claude-sonnet-4-20250514":   "anthropic.claude-sonnet-4-20250514-v1:0",
 	"claude-sonnet-4-20250514":   "anthropic.claude-sonnet-4-20250514-v1:0",
 	"claude-opus-4-20250514":     "anthropic.claude-opus-4-20250514-v1:0",
 	"claude-opus-4-20250514":     "anthropic.claude-opus-4-20250514-v1:0",
 	"claude-opus-4-1-20250805":   "anthropic.claude-opus-4-1-20250805-v1:0",
 	"claude-opus-4-1-20250805":   "anthropic.claude-opus-4-1-20250805-v1:0",
+	// Nova models
+	"nova-micro-v1:0":   "amazon.nova-micro-v1:0",
+	"nova-lite-v1:0":    "amazon.nova-lite-v1:0",
+	"nova-pro-v1:0":     "amazon.nova-pro-v1:0",
+	"nova-premier-v1:0": "amazon.nova-premier-v1:0",
 }
 }
 
 
 var awsModelCanCrossRegionMap = map[string]map[string]bool{
 var awsModelCanCrossRegionMap = map[string]map[string]bool{
@@ -58,7 +65,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
 	"anthropic.claude-opus-4-1-20250805-v1:0": {
 	"anthropic.claude-opus-4-1-20250805-v1:0": {
 		"us": true,
 		"us": true,
 	},
 	},
-}
+	// Nova models - all support three major regions
+	"amazon.nova-micro-v1:0": {
+		"us":   true,
+		"eu":   true,
+		"apac": true,
+	},
+	"amazon.nova-lite-v1:0": {
+		"us":   true,
+		"eu":   true,
+		"apac": true,
+	},
+	"amazon.nova-pro-v1:0": {
+		"us":   true,
+		"eu":   true,
+		"apac": true,
+	},
+	"amazon.nova-premier-v1:0": {
+		"us":   true,
+		"eu":   true,
+		"apac": true,
+	}}
 
 
 var awsRegionCrossModelPrefixMap = map[string]string{
 var awsRegionCrossModelPrefixMap = map[string]string{
 	"us": "us",
 	"us": "us",
@@ -67,3 +94,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
 }
 }
 
 
 var ChannelName = "aws"
 var ChannelName = "aws"
+
+// 判断是否为Nova模型
+func isNovaModel(modelId string) bool {
+	return strings.HasPrefix(modelId, "nova-")
+}

+ 89 - 0
relay/channel/aws/dto.go

@@ -34,3 +34,92 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
 		Thinking:         req.Thinking,
 		Thinking:         req.Thinking,
 	}
 	}
 }
 }
+
+// NovaMessage Nova模型使用messages-v1格式
+type NovaMessage struct {
+	Role    string        `json:"role"`
+	Content []NovaContent `json:"content"`
+}
+
+type NovaContent struct {
+	Text string `json:"text"`
+}
+
+type NovaRequest struct {
+	SchemaVersion   string               `json:"schemaVersion"`             // 请求版本,例如 "1.0"
+	Messages        []NovaMessage        `json:"messages"`                  // 对话消息列表
+	InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
+}
+
+type NovaInferenceConfig struct {
+	MaxTokens     int      `json:"maxTokens,omitempty"`     // 最大生成的 token 数
+	Temperature   float64  `json:"temperature,omitempty"`   // 随机性 (默认 0.7, 范围 0-1)
+	TopP          float64  `json:"topP,omitempty"`          // nucleus sampling (默认 0.9, 范围 0-1)
+	TopK          int      `json:"topK,omitempty"`          // 限制候选 token 数 (默认 50, 范围 0-128)
+	StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
+}
+
+// 转换OpenAI请求为Nova格式
+func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
+	novaMessages := make([]NovaMessage, len(req.Messages))
+	for i, msg := range req.Messages {
+		novaMessages[i] = NovaMessage{
+			Role:    msg.Role,
+			Content: []NovaContent{{Text: msg.StringContent()}},
+		}
+	}
+
+	novaReq := &NovaRequest{
+		SchemaVersion: "messages-v1",
+		Messages:      novaMessages,
+	}
+
+	// 设置推理配置
+	if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
+		novaReq.InferenceConfig = &NovaInferenceConfig{}
+		if req.MaxTokens != 0 {
+			novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
+		}
+		if req.Temperature != nil && *req.Temperature != 0 {
+			novaReq.InferenceConfig.Temperature = *req.Temperature
+		}
+		if req.TopP != 0 {
+			novaReq.InferenceConfig.TopP = req.TopP
+		}
+		if req.TopK != 0 {
+			novaReq.InferenceConfig.TopK = req.TopK
+		}
+		if req.Stop != nil {
+			if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
+				novaReq.InferenceConfig.StopSequences = stopSequences
+			}
+		}
+	}
+
+	return novaReq
+}
+
+// parseStopSequences 解析停止序列,支持字符串或字符串数组
+func parseStopSequences(stop any) []string {
+	if stop == nil {
+		return nil
+	}
+
+	switch v := stop.(type) {
+	case string:
+		if v != "" {
+			return []string{v}
+		}
+	case []string:
+		return v
+	case []interface{}:
+		var sequences []string
+		for _, item := range v {
+			if str, ok := item.(string); ok && str != "" {
+				sequences = append(sequences, str)
+			}
+		}
+		return sequences
+	}
+	return nil
+}

+ 84 - 0
relay/channel/aws/relay-aws.go

@@ -1,6 +1,7 @@
 package aws
 package aws
 
 
 import (
 import (
+	"encoding/json"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
@@ -93,7 +94,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
 	}
 	}
 
 
 	awsModelId := awsModelID(c.GetString("request_model"))
 	awsModelId := awsModelID(c.GetString("request_model"))
+	// 检查是否为Nova模型
+	isNova, _ := c.Get("is_nova_model")
+	if isNova == true {
+		// Nova模型也支持跨区域
+		awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
+		canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
+		if canCrossRegion {
+			awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
+		}
+		return handleNovaRequest(c, awsCli, info, awsModelId)
+	}
 
 
+	// 原有的Claude处理逻辑
 	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
 	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
 	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
 	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
 	if canCrossRegion {
 	if canCrossRegion {
@@ -209,3 +222,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
 	claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
 	return nil, claudeInfo.Usage
 	return nil, claudeInfo.Usage
 }
 }
+
+// Nova模型处理函数
+func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
+	novaReq_, ok := c.Get("converted_request")
+	if !ok {
+		return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
+	}
+	novaReq := novaReq_.(*NovaRequest)
+
+	// 使用InvokeModel API,但使用Nova格式的请求体
+	awsReq := &bedrockruntime.InvokeModelInput{
+		ModelId:     aws.String(awsModelId),
+		Accept:      aws.String("application/json"),
+		ContentType: aws.String("application/json"),
+	}
+
+	reqBody, err := json.Marshal(novaReq)
+	if err != nil {
+		return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
+	}
+	awsReq.Body = reqBody
+
+	awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
+	if err != nil {
+		return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
+	}
+
+	// 解析Nova响应
+	var novaResp struct {
+		Output struct {
+			Message struct {
+				Content []struct {
+					Text string `json:"text"`
+				} `json:"content"`
+			} `json:"message"`
+		} `json:"output"`
+		Usage struct {
+			InputTokens  int `json:"inputTokens"`
+			OutputTokens int `json:"outputTokens"`
+			TotalTokens  int `json:"totalTokens"`
+		} `json:"usage"`
+	}
+
+	if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
+		return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
+	}
+
+	// 构造OpenAI格式响应
+	response := dto.OpenAITextResponse{
+		Id:      helper.GetResponseID(c),
+		Object:  "chat.completion",
+		Created: common.GetTimestamp(),
+		Model:   info.UpstreamModelName,
+		Choices: []dto.OpenAITextResponseChoice{{
+			Index: 0,
+			Message: dto.Message{
+				Role:    "assistant",
+				Content: novaResp.Output.Message.Content[0].Text,
+			},
+			FinishReason: "stop",
+		}},
+		Usage: dto.Usage{
+			PromptTokens:     novaResp.Usage.InputTokens,
+			CompletionTokens: novaResp.Usage.OutputTokens,
+			TotalTokens:      novaResp.Usage.TotalTokens,
+		},
+	}
+
+	c.JSON(http.StatusOK, response)
+	return nil, &response.Usage
+}