video_proxy.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package controller
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/url"
  7. "time"
  8. "github.com/QuantumNous/new-api/constant"
  9. "github.com/QuantumNous/new-api/logger"
  10. "github.com/QuantumNous/new-api/model"
  11. "github.com/gin-gonic/gin"
  12. )
  13. func VideoProxy(c *gin.Context) {
  14. taskID := c.Param("task_id")
  15. if taskID == "" {
  16. c.JSON(http.StatusBadRequest, gin.H{
  17. "error": gin.H{
  18. "message": "task_id is required",
  19. "type": "invalid_request_error",
  20. },
  21. })
  22. return
  23. }
  24. task, exists, err := model.GetByOnlyTaskId(taskID)
  25. if err != nil {
  26. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
  27. c.JSON(http.StatusInternalServerError, gin.H{
  28. "error": gin.H{
  29. "message": "Failed to query task",
  30. "type": "server_error",
  31. },
  32. })
  33. return
  34. }
  35. if !exists || task == nil {
  36. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
  37. c.JSON(http.StatusNotFound, gin.H{
  38. "error": gin.H{
  39. "message": "Task not found",
  40. "type": "invalid_request_error",
  41. },
  42. })
  43. return
  44. }
  45. if task.Status != model.TaskStatusSuccess {
  46. c.JSON(http.StatusBadRequest, gin.H{
  47. "error": gin.H{
  48. "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
  49. "type": "invalid_request_error",
  50. },
  51. })
  52. return
  53. }
  54. channel, err := model.CacheGetChannel(task.ChannelId)
  55. if err != nil {
  56. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
  57. c.JSON(http.StatusInternalServerError, gin.H{
  58. "error": gin.H{
  59. "message": "Failed to retrieve channel information",
  60. "type": "server_error",
  61. },
  62. })
  63. return
  64. }
  65. baseURL := channel.GetBaseURL()
  66. if baseURL == "" {
  67. baseURL = "https://api.openai.com"
  68. }
  69. var videoURL string
  70. client := &http.Client{
  71. Timeout: 60 * time.Second,
  72. }
  73. req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil)
  74. if err != nil {
  75. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
  76. c.JSON(http.StatusInternalServerError, gin.H{
  77. "error": gin.H{
  78. "message": "Failed to create proxy request",
  79. "type": "server_error",
  80. },
  81. })
  82. return
  83. }
  84. if channel.Type == constant.ChannelTypeGemini {
  85. apiKey := task.PrivateData.Key
  86. if apiKey == "" {
  87. logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
  88. c.JSON(http.StatusInternalServerError, gin.H{
  89. "error": gin.H{
  90. "message": "API key not stored for task",
  91. "type": "server_error",
  92. },
  93. })
  94. return
  95. }
  96. videoURL, err = getGeminiVideoURL(channel, task, apiKey)
  97. if err != nil {
  98. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
  99. c.JSON(http.StatusBadGateway, gin.H{
  100. "error": gin.H{
  101. "message": "Failed to resolve Gemini video URL",
  102. "type": "server_error",
  103. },
  104. })
  105. return
  106. }
  107. req.Header.Set("x-goog-api-key", apiKey)
  108. } else {
  109. // Default (Sora, etc.): Use original logic
  110. videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
  111. req.Header.Set("Authorization", "Bearer "+channel.Key)
  112. }
  113. req.URL, err = url.Parse(videoURL)
  114. if err != nil {
  115. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
  116. c.JSON(http.StatusInternalServerError, gin.H{
  117. "error": gin.H{
  118. "message": "Failed to create proxy request",
  119. "type": "server_error",
  120. },
  121. })
  122. return
  123. }
  124. resp, err := client.Do(req)
  125. if err != nil {
  126. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
  127. c.JSON(http.StatusBadGateway, gin.H{
  128. "error": gin.H{
  129. "message": "Failed to fetch video content",
  130. "type": "server_error",
  131. },
  132. })
  133. return
  134. }
  135. defer resp.Body.Close()
  136. if resp.StatusCode != http.StatusOK {
  137. logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
  138. c.JSON(http.StatusBadGateway, gin.H{
  139. "error": gin.H{
  140. "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
  141. "type": "server_error",
  142. },
  143. })
  144. return
  145. }
  146. for key, values := range resp.Header {
  147. for _, value := range values {
  148. c.Writer.Header().Add(key, value)
  149. }
  150. }
  151. c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
  152. c.Writer.WriteHeader(resp.StatusCode)
  153. _, err = io.Copy(c.Writer, resp.Body)
  154. if err != nil {
  155. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
  156. }
  157. }