Bläddra i källkod

feat: support amazon nova

huanghejian 5 månader sedan
förälder
incheckning
47aaa695b2
4 ändrade filer med 152 tillägg och 0 borttagningar
  1. 10 0
      relay/channel/aws/adaptor.go
  2. 11 0
      relay/channel/aws/constants.go
  3. 53 0
      relay/channel/aws/dto.go
  4. 78 0
      relay/channel/aws/relay-aws.go

+ 10 - 0
relay/channel/aws/adaptor.go

@@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
+	// 检查是否为Nova模型
+	if isNovaModel(request.Model) {
+		novaReq := convertToNovaRequest(request)
+		c.Set("request_model", request.Model)
+		c.Set("converted_request", novaReq)
+		c.Set("is_nova_model", true)
+		return novaReq, nil
+	}
 
+	// 原有的Claude模型处理逻辑
 	var claudeReq *dto.ClaudeRequest
 	var err error
 	claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
@@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
 	}
 	c.Set("request_model", claudeReq.Model)
 	c.Set("converted_request", claudeReq)
+	c.Set("is_nova_model", false)
 	return claudeReq, err
 }
 

+ 11 - 0
relay/channel/aws/constants.go

@@ -1,5 +1,7 @@
 package aws
 
+import "strings"
+
 var awsModelIDMap = map[string]string{
 	"claude-instant-1.2":         "anthropic.claude-instant-v1",
 	"claude-2.0":                 "anthropic.claude-v2",
@@ -14,6 +16,10 @@ var awsModelIDMap = map[string]string{
 	"claude-sonnet-4-20250514":   "anthropic.claude-sonnet-4-20250514-v1:0",
 	"claude-opus-4-20250514":     "anthropic.claude-opus-4-20250514-v1:0",
 	"claude-opus-4-1-20250805":   "anthropic.claude-opus-4-1-20250805-v1:0",
+	// Nova models
+	"amazon.nova-micro-v1:0": "us.amazon.nova-micro-v1:0",
+	"amazon.nova-lite-v1:0":  "us.amazon.nova-lite-v1:0",
+	"amazon.nova-pro-v1:0":   "us.amazon.nova-pro-v1:0",
 }
 
 var awsModelCanCrossRegionMap = map[string]map[string]bool{
@@ -67,3 +73,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
 }
 
 var ChannelName = "aws"
+
+// 判断是否为Nova模型
+func isNovaModel(modelId string) bool {
+	return strings.HasPrefix(modelId, "amazon.nova-")
+}

+ 53 - 0
relay/channel/aws/dto.go

@@ -34,3 +34,56 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
 		Thinking:         req.Thinking,
 	}
 }
+
+// Nova模型使用messages-v1格式
+type NovaMessage struct {
+	Role    string        `json:"role"`
+	Content []NovaContent `json:"content"`
+}
+
+type NovaContent struct {
+	Text string `json:"text"`
+}
+
+type NovaRequest struct {
+	SchemaVersion   string              `json:"schemaVersion"`
+	Messages        []NovaMessage       `json:"messages"`
+	InferenceConfig NovaInferenceConfig `json:"inferenceConfig,omitempty"`
+}
+
+type NovaInferenceConfig struct {
+	MaxTokens   int     `json:"maxTokens,omitempty"`
+	Temperature float64 `json:"temperature,omitempty"`
+	TopP        float64 `json:"topP,omitempty"`
+}
+
+// 转换OpenAI请求为Nova格式
+func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
+	novaMessages := make([]NovaMessage, len(req.Messages))
+	for i, msg := range req.Messages {
+		novaMessages[i] = NovaMessage{
+			Role:    msg.Role,
+			Content: []NovaContent{{Text: msg.StringContent()}},
+		}
+	}
+
+	novaReq := &NovaRequest{
+		SchemaVersion: "messages-v1",
+		Messages:      novaMessages,
+	}
+
+	// 设置推理配置
+	if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 {
+		if req.MaxTokens != 0 {
+			novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
+		}
+		if req.Temperature != nil && *req.Temperature != 0 {
+			novaReq.InferenceConfig.Temperature = *req.Temperature
+		}
+		if req.TopP != 0 {
+			novaReq.InferenceConfig.TopP = req.TopP
+		}
+	}
+
+	return novaReq
+}

+ 78 - 0
relay/channel/aws/relay-aws.go

@@ -1,6 +1,7 @@
 package aws
 
 import (
+	"encoding/json"
 	"fmt"
 	"net/http"
 	"one-api/common"
@@ -93,7 +94,13 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
 	}
 
 	awsModelId := awsModelID(c.GetString("request_model"))
+	// 检查是否为Nova模型
+	isNova, _ := c.Get("is_nova_model")
+	if isNova == true {
+		return handleNovaRequest(c, awsCli, info, awsModelId)
+	}
 
+	// 原有的Claude处理逻辑
 	awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
 	canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
 	if canCrossRegion {
@@ -209,3 +216,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
 	return nil, claudeInfo.Usage
 }
+
+// Nova模型处理函数
+func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
+	novaReq_, ok := c.Get("converted_request")
+	if !ok {
+		return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
+	}
+	novaReq := novaReq_.(*NovaRequest)
+
+	// 使用InvokeModel API,但使用Nova格式的请求体
+	awsReq := &bedrockruntime.InvokeModelInput{
+		ModelId:     aws.String(awsModelId),
+		Accept:      aws.String("application/json"),
+		ContentType: aws.String("application/json"),
+	}
+
+	reqBody, err := json.Marshal(novaReq)
+	if err != nil {
+		return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
+	}
+	awsReq.Body = reqBody
+
+	awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
+	if err != nil {
+		return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
+	}
+
+	// 解析Nova响应
+	var novaResp struct {
+		Output struct {
+			Message struct {
+				Content []struct {
+					Text string `json:"text"`
+				} `json:"content"`
+			} `json:"message"`
+		} `json:"output"`
+		Usage struct {
+			InputTokens  int `json:"inputTokens"`
+			OutputTokens int `json:"outputTokens"`
+			TotalTokens  int `json:"totalTokens"`
+		} `json:"usage"`
+	}
+
+	if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
+		return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
+	}
+
+	// 构造OpenAI格式响应
+	response := dto.OpenAITextResponse{
+		Id:      helper.GetResponseID(c),
+		Object:  "chat.completion",
+		Created: common.GetTimestamp(),
+		Model:   info.UpstreamModelName,
+		Choices: []dto.OpenAITextResponseChoice{{
+			Index: 0,
+			Message: dto.Message{
+				Role:    "assistant",
+				Content: novaResp.Output.Message.Content[0].Text,
+			},
+			FinishReason: "stop",
+		}},
+		Usage: dto.Usage{
+			PromptTokens:     novaResp.Usage.InputTokens,
+			CompletionTokens: novaResp.Usage.OutputTokens,
+			TotalTokens:      novaResp.Usage.TotalTokens,
+		},
+	}
+
+	c.JSON(http.StatusOK, response)
+	return nil, &response.Usage
+}