dto.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package aws
  2. import (
  3. "context"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "strings"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/dto"
  10. "github.com/QuantumNous/new-api/logger"
  11. )
  12. type AwsClaudeRequest struct {
  13. // AnthropicVersion should be "bedrock-2023-05-31"
  14. AnthropicVersion string `json:"anthropic_version"`
  15. AnthropicBeta json.RawMessage `json:"anthropic_beta,omitempty"`
  16. System any `json:"system,omitempty"`
  17. Messages []dto.ClaudeMessage `json:"messages"`
  18. MaxTokens uint `json:"max_tokens,omitempty"`
  19. Temperature *float64 `json:"temperature,omitempty"`
  20. TopP float64 `json:"top_p,omitempty"`
  21. TopK int `json:"top_k,omitempty"`
  22. StopSequences []string `json:"stop_sequences,omitempty"`
  23. Tools any `json:"tools,omitempty"`
  24. ToolChoice any `json:"tool_choice,omitempty"`
  25. Thinking *dto.Thinking `json:"thinking,omitempty"`
  26. OutputConfig json.RawMessage `json:"output_config,omitempty"`
  27. }
  28. func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {
  29. var awsClaudeRequest AwsClaudeRequest
  30. err := common.DecodeJson(requestBody, &awsClaudeRequest)
  31. if err != nil {
  32. return nil, err
  33. }
  34. awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
  35. // check header anthropic-beta
  36. anthropicBetaValues := requestHeader.Get("anthropic-beta")
  37. if len(anthropicBetaValues) > 0 {
  38. var tempArray []string
  39. tempArray = strings.Split(anthropicBetaValues, ",")
  40. if len(tempArray) > 0 {
  41. betaJson, err := json.Marshal(tempArray)
  42. if err != nil {
  43. return nil, err
  44. }
  45. awsClaudeRequest.AnthropicBeta = betaJson
  46. }
  47. }
  48. logger.LogJson(context.Background(), "json", awsClaudeRequest)
  49. return &awsClaudeRequest, nil
  50. }
  51. // NovaMessage Nova模型使用messages-v1格式
  52. type NovaMessage struct {
  53. Role string `json:"role"`
  54. Content []NovaContent `json:"content"`
  55. }
  56. type NovaContent struct {
  57. Text string `json:"text"`
  58. }
  59. type NovaRequest struct {
  60. SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
  61. Messages []NovaMessage `json:"messages"` // 对话消息列表
  62. InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
  63. }
  64. type NovaInferenceConfig struct {
  65. MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
  66. Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
  67. TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
  68. TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
  69. StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
  70. }
  71. // 转换OpenAI请求为Nova格式
  72. func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
  73. novaMessages := make([]NovaMessage, len(req.Messages))
  74. for i, msg := range req.Messages {
  75. novaMessages[i] = NovaMessage{
  76. Role: msg.Role,
  77. Content: []NovaContent{{Text: msg.StringContent()}},
  78. }
  79. }
  80. novaReq := &NovaRequest{
  81. SchemaVersion: "messages-v1",
  82. Messages: novaMessages,
  83. }
  84. // 设置推理配置
  85. if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
  86. novaReq.InferenceConfig = &NovaInferenceConfig{}
  87. if req.MaxTokens != 0 {
  88. novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
  89. }
  90. if req.Temperature != nil && *req.Temperature != 0 {
  91. novaReq.InferenceConfig.Temperature = *req.Temperature
  92. }
  93. if req.TopP != 0 {
  94. novaReq.InferenceConfig.TopP = req.TopP
  95. }
  96. if req.TopK != 0 {
  97. novaReq.InferenceConfig.TopK = req.TopK
  98. }
  99. if req.Stop != nil {
  100. if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
  101. novaReq.InferenceConfig.StopSequences = stopSequences
  102. }
  103. }
  104. }
  105. return novaReq
  106. }
  107. // parseStopSequences 解析停止序列,支持字符串或字符串数组
  108. func parseStopSequences(stop any) []string {
  109. if stop == nil {
  110. return nil
  111. }
  112. switch v := stop.(type) {
  113. case string:
  114. if v != "" {
  115. return []string{v}
  116. }
  117. case []string:
  118. return v
  119. case []interface{}:
  120. var sequences []string
  121. for _, item := range v {
  122. if str, ok := item.(string); ok && str != "" {
  123. sequences = append(sequences, str)
  124. }
  125. }
  126. return sequences
  127. }
  128. return nil
  129. }