dto.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package aws
  2. import (
  3. "context"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "strings"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/dto"
  10. "github.com/QuantumNous/new-api/logger"
  11. )
  12. type AwsClaudeRequest struct {
  13. // AnthropicVersion should be "bedrock-2023-05-31"
  14. AnthropicVersion string `json:"anthropic_version"`
  15. AnthropicBeta json.RawMessage `json:"anthropic_beta,omitempty"`
  16. System any `json:"system,omitempty"`
  17. Messages []dto.ClaudeMessage `json:"messages"`
  18. MaxTokens uint `json:"max_tokens,omitempty"`
  19. Temperature *float64 `json:"temperature,omitempty"`
  20. TopP float64 `json:"top_p,omitempty"`
  21. TopK int `json:"top_k,omitempty"`
  22. StopSequences []string `json:"stop_sequences,omitempty"`
  23. Tools any `json:"tools,omitempty"`
  24. ToolChoice any `json:"tool_choice,omitempty"`
  25. Thinking *dto.Thinking `json:"thinking,omitempty"`
  26. OutputConfig json.RawMessage `json:"output_config,omitempty"`
  27. //Metadata json.RawMessage `json:"metadata,omitempty"`
  28. }
  29. func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {
  30. var awsClaudeRequest AwsClaudeRequest
  31. err := common.DecodeJson(requestBody, &awsClaudeRequest)
  32. if err != nil {
  33. return nil, err
  34. }
  35. awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
  36. // check header anthropic-beta
  37. anthropicBetaValues := requestHeader.Get("anthropic-beta")
  38. if len(anthropicBetaValues) > 0 {
  39. var tempArray []string
  40. tempArray = strings.Split(anthropicBetaValues, ",")
  41. if len(tempArray) > 0 {
  42. betaJson, err := json.Marshal(tempArray)
  43. if err != nil {
  44. return nil, err
  45. }
  46. awsClaudeRequest.AnthropicBeta = betaJson
  47. }
  48. }
  49. logger.LogJson(context.Background(), "json", awsClaudeRequest)
  50. return &awsClaudeRequest, nil
  51. }
  52. // NovaMessage Nova模型使用messages-v1格式
  53. type NovaMessage struct {
  54. Role string `json:"role"`
  55. Content []NovaContent `json:"content"`
  56. }
  57. type NovaContent struct {
  58. Text string `json:"text"`
  59. }
  60. type NovaRequest struct {
  61. SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
  62. Messages []NovaMessage `json:"messages"` // 对话消息列表
  63. InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
  64. }
  65. type NovaInferenceConfig struct {
  66. MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
  67. Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
  68. TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
  69. TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
  70. StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
  71. }
  72. // 转换OpenAI请求为Nova格式
  73. func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
  74. novaMessages := make([]NovaMessage, len(req.Messages))
  75. for i, msg := range req.Messages {
  76. novaMessages[i] = NovaMessage{
  77. Role: msg.Role,
  78. Content: []NovaContent{{Text: msg.StringContent()}},
  79. }
  80. }
  81. novaReq := &NovaRequest{
  82. SchemaVersion: "messages-v1",
  83. Messages: novaMessages,
  84. }
  85. // 设置推理配置
  86. if (req.MaxTokens != nil && *req.MaxTokens != 0) || (req.Temperature != nil && *req.Temperature != 0) || (req.TopP != nil && *req.TopP != 0) || (req.TopK != nil && *req.TopK != 0) || req.Stop != nil {
  87. novaReq.InferenceConfig = &NovaInferenceConfig{}
  88. if req.MaxTokens != nil && *req.MaxTokens != 0 {
  89. novaReq.InferenceConfig.MaxTokens = int(*req.MaxTokens)
  90. }
  91. if req.Temperature != nil && *req.Temperature != 0 {
  92. novaReq.InferenceConfig.Temperature = *req.Temperature
  93. }
  94. if req.TopP != nil && *req.TopP != 0 {
  95. novaReq.InferenceConfig.TopP = *req.TopP
  96. }
  97. if req.TopK != nil && *req.TopK != 0 {
  98. novaReq.InferenceConfig.TopK = *req.TopK
  99. }
  100. if req.Stop != nil {
  101. if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
  102. novaReq.InferenceConfig.StopSequences = stopSequences
  103. }
  104. }
  105. }
  106. return novaReq
  107. }
  108. // parseStopSequences 解析停止序列,支持字符串或字符串数组
  109. func parseStopSequences(stop any) []string {
  110. if stop == nil {
  111. return nil
  112. }
  113. switch v := stop.(type) {
  114. case string:
  115. if v != "" {
  116. return []string{v}
  117. }
  118. case []string:
  119. return v
  120. case []interface{}:
  121. var sequences []string
  122. for _, item := range v {
  123. if str, ok := item.(string); ok && str != "" {
  124. sequences = append(sequences, str)
  125. }
  126. }
  127. return sequences
  128. }
  129. return nil
  130. }