Browse Source

refactor: 简化请求转换函数和流处理逻辑

somnifex 5 months ago
parent
commit
f7d393fc72

+ 7 - 25
relay/channel/ollama/adaptor.go

@@ -18,10 +18,7 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
 	openaiAdaptor := openai.Adaptor{}
@@ -36,29 +33,17 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
 	return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
 }
 
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
 
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	// embeddings fixed endpoint
-	if info.RelayMode == relayconstant.RelayModeEmbeddings {
-		return info.ChannelBaseUrl + "/api/embed", nil
-	}
-	// For chat vs generate: if original path contains "/v1/completions" map to generate; otherwise chat
-	if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
-		return info.ChannelBaseUrl + "/api/generate", nil
-	}
-	return info.ChannelBaseUrl + "/api/chat", nil
+    if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
+    if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
+    return info.ChannelBaseUrl + "/api/chat", nil
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -84,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 	return requestOpenAI2Embeddings(request), nil
 }
 
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
-	// TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)

+ 0 - 3
relay/channel/ollama/dto.go

@@ -4,7 +4,6 @@ import (
 	"encoding/json"
 )
 
-// OllamaChatMessage represents a single chat message
 type OllamaChatMessage struct {
 	Role      string            `json:"role"`
 	Content   string            `json:"content,omitempty"`
@@ -32,7 +31,6 @@ type OllamaToolCall struct {
 	} `json:"function"`
 }
 
-// OllamaChatRequest -> /api/chat
 type OllamaChatRequest struct {
 	Model     string              `json:"model"`
 	Messages  []OllamaChatMessage `json:"messages"`
@@ -44,7 +42,6 @@ type OllamaChatRequest struct {
 	Think     json.RawMessage     `json:"think,omitempty"`
 }
 
-// OllamaGenerateRequest -> /api/generate
 type OllamaGenerateRequest struct {
 	Model     string         `json:"model"`
 	Prompt    string         `json:"prompt,omitempty"`

+ 0 - 9
relay/channel/ollama/relay-ollama.go

@@ -15,7 +15,6 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-// openAIChatToOllamaChat converts OpenAI-style chat request to Ollama chat
 func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
 	chatReq := &OllamaChatRequest{
 		Model:   r.Model,
@@ -23,12 +22,10 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
 		Options: map[string]any{},
 		Think:   r.Think,
 	}
-	// format mapping
 	if r.ResponseFormat != nil {
 		if r.ResponseFormat.Type == "json" {
 			chatReq.Format = "json"
 		} else if r.ResponseFormat.Type == "json_schema" {
-			// supply schema object directly
 			if len(r.ResponseFormat.JsonSchema) > 0 {
 				var schema any
 				_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
@@ -46,7 +43,6 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
 	if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
 	if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
 
-	// Stop -> options.stop (array)
 	if r.Stop != nil {
 		switch v := r.Stop.(type) {
 		case string:
@@ -60,7 +56,6 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
 		}
 	}
 
-	// tools
 	if len(r.Tools) > 0 {
 		tools := make([]OllamaTool,0,len(r.Tools))
 		for _, t := range r.Tools {
@@ -69,10 +64,8 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
 		chatReq.Tools = tools
 	}
 
-	// messages
 	chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
 	for _, m := range r.Messages {
-		// gather text parts & images
 		var textBuilder strings.Builder
 		var images []string
 		if m.IsStringContent() {
@@ -98,9 +91,7 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
 		}
 		cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
 		if len(images)>0 { cm.Images = images }
-		// history tool call result message
 		if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
-		// tool calls from assistant previous message
 		if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
 			parsed := m.ParseToolCalls()
 			if len(parsed) > 0 {

+ 42 - 14
relay/channel/ollama/stream.go

@@ -19,7 +19,6 @@ import (
     "github.com/gin-gonic/gin"
 )
 
-// Ollama streaming chunk (chat or generate)
 type ollamaChatStreamChunk struct {
     Model            string `json:"model"`
     CreatedAt        string `json:"created_at"`
@@ -47,7 +46,7 @@ type ollamaChatStreamChunk struct {
     EvalDuration       int64 `json:"eval_duration"`
 }
 
-func toUnix(ts string) int64 { // parse RFC3339 / variant; fallback time.Now
+func toUnix(ts string) int64 {
     if ts == "" { return time.Now().Unix() }
     // try time.RFC3339 or with nanoseconds
     t, err := time.Parse(time.RFC3339Nano, ts)
@@ -55,7 +54,6 @@ func toUnix(ts string) int64 { // parse RFC3339 / variant; fallback time.Now
     return t.Unix()
 }
 
-// streaming handler: convert Ollama stream -> OpenAI SSE
 func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
     if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
     defer service.CloseResponseBodyGracefully(resp)
@@ -67,7 +65,6 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
     var responseId = common.GetUUID()
     var created = time.Now().Unix()
     var toolCallIndex int
-    // send start event
     start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
     if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
 
@@ -141,16 +138,47 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
     body, err := io.ReadAll(resp.Body)
     if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
     service.CloseResponseBodyGracefully(resp)
-    if common.DebugEnabled { println("ollama non-stream resp:", string(body)) }
-    var chunk ollamaChatStreamChunk
-    if err = json.Unmarshal(body, &chunk); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
-    model := chunk.Model
+    raw := string(body)
+    if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
+
+    lines := strings.Split(raw, "\n")
+    var (
+        aggContent strings.Builder
+        lastChunk ollamaChatStreamChunk
+        parsedAny bool
+    )
+    for _, ln := range lines {
+        ln = strings.TrimSpace(ln)
+        if ln == "" { continue }
+        var ck ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(ln), &ck); err != nil {
+            if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+            continue
+        }
+        parsedAny = true
+        lastChunk = ck
+        if !ck.Done { 
+            if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
+        } else { 
+            if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
+        }
+    }
+
+    if !parsedAny {
+        var single ollamaChatStreamChunk
+        if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+        lastChunk = single
+        if single.Message != nil { aggContent.WriteString(single.Message.Content) } else { aggContent.WriteString(single.Response) }
+    }
+
+    model := lastChunk.Model
     if model == "" { model = info.UpstreamModelName }
-    created := toUnix(chunk.CreatedAt)
-    content := ""
-    if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
-    usage := &dto.Usage{PromptTokens: chunk.PromptEvalCount, CompletionTokens: chunk.EvalCount, TotalTokens: chunk.PromptEvalCount + chunk.EvalCount}
-    // Build OpenAI style response
+    created := toUnix(lastChunk.CreatedAt)
+    usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
+    content := aggContent.String()
+    finishReason := lastChunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+
     full := dto.OpenAITextResponse{
         Id:      common.GetUUID(),
         Model:   model,
@@ -159,7 +187,7 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
         Choices: []dto.OpenAITextResponseChoice{ {
             Index: 0,
             Message: dto.Message{Role: "assistant", Content: contentPtr(content)},
-            FinishReason: func() string { if chunk.DoneReason == "" { return "stop" } ; return chunk.DoneReason }(),
+            FinishReason: &finishReason,
         } },
         Usage: *usage,
     }