adaptor.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. package gemini
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "one-api/dto"
  8. "one-api/relay/channel"
  9. "one-api/relay/channel/openai"
  10. relaycommon "one-api/relay/common"
  11. "one-api/relay/constant"
  12. "one-api/setting/model_setting"
  13. "one-api/types"
  14. "strings"
  15. "github.com/gin-gonic/gin"
  16. )
  17. type Adaptor struct {
  18. }
  19. func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
  20. adaptor := openai.Adaptor{}
  21. oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
  22. if err != nil {
  23. return nil, err
  24. }
  25. return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest))
  26. }
  27. func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
  28. //TODO implement me
  29. return nil, errors.New("not implemented")
  30. }
  31. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
  32. if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
  33. return nil, errors.New("not supported model for image generation")
  34. }
  35. // convert size to aspect ratio
  36. aspectRatio := "1:1" // default aspect ratio
  37. switch request.Size {
  38. case "1024x1024":
  39. aspectRatio = "1:1"
  40. case "1024x1792":
  41. aspectRatio = "9:16"
  42. case "1792x1024":
  43. aspectRatio = "16:9"
  44. }
  45. // build gemini imagen request
  46. geminiRequest := GeminiImageRequest{
  47. Instances: []GeminiImageInstance{
  48. {
  49. Prompt: request.Prompt,
  50. },
  51. },
  52. Parameters: GeminiImageParameters{
  53. SampleCount: request.N,
  54. AspectRatio: aspectRatio,
  55. PersonGeneration: "allow_adult", // default allow adult
  56. },
  57. }
  58. return geminiRequest, nil
  59. }
  60. func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
  61. }
  62. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  63. if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
  64. // 新增逻辑:处理 -thinking-<budget> 格式
  65. if strings.Contains(info.UpstreamModelName, "-thinking-") {
  66. parts := strings.Split(info.UpstreamModelName, "-thinking-")
  67. info.UpstreamModelName = parts[0]
  68. } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
  69. info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
  70. } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
  71. info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
  72. }
  73. }
  74. version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
  75. if strings.HasPrefix(info.UpstreamModelName, "imagen") {
  76. return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
  77. }
  78. if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
  79. strings.HasPrefix(info.UpstreamModelName, "embedding") ||
  80. strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
  81. return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
  82. }
  83. action := "generateContent"
  84. if info.IsStream {
  85. action = "streamGenerateContent?alt=sse"
  86. }
  87. return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
  88. }
  89. func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
  90. channel.SetupApiRequestHeader(info, c, req)
  91. req.Set("x-goog-api-key", info.ApiKey)
  92. return nil
  93. }
  94. func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
  95. if request == nil {
  96. return nil, errors.New("request is nil")
  97. }
  98. geminiRequest, err := CovertGemini2OpenAI(*request, info)
  99. if err != nil {
  100. return nil, err
  101. }
  102. return geminiRequest, nil
  103. }
  104. func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
  105. return nil, nil
  106. }
  107. func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
  108. if request.Input == nil {
  109. return nil, errors.New("input is required")
  110. }
  111. inputs := request.ParseInput()
  112. if len(inputs) == 0 {
  113. return nil, errors.New("input is empty")
  114. }
  115. // only process the first input
  116. geminiRequest := GeminiEmbeddingRequest{
  117. Content: GeminiChatContent{
  118. Parts: []GeminiPart{
  119. {
  120. Text: inputs[0],
  121. },
  122. },
  123. },
  124. }
  125. // set specific parameters for different models
  126. // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
  127. switch info.UpstreamModelName {
  128. case "text-embedding-004":
  129. // except embedding-001 supports setting `OutputDimensionality`
  130. if request.Dimensions > 0 {
  131. geminiRequest.OutputDimensionality = request.Dimensions
  132. }
  133. }
  134. return geminiRequest, nil
  135. }
  136. func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
  137. // TODO implement me
  138. return nil, errors.New("not implemented")
  139. }
  140. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
  141. return channel.DoApiRequest(a, c, info, requestBody)
  142. }
  143. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
  144. if info.RelayMode == constant.RelayModeGemini {
  145. if info.IsStream {
  146. info.DisablePing = true
  147. return GeminiTextGenerationStreamHandler(c, info, resp)
  148. } else {
  149. return GeminiTextGenerationHandler(c, info, resp)
  150. }
  151. }
  152. if strings.HasPrefix(info.UpstreamModelName, "imagen") {
  153. return GeminiImageHandler(c, info, resp)
  154. }
  155. // check if the model is an embedding model
  156. if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
  157. strings.HasPrefix(info.UpstreamModelName, "embedding") ||
  158. strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
  159. return GeminiEmbeddingHandler(c, info, resp)
  160. }
  161. if info.IsStream {
  162. return GeminiChatStreamHandler(c, info, resp)
  163. } else {
  164. return GeminiChatHandler(c, info, resp)
  165. }
  166. //if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
  167. // // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费
  168. // if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
  169. // !strings.HasSuffix(info.OriginModelName, "-nothinking") {
  170. // thinkingModelName := info.OriginModelName + "-thinking"
  171. // if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
  172. // info.OriginModelName = thinkingModelName
  173. // }
  174. // }
  175. //}
  176. return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
  177. }
  178. func (a *Adaptor) GetModelList() []string {
  179. return ModelList
  180. }
  181. func (a *Adaptor) GetChannelName() string {
  182. return ChannelName
  183. }