瀏覽代碼

update gemini file api

supeng 1 月之前
父節點
當前提交
248d82bb8d
共有 5 個文件被更改,包括 572 次插入0 次删除
  1. 212 0
      controller/relay_gemini_file.go
  2. 35 0
      dto/gemini.go
  3. 157 0
      relay/channel/gemini/file_helper.go
  4. 14 0
      router/relay-router.go
  5. 154 0
      test_gemini_file_api.sh

+ 212 - 0
controller/relay_gemini_file.go

@@ -0,0 +1,212 @@
+package controller
+
+import (
+	"fmt"
+	"net/http"
+	"strings"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/constant"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/relay/channel/gemini"
+
+	"github.com/gin-gonic/gin"
+)
+
+// RelayGeminiFileUpload handles file upload to Gemini File API
+func RelayGeminiFileUpload(c *gin.Context) {
+	// Parse multipart form
+	form, err := common.ParseMultipartFormReusable(c)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to parse multipart form: %s", err.Error()))
+		c.JSON(http.StatusBadRequest, gin.H{
+			"error": gin.H{
+				"message": fmt.Sprintf("failed to parse multipart form: %s", err.Error()),
+				"type":    "invalid_request_error",
+				"code":    "invalid_multipart_form",
+			},
+		})
+		return
+	}
+	defer form.RemoveAll()
+
+	// Get API key from channel context (set by middleware)
+	apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
+	if apiKey == "" {
+		logger.LogError(c, "API key not found in context")
+		c.JSON(http.StatusUnauthorized, gin.H{
+			"error": gin.H{
+				"message": "API key not found",
+				"type":    "authentication_error",
+				"code":    "invalid_api_key",
+			},
+		})
+		return
+	}
+
+	// Rebuild multipart form for upstream request
+	body, contentType, err := gemini.RebuildMultipartForm(form)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to rebuild multipart form: %s", err.Error()))
+		c.JSON(http.StatusInternalServerError, gin.H{
+			"error": gin.H{
+				"message": fmt.Sprintf("failed to rebuild multipart form: %s", err.Error()),
+				"type":    "internal_error",
+				"code":    "form_rebuild_error",
+			},
+		})
+		return
+	}
+
+	// Build upstream URL
+	url := gemini.BuildGeminiFileURL("/upload/v1beta/files")
+
+	// Prepare headers
+	headers := map[string]string{
+		"Content-Type":   contentType,
+		"x-goog-api-key": apiKey,
+	}
+
+	// Forward request to Gemini
+	err = gemini.ForwardGeminiFileRequest(c, http.MethodPost, url, body, headers)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to forward file upload request: %s", err.Error()))
+		// Error response already sent by ForwardGeminiFileRequest
+		return
+	}
+}
+
+// RelayGeminiFileGet retrieves file metadata from Gemini File API
+func RelayGeminiFileGet(c *gin.Context) {
+	// Get file name from URL parameter
+	fileName := c.Param("name")
+	if fileName == "" {
+		c.JSON(http.StatusBadRequest, gin.H{
+			"error": gin.H{
+				"message": "file name is required",
+				"type":    "invalid_request_error",
+				"code":    "missing_file_name",
+			},
+		})
+		return
+	}
+
+	// Get API key from channel context
+	apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
+	if apiKey == "" {
+		logger.LogError(c, "API key not found in context")
+		c.JSON(http.StatusUnauthorized, gin.H{
+			"error": gin.H{
+				"message": "API key not found",
+				"type":    "authentication_error",
+				"code":    "invalid_api_key",
+			},
+		})
+		return
+	}
+
+	// Build upstream URL - fileName already includes "files/" prefix from route
+	url := gemini.BuildGeminiFileURL(fmt.Sprintf("/v1beta/%s", fileName))
+
+	// Prepare headers
+	headers := map[string]string{
+		"x-goog-api-key": apiKey,
+	}
+
+	// Forward request to Gemini
+	err := gemini.ForwardGeminiFileRequest(c, http.MethodGet, url, nil, headers)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to forward file get request: %s", err.Error()))
+		return
+	}
+}
+
+// RelayGeminiFileDelete deletes a file from Gemini File API
+func RelayGeminiFileDelete(c *gin.Context) {
+	// Get file name from URL parameter
+	fileName := c.Param("name")
+	if fileName == "" {
+		c.JSON(http.StatusBadRequest, gin.H{
+			"error": gin.H{
+				"message": "file name is required",
+				"type":    "invalid_request_error",
+				"code":    "missing_file_name",
+			},
+		})
+		return
+	}
+
+	// Get API key from channel context
+	apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
+	if apiKey == "" {
+		logger.LogError(c, "API key not found in context")
+		c.JSON(http.StatusUnauthorized, gin.H{
+			"error": gin.H{
+				"message": "API key not found",
+				"type":    "authentication_error",
+				"code":    "invalid_api_key",
+			},
+		})
+		return
+	}
+
+	// Build upstream URL - fileName already includes "files/" prefix from route
+	url := gemini.BuildGeminiFileURL(fmt.Sprintf("/v1beta/%s", fileName))
+
+	// Prepare headers
+	headers := map[string]string{
+		"x-goog-api-key": apiKey,
+	}
+
+	// Forward request to Gemini
+	err := gemini.ForwardGeminiFileRequest(c, http.MethodDelete, url, nil, headers)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to forward file delete request: %s", err.Error()))
+		return
+	}
+}
+
+// RelayGeminiFileList lists files from Gemini File API
+func RelayGeminiFileList(c *gin.Context) {
+	// Get API key from channel context
+	apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
+	if apiKey == "" {
+		logger.LogError(c, "API key not found in context")
+		c.JSON(http.StatusUnauthorized, gin.H{
+			"error": gin.H{
+				"message": "API key not found",
+				"type":    "authentication_error",
+				"code":    "invalid_api_key",
+			},
+		})
+		return
+	}
+
+	// Build upstream URL with query parameters
+	url := gemini.BuildGeminiFileURL("/v1beta/files")
+
+	// Add query parameters if present
+	queryParams := []string{}
+	if pageSize := c.Query("pageSize"); pageSize != "" {
+		queryParams = append(queryParams, fmt.Sprintf("pageSize=%s", pageSize))
+	}
+	if pageToken := c.Query("pageToken"); pageToken != "" {
+		queryParams = append(queryParams, fmt.Sprintf("pageToken=%s", pageToken))
+	}
+
+	if len(queryParams) > 0 {
+		url = fmt.Sprintf("%s?%s", url, strings.Join(queryParams, "&"))
+	}
+
+	// Prepare headers
+	headers := map[string]string{
+		"x-goog-api-key": apiKey,
+	}
+
+	// Forward request to Gemini
+	err := gemini.ForwardGeminiFileRequest(c, http.MethodGet, url, nil, headers)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to forward file list request: %s", err.Error()))
+		return
+	}
+}

+ 35 - 0
dto/gemini.go

@@ -564,3 +564,38 @@ type GeminiBatchEmbeddingResponse struct {
 type ContentEmbedding struct {
 	Values []float64 `json:"values"`
 }
+
+// File API related structs
+type GeminiFileUploadResponse struct {
+	File GeminiFile `json:"file"`
+}
+
+type GeminiFile struct {
+	Name           string            `json:"name"`
+	DisplayName    string            `json:"displayName,omitempty"`
+	MimeType       string            `json:"mimeType"`
+	SizeBytes      string            `json:"sizeBytes"`
+	CreateTime     string            `json:"createTime"`
+	UpdateTime     string            `json:"updateTime"`
+	ExpirationTime string            `json:"expirationTime,omitempty"`
+	Sha256Hash     string            `json:"sha256Hash,omitempty"`
+	Uri            string            `json:"uri"`
+	State          string            `json:"state"`
+	Error          *GeminiFileError  `json:"error,omitempty"`
+	VideoMetadata  *GeminiVideoMeta  `json:"videoMetadata,omitempty"`
+}
+
+type GeminiVideoMeta struct {
+	VideoDuration string `json:"videoDuration,omitempty"`
+}
+
+type GeminiFileListResponse struct {
+	Files         []GeminiFile `json:"files"`
+	NextPageToken string       `json:"nextPageToken,omitempty"`
+}
+
+type GeminiFileError struct {
+	Code    int    `json:"code"`
+	Message string `json:"message"`
+	Status  string `json:"status"`
+}

+ 157 - 0
relay/channel/gemini/file_helper.go

@@ -0,0 +1,157 @@
+package gemini
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"mime/multipart"
+	"net/http"
+	"strings"
+
+	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/logger"
+	"github.com/QuantumNous/new-api/types"
+
+	"github.com/gin-gonic/gin"
+)
+
+const (
+	GeminiFileAPIBaseURL = "https://generativelanguage.googleapis.com"
+)
+
+// BuildGeminiFileURL constructs the full URL for file operations
+func BuildGeminiFileURL(path string) string {
+	return GeminiFileAPIBaseURL + path
+}
+
+// ExtractAPIKey gets API key from header or query param
+// Gemini supports both x-goog-api-key header and key query parameter
+func ExtractAPIKey(c *gin.Context) string {
+	// First check header
+	apiKey := c.GetHeader("x-goog-api-key")
+	if apiKey != "" {
+		return apiKey
+	}
+	// Then check query parameter
+	return c.Query("key")
+}
+
+// ForwardGeminiFileRequest sends request to Gemini and streams response
+func ForwardGeminiFileRequest(c *gin.Context, method, url string, body io.Reader, headers map[string]string) error {
+	req, err := http.NewRequest(method, url, body)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to create request: %s", err.Error()))
+		return err
+	}
+
+	// Copy headers
+	for key, value := range headers {
+		req.Header.Set(key, value)
+	}
+
+	// Send request
+	client := &http.Client{}
+	resp, err := client.Do(req)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to send request: %s", err.Error()))
+		return err
+	}
+	defer resp.Body.Close()
+
+	// Copy response headers
+	for key, values := range resp.Header {
+		for _, value := range values {
+			c.Header(key, value)
+		}
+	}
+
+	// Set status code
+	c.Status(resp.StatusCode)
+
+	// Stream response body
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		logger.LogError(c, fmt.Sprintf("failed to stream response: %s", err.Error()))
+		return err
+	}
+
+	return nil
+}
+
+// RebuildMultipartForm rebuilds a multipart form for forwarding to upstream
+func RebuildMultipartForm(form *multipart.Form) (io.Reader, string, error) {
+	body := &bytes.Buffer{}
+	writer := multipart.NewWriter(body)
+
+	// Add form fields
+	for key, values := range form.Value {
+		for _, value := range values {
+			if err := writer.WriteField(key, value); err != nil {
+				return nil, "", err
+			}
+		}
+	}
+
+	// Add files
+	for key, files := range form.File {
+		for _, fileHeader := range files {
+			file, err := fileHeader.Open()
+			if err != nil {
+				return nil, "", err
+			}
+			defer file.Close()
+
+			part, err := writer.CreateFormFile(key, fileHeader.Filename)
+			if err != nil {
+				return nil, "", err
+			}
+
+			if _, err := io.Copy(part, file); err != nil {
+				return nil, "", err
+			}
+		}
+	}
+
+	contentType := writer.FormDataContentType()
+	if err := writer.Close(); err != nil {
+		return nil, "", err
+	}
+
+	return body, contentType, nil
+}
+
+// getAPIKeyFromToken extracts the Gemini API key from the token
+// This function retrieves the actual API key that should be forwarded to Gemini
+func getAPIKeyFromToken(c *gin.Context) (string, error) {
+	// Get the token info from context (set by middleware)
+	tokenId, exists := common.GetContextKey(c, "token_id")
+	if !exists {
+		return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
+	}
+
+	// Get the token key from context
+	tokenKey, exists := common.GetContextKey(c, "token_key")
+	if !exists {
+		return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
+	}
+
+	// For now, we'll use the token key directly
+	// In a production system, you might want to look up the actual Gemini API key
+	// from the database using the token ID
+	if tokenKey == nil {
+		return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
+	}
+
+	key, ok := tokenKey.(string)
+	if !ok {
+		return "", types.NewError(nil, types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
+	}
+
+	// Remove "Bearer " prefix if present
+	key = strings.TrimPrefix(key, "Bearer ")
+	key = strings.TrimSpace(key)
+
+	logger.LogDebug(c, fmt.Sprintf("token_id: %v, using API key for Gemini", tokenId))
+
+	return key, nil
+}

+ 14 - 0
router/relay-router.go

@@ -174,11 +174,25 @@ func SetRelayRouter(router *gin.Engine) {
 	relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
 	relayGeminiRouter.Use(middleware.Distribute())
 	{
+		// Gemini File API routes
+		relayGeminiRouter.GET("/files", controller.RelayGeminiFileList)
+		relayGeminiRouter.GET("/files/*name", controller.RelayGeminiFileGet)
+		relayGeminiRouter.DELETE("/files/*name", controller.RelayGeminiFileDelete)
+
 		// Gemini API 路径格式: /v1beta/models/{model_name}:{action}
 		relayGeminiRouter.POST("/models/*path", func(c *gin.Context) {
 			controller.Relay(c, types.RelayFormatGemini)
 		})
 	}
+
+	// Gemini File Upload route (separate group for different path prefix)
+	relayGeminiUploadRouter := router.Group("/upload/v1beta")
+	relayGeminiUploadRouter.Use(middleware.TokenAuth())
+	relayGeminiUploadRouter.Use(middleware.ModelRequestRateLimit())
+	relayGeminiUploadRouter.Use(middleware.Distribute())
+	{
+		relayGeminiUploadRouter.POST("/files", controller.RelayGeminiFileUpload)
+	}
 }
 
 func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {

+ 154 - 0
test_gemini_file_api.sh

@@ -0,0 +1,154 @@
+#!/bin/bash
+
+# Test script for Gemini File API endpoints
+# This script tests all four file operations: upload, get, list, and delete
+
+set -e
+
+# Configuration
+BASE_URL="${BASE_URL:-http://localhost:3000}"
+API_TOKEN="${API_TOKEN:-your-api-token-here}"
+
+# Colors for output
+GREEN='\033[0;32m'
+RED='\033[0;31m'
+YELLOW='\033[1;33m'
+NC='\033[0m' # No Color
+
+# Helper functions
+print_success() {
+    echo -e "${GREEN}✓ $1${NC}"
+}
+
+print_error() {
+    echo -e "${RED}✗ $1${NC}"
+}
+
+print_info() {
+    echo -e "${YELLOW}ℹ $1${NC}"
+}
+
+# Check if API token is set
+if [ "$API_TOKEN" = "your-api-token-here" ]; then
+    print_error "Please set API_TOKEN environment variable"
+    echo "Usage: API_TOKEN=your-token BASE_URL=http://localhost:3000 ./test_gemini_file_api.sh"
+    exit 1
+fi
+
+echo "=========================================="
+echo "Testing Gemini File API Endpoints"
+echo "=========================================="
+echo ""
+
+# Create a test file
+TEST_FILE="/tmp/test_gemini_upload.txt"
+echo "This is a test file for Gemini File API upload" > "$TEST_FILE"
+print_info "Created test file: $TEST_FILE"
+echo ""
+
+# Test 1: Upload File
+echo "Test 1: Upload File"
+echo "--------------------"
+print_info "Uploading file to $BASE_URL/upload/v1beta/files"
+
+UPLOAD_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/upload/v1beta/files" \
+  -H "Authorization: Bearer $API_TOKEN" \
+  -F "file=@$TEST_FILE" \
+  -F "display_name=Test Document")
+
+HTTP_CODE=$(echo "$UPLOAD_RESPONSE" | tail -n1)
+RESPONSE_BODY=$(echo "$UPLOAD_RESPONSE" | sed '$d')
+
+if [ "$HTTP_CODE" = "200" ]; then
+    print_success "File uploaded successfully (HTTP $HTTP_CODE)"
+ Response: $RESPONSE_BODY"
+
+    # Extract file name from response
+    FILE_NAME=$(echo "$RESPONSE_BODY" | grep -o '"name":"[^"]*"' | cut -d'"' -f4)
+    if [ -n "$FILE_NAME" ]; then
+        print_info "File name: $FILE_NAME"
+    fi
+else
+    print_error "File upload failed (HTTP $HTTP_CODE)"
+    echo "Response: $RESPONSE_BODY"
+    exit 1
+fi
+echo ""
+
+# Test 2: List Files
+echo "Test 2: List Files"
+echo "------------------"
+print_info "Listing files from $BASE_URL/v1beta/files"
+
+LIST_RESPONSE=$(curl -s -w "\n%{http_code}" -X GET "$BASE_URL/v1beta/files" \
+  -H "Authorization: Bearer $API_TOKEN")
+
+HTTP_CODE=$(echo "$LIST_RESPONSE" | tail -n1)
+RESPONSE_BODY=$(echo "$LIST_RESPONSE" | sed '$d')
+
+if [ "$HTTP_CODE" = "200" ]; then
+    print_success "Files listed successfully (HTTP $HTTP_CODE)"
+    echo "Response: $RESPONSE_BODY"
+else
+    print_error "File listing failed (HTTP $HTTP_CODE)"
+    echo "Response: $RESPONSE_BODY"
+fi
+echo ""
+
+# Test 3: Get File Metadata (if we have a file name)
+if [ -n "$FILE_NAME" ]; then
+    echo "Test 3: Get File Metadata"
+    echo "-------------------------"
+    print_info "Getting metadata for $FILE_NAME"
+
+    GET_RESPONSE=$(curl -s -w "\n%{http_code}" -X GET "$BASE_URL/v1beta/$FILE_NAME" \
+      -H "Authorization: Bearer $API_TOKEN")
+
+    HTTP_CODE=$(echo "$GET_RESPONSE" | tail -n1)
+    RESPONSE_BODY=$(echo "$GET_RESPONSE" | sed '$d')
+
+    if [ "$HTTP_CODE" = "200" ]; then
+        print_success "File metadata retrieved successfully (HTTP $HTTP_CODE)"
+        echo "Response: $RESPONSE_BODY"
+    else
+        print_error "File metadata retrieval failed (HTTP $HTTP_CODE)"
+        echo "Response: $RESPONSE_BODY"
+    fi
+    echo ""
+
+    # Test 4: Delete File
+    echo "Test 4: Delete File"
+    echo "-------------------"
+    print_info "Deleting file $FILE_NAME"
+
+    DELETE_RESPONSE=$(curl -s -w "\n%{http_code}" -X DELETE "$BASE_URL/v1beta/$FILE_NAME" \
+      -H "Authorization: Bearer $API_TOKEN")
+
+    HTTP_CODE=$(echo "$DELETE_RESPONSE" | tail -n1)
+    RESPONSE_BODY=$(echo "$DELETE_RESPONSE" | sed '$d')
+
+    if [ "$HTTP_CODE" = "200" ] || [ "$HTTP_CODE" = "204" ]; then
+        print_success "File deleted successfully (HTTP $HTTP_CODE)"
+        echo "Response: $RESPONSE_BODY"
+    else
+        print_error "File deletion failed (HTTP $HTTP_CODE)"
+        echo "Response: $RESPONSE_BODY"
+    fi
+    echo ""
+fi
+
+# Cleanup
+rm -f "$TEST_FILE"
+print_info "Cleaned up test file"
+
+echo ""
+echo "=========================================="
+echo "Test Summary"
+echo "=========================================="
+print_success "All tests completed!"
+echo ""
+echo "Note: If you see authentication errors, make sure:"
+echo "  1. The server is running"
+echo "  2. Your API token is valid"
+echo "  3. The token has access to a Gemini channel"
+echo "  4. The Gemini API key is properly configured in the channel"