package gemini import ( "bytes" "fmt" "io" "net/http" "regexp" "strings" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/model_setting" "github.com/gin-gonic/gin" "github.com/pkg/errors" ) // ============================ // Request / Response structures // ============================ // GeminiVideoGenerationConfig represents the video generation configuration // Based on: https://ai.google.dev/gemini-api/docs/video type GeminiVideoGenerationConfig struct { AspectRatio string `json:"aspectRatio,omitempty"` // "16:9" or "9:16" DurationSeconds float64 `json:"durationSeconds,omitempty"` // 4, 6, or 8 (as number) NegativePrompt string `json:"negativePrompt,omitempty"` // unwanted elements PersonGeneration string `json:"personGeneration,omitempty"` // "allow_all" for text-to-video, "allow_adult" for image-to-video Resolution string `json:"resolution,omitempty"` // video resolution } // GeminiVideoRequest represents a single video generation instance type GeminiVideoRequest struct { Prompt string `json:"prompt"` } // GeminiVideoPayload represents the complete video generation request payload type GeminiVideoPayload struct { Instances []GeminiVideoRequest `json:"instances"` Parameters GeminiVideoGenerationConfig `json:"parameters,omitempty"` } type submitResponse struct { Name string `json:"name"` } type operationVideo struct { MimeType string `json:"mimeType"` BytesBase64Encoded string `json:"bytesBase64Encoded"` Encoding string `json:"encoding"` } type operationResponse struct { Name string `json:"name"` Done bool `json:"done"` Response struct { Type string `json:"@type"` RaiMediaFilteredCount int `json:"raiMediaFilteredCount"` Videos []operationVideo `json:"videos"` BytesBase64Encoded string `json:"bytesBase64Encoded"` Encoding string `json:"encoding"` Video string `json:"video"` GenerateVideoResponse struct { GeneratedSamples []struct { Video struct { URI string `json:"uri"` } `json:"video"` } `json:"generatedSamples"` } `json:"generateVideoResponse"` } `json:"response"` Error struct { Message string `json:"message"` } `json:"error"` } // ============================ // Adaptor implementation // ============================ type TaskAdaptor struct { taskcommon.BaseBilling ChannelType int apiKey string baseURL string } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl a.apiKey = info.ApiKey } // ValidateRequestAndSetAction parses body, validates fields and sets default action. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Use the standard validation method for TaskSubmitReq return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate) } // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { modelName := info.OriginModelName version := model_setting.GetGeminiVersionSetting(modelName) return fmt.Sprintf( "%s/%s/models/%s:predictLongRunning", a.baseURL, version, modelName, ), nil } // BuildRequestHeader sets required headers. func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("x-goog-api-key", a.apiKey) return nil } // BuildRequestBody converts request into Gemini specific format. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") if !ok { return nil, fmt.Errorf("request not found in context") } req, ok := v.(relaycommon.TaskSubmitReq) if !ok { return nil, fmt.Errorf("unexpected task_request type") } // Create structured video generation request body := GeminiVideoPayload{ Instances: []GeminiVideoRequest{ {Prompt: req.Prompt}, }, Parameters: GeminiVideoGenerationConfig{}, } metadata := req.Metadata if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil { return nil, errors.Wrap(err, "unmarshal metadata failed") } data, err := common.Marshal(body) if err != nil { return nil, err } return bytes.NewReader(data), nil } // DoRequest delegates to common helper. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } // DoResponse handles upstream response, returns taskID etc. func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } _ = resp.Body.Close() var s submitResponse if err := common.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } taskID = taskcommon.EncodeLocalTaskID(s.Name) ov := dto.NewOpenAIVideo() ov.ID = info.PublicTaskID ov.TaskID = info.PublicTaskID ov.CreatedAt = time.Now().Unix() ov.Model = info.OriginModelName c.JSON(http.StatusOK, ov) return taskID, responseBody, nil } func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001", "veo-3.1-generate-preview", "veo-3.1-fast-generate-preview"} } func (a *TaskAdaptor) GetChannelName() string { return "gemini" } // FetchTask fetch task status func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") } upstreamName, err := taskcommon.DecodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } // For Gemini API, we use GET request to the operations endpoint version := model_setting.GetGeminiVersionSetting("default") url := fmt.Sprintf("%s/%s/%s", baseUrl, version, upstreamName) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err } req.Header.Set("Accept", "application/json") req.Header.Set("x-goog-api-key", key) client, err := service.GetHttpClientWithProxy(proxy) if err != nil { return nil, fmt.Errorf("new proxy http client failed: %w", err) } return client.Do(req) } func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse if err := common.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } ti := &relaycommon.TaskInfo{} if op.Error.Message != "" { ti.Status = model.TaskStatusFailure ti.Reason = op.Error.Message ti.Progress = "100%" return ti, nil } if !op.Done { ti.Status = model.TaskStatusInProgress ti.Progress = "50%" return ti, nil } ti.Status = model.TaskStatusSuccess ti.Progress = "100%" ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name) // Url intentionally left empty — the caller constructs the proxy URL using the public task ID // Extract URL from generateVideoResponse if available if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 { if uri := op.Response.GenerateVideoResponse.GeneratedSamples[0].Video.URI; uri != "" { ti.RemoteUrl = uri } } return ti, nil } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { // Use GetUpstreamTaskID() to get the real upstream operation name for model extraction. // task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name. upstreamTaskID := task.GetUpstreamTaskID() upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID) if err != nil { upstreamName = "" } modelName := extractModelFromOperationName(upstreamName) if strings.TrimSpace(modelName) == "" { modelName = "veo-3.0-generate-001" } video := dto.NewOpenAIVideo() video.ID = task.TaskID video.Model = modelName video.Status = task.Status.ToVideoStatus() video.SetProgressStr(task.Progress) video.CreatedAt = task.CreatedAt if task.FinishTime > 0 { video.CompletedAt = task.FinishTime } else if task.UpdatedAt > 0 { video.CompletedAt = task.UpdatedAt } return common.Marshal(video) } // ============================ // helpers // ============================ var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) func extractModelFromOperationName(name string) string { if name == "" { return "" } if m := modelRe.FindStringSubmatch(name); len(m) == 2 { return m[1] } if idx := strings.Index(name, "models/"); idx >= 0 { s := name[idx+len("models/"):] if p := strings.Index(s, "/operations/"); p > 0 { return s[:p] } } return "" }