video_proxy.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package controller
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "net/url"
  7. "time"
  8. "github.com/QuantumNous/new-api/constant"
  9. "github.com/QuantumNous/new-api/logger"
  10. "github.com/QuantumNous/new-api/model"
  11. "github.com/gin-gonic/gin"
  12. )
  13. func VideoProxy(c *gin.Context) {
  14. taskID := c.Param("task_id")
  15. if taskID == "" {
  16. c.JSON(http.StatusBadRequest, gin.H{
  17. "error": gin.H{
  18. "message": "task_id is required",
  19. "type": "invalid_request_error",
  20. },
  21. })
  22. return
  23. }
  24. task, exists, err := model.GetByOnlyTaskId(taskID)
  25. if err != nil {
  26. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
  27. c.JSON(http.StatusInternalServerError, gin.H{
  28. "error": gin.H{
  29. "message": "Failed to query task",
  30. "type": "server_error",
  31. },
  32. })
  33. return
  34. }
  35. if !exists || task == nil {
  36. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
  37. c.JSON(http.StatusNotFound, gin.H{
  38. "error": gin.H{
  39. "message": "Task not found",
  40. "type": "invalid_request_error",
  41. },
  42. })
  43. return
  44. }
  45. if task.Status != model.TaskStatusSuccess {
  46. c.JSON(http.StatusBadRequest, gin.H{
  47. "error": gin.H{
  48. "message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
  49. "type": "invalid_request_error",
  50. },
  51. })
  52. return
  53. }
  54. channel, err := model.CacheGetChannel(task.ChannelId)
  55. if err != nil {
  56. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
  57. c.JSON(http.StatusInternalServerError, gin.H{
  58. "error": gin.H{
  59. "message": "Failed to retrieve channel information",
  60. "type": "server_error",
  61. },
  62. })
  63. return
  64. }
  65. baseURL := channel.GetBaseURL()
  66. if baseURL == "" {
  67. baseURL = "https://api.openai.com"
  68. }
  69. var videoURL string
  70. client := &http.Client{
  71. Timeout: 60 * time.Second,
  72. }
  73. req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, "", nil)
  74. if err != nil {
  75. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
  76. c.JSON(http.StatusInternalServerError, gin.H{
  77. "error": gin.H{
  78. "message": "Failed to create proxy request",
  79. "type": "server_error",
  80. },
  81. })
  82. return
  83. }
  84. if channel.Type == constant.ChannelTypeGemini {
  85. videoURL = fmt.Sprintf("%s&key=%s", c.Query("url"), channel.Key)
  86. req.Header.Set("x-goog-api-key", channel.Key)
  87. } else {
  88. // Default (Sora, etc.): Use original logic
  89. videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
  90. req.Header.Set("Authorization", "Bearer "+channel.Key)
  91. }
  92. req.URL, err = url.Parse(videoURL)
  93. if err != nil {
  94. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
  95. c.JSON(http.StatusInternalServerError, gin.H{
  96. "error": gin.H{
  97. "message": "Failed to create proxy request",
  98. "type": "server_error",
  99. },
  100. })
  101. return
  102. }
  103. resp, err := client.Do(req)
  104. if err != nil {
  105. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
  106. c.JSON(http.StatusBadGateway, gin.H{
  107. "error": gin.H{
  108. "message": "Failed to fetch video content",
  109. "type": "server_error",
  110. },
  111. })
  112. return
  113. }
  114. defer resp.Body.Close()
  115. if resp.StatusCode != http.StatusOK {
  116. logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
  117. c.JSON(http.StatusBadGateway, gin.H{
  118. "error": gin.H{
  119. "message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
  120. "type": "server_error",
  121. },
  122. })
  123. return
  124. }
  125. for key, values := range resp.Header {
  126. for _, value := range values {
  127. c.Writer.Header().Add(key, value)
  128. }
  129. }
  130. c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
  131. c.Writer.WriteHeader(resp.StatusCode)
  132. _, err = io.Copy(c.Writer, resp.Body)
  133. if err != nil {
  134. logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
  135. }
  136. }