Просмотр исходного кода

feat: 添加ollama聊天流处理和非流处理功能

somnifex 5 месяцев назад
Родитель
Сommit
69a88a0563

+ 210 - 0
.history/relay/channel/ollama/stream_20250916085416.go

@@ -0,0 +1,210 @@
+package ollama
+
+import (
+    "bufio"
+    "encoding/json"
+    "fmt"
+    "io"
+    "net/http"
+    "one-api/common"
+    "one-api/dto"
+    "one-api/logger"
+    relaycommon "one-api/relay/common"
+    "one-api/relay/helper"
+    "one-api/service"
+    "one-api/types"
+    "strings"
+    "time"
+
+    "github.com/gin-gonic/gin"
+)
+
+type ollamaChatStreamChunk struct {
+    Model            string `json:"model"`
+    CreatedAt        string `json:"created_at"`
+    // chat
+    Message *struct {
+        Role      string `json:"role"`
+        Content   string `json:"content"`
+        Thinking  json.RawMessage `json:"thinking"`
+        ToolCalls []struct {
+            Function struct {
+                Name      string      `json:"name"`
+                Arguments interface{} `json:"arguments"`
+            } `json:"function"`
+        } `json:"tool_calls"`
+    } `json:"message"`
+    // generate
+    Response string `json:"response"`
+    Done         bool    `json:"done"`
+    DoneReason   string  `json:"done_reason"`
+    TotalDuration int64  `json:"total_duration"`
+    LoadDuration  int64  `json:"load_duration"`
+    PromptEvalCount int  `json:"prompt_eval_count"`
+    EvalCount       int  `json:"eval_count"`
+    PromptEvalDuration int64 `json:"prompt_eval_duration"`
+    EvalDuration       int64 `json:"eval_duration"`
+}
+
+func toUnix(ts string) int64 {
+    if ts == "" { return time.Now().Unix() }
+    // try time.RFC3339 or with nanoseconds
+    t, err := time.Parse(time.RFC3339Nano, ts)
+    if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
+    return t.Unix()
+}
+
+func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+    if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
+    defer service.CloseResponseBodyGracefully(resp)
+
+    helper.SetEventStreamHeaders(c)
+    scanner := bufio.NewScanner(resp.Body)
+    usage := &dto.Usage{}
+    var model = info.UpstreamModelName
+    var responseId = common.GetUUID()
+    var created = time.Now().Unix()
+    var toolCallIndex int
+    start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
+    if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
+
+    for scanner.Scan() {
+        line := scanner.Text()
+        line = strings.TrimSpace(line)
+        if line == "" { continue }
+        var chunk ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+            logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
+            return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+        }
+        if chunk.Model != "" { model = chunk.Model }
+        created = toUnix(chunk.CreatedAt)
+
+        if !chunk.Done {
+            // delta content
+            var content string
+            if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
+            delta := dto.ChatCompletionsStreamResponse{
+                Id:      responseId,
+                Object:  "chat.completion.chunk",
+                Created: created,
+                Model:   model,
+                Choices: []dto.ChatCompletionsStreamResponseChoice{ {
+                    Index: 0,
+                    Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
+                } },
+            }
+            if content != "" { delta.Choices[0].Delta.SetContentString(content) }
+            if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
+                raw := strings.TrimSpace(string(chunk.Message.Thinking))
+                if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
+            }
+            // tool calls
+            if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
+                delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
+                for _, tc := range chunk.Message.ToolCalls {
+                    // arguments -> string
+                    argBytes, _ := json.Marshal(tc.Function.Arguments)
+                    toolId := fmt.Sprintf("call_%d", toolCallIndex)
+                    tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
+                    tr.SetIndex(toolCallIndex)
+                    toolCallIndex++
+                    delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
+                }
+            }
+            if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
+            continue
+        }
+        // done frame
+        // finalize once and break loop
+        usage.PromptTokens = chunk.PromptEvalCount
+        usage.CompletionTokens = chunk.EvalCount
+        usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+    finishReason := chunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+        // emit stop delta
+        if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
+            if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
+        }
+        // emit usage frame
+        if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
+            if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
+        }
+        // send [DONE]
+        helper.Done(c)
+        break
+    }
+    if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
+    return usage, nil
+}
+
+// non-stream handler for chat/generate
+func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+    body, err := io.ReadAll(resp.Body)
+    if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
+    service.CloseResponseBodyGracefully(resp)
+    raw := string(body)
+    if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
+
+    lines := strings.Split(raw, "\n")
+    var (
+        aggContent strings.Builder
+        reasoningBuilder strings.Builder
+        lastChunk ollamaChatStreamChunk
+        parsedAny bool
+    )
+    for _, ln := range lines {
+        ln = strings.TrimSpace(ln)
+        if ln == "" { continue }
+        var ck ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(ln), &ck); err != nil {
+            if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+            continue
+        }
+        parsedAny = true
+        lastChunk = ck
+        if ck.Message != nil && len(ck.Message.Thinking) > 0 {
+            raw := strings.TrimSpace(string(ck.Message.Thinking))
+            if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
+        }
+        if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
+    }
+
+    if !parsedAny {
+        var single ollamaChatStreamChunk
+        if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+        lastChunk = single
+        if single.Message != nil {
+            if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
+            aggContent.WriteString(single.Message.Content)
+        } else { aggContent.WriteString(single.Response) }
+    }
+
+    model := lastChunk.Model
+    if model == "" { model = info.UpstreamModelName }
+    created := toUnix(lastChunk.CreatedAt)
+    usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
+    content := aggContent.String()
+    finishReason := lastChunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+
+    msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
+    if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
+    full := dto.OpenAITextResponse{
+        Id:      common.GetUUID(),
+        Model:   model,
+        Object:  "chat.completion",
+        Created: created,
+        Choices: []dto.OpenAITextResponseChoice{ {
+            Index: 0,
+            Message: msg,
+            FinishReason: finishReason,
+        } },
+        Usage: *usage,
+    }
+    out, _ := common.Marshal(full)
+    service.IOCopyBytesGracefully(c, resp, out)
+    return usage, nil
+}
+
+func contentPtr(s string) *string { if s=="" { return nil }; return &s }

+ 210 - 0
.history/relay/channel/ollama/stream_20250916085435.go

@@ -0,0 +1,210 @@
+package ollama
+
+import (
+    "bufio"
+    "encoding/json"
+    "fmt"
+    "io"
+    "net/http"
+    "one-api/common"
+    "one-api/dto"
+    "one-api/logger"
+    relaycommon "one-api/relay/common"
+    "one-api/relay/helper"
+    "one-api/service"
+    "one-api/types"
+    "strings"
+    "time"
+
+    "github.com/gin-gonic/gin"
+)
+
+type ollamaChatStreamChunk struct {
+    Model            string `json:"model"`
+    CreatedAt        string `json:"created_at"`
+    // chat
+    Message *struct {
+        Role      string `json:"role"`
+        Content   string `json:"content"`
+        Thinking  json.RawMessage `json:"thinking"`
+        ToolCalls []struct {
+            Function struct {
+                Name      string      `json:"name"`
+                Arguments interface{} `json:"arguments"`
+            } `json:"function"`
+        } `json:"tool_calls"`
+    } `json:"message"`
+    // generate
+    Response string `json:"response"`
+    Done         bool    `json:"done"`
+    DoneReason   string  `json:"done_reason"`
+    TotalDuration int64  `json:"total_duration"`
+    LoadDuration  int64  `json:"load_duration"`
+    PromptEvalCount int  `json:"prompt_eval_count"`
+    EvalCount       int  `json:"eval_count"`
+    PromptEvalDuration int64 `json:"prompt_eval_duration"`
+    EvalDuration       int64 `json:"eval_duration"`
+}
+
+func toUnix(ts string) int64 {
+    if ts == "" { return time.Now().Unix() }
+    // try time.RFC3339 or with nanoseconds
+    t, err := time.Parse(time.RFC3339Nano, ts)
+    if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
+    return t.Unix()
+}
+
+func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+    if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
+    defer service.CloseResponseBodyGracefully(resp)
+
+    helper.SetEventStreamHeaders(c)
+    scanner := bufio.NewScanner(resp.Body)
+    usage := &dto.Usage{}
+    var model = info.UpstreamModelName
+    var responseId = common.GetUUID()
+    var created = time.Now().Unix()
+    var toolCallIndex int
+    start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
+    if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
+
+    for scanner.Scan() {
+        line := scanner.Text()
+        line = strings.TrimSpace(line)
+        if line == "" { continue }
+        var chunk ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+            logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
+            return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+        }
+        if chunk.Model != "" { model = chunk.Model }
+        created = toUnix(chunk.CreatedAt)
+
+        if !chunk.Done {
+            // delta content
+            var content string
+            if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
+            delta := dto.ChatCompletionsStreamResponse{
+                Id:      responseId,
+                Object:  "chat.completion.chunk",
+                Created: created,
+                Model:   model,
+                Choices: []dto.ChatCompletionsStreamResponseChoice{ {
+                    Index: 0,
+                    Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
+                } },
+            }
+            if content != "" { delta.Choices[0].Delta.SetContentString(content) }
+            if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
+                raw := strings.TrimSpace(string(chunk.Message.Thinking))
+                if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
+            }
+            // tool calls
+            if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
+                delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
+                for _, tc := range chunk.Message.ToolCalls {
+                    // arguments -> string
+                    argBytes, _ := json.Marshal(tc.Function.Arguments)
+                    toolId := fmt.Sprintf("call_%d", toolCallIndex)
+                    tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
+                    tr.SetIndex(toolCallIndex)
+                    toolCallIndex++
+                    delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
+                }
+            }
+            if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
+            continue
+        }
+        // done frame
+        // finalize once and break loop
+        usage.PromptTokens = chunk.PromptEvalCount
+        usage.CompletionTokens = chunk.EvalCount
+        usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+    finishReason := chunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+        // emit stop delta
+        if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
+            if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
+        }
+        // emit usage frame
+        if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
+            if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
+        }
+        // send [DONE]
+        helper.Done(c)
+        break
+    }
+    if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
+    return usage, nil
+}
+
+// non-stream handler for chat/generate
+func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+    body, err := io.ReadAll(resp.Body)
+    if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
+    service.CloseResponseBodyGracefully(resp)
+    raw := string(body)
+    if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
+
+    lines := strings.Split(raw, "\n")
+    var (
+        aggContent strings.Builder
+        reasoningBuilder strings.Builder
+        lastChunk ollamaChatStreamChunk
+        parsedAny bool
+    )
+    for _, ln := range lines {
+        ln = strings.TrimSpace(ln)
+        if ln == "" { continue }
+        var ck ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(ln), &ck); err != nil {
+            if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+            continue
+        }
+        parsedAny = true
+        lastChunk = ck
+        if ck.Message != nil && len(ck.Message.Thinking) > 0 {
+            raw := strings.TrimSpace(string(ck.Message.Thinking))
+            if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
+        }
+        if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
+    }
+
+    if !parsedAny {
+        var single ollamaChatStreamChunk
+        if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+        lastChunk = single
+        if single.Message != nil {
+            if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
+            aggContent.WriteString(single.Message.Content)
+        } else { aggContent.WriteString(single.Response) }
+    }
+
+    model := lastChunk.Model
+    if model == "" { model = info.UpstreamModelName }
+    created := toUnix(lastChunk.CreatedAt)
+    usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
+    content := aggContent.String()
+    finishReason := lastChunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+
+    msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
+    if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
+    full := dto.OpenAITextResponse{
+        Id:      common.GetUUID(),
+        Model:   model,
+        Object:  "chat.completion",
+        Created: created,
+        Choices: []dto.OpenAITextResponseChoice{ {
+            Index: 0,
+            Message: msg,
+            FinishReason: finishReason,
+        } },
+        Usage: *usage,
+    }
+    out, _ := common.Marshal(full)
+    service.IOCopyBytesGracefully(c, resp, out)
+    return usage, nil
+}
+
+func contentPtr(s string) *string { if s=="" { return nil }; return &s }