relay_gemini_file.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package controller
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strings"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/constant"
  8. "github.com/QuantumNous/new-api/logger"
  9. "github.com/QuantumNous/new-api/relay/channel/gemini"
  10. "github.com/gin-gonic/gin"
  11. )
  12. // RelayGeminiFileUpload handles file upload to Gemini File API
  13. func RelayGeminiFileUpload(c *gin.Context) {
  14. // Parse multipart form
  15. form, err := common.ParseMultipartFormReusable(c)
  16. if err != nil {
  17. logger.LogError(c, fmt.Sprintf("failed to parse multipart form: %s", err.Error()))
  18. c.JSON(http.StatusBadRequest, gin.H{
  19. "error": gin.H{
  20. "message": fmt.Sprintf("failed to parse multipart form: %s", err.Error()),
  21. "type": "invalid_request_error",
  22. "code": "invalid_multipart_form",
  23. },
  24. })
  25. return
  26. }
  27. defer form.RemoveAll()
  28. // Get API key from channel context (set by setupGeminiFileChannel)
  29. apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
  30. if apiKey == "" {
  31. logger.LogError(c, "Failed to get Gemini channel API key")
  32. c.JSON(http.StatusServiceUnavailable, gin.H{
  33. "error": gin.H{
  34. "message": "No available Gemini channel found",
  35. "type": "service_unavailable_error",
  36. "code": "no_available_channel",
  37. },
  38. })
  39. return
  40. }
  41. // Rebuild multipart form for upstream request
  42. body, contentType, err := gemini.RebuildMultipartForm(form)
  43. if err != nil {
  44. logger.LogError(c, fmt.Sprintf("failed to rebuild multipart form: %s", err.Error()))
  45. c.JSON(http.StatusInternalServerError, gin.H{
  46. "error": gin.H{
  47. "message": fmt.Sprintf("failed to rebuild multipart form: %s", err.Error()),
  48. "type": "internal_error",
  49. "code": "form_rebuild_error",
  50. },
  51. })
  52. return
  53. }
  54. // Build upstream URL
  55. url := gemini.BuildGeminiFileURL("/upload/v1beta/files")
  56. // Prepare headers
  57. headers := map[string]string{
  58. "Content-Type": contentType,
  59. "x-goog-api-key": apiKey,
  60. }
  61. // Forward request to Gemini
  62. err = gemini.ForwardGeminiFileRequest(c, http.MethodPost, url, body, headers)
  63. if err != nil {
  64. logger.LogError(c, fmt.Sprintf("failed to forward file upload request: %s", err.Error()))
  65. // Error response already sent by ForwardGeminiFileRequest
  66. return
  67. }
  68. }
  69. // RelayGeminiFileList lists files from Gemini File API
  70. func RelayGeminiFileList(c *gin.Context) {
  71. // Get API key from channel context
  72. apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
  73. if apiKey == "" {
  74. logger.LogError(c, "API key not found in context")
  75. c.JSON(http.StatusUnauthorized, gin.H{
  76. "error": gin.H{
  77. "message": "API key not found",
  78. "type": "authentication_error",
  79. "code": "invalid_api_key",
  80. },
  81. })
  82. return
  83. }
  84. // Build upstream URL with query parameters
  85. url := gemini.BuildGeminiFileURL("/v1beta/files")
  86. // Add query parameters if present
  87. queryParams := []string{}
  88. if pageSize := c.Query("pageSize"); pageSize != "" {
  89. queryParams = append(queryParams, fmt.Sprintf("pageSize=%s", pageSize))
  90. }
  91. if pageToken := c.Query("pageToken"); pageToken != "" {
  92. queryParams = append(queryParams, fmt.Sprintf("pageToken=%s", pageToken))
  93. }
  94. if len(queryParams) > 0 {
  95. url = fmt.Sprintf("%s?%s", url, strings.Join(queryParams, "&"))
  96. }
  97. // Prepare headers
  98. headers := map[string]string{
  99. "x-goog-api-key": apiKey,
  100. }
  101. // Forward request to Gemini
  102. err := gemini.ForwardGeminiFileRequest(c, http.MethodGet, url, nil, headers)
  103. if err != nil {
  104. logger.LogError(c, fmt.Sprintf("failed to forward file list request: %s", err.Error()))
  105. return
  106. }
  107. }