video_proxy.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. package controller
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/url"
  9. "strings"
  10. "time"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/logger"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/service"
  15. "github.com/gin-gonic/gin"
  16. )
  17. // videoProxyError returns a standardized OpenAI-style error response.
  18. func videoProxyError(c *gin.Context, status int, errType, message string) {
  19. c.JSON(status, gin.H{
  20. "error": gin.H{
  21. "message": message,
  22. "type": errType,
  23. },
  24. })
  25. }
  26. func VideoProxy(c *gin.Context) {
  27. taskID := c.Param("task_id")
  28. if taskID == "" {
  29. videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required")
  30. return
  31. }
  32. task, exists, err := model.GetByOnlyTaskId(taskID)
  33. if err != nil {
  34. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
  35. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
  36. return
  37. }
  38. if !exists || task == nil {
  39. videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found")
  40. return
  41. }
  42. if task.Status != model.TaskStatusSuccess {
  43. videoProxyError(c, http.StatusBadRequest, "invalid_request_error",
  44. fmt.Sprintf("Task is not completed yet, current status: %s", task.Status))
  45. return
  46. }
  47. channel, err := model.CacheGetChannel(task.ChannelId)
  48. if err != nil {
  49. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error()))
  50. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information")
  51. return
  52. }
  53. baseURL := channel.GetBaseURL()
  54. if baseURL == "" {
  55. baseURL = "https://api.openai.com"
  56. }
  57. var videoURL string
  58. proxy := channel.GetSetting().Proxy
  59. client, err := service.GetHttpClientWithProxy(proxy)
  60. if err != nil {
  61. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
  62. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client")
  63. return
  64. }
  65. ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second)
  66. defer cancel()
  67. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
  68. if err != nil {
  69. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
  70. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
  71. return
  72. }
  73. switch channel.Type {
  74. case constant.ChannelTypeGemini:
  75. apiKey := task.PrivateData.Key
  76. if apiKey == "" {
  77. logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
  78. videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task")
  79. return
  80. }
  81. videoURL, err = getGeminiVideoURL(channel, task, apiKey)
  82. if err != nil {
  83. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
  84. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL")
  85. return
  86. }
  87. req.Header.Set("x-goog-api-key", apiKey)
  88. case constant.ChannelTypeVertexAi:
  89. videoURL, err = getVertexVideoURL(channel, task)
  90. if err != nil {
  91. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Vertex video URL for task %s: %s", taskID, err.Error()))
  92. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Vertex video URL")
  93. return
  94. }
  95. case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
  96. videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
  97. req.Header.Set("Authorization", "Bearer "+channel.Key)
  98. default:
  99. // Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data)
  100. videoURL = task.GetResultURL()
  101. }
  102. videoURL = strings.TrimSpace(videoURL)
  103. if videoURL == "" {
  104. logger.LogError(c.Request.Context(), fmt.Sprintf("Video URL is empty for task %s", taskID))
  105. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
  106. return
  107. }
  108. if strings.HasPrefix(videoURL, "data:") {
  109. if err := writeVideoDataURL(c, videoURL); err != nil {
  110. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to decode video data URL for task %s: %s", taskID, err.Error()))
  111. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
  112. }
  113. return
  114. }
  115. req.URL, err = url.Parse(videoURL)
  116. if err != nil {
  117. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
  118. videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
  119. return
  120. }
  121. resp, err := client.Do(req)
  122. if err != nil {
  123. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
  124. videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
  125. return
  126. }
  127. defer resp.Body.Close()
  128. if resp.StatusCode != http.StatusOK {
  129. logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
  130. videoProxyError(c, http.StatusBadGateway, "server_error",
  131. fmt.Sprintf("Upstream service returned status %d", resp.StatusCode))
  132. return
  133. }
  134. for key, values := range resp.Header {
  135. for _, value := range values {
  136. c.Writer.Header().Add(key, value)
  137. }
  138. }
  139. c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
  140. c.Writer.WriteHeader(resp.StatusCode)
  141. if _, err = io.Copy(c.Writer, resp.Body); err != nil {
  142. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
  143. }
  144. }
  145. func writeVideoDataURL(c *gin.Context, dataURL string) error {
  146. parts := strings.SplitN(dataURL, ",", 2)
  147. if len(parts) != 2 {
  148. return fmt.Errorf("invalid data url")
  149. }
  150. header := parts[0]
  151. payload := parts[1]
  152. if !strings.HasPrefix(header, "data:") || !strings.Contains(header, ";base64") {
  153. return fmt.Errorf("unsupported data url")
  154. }
  155. mimeType := strings.TrimPrefix(header, "data:")
  156. mimeType = strings.TrimSuffix(mimeType, ";base64")
  157. if mimeType == "" {
  158. mimeType = "video/mp4"
  159. }
  160. videoBytes, err := base64.StdEncoding.DecodeString(payload)
  161. if err != nil {
  162. videoBytes, err = base64.RawStdEncoding.DecodeString(payload)
  163. if err != nil {
  164. return err
  165. }
  166. }
  167. c.Writer.Header().Set("Content-Type", mimeType)
  168. c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
  169. c.Writer.WriteHeader(http.StatusOK)
  170. _, err = c.Writer.Write(videoBytes)
  171. return err
  172. }