adaptor.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package jimeng
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "github.com/QuantumNous/new-api/dto"
  9. "github.com/QuantumNous/new-api/relay/channel"
  10. "github.com/QuantumNous/new-api/relay/channel/openai"
  11. relaycommon "github.com/QuantumNous/new-api/relay/common"
  12. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  13. "github.com/QuantumNous/new-api/types"
  14. "github.com/gin-gonic/gin"
  15. )
  16. type Adaptor struct {
  17. }
  18. func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
  19. //TODO implement me
  20. return nil, errors.New("not implemented")
  21. }
  22. func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
  23. return nil, errors.New("not implemented")
  24. }
  25. func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
  26. }
  27. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  28. return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil
  29. }
  30. func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
  31. return errors.New("not implemented")
  32. }
  33. func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
  34. if request == nil {
  35. return nil, errors.New("request is nil")
  36. }
  37. return request, nil
  38. }
  39. type LogoInfo struct {
  40. AddLogo bool `json:"add_logo,omitempty"`
  41. Position int `json:"position,omitempty"`
  42. Language int `json:"language,omitempty"`
  43. Opacity float64 `json:"opacity,omitempty"`
  44. LogoTextContent string `json:"logo_text_content,omitempty"`
  45. }
  46. type imageRequestPayload struct {
  47. ReqKey string `json:"req_key"` // Service identifier, fixed value: jimeng_high_aes_general_v21_L
  48. Prompt string `json:"prompt"` // Prompt for image generation, supports both Chinese and English
  49. Seed int64 `json:"seed,omitempty"` // Random seed, default -1 (random)
  50. Width int `json:"width,omitempty"` // Image width, default 512, range [256, 768]
  51. Height int `json:"height,omitempty"` // Image height, default 512, range [256, 768]
  52. UsePreLLM bool `json:"use_pre_llm,omitempty"` // Enable text expansion, default true
  53. UseSR bool `json:"use_sr,omitempty"` // Enable super resolution, default true
  54. ReturnURL bool `json:"return_url,omitempty"` // Whether to return image URL (valid for 24 hours)
  55. LogoInfo LogoInfo `json:"logo_info,omitempty"` // Watermark information
  56. ImageUrls []string `json:"image_urls,omitempty"` // Image URLs for input
  57. BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data
  58. }
  59. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
  60. payload := imageRequestPayload{
  61. ReqKey: request.Model,
  62. Prompt: request.Prompt,
  63. }
  64. if request.ResponseFormat == "" || request.ResponseFormat == "url" {
  65. payload.ReturnURL = true // Default to returning image URLs
  66. }
  67. if len(request.ExtraFields) > 0 {
  68. if err := json.Unmarshal(request.ExtraFields, &payload); err != nil {
  69. return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err)
  70. }
  71. }
  72. return payload, nil
  73. }
  74. func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
  75. return nil, errors.New("not implemented")
  76. }
  77. func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
  78. return nil, errors.New("not implemented")
  79. }
  80. func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
  81. return nil, errors.New("not implemented")
  82. }
  83. func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
  84. return nil, errors.New("not implemented")
  85. }
  86. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
  87. fullRequestURL, err := a.GetRequestURL(info)
  88. if err != nil {
  89. return nil, fmt.Errorf("get request url failed: %w", err)
  90. }
  91. req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
  92. if err != nil {
  93. return nil, fmt.Errorf("new request failed: %w", err)
  94. }
  95. err = Sign(c, req, info.ApiKey)
  96. if err != nil {
  97. return nil, fmt.Errorf("setup request header failed: %w", err)
  98. }
  99. resp, err := channel.DoRequest(c, req, info)
  100. if err != nil {
  101. return nil, fmt.Errorf("do request failed: %w", err)
  102. }
  103. return resp, nil
  104. }
  105. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
  106. if info.RelayMode == relayconstant.RelayModeImagesGenerations {
  107. usage, err = jimengImageHandler(c, resp, info)
  108. } else if info.IsStream {
  109. usage, err = openai.OaiStreamHandler(c, info, resp)
  110. } else {
  111. usage, err = openai.OpenaiHandler(c, info, resp)
  112. }
  113. return
  114. }
  115. func (a *Adaptor) GetModelList() []string {
  116. return ModelList
  117. }
  118. func (a *Adaptor) GetChannelName() string {
  119. return ChannelName
  120. }