Просмотр исходного кода

fix(gemini): detect streaming from URL path :streamGenerateContent

Google's native Gemini API uses the URL action :streamGenerateContent
to indicate streaming intent, not just the ?alt=sse query parameter.
The current IsStream() only checks c.Query("alt") == "sse", causing
all :streamGenerateContent requests (without ?alt=sse) to be treated
as non-streaming.

This adds a strings.Contains check for "streamGenerateContent" in the
request URL path, so both streaming indicators are recognized.
D26FORWARD 1 месяц назад
Родитель
Сommit
23fde25b15
2 измененных файлов с 78 добавлено и 0 удалено
  1. 5 0
      dto/gemini.go
  2. 73 0
      dto/gemini_isstream_test.go

+ 5 - 0
dto/gemini.go

@@ -121,6 +121,11 @@ func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
 	if c.Query("alt") == "sse" {
 		return true
 	}
+	// Native Gemini API uses URL action to indicate streaming:
+	// /v1beta/models/{model}:streamGenerateContent
+	if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
+		return true
+	}
 	return false
 }
 

+ 73 - 0
dto/gemini_isstream_test.go

@@ -0,0 +1,73 @@
+package dto
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/gin-gonic/gin"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestGeminiChatRequest_IsStream(t *testing.T) {
+	gin.SetMode(gin.TestMode)
+
+	tests := []struct {
+		name     string
+		path     string
+		query    string
+		expected bool
+	}{
+		{
+			name:     "streamGenerateContent without alt=sse",
+			path:     "/v1beta/models/gemini-2.0-flash:streamGenerateContent",
+			query:    "key=sk-xxx",
+			expected: true,
+		},
+		{
+			name:     "streamGenerateContent with alt=sse",
+			path:     "/v1beta/models/gemini-2.0-flash:streamGenerateContent",
+			query:    "alt=sse&key=sk-xxx",
+			expected: true,
+		},
+		{
+			name:     "generateContent without alt=sse",
+			path:     "/v1beta/models/gemini-2.0-flash:generateContent",
+			query:    "key=sk-xxx",
+			expected: false,
+		},
+		{
+			name:     "generateContent with alt=sse",
+			path:     "/v1beta/models/gemini-2.0-flash:generateContent",
+			query:    "alt=sse",
+			expected: true,
+		},
+		{
+			name:     "GenerateContent capitalized",
+			path:     "/v1beta/models/gemini-2.0-flash:GenerateContent",
+			query:    "key=sk-xxx",
+			expected: false,
+		},
+		{
+			name:     "embedding path",
+			path:     "/v1beta/models/gemini-2.0-flash:embedContent",
+			query:    "",
+			expected: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			w := httptest.NewRecorder()
+			c, _ := gin.CreateTestContext(w)
+			url := tt.path
+			if tt.query != "" {
+				url += "?" + tt.query
+			}
+			c.Request, _ = http.NewRequest("POST", url, nil)
+
+			req := &GeminiChatRequest{}
+			assert.Equal(t, tt.expected, req.IsStream(c))
+		})
+	}
+}