dto.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package aws
  2. import (
  3. "io"
  4. "github.com/QuantumNous/new-api/common"
  5. "github.com/QuantumNous/new-api/dto"
  6. )
  7. type AwsClaudeRequest struct {
  8. // AnthropicVersion should be "bedrock-2023-05-31"
  9. AnthropicVersion string `json:"anthropic_version"`
  10. System any `json:"system,omitempty"`
  11. Messages []dto.ClaudeMessage `json:"messages"`
  12. MaxTokens uint `json:"max_tokens,omitempty"`
  13. Temperature *float64 `json:"temperature,omitempty"`
  14. TopP float64 `json:"top_p,omitempty"`
  15. TopK int `json:"top_k,omitempty"`
  16. StopSequences []string `json:"stop_sequences,omitempty"`
  17. Tools any `json:"tools,omitempty"`
  18. ToolChoice any `json:"tool_choice,omitempty"`
  19. Thinking *dto.Thinking `json:"thinking,omitempty"`
  20. }
  21. func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
  22. return &AwsClaudeRequest{
  23. AnthropicVersion: "bedrock-2023-05-31",
  24. System: req.System,
  25. Messages: req.Messages,
  26. MaxTokens: req.MaxTokens,
  27. Temperature: req.Temperature,
  28. TopP: req.TopP,
  29. TopK: req.TopK,
  30. StopSequences: req.StopSequences,
  31. Tools: req.Tools,
  32. ToolChoice: req.ToolChoice,
  33. Thinking: req.Thinking,
  34. }
  35. }
  36. func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
  37. var awsClaudeRequest AwsClaudeRequest
  38. err := common.DecodeJson(requestBody, &awsClaudeRequest)
  39. if err != nil {
  40. return nil, err
  41. }
  42. awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
  43. return &awsClaudeRequest, nil
  44. }
  45. // NovaMessage Nova模型使用messages-v1格式
  46. type NovaMessage struct {
  47. Role string `json:"role"`
  48. Content []NovaContent `json:"content"`
  49. }
  50. type NovaContent struct {
  51. Text string `json:"text"`
  52. }
  53. type NovaRequest struct {
  54. SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
  55. Messages []NovaMessage `json:"messages"` // 对话消息列表
  56. InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
  57. }
  58. type NovaInferenceConfig struct {
  59. MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
  60. Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
  61. TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
  62. TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
  63. StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
  64. }
  65. // 转换OpenAI请求为Nova格式
  66. func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
  67. novaMessages := make([]NovaMessage, len(req.Messages))
  68. for i, msg := range req.Messages {
  69. novaMessages[i] = NovaMessage{
  70. Role: msg.Role,
  71. Content: []NovaContent{{Text: msg.StringContent()}},
  72. }
  73. }
  74. novaReq := &NovaRequest{
  75. SchemaVersion: "messages-v1",
  76. Messages: novaMessages,
  77. }
  78. // 设置推理配置
  79. if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
  80. novaReq.InferenceConfig = &NovaInferenceConfig{}
  81. if req.MaxTokens != 0 {
  82. novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
  83. }
  84. if req.Temperature != nil && *req.Temperature != 0 {
  85. novaReq.InferenceConfig.Temperature = *req.Temperature
  86. }
  87. if req.TopP != 0 {
  88. novaReq.InferenceConfig.TopP = req.TopP
  89. }
  90. if req.TopK != 0 {
  91. novaReq.InferenceConfig.TopK = req.TopK
  92. }
  93. if req.Stop != nil {
  94. if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
  95. novaReq.InferenceConfig.StopSequences = stopSequences
  96. }
  97. }
  98. }
  99. return novaReq
  100. }
  101. // parseStopSequences 解析停止序列,支持字符串或字符串数组
  102. func parseStopSequences(stop any) []string {
  103. if stop == nil {
  104. return nil
  105. }
  106. switch v := stop.(type) {
  107. case string:
  108. if v != "" {
  109. return []string{v}
  110. }
  111. case []string:
  112. return v
  113. case []interface{}:
  114. var sequences []string
  115. for _, item := range v {
  116. if str, ok := item.(string); ok && str != "" {
  117. sequences = append(sequences, str)
  118. }
  119. }
  120. return sequences
  121. }
  122. return nil
  123. }