dto.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package aws
  2. import (
  3. "context"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "strings"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/dto"
  10. "github.com/QuantumNous/new-api/logger"
  11. )
  12. type AwsClaudeRequest struct {
  13. // AnthropicVersion should be "bedrock-2023-05-31"
  14. AnthropicVersion string `json:"anthropic_version"`
  15. AnthropicBeta json.RawMessage `json:"anthropic_beta,omitempty"`
  16. System any `json:"system,omitempty"`
  17. Messages []dto.ClaudeMessage `json:"messages"`
  18. MaxTokens uint `json:"max_tokens,omitempty"`
  19. Temperature *float64 `json:"temperature,omitempty"`
  20. TopP float64 `json:"top_p,omitempty"`
  21. TopK int `json:"top_k,omitempty"`
  22. StopSequences []string `json:"stop_sequences,omitempty"`
  23. Tools any `json:"tools,omitempty"`
  24. ToolChoice any `json:"tool_choice,omitempty"`
  25. Thinking *dto.Thinking `json:"thinking,omitempty"`
  26. }
  27. func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {
  28. var awsClaudeRequest AwsClaudeRequest
  29. err := common.DecodeJson(requestBody, &awsClaudeRequest)
  30. if err != nil {
  31. return nil, err
  32. }
  33. awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
  34. // check header anthropic-beta
  35. anthropicBetaValues := requestHeader.Get("anthropic-beta")
  36. if len(anthropicBetaValues) > 0 {
  37. var tempArray []string
  38. tempArray = strings.Split(anthropicBetaValues, ",")
  39. if len(tempArray) > 0 {
  40. betaJson, err := json.Marshal(tempArray)
  41. if err != nil {
  42. return nil, err
  43. }
  44. awsClaudeRequest.AnthropicBeta = betaJson
  45. }
  46. }
  47. logger.LogJson(context.Background(), "json", awsClaudeRequest)
  48. return &awsClaudeRequest, nil
  49. }
  50. // NovaMessage Nova模型使用messages-v1格式
  51. type NovaMessage struct {
  52. Role string `json:"role"`
  53. Content []NovaContent `json:"content"`
  54. }
  55. type NovaContent struct {
  56. Text string `json:"text"`
  57. }
  58. type NovaRequest struct {
  59. SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
  60. Messages []NovaMessage `json:"messages"` // 对话消息列表
  61. InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
  62. }
  63. type NovaInferenceConfig struct {
  64. MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
  65. Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
  66. TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
  67. TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
  68. StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
  69. }
  70. // 转换OpenAI请求为Nova格式
  71. func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
  72. novaMessages := make([]NovaMessage, len(req.Messages))
  73. for i, msg := range req.Messages {
  74. novaMessages[i] = NovaMessage{
  75. Role: msg.Role,
  76. Content: []NovaContent{{Text: msg.StringContent()}},
  77. }
  78. }
  79. novaReq := &NovaRequest{
  80. SchemaVersion: "messages-v1",
  81. Messages: novaMessages,
  82. }
  83. // 设置推理配置
  84. if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
  85. novaReq.InferenceConfig = &NovaInferenceConfig{}
  86. if req.MaxTokens != 0 {
  87. novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
  88. }
  89. if req.Temperature != nil && *req.Temperature != 0 {
  90. novaReq.InferenceConfig.Temperature = *req.Temperature
  91. }
  92. if req.TopP != 0 {
  93. novaReq.InferenceConfig.TopP = req.TopP
  94. }
  95. if req.TopK != 0 {
  96. novaReq.InferenceConfig.TopK = req.TopK
  97. }
  98. if req.Stop != nil {
  99. if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
  100. novaReq.InferenceConfig.StopSequences = stopSequences
  101. }
  102. }
  103. }
  104. return novaReq
  105. }
  106. // parseStopSequences 解析停止序列,支持字符串或字符串数组
  107. func parseStopSequences(stop any) []string {
  108. if stop == nil {
  109. return nil
  110. }
  111. switch v := stop.(type) {
  112. case string:
  113. if v != "" {
  114. return []string{v}
  115. }
  116. case []string:
  117. return v
  118. case []interface{}:
  119. var sequences []string
  120. for _, item := range v {
  121. if str, ok := item.(string); ok && str != "" {
  122. sequences = append(sequences, str)
  123. }
  124. }
  125. return sequences
  126. }
  127. return nil
  128. }