Bladeren bron

Merge pull request #1811 from somnifex/main

refactor: 重构ollama渠道
Calcium-Ion 5 maanden geleden
bovenliggende
commit
bdefed7b0a
4 gewijzigde bestanden met toevoegingen van 442 en 171 verwijderingen
  1. 18 34
      relay/channel/ollama/adaptor.go
  2. 57 36
      relay/channel/ollama/dto.go
  3. 157 101
      relay/channel/ollama/relay-ollama.go
  4. 210 0
      relay/channel/ollama/stream.go

+ 18 - 34
relay/channel/ollama/adaptor.go

@@ -10,6 +10,7 @@ import (
 	relaycommon "one-api/relay/common"
 	relayconstant "one-api/relay/constant"
 	"one-api/types"
+	"strings"
 
 	"github.com/gin-gonic/gin"
 )
@@ -17,10 +18,7 @@ import (
 type Adaptor struct {
 }
 
-func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
 	openaiAdaptor := openai.Adaptor{}
@@ -31,32 +29,21 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
 	openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
 		IncludeUsage: true,
 	}
-	return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
+	// map to ollama chat request (Claude -> OpenAI -> Ollama chat)
+	return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
 }
 
-func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
 
-func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	if info.RelayFormat == types.RelayFormatClaude {
-		return info.ChannelBaseUrl + "/v1/chat/completions", nil
-	}
-	switch info.RelayMode {
-	case relayconstant.RelayModeEmbeddings:
-		return info.ChannelBaseUrl + "/api/embed", nil
-	default:
-		return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
-	}
+    if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
+    if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
+    return info.ChannelBaseUrl + "/api/chat", nil
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -66,10 +53,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
 }
 
 func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
-	if request == nil {
-		return nil, errors.New("request is nil")
+	if request == nil { return nil, errors.New("request is nil") }
+	// decide generate or chat
+	if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
+		return openAIToGenerate(c, request)
 	}
-	return requestOpenAI2Ollama(c, request)
+	return openAIChatToOllamaChat(c, request)
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -80,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 	return requestOpenAI2Embeddings(request), nil
 }
 
-func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
-	// TODO implement me
-	return nil, errors.New("not implemented")
-}
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
@@ -92,15 +78,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
 	switch info.RelayMode {
 	case relayconstant.RelayModeEmbeddings:
-		usage, err = ollamaEmbeddingHandler(c, info, resp)
+		return ollamaEmbeddingHandler(c, info, resp)
 	default:
 		if info.IsStream {
-			usage, err = openai.OaiStreamHandler(c, info, resp)
-		} else {
-			usage, err = openai.OpenaiHandler(c, info, resp)
+			return ollamaStreamHandler(c, info, resp)
 		}
+		return ollamaChatHandler(c, info, resp)
 	}
-	return
 }
 
 func (a *Adaptor) GetModelList() []string {

+ 57 - 36
relay/channel/ollama/dto.go

@@ -2,48 +2,69 @@ package ollama
 
 import (
 	"encoding/json"
-	"one-api/dto"
 )
 
-type OllamaRequest struct {
-	Model            string                `json:"model,omitempty"`
-	Messages         []dto.Message         `json:"messages,omitempty"`
-	Stream           bool                  `json:"stream,omitempty"`
-	Temperature      *float64              `json:"temperature,omitempty"`
-	Seed             float64               `json:"seed,omitempty"`
-	Topp             float64               `json:"top_p,omitempty"`
-	TopK             int                   `json:"top_k,omitempty"`
-	Stop             any                   `json:"stop,omitempty"`
-	MaxTokens        uint                  `json:"max_tokens,omitempty"`
-	Tools            []dto.ToolCallRequest `json:"tools,omitempty"`
-	ResponseFormat   any                   `json:"response_format,omitempty"`
-	FrequencyPenalty float64               `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float64               `json:"presence_penalty,omitempty"`
-	Suffix           any                   `json:"suffix,omitempty"`
-	StreamOptions    *dto.StreamOptions    `json:"stream_options,omitempty"`
-	Prompt           any                   `json:"prompt,omitempty"`
-	Think            json.RawMessage       `json:"think,omitempty"`
-}
-
-type Options struct {
-	Seed             int      `json:"seed,omitempty"`
-	Temperature      *float64 `json:"temperature,omitempty"`
-	TopK             int      `json:"top_k,omitempty"`
-	TopP             float64  `json:"top_p,omitempty"`
-	FrequencyPenalty float64  `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float64  `json:"presence_penalty,omitempty"`
-	NumPredict       int      `json:"num_predict,omitempty"`
-	NumCtx           int      `json:"num_ctx,omitempty"`
+type OllamaChatMessage struct {
+	Role      string            `json:"role"`
+	Content   string            `json:"content,omitempty"`
+	Images    []string          `json:"images,omitempty"`
+	ToolCalls []OllamaToolCall  `json:"tool_calls,omitempty"`
+	ToolName  string            `json:"tool_name,omitempty"`
+	Thinking  json.RawMessage   `json:"thinking,omitempty"`
+}
+
+type OllamaToolFunction struct {
+	Name        string      `json:"name"`
+	Description string      `json:"description,omitempty"`
+	Parameters  interface{} `json:"parameters,omitempty"`
+}
+
+type OllamaTool struct {
+	Type     string            `json:"type"`
+	Function OllamaToolFunction `json:"function"`
+}
+
+type OllamaToolCall struct {
+	Function struct {
+		Name      string      `json:"name"`
+		Arguments interface{} `json:"arguments"`
+	} `json:"function"`
+}
+
+type OllamaChatRequest struct {
+	Model     string              `json:"model"`
+	Messages  []OllamaChatMessage `json:"messages"`
+	Tools     interface{}         `json:"tools,omitempty"`
+	Format    interface{}         `json:"format,omitempty"`
+	Stream    bool                `json:"stream,omitempty"`
+	Options   map[string]any      `json:"options,omitempty"`
+	KeepAlive interface{}         `json:"keep_alive,omitempty"`
+	Think     json.RawMessage     `json:"think,omitempty"`
+}
+
+type OllamaGenerateRequest struct {
+	Model     string         `json:"model"`
+	Prompt    string         `json:"prompt,omitempty"`
+	Suffix    string         `json:"suffix,omitempty"`
+	Images    []string       `json:"images,omitempty"`
+	Format    interface{}    `json:"format,omitempty"`
+	Stream    bool           `json:"stream,omitempty"`
+	Options   map[string]any `json:"options,omitempty"`
+	KeepAlive interface{}    `json:"keep_alive,omitempty"`
+	Think     json.RawMessage `json:"think,omitempty"`
 }
 
 type OllamaEmbeddingRequest struct {
-	Model   string   `json:"model,omitempty"`
-	Input   []string `json:"input"`
-	Options *Options `json:"options,omitempty"`
+	Model     string         `json:"model"`
+	Input     interface{}    `json:"input"`
+	Options   map[string]any `json:"options,omitempty"`
+	Dimensions int            `json:"dimensions,omitempty"`
 }
 
 type OllamaEmbeddingResponse struct {
-	Error     string      `json:"error,omitempty"`
-	Model     string      `json:"model"`
-	Embedding [][]float64 `json:"embeddings,omitempty"`
+	Error           string        `json:"error,omitempty"`
+	Model           string        `json:"model"`
+	Embeddings      [][]float64   `json:"embeddings"`
+	PromptEvalCount int           `json:"prompt_eval_count,omitempty"`
 }
+

+ 157 - 101
relay/channel/ollama/relay-ollama.go

@@ -1,6 +1,7 @@
 package ollama
 
 import (
+	"encoding/json"
 	"fmt"
 	"io"
 	"net/http"
@@ -14,121 +15,176 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
-	messages := make([]dto.Message, 0, len(request.Messages))
-	for _, message := range request.Messages {
-		if !message.IsStringContent() {
-			mediaMessages := message.ParseContent()
-			for j, mediaMessage := range mediaMessages {
-				if mediaMessage.Type == dto.ContentTypeImageURL {
-					imageUrl := mediaMessage.GetImageMedia()
-					// check if not base64
-					if strings.HasPrefix(imageUrl.Url, "http") {
-						fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
-						if err != nil {
-							return nil, err
+func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
+	chatReq := &OllamaChatRequest{
+		Model:   r.Model,
+		Stream:  r.Stream,
+		Options: map[string]any{},
+		Think:   r.Think,
+	}
+	if r.ResponseFormat != nil {
+		if r.ResponseFormat.Type == "json" {
+			chatReq.Format = "json"
+		} else if r.ResponseFormat.Type == "json_schema" {
+			if len(r.ResponseFormat.JsonSchema) > 0 {
+				var schema any
+				_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
+				chatReq.Format = schema
+			}
+		}
+	}
+
+	// options mapping
+	if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
+	if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
+	if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
+	if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
+	if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
+	if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
+	if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
+
+	if r.Stop != nil {
+		switch v := r.Stop.(type) {
+		case string:
+			chatReq.Options["stop"] = []string{v}
+		case []string:
+			chatReq.Options["stop"] = v
+		case []any:
+			arr := make([]string,0,len(v))
+			for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
+			if len(arr)>0 { chatReq.Options["stop"] = arr }
+		}
+	}
+
+	if len(r.Tools) > 0 {
+		tools := make([]OllamaTool,0,len(r.Tools))
+		for _, t := range r.Tools {
+			tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
+		}
+		chatReq.Tools = tools
+	}
+
+	chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
+	for _, m := range r.Messages {
+		var textBuilder strings.Builder
+		var images []string
+		if m.IsStringContent() {
+			textBuilder.WriteString(m.StringContent())
+		} else {
+			parts := m.ParseContent()
+			for _, part := range parts {
+				if part.Type == dto.ContentTypeImageURL {
+					img := part.GetImageMedia()
+					if img != nil && img.Url != "" {
+						var base64Data string
+						if strings.HasPrefix(img.Url, "http") {
+							fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
+							if err != nil { return nil, err }
+							base64Data = fileData.Base64Data
+						} else if strings.HasPrefix(img.Url, "data:") {
+							if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
+						} else {
+							base64Data = img.Url
 						}
-						imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
+						if base64Data != "" { images = append(images, base64Data) }
 					}
-					mediaMessage.ImageUrl = imageUrl
-					mediaMessages[j] = mediaMessage
+				} else if part.Type == dto.ContentTypeText {
+					textBuilder.WriteString(part.Text)
+				}
+			}
+		}
+		cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
+		if len(images)>0 { cm.Images = images }
+		if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
+		if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
+			parsed := m.ParseToolCalls()
+			if len(parsed) > 0 {
+				calls := make([]OllamaToolCall,0,len(parsed))
+				for _, tc := range parsed {
+					var args interface{}
+					if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
+					if args==nil { args = map[string]any{} }
+					oc := OllamaToolCall{}
+					oc.Function.Name = tc.Function.Name
+					oc.Function.Arguments = args
+					calls = append(calls, oc)
 				}
+				cm.ToolCalls = calls
 			}
-			message.SetMediaContent(mediaMessages)
 		}
-		messages = append(messages, dto.Message{
-			Role:       message.Role,
-			Content:    message.Content,
-			ToolCalls:  message.ToolCalls,
-			ToolCallId: message.ToolCallId,
-		})
+		chatReq.Messages = append(chatReq.Messages, cm)
 	}
-	str, ok := request.Stop.(string)
-	var Stop []string
-	if ok {
-		Stop = []string{str}
-	} else {
-		Stop, _ = request.Stop.([]string)
+	return chatReq, nil
+}
+
+// openAIToGenerate converts OpenAI completions request to Ollama generate
+func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
+	gen := &OllamaGenerateRequest{
+		Model:   r.Model,
+		Stream:  r.Stream,
+		Options: map[string]any{},
+		Think:   r.Think,
+	}
+	// Prompt may be in r.Prompt (string or []any)
+	if r.Prompt != nil {
+		switch v := r.Prompt.(type) {
+		case string:
+			gen.Prompt = v
+		case []any:
+			var sb strings.Builder
+			for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
+			gen.Prompt = sb.String()
+		default:
+			gen.Prompt = fmt.Sprintf("%v", r.Prompt)
+		}
 	}
-	ollamaRequest := &OllamaRequest{
-		Model:            request.Model,
-		Messages:         messages,
-		Stream:           request.Stream,
-		Temperature:      request.Temperature,
-		Seed:             request.Seed,
-		Topp:             request.TopP,
-		TopK:             request.TopK,
-		Stop:             Stop,
-		Tools:            request.Tools,
-		MaxTokens:        request.GetMaxTokens(),
-		ResponseFormat:   request.ResponseFormat,
-		FrequencyPenalty: request.FrequencyPenalty,
-		PresencePenalty:  request.PresencePenalty,
-		Prompt:           request.Prompt,
-		StreamOptions:    request.StreamOptions,
-		Suffix:           request.Suffix,
+	if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
+	if r.ResponseFormat != nil {
+		if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
 	}
-	ollamaRequest.Think = request.Think
-	return ollamaRequest, nil
+	if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
+	if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
+	if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
+	if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
+	if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
+	if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
+	if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
+	if r.Stop != nil {
+		switch v := r.Stop.(type) {
+		case string: gen.Options["stop"] = []string{v}
+		case []string: gen.Options["stop"] = v
+		case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
+		}
+	}
+	return gen, nil
 }
 
-func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
-	return &OllamaEmbeddingRequest{
-		Model: request.Model,
-		Input: request.ParseInput(),
-		Options: &Options{
-			Seed:             int(request.Seed),
-			Temperature:      request.Temperature,
-			TopP:             request.TopP,
-			FrequencyPenalty: request.FrequencyPenalty,
-			PresencePenalty:  request.PresencePenalty,
-		},
-	}
+func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
+	opts := map[string]any{}
+	if r.Temperature != nil { opts["temperature"] = r.Temperature }
+	if r.TopP != 0 { opts["top_p"] = r.TopP }
+	if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
+	if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
+	if r.Seed != 0 { opts["seed"] = int(r.Seed) }
+	if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
+	input := r.ParseInput()
+	if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
+	return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
 }
 
 func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
-	var ollamaEmbeddingResponse OllamaEmbeddingResponse
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-	}
+	var oResp OllamaEmbeddingResponse
+	body, err := io.ReadAll(resp.Body)
+	if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
 	service.CloseResponseBodyGracefully(resp)
-	err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
-	if err != nil {
-		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-	}
-	if ollamaEmbeddingResponse.Error != "" {
-		return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-	}
-	flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
-	data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
-	data = append(data, dto.OpenAIEmbeddingResponseItem{
-		Embedding: flattenedEmbeddings,
-		Object:    "embedding",
-	})
-	usage := &dto.Usage{
-		TotalTokens:      info.PromptTokens,
-		CompletionTokens: 0,
-		PromptTokens:     info.PromptTokens,
-	}
-	embeddingResponse := &dto.OpenAIEmbeddingResponse{
-		Object: "list",
-		Data:   data,
-		Model:  info.UpstreamModelName,
-		Usage:  *usage,
-	}
-	doResponseBody, err := common.Marshal(embeddingResponse)
-	if err != nil {
-		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
-	}
-	service.IOCopyBytesGracefully(c, resp, doResponseBody)
+	if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+	if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+	data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
+	for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
+	usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
+	embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
+	out, _ := common.Marshal(embResp)
+	service.IOCopyBytesGracefully(c, resp, out)
 	return usage, nil
 }
 
-func flattenEmbeddings(embeddings [][]float64) []float64 {
-	flattened := []float64{}
-	for _, row := range embeddings {
-		flattened = append(flattened, row...)
-	}
-	return flattened
-}

+ 210 - 0
relay/channel/ollama/stream.go

@@ -0,0 +1,210 @@
+package ollama
+
+import (
+    "bufio"
+    "encoding/json"
+    "fmt"
+    "io"
+    "net/http"
+    "one-api/common"
+    "one-api/dto"
+    "one-api/logger"
+    relaycommon "one-api/relay/common"
+    "one-api/relay/helper"
+    "one-api/service"
+    "one-api/types"
+    "strings"
+    "time"
+
+    "github.com/gin-gonic/gin"
+)
+
+type ollamaChatStreamChunk struct {
+    Model            string `json:"model"`
+    CreatedAt        string `json:"created_at"`
+    // chat
+    Message *struct {
+        Role      string `json:"role"`
+        Content   string `json:"content"`
+        Thinking  json.RawMessage `json:"thinking"`
+        ToolCalls []struct {
+            Function struct {
+                Name      string      `json:"name"`
+                Arguments interface{} `json:"arguments"`
+            } `json:"function"`
+        } `json:"tool_calls"`
+    } `json:"message"`
+    // generate
+    Response string `json:"response"`
+    Done         bool    `json:"done"`
+    DoneReason   string  `json:"done_reason"`
+    TotalDuration int64  `json:"total_duration"`
+    LoadDuration  int64  `json:"load_duration"`
+    PromptEvalCount int  `json:"prompt_eval_count"`
+    EvalCount       int  `json:"eval_count"`
+    PromptEvalDuration int64 `json:"prompt_eval_duration"`
+    EvalDuration       int64 `json:"eval_duration"`
+}
+
+func toUnix(ts string) int64 {
+    if ts == "" { return time.Now().Unix() }
+    // try time.RFC3339 or with nanoseconds
+    t, err := time.Parse(time.RFC3339Nano, ts)
+    if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
+    return t.Unix()
+}
+
+func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+    if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
+    defer service.CloseResponseBodyGracefully(resp)
+
+    helper.SetEventStreamHeaders(c)
+    scanner := bufio.NewScanner(resp.Body)
+    usage := &dto.Usage{}
+    var model = info.UpstreamModelName
+    var responseId = common.GetUUID()
+    var created = time.Now().Unix()
+    var toolCallIndex int
+    start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
+    if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
+
+    for scanner.Scan() {
+        line := scanner.Text()
+        line = strings.TrimSpace(line)
+        if line == "" { continue }
+        var chunk ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+            logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
+            return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+        }
+        if chunk.Model != "" { model = chunk.Model }
+        created = toUnix(chunk.CreatedAt)
+
+        if !chunk.Done {
+            // delta content
+            var content string
+            if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
+            delta := dto.ChatCompletionsStreamResponse{
+                Id:      responseId,
+                Object:  "chat.completion.chunk",
+                Created: created,
+                Model:   model,
+                Choices: []dto.ChatCompletionsStreamResponseChoice{ {
+                    Index: 0,
+                    Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
+                } },
+            }
+            if content != "" { delta.Choices[0].Delta.SetContentString(content) }
+            if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
+                raw := strings.TrimSpace(string(chunk.Message.Thinking))
+                if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
+            }
+            // tool calls
+            if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
+                delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
+                for _, tc := range chunk.Message.ToolCalls {
+                    // arguments -> string
+                    argBytes, _ := json.Marshal(tc.Function.Arguments)
+                    toolId := fmt.Sprintf("call_%d", toolCallIndex)
+                    tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
+                    tr.SetIndex(toolCallIndex)
+                    toolCallIndex++
+                    delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
+                }
+            }
+            if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
+            continue
+        }
+        // done frame
+        // finalize once and break loop
+        usage.PromptTokens = chunk.PromptEvalCount
+        usage.CompletionTokens = chunk.EvalCount
+        usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+    finishReason := chunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+        // emit stop delta
+        if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
+            if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
+        }
+        // emit usage frame
+        if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
+            if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
+        }
+        // send [DONE]
+        helper.Done(c)
+        break
+    }
+    if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
+    return usage, nil
+}
+
+// non-stream handler for chat/generate
+func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+    body, err := io.ReadAll(resp.Body)
+    if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
+    service.CloseResponseBodyGracefully(resp)
+    raw := string(body)
+    if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
+
+    lines := strings.Split(raw, "\n")
+    var (
+        aggContent strings.Builder
+        reasoningBuilder strings.Builder
+        lastChunk ollamaChatStreamChunk
+        parsedAny bool
+    )
+    for _, ln := range lines {
+        ln = strings.TrimSpace(ln)
+        if ln == "" { continue }
+        var ck ollamaChatStreamChunk
+        if err := json.Unmarshal([]byte(ln), &ck); err != nil {
+            if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+            continue
+        }
+        parsedAny = true
+        lastChunk = ck
+        if ck.Message != nil && len(ck.Message.Thinking) > 0 {
+            raw := strings.TrimSpace(string(ck.Message.Thinking))
+            if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
+        }
+        if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
+    }
+
+    if !parsedAny {
+        var single ollamaChatStreamChunk
+        if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
+        lastChunk = single
+        if single.Message != nil {
+            if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
+            aggContent.WriteString(single.Message.Content)
+        } else { aggContent.WriteString(single.Response) }
+    }
+
+    model := lastChunk.Model
+    if model == "" { model = info.UpstreamModelName }
+    created := toUnix(lastChunk.CreatedAt)
+    usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
+    content := aggContent.String()
+    finishReason := lastChunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
+
+    msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
+    if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
+    full := dto.OpenAITextResponse{
+        Id:      common.GetUUID(),
+        Model:   model,
+        Object:  "chat.completion",
+        Created: created,
+        Choices: []dto.OpenAITextResponseChoice{ {
+            Index: 0,
+            Message: msg,
+            FinishReason: finishReason,
+        } },
+        Usage: *usage,
+    }
+    out, _ := common.Marshal(full)
+    service.IOCopyBytesGracefully(c, resp, out)
+    return usage, nil
+}
+
+func contentPtr(s string) *string { if s=="" { return nil }; return &s }