adaptor.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. package gemini
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "regexp"
  10. "strings"
  11. "time"
  12. "github.com/QuantumNous/new-api/common"
  13. "github.com/QuantumNous/new-api/constant"
  14. "github.com/QuantumNous/new-api/dto"
  15. "github.com/QuantumNous/new-api/model"
  16. "github.com/QuantumNous/new-api/relay/channel"
  17. relaycommon "github.com/QuantumNous/new-api/relay/common"
  18. "github.com/QuantumNous/new-api/service"
  19. "github.com/QuantumNous/new-api/setting/model_setting"
  20. "github.com/QuantumNous/new-api/setting/system_setting"
  21. "github.com/gin-gonic/gin"
  22. "github.com/pkg/errors"
  23. )
  24. // VideoGenerationConfig represents the video generation configuration
  25. // Based on: https://ai.google.dev/gemini-api/docs/video
  26. type VideoGenerationConfig struct {
  27. AspectRatio string `json:"aspectRatio,omitempty"` // "16:9" or "9:16"
  28. DurationSeconds float64 `json:"durationSeconds,omitempty"` // 4, 6, or 8 (as number)
  29. NegativePrompt string `json:"negativePrompt,omitempty"` // unwanted elements
  30. PersonGeneration string `json:"personGeneration,omitempty"` // "allow_all" for text-to-video, "allow_adult" for image-to-video
  31. Resolution string `json:"resolution,omitempty"` // video resolution
  32. }
  33. type Image struct {
  34. BytesBase64Encoded string `json:"bytesBase64Encoded,omitempty"`
  35. MimeType string `json:"mimeType,omitempty"`
  36. }
  37. type VideoRequest struct {
  38. Prompt string `json:"prompt"`
  39. Image *Image `json:"image,omitempty"`
  40. LastFrame *Image `json:"lastFrame,omitempty"`
  41. }
  42. // VideoPayload represents the complete video generation request payload
  43. type VideoPayload struct {
  44. Instances []VideoRequest `json:"instances"`
  45. Parameters VideoGenerationConfig `json:"parameters,omitempty"`
  46. }
  47. type submitResponse struct {
  48. Name string `json:"name"`
  49. }
  50. type operationVideo struct {
  51. MimeType string `json:"mimeType"`
  52. BytesBase64Encoded string `json:"bytesBase64Encoded"`
  53. Encoding string `json:"encoding"`
  54. }
  55. type operationResponse struct {
  56. Name string `json:"name"`
  57. Done bool `json:"done"`
  58. Response struct {
  59. Type string `json:"@type"`
  60. RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
  61. Videos []operationVideo `json:"videos"`
  62. BytesBase64Encoded string `json:"bytesBase64Encoded"`
  63. Encoding string `json:"encoding"`
  64. Video string `json:"video"`
  65. GenerateVideoResponse struct {
  66. GeneratedSamples []struct {
  67. Video struct {
  68. URI string `json:"uri"`
  69. } `json:"video"`
  70. } `json:"generatedSamples"`
  71. } `json:"generateVideoResponse"`
  72. } `json:"response"`
  73. Error struct {
  74. Message string `json:"message"`
  75. } `json:"error"`
  76. }
  77. // ============================
  78. // Adaptor implementation
  79. // ============================
  80. type TaskAdaptor struct {
  81. ChannelType int
  82. apiKey string
  83. baseURL string
  84. }
  85. func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
  86. a.ChannelType = info.ChannelType
  87. a.baseURL = info.ChannelBaseUrl
  88. a.apiKey = info.ApiKey
  89. }
  90. // ValidateRequestAndSetAction parses body, validates fields and sets default action.
  91. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  92. return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
  93. }
  94. // BuildRequestURL constructs the upstream URL.
  95. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
  96. modelName := info.OriginModelName
  97. version := model_setting.GetGeminiVersionSetting(modelName)
  98. return fmt.Sprintf(
  99. "%s/%s/models/%s:predictLongRunning",
  100. a.baseURL,
  101. version,
  102. modelName,
  103. ), nil
  104. }
  105. // BuildRequestHeader sets required headers.
  106. func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
  107. req.Header.Set("Content-Type", "application/json")
  108. req.Header.Set("Accept", "application/json")
  109. req.Header.Set("x-goog-api-key", a.apiKey)
  110. return nil
  111. }
  112. // BuildRequestBody converts request into Gemini specific format.
  113. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
  114. v, ok := c.Get("task_request")
  115. if !ok {
  116. return nil, fmt.Errorf("request not found in context")
  117. }
  118. req, ok := v.(relaycommon.TaskSubmitReq)
  119. if !ok {
  120. return nil, fmt.Errorf("unexpected task_request type")
  121. }
  122. // Create structured video generation request
  123. body := VideoPayload{
  124. Instances: []VideoRequest{
  125. {Prompt: req.Prompt},
  126. },
  127. Parameters: VideoGenerationConfig{},
  128. }
  129. if len(req.Images) > 0 {
  130. body.Instances[0].Image = a.convertImage(req.Images[0])
  131. }
  132. if len(req.Images) > 1 {
  133. body.Instances[0].LastFrame = a.convertImage(req.Images[1])
  134. }
  135. // Parse metadata for additional configuration
  136. metadata := req.Metadata
  137. medaBytes, err := json.Marshal(metadata)
  138. if err != nil {
  139. return nil, errors.Wrap(err, "metadata marshal metadata failed")
  140. }
  141. err = json.Unmarshal(medaBytes, &body.Parameters)
  142. if err != nil {
  143. return nil, errors.Wrap(err, "unmarshal metadata failed")
  144. }
  145. data, err := json.Marshal(body)
  146. if err != nil {
  147. return nil, err
  148. }
  149. return bytes.NewReader(data), nil
  150. }
  151. // DoRequest delegates to common helper.
  152. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  153. return channel.DoTaskApiRequest(a, c, info, requestBody)
  154. }
  155. // DoResponse handles upstream response, returns taskID etc.
  156. func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
  157. responseBody, err := io.ReadAll(resp.Body)
  158. if err != nil {
  159. return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  160. }
  161. _ = resp.Body.Close()
  162. var s submitResponse
  163. if err := json.Unmarshal(responseBody, &s); err != nil {
  164. return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
  165. }
  166. if strings.TrimSpace(s.Name) == "" {
  167. return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
  168. }
  169. taskID = encodeLocalTaskID(s.Name)
  170. ov := dto.NewOpenAIVideo()
  171. ov.ID = taskID
  172. ov.TaskID = taskID
  173. ov.CreatedAt = time.Now().Unix()
  174. ov.Model = info.OriginModelName
  175. c.JSON(http.StatusOK, ov)
  176. return taskID, responseBody, nil
  177. }
  178. func (a *TaskAdaptor) GetModelList() []string {
  179. return []string{"veo-3.0-generate-001", "veo-3.1-generate-preview", "veo-3.1-fast-generate-preview"}
  180. }
  181. func (a *TaskAdaptor) GetChannelName() string {
  182. return "gemini"
  183. }
  184. // FetchTask fetch task status
  185. func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
  186. taskID, ok := body["task_id"].(string)
  187. if !ok {
  188. return nil, fmt.Errorf("invalid task_id")
  189. }
  190. upstreamName, err := decodeLocalTaskID(taskID)
  191. if err != nil {
  192. return nil, fmt.Errorf("decode task_id failed: %w", err)
  193. }
  194. // For Gemini API, we use GET request to the operations endpoint
  195. version := model_setting.GetGeminiVersionSetting("default")
  196. url := fmt.Sprintf("%s/%s/%s", baseUrl, version, upstreamName)
  197. req, err := http.NewRequest(http.MethodGet, url, nil)
  198. if err != nil {
  199. return nil, err
  200. }
  201. req.Header.Set("Accept", "application/json")
  202. req.Header.Set("x-goog-api-key", key)
  203. return service.GetHttpClient().Do(req)
  204. }
  205. func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
  206. var op operationResponse
  207. if err := json.Unmarshal(respBody, &op); err != nil {
  208. return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
  209. }
  210. ti := &relaycommon.TaskInfo{}
  211. if op.Error.Message != "" {
  212. ti.Status = model.TaskStatusFailure
  213. ti.Reason = op.Error.Message
  214. ti.Progress = "100%"
  215. return ti, nil
  216. }
  217. if !op.Done {
  218. ti.Status = model.TaskStatusInProgress
  219. ti.Progress = "50%"
  220. return ti, nil
  221. }
  222. ti.Status = model.TaskStatusSuccess
  223. ti.Progress = "100%"
  224. taskID := encodeLocalTaskID(op.Name)
  225. ti.TaskID = taskID
  226. ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
  227. // Extract URL from generateVideoResponse if available
  228. if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
  229. if uri := op.Response.GenerateVideoResponse.GeneratedSamples[0].Video.URI; uri != "" {
  230. ti.RemoteUrl = uri
  231. }
  232. }
  233. return ti, nil
  234. }
  235. func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
  236. upstreamName, err := decodeLocalTaskID(task.TaskID)
  237. if err != nil {
  238. upstreamName = ""
  239. }
  240. modelName := extractModelFromOperationName(upstreamName)
  241. if strings.TrimSpace(modelName) == "" {
  242. modelName = "veo-3.0-generate-001"
  243. }
  244. video := dto.NewOpenAIVideo()
  245. video.ID = task.TaskID
  246. video.Model = modelName
  247. video.Status = task.Status.ToVideoStatus()
  248. video.SetProgressStr(task.Progress)
  249. video.CreatedAt = task.CreatedAt
  250. if task.FinishTime > 0 {
  251. video.CompletedAt = task.FinishTime
  252. } else if task.UpdatedAt > 0 {
  253. video.CompletedAt = task.UpdatedAt
  254. }
  255. return common.Marshal(video)
  256. }
  257. func (a *TaskAdaptor) convertImage(imageStr string) *Image {
  258. if strings.TrimSpace(imageStr) == "" {
  259. return nil
  260. }
  261. img := &Image{
  262. MimeType: "image/png",
  263. BytesBase64Encoded: imageStr,
  264. }
  265. if strings.HasPrefix(imageStr, "data:image/") {
  266. parts := strings.Split(imageStr, ";base64,")
  267. if len(parts) == 2 {
  268. img.MimeType = strings.TrimPrefix(parts[0], "data:")
  269. img.BytesBase64Encoded = parts[1]
  270. }
  271. } else if strings.HasPrefix(imageStr, "http") {
  272. mimeType, data, err := service.GetImageFromUrl(imageStr)
  273. if err == nil {
  274. img.MimeType = mimeType
  275. img.BytesBase64Encoded = data
  276. }
  277. }
  278. return img
  279. }
  280. // ============================
  281. // helpers
  282. // ============================
  283. func encodeLocalTaskID(name string) string {
  284. return base64.RawURLEncoding.EncodeToString([]byte(name))
  285. }
  286. func decodeLocalTaskID(local string) (string, error) {
  287. b, err := base64.RawURLEncoding.DecodeString(local)
  288. if err != nil {
  289. return "", err
  290. }
  291. return string(b), nil
  292. }
  293. var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
  294. func extractModelFromOperationName(name string) string {
  295. if name == "" {
  296. return ""
  297. }
  298. if m := modelRe.FindStringSubmatch(name); len(m) == 2 {
  299. return m[1]
  300. }
  301. if idx := strings.Index(name, "models/"); idx >= 0 {
  302. s := name[idx+len("models/"):]
  303. if p := strings.Index(s, "/operations/"); p > 0 {
  304. return s[:p]
  305. }
  306. }
  307. return ""
  308. }