video_proxy.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package controller
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "time"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/logger"
  11. "github.com/QuantumNous/new-api/model"
  12. "github.com/QuantumNous/new-api/service"
  13. "github.com/samber/lo"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func VideoProxy(c *gin.Context) {
  17. taskID := c.Param("task_id")
  18. if taskID == "" {
  19. c.JSON(http.StatusBadRequest, gin.H{
  20. "error": gin.H{
  21. "message": "task_id is required",
  22. "type": "invalid_request_error",
  23. },
  24. })
  25. return
  26. }
  27. task, exists, err := model.GetByOnlyTaskId(taskID)
  28. if err != nil {
  29. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
  30. c.JSON(http.StatusInternalServerError, gin.H{
  31. "error": gin.H{
  32. "message": "Failed to query task",
  33. "type": "server_error",
  34. },
  35. })
  36. return
  37. }
  38. if !exists || task == nil {
  39. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
  40. c.JSON(http.StatusNotFound, gin.H{
  41. "error": gin.H{
  42. "message": "Task not found",
  43. "type": "invalid_request_error",
  44. },
  45. })
  46. return
  47. }
  48. if task.Status != model.TaskStatusSuccess {
  49. c.JSON(http.StatusBadRequest, gin.H{
  50. "error": gin.H{
  51. "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
  52. "type": "invalid_request_error",
  53. },
  54. })
  55. return
  56. }
  57. channel, err := model.CacheGetChannel(task.ChannelId)
  58. if err != nil {
  59. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
  60. c.JSON(http.StatusInternalServerError, gin.H{
  61. "error": gin.H{
  62. "message": "Failed to retrieve channel information",
  63. "type": "server_error",
  64. },
  65. })
  66. return
  67. }
  68. baseURL := channel.GetBaseURL()
  69. if baseURL == "" {
  70. baseURL = "https://api.openai.com"
  71. }
  72. var videoURL string
  73. proxy := channel.GetSetting().Proxy
  74. client, err := service.GetHttpClientWithProxy(proxy)
  75. if err != nil {
  76. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
  77. c.JSON(http.StatusInternalServerError, gin.H{
  78. "error": gin.H{
  79. "message": "Failed to create proxy client",
  80. "type": "server_error",
  81. },
  82. })
  83. return
  84. }
  85. ctx, cancel := context.WithTimeout(c.Request.Context(), 60*time.Second)
  86. defer cancel()
  87. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
  88. if err != nil {
  89. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
  90. c.JSON(http.StatusInternalServerError, gin.H{
  91. "error": gin.H{
  92. "message": "Failed to create proxy request",
  93. "type": "server_error",
  94. },
  95. })
  96. return
  97. }
  98. switch channel.Type {
  99. case constant.ChannelTypeGemini:
  100. apiKey := task.PrivateData.Key
  101. if apiKey == "" {
  102. logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
  103. c.JSON(http.StatusInternalServerError, gin.H{
  104. "error": gin.H{
  105. "message": "API key not stored for task",
  106. "type": "server_error",
  107. },
  108. })
  109. return
  110. }
  111. videoURL, err = getGeminiVideoURL(channel, task, apiKey)
  112. if err != nil {
  113. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
  114. c.JSON(http.StatusBadGateway, gin.H{
  115. "error": gin.H{
  116. "message": "Failed to resolve Gemini video URL",
  117. "type": "server_error",
  118. },
  119. })
  120. return
  121. }
  122. req.Header.Set("x-goog-api-key", apiKey)
  123. case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
  124. videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
  125. req.Header.Set("Authorization", "Bearer "+channel.Key)
  126. default:
  127. videoURL = lo.Ternary(task.Url != "", task.Url, task.FailReason)
  128. }
  129. req.URL, err = url.Parse(videoURL)
  130. if err != nil {
  131. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
  132. c.JSON(http.StatusInternalServerError, gin.H{
  133. "error": gin.H{
  134. "message": "Failed to create proxy request",
  135. "type": "server_error",
  136. },
  137. })
  138. return
  139. }
  140. resp, err := client.Do(req)
  141. if err != nil {
  142. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
  143. c.JSON(http.StatusBadGateway, gin.H{
  144. "error": gin.H{
  145. "message": "Failed to fetch video content",
  146. "type": "server_error",
  147. },
  148. })
  149. return
  150. }
  151. defer resp.Body.Close()
  152. if resp.StatusCode != http.StatusOK {
  153. logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
  154. c.JSON(http.StatusBadGateway, gin.H{
  155. "error": gin.H{
  156. "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
  157. "type": "server_error",
  158. },
  159. })
  160. return
  161. }
  162. for key, values := range resp.Header {
  163. for _, value := range values {
  164. c.Writer.Header().Add(key, value)
  165. }
  166. }
  167. c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
  168. c.Writer.WriteHeader(resp.StatusCode)
  169. _, err = io.Copy(c.Writer, resp.Body)
  170. if err != nil {
  171. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
  172. }
  173. }